Quantization Support

Quantization Support#

The PolyBlocks compiler supports the execution of quantized TensorFlow models quantized using NVIDIA’s tensorflow_quantization library and Pytorch/PolyBlocks models quantized using Pytorch 2.0 export path. Quantization is not yet supported with JAX/PolyBlocks.

In the following example, a complete quantization of the ResNet50 model in PolyBlocks/Tensorflow is performed. The quantized model is then executed using the PolyBlocks JIT decorator.

# Quantized ResNet50_V2 inference.

import numpy as np
import tensorflow as tf
import time
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow_quantization.quantize import quantize_model
from polyblocks.tf_polyblocks_compiler import polyblocks_jit_tf

# 1. Initialize the ResNet50V2 model trained on the ImageNet dataset.
ResNet = ResNet50V2(
     input_shape=(224, 224, 3),

# 2. Quantize the model. 8-bit quantization is the default.
quantized_resnet = quantize_model(ResNet)

# 3. Execute using PolyBlocks JIT decorator.

# Generate a random input image to provide as input to the model.
image = np.random.rand(1, 224, 224, 3).astype(np.float32)

# Perform ResNet50_V2 inference via PolyBlocks.
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def polyblocks_resnet_inference(image):
    return quantized_resnet(image)

out = polyblocks_resnet_inference(image)

Similarly, in the below example, the ResNet50 model written in PyTorch has been quantized while compiling with PolyBlocks.

# Quantized PyTorch ResNet50 inference.
import copy
import torch
import torch._dynamo as torchdynamo
from resnet50_quantizer_utils import get_symmetric_quantization_config, BackendQuantizer
import torchvision
from torch.ao.quantization.quantize_pt2e import (
from polyblocks.torch_polyblocks_compiler import polyblocks_jit_torch
from polyblocks.common_utils import validate_outputs
from torchvision.models import resnet50, ResNet50_Weights

if __name__ == "__main__":
    example_inputs = (torch.randn(1, 3, 224, 224),)
    model = resnet50(weights=ResNet50_Weights.DEFAULT, progress=True).eval()
    # Program capture.
    exported_model, guards = torchdynamo.export(
    quantizer = BackendQuantizer()
    operator_config = get_symmetric_quantization_config()

    # Prepare the model for PTQ.
    prepared_model = prepare_pt2e(exported_model, quantizer)
    # Calibrate/train the model.
    after_prepare_result = prepared_model(*example_inputs)
    # Convert back the calibrated/trained model to the quantized model.
    quantized_model = convert_pt2e(prepared_model, fold_quantize=True)
    print("converted module is: {}".format(quantized_model), flush=True)

    # Perform ResNet50_V2 inference via PolyBlocks.
    polyblocks_quantized_resnet50_inference = polyblocks_jit_torch(
        compile_options={'target': 'nvgpu'}

    image = torch.randn(1, 3, 224, 224)
    out = polyblocks_quantized_resnet50_inference(image)

All of the optimizations available to non-quantized models benefit quantized models as well. In addition, reduced precision delivers improved performance for quantized models through the use of int8-based tensor core instructions (on NVIDIA GPUs) and other optimizations in conjunction that use memory bandwidth more efficiently.