Quantization Support#

The PolyBlocks compiler supports the compilation and execution of quantized PyTorch and TensorFlow models. Quantization is not yet supported with JAX/PolyBlocks. All of the optimizations available to non-quantized models benefit quantized models as well. In addition, reduced precision delivers improved performance for quantized models through the use of int8-based tensor core instructions (on NVIDIA GPUs) and other optimizations in conjunction that use memory bandwidth more efficiently.

With TensorFlow#

The PolyBlocks compiler supports the execution of quantized TensorFlow models quantized using NVIDIA’s tensorflow_quantization library. In the following example, a complete quantization of the ResNet50 model in PolyBlocks/Tensorflow is performed. The quantized model is then executed using the PolyBlocks JIT decorator.

# Quantized ResNet50_V2 inference.

import numpy as np
import tensorflow as tf
import time
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow_quantization.quantize import quantize_model
from polyblocks.tf_polyblocks_compiler import polyblocks_jit_tf


# 1. Initialize the ResNet50V2 model trained on the ImageNet dataset.
ResNet = ResNet50V2(
     include_top=True,
     weights="imagenet",
     input_tensor=None,
     input_shape=(224, 224, 3),
     pooling=None,
     classes=1000,
     classifier_activation="softmax",
)

# 2. Quantize the model. 8-bit quantization is the default.
quantized_resnet = quantize_model(ResNet)

# 3. Execute using PolyBlocks JIT decorator.

# Generate a random input image to provide as input to the model.
image = np.random.rand(1, 224, 224, 3).astype(np.float32)

# Perform ResNet50_V2 inference via PolyBlocks.
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def polyblocks_resnet_inference(image):
    return quantized_resnet(image)

out = polyblocks_resnet_inference(image)

With PyTorch#

The PolyBlocks compiler supports the execution of quantized PyTorch models using the Pytorch 2.0 export path In the example below, the ResNet50 model written in PyTorch is quantized while compiling with PolyBlocks.

# Quantized PyTorch ResNet50 inference.
import copy
import torch
import torch._dynamo as torchdynamo
from resnet50_quantizer_utils import get_symmetric_quantization_config, BackendQuantizer
import torchvision
from torch.ao.quantization.quantize_pt2e import (
    convert_pt2e,
    prepare_pt2e,
)
from polyblocks.torch_polyblocks_compiler import polyblocks_jit_torch
from polyblocks.common_utils import validate_outputs
from torchvision.models import resnet50, ResNet50_Weights

if __name__ == "__main__":
    example_inputs = (torch.randn(1, 3, 224, 224),)
    model = resnet50(weights=ResNet50_Weights.DEFAULT, progress=True).eval()
    # Program capture.
    exported_model, guards = torchdynamo.export(
        model,
        *copy.deepcopy(example_inputs),
        aten_graph=True,
    )
    quantizer = BackendQuantizer()
    operator_config = get_symmetric_quantization_config()
    quantizer.set_global(operator_config)

    # Prepare the model for PTQ.
    prepared_model = prepare_pt2e(exported_model, quantizer)
    # Calibrate/train the model.
    after_prepare_result = prepared_model(*example_inputs)
    # Convert back the calibrated/trained model to the quantized model.
    quantized_model = convert_pt2e(prepared_model, fold_quantize=True)
    print("converted module is: {}".format(quantized_model), flush=True)

    # Perform ResNet50_V2 inference via PolyBlocks.
    polyblocks_quantized_resnet50_inference = polyblocks_jit_torch(
        compile_options={'target': 'nvgpu'}
    )(quantized_model)

    image = torch.randn(1, 3, 224, 224)
    out = polyblocks_quantized_resnet50_inference(image)

A performance evaluation with PolyBlocks compiled quantized models can be found here.