Quantization Support#
The PolyBlocks compiler supports the execution of quantized TensorFlow models quantized using NVIDIA’s tensorflow_quantization library and Pytorch/PolyBlocks models quantized using Pytorch 2.0 export path. Quantization is not yet supported with JAX/PolyBlocks.
In the following example, a complete quantization of the ResNet50 model in PolyBlocks/Tensorflow is performed. The quantized model is then executed using the PolyBlocks JIT decorator.
# Quantized ResNet50_V2 inference.
import numpy as np
import tensorflow as tf
import time
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow_quantization.quantize import quantize_model
from polyblocks.tf_polyblocks_compiler import polyblocks_jit_tf
# 1. Initialize the ResNet50V2 model trained on the ImageNet dataset.
ResNet = ResNet50V2(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=(224, 224, 3),
pooling=None,
classes=1000,
classifier_activation="softmax",
)
# 2. Quantize the model. 8-bit quantization is the default.
quantized_resnet = quantize_model(ResNet)
# 3. Execute using PolyBlocks JIT decorator.
# Generate a random input image to provide as input to the model.
image = np.random.rand(1, 224, 224, 3).astype(np.float32)
# Perform ResNet50_V2 inference via PolyBlocks.
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def polyblocks_resnet_inference(image):
return quantized_resnet(image)
out = polyblocks_resnet_inference(image)
Similarly, in the below example, the ResNet50 model written in PyTorch has been quantized while compiling with PolyBlocks.
# Quantized PyTorch ResNet50 inference.
import copy
import torch
import torch._dynamo as torchdynamo
from resnet50_quantizer_utils import get_symmetric_quantization_config, BackendQuantizer
import torchvision
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)
from polyblocks.torch_polyblocks_compiler import polyblocks_jit_torch
from polyblocks.common_utils import validate_outputs
from torchvision.models import resnet50, ResNet50_Weights
if __name__ == "__main__":
example_inputs = (torch.randn(1, 3, 224, 224),)
model = resnet50(weights=ResNet50_Weights.DEFAULT, progress=True).eval()
# Program capture.
exported_model, guards = torchdynamo.export(
model,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
quantizer = BackendQuantizer()
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)
# Prepare the model for PTQ.
prepared_model = prepare_pt2e(exported_model, quantizer)
# Calibrate/train the model.
after_prepare_result = prepared_model(*example_inputs)
# Convert back the calibrated/trained model to the quantized model.
quantized_model = convert_pt2e(prepared_model, fold_quantize=True)
print("converted module is: {}".format(quantized_model), flush=True)
# Perform ResNet50_V2 inference via PolyBlocks.
polyblocks_quantized_resnet50_inference = polyblocks_jit_torch(
compile_options={'target': 'nvgpu'}
)(quantized_model)
image = torch.randn(1, 3, 224, 224)
out = polyblocks_quantized_resnet50_inference(image)
All of the optimizations available to non-quantized models benefit quantized models as well. In addition, reduced precision delivers improved performance for quantized models through the use of int8-based tensor core instructions (on NVIDIA GPUs) and other optimizations in conjunction that use memory bandwidth more efficiently.