Just-in-time Compilation#
The @polyblocks_jit_* decorators#
The @polyblocks_jit_tf
decorator can be attached to a function to specify that
the latter should be compiled and executed using the PolyBlocks JIT compiler.
Additionally, compile options can be specified as a trailing argument. The
target
option is an important one and can be ‘cpu’, ‘nvgpu’, or ‘amdgpu’. It
defaults to ‘cpu’ when not provided. TensorFlow, PyTorch (2.0), and JAX are supported.
Using PolyBlocks with TensorFlow#
PolyBlocks works with tf.functions written using the various Tensorflow Python API operators available.
@polyblocks_jit_tf(input_signature: list, compile_options: dict)
input_signature
is optional and can be provided by the user; its syntax and semantics remain the same as in the case of atf.function
. It ensures that the code is specialized for the supplied shape and element type instead of being inferred from the inputs. Even when a signature is not supplied, the compilation is cached, and repeated calls with the same input shapes and elemental types do not trigger recompilation.compile_options
is a dictionary to be used to supply any options/flags to be sent to the PolyBlocks compiler. This is especially useful to enable/disable optimizations and features or specify a target. A list of compile options can be found here.
Here is a simple example that leads to the blur_x_blur_y
function being JIT
compiled via PolyBlocks to an NVIDIA GPU on the system.
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
img_reshaped = tf.reshape(img, [3, height, width, 1])
# The output is two short at each end.
blurx = tf.nn.conv2d(img_reshaped,
blur_kernel, [1, 1, 1, 1], "VALID")
blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
blury = tf.nn.conv2d(blurx, blur_kernel,
[1, 1, 1, 1], "VALID")
output = tf.reshape(blurry, [3, height - 4, width - 4])
return output
The annotated function can now be called just like before, passing it either numpy
arrays or tf.tensor
. The call to it triggers JIT compilation with PolyBlocks and execution. When making multiple calls, the JIT compilation happens only the first time unless the input shapes or compile options change on subsequent invocations.
img = np.random.rand(3, height, width).astype(np.float32)
out = blur_x_blur_y(img)
Another way is to decorate the function at the callsite. In this case, there is no need to annotate the blur_x_blur_y definition. The following example demonstrates it:
img = np.random.rand(3, height, width).astype(np.float32)
blur_x_blur_y_polyblocks_compilable = polymage_jit_tf(
compile_options={'target': 'nvgpu'}
)(blur_x_blur_y)
out = blur_x_blur_y_polyblocks_compilable(img)
When executing on a system with NVIDIA GPUs, the CUDA_VISIBLE_DEVICES
environment variable can be used to control which one(s) the function can
execute.
PolyBlocks-powered PyTorch#
The PolyBlocks compiler supports the PyTorch framework. PyTorch/PolyBlocks
uses TorchDynamo with
the PolyBlocks engine as a backend to generate fast optimized code by compiling
the captured Torch FX graphs. One can either use an annotation or use PolyBlocks
as the backend compiler with torch.compile
.
@polyblocks_jit_torch(compile_options: dict)
An example with PyTorch.
@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625], device = img.device).float()
blur_kernel = kernel.view(1, 1, 1, 5)
img_reshaped = img.view(3, 1, height, width)
# The output is two short at each end.
blurx = torch.nn.functional.conv2d(img_reshaped,
blur_kernel, stride=1, padding=0)
blur_kernel = kernel.view(1, 1, 5, 1)
blury = torch.nn.functional.conv2d(blurx, blur_kernel,
stride=1, padding=0)
output = blury.view(3, height - 4, width - 4)
return output
The function can be then called the usual way:
img = torch.rand(3, height, width, device=device).float()
out = blur_x_blur_y(img)
Use torch.compile with “polyblocks” backend#
A more torch-native way of using PolyBlocks would be via torch.compile
, and
the following is functionally equivalent to the polyblocks annotation-based approach:
img = torch.rand(3, height, width, device=device).float()
polyblocks_compiled_func = torch.compile(blur_x_blur_y, backend="polyblocks")
out = polyblocks_compiled_func(img)
PolyBlocks compiled PyTorch functions support inputs of type torch.Tensor
irrespective of where the inputs reside. The only restriction is that all of the
inputs should be on the host, or all of them should be on a GPU.
For optimal end-to-end performance PolyBlocks/PyTorch GPU runtime does not
synchronize unless the data is brought back to CPU. However, cuda_synchronize
function from polyblocks.device_helpers
can be used to make sure that all the
PolyBlocks accelerated work running on GPU is finished.
PolyBlocks-powered JAX#
Functions written in JAX can be compiled via PolyBlocks using the
@polyblocks_jit_jax
annotation. @polyblocks_jit_jax
has the following
syntax. (The annotation does not require an input specification.)
@polyblocks_jit_jax(compile_options: dict)
An example with JAX.
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(
features=out_channels,
kernel_size=filter_shape,
padding="VALID",
use_bias=True,
dtype=jnp.dtype("float16"),
param_dtype=jnp.dtype("float16"),
)(x)
x = nn.relu(x)
return x
# Care must be taken when passing the model parameters. They should always be
# passed as function arguments to the function that is being JITted. If we
# do not pass them in that manner and obtain them from the enclosing scope, all those parameters
# will be captured as constant in the resultant IR, and that will negatively impact
# compilation times as well as performance. When they are explicitly passed,
# they are captured as function parameters, which is better for compilation time
# and performance.
@polyblocks_jit_jax(compile_options={'target': 'nvgpu'})
def ConvBiasRelu(variables, img):
out = cnn.apply(variables, img)
return out
The JAX JIT annotation also supports passing static arguments (arguments to be
treated as compile-time constants) via the
static_argnums
compile option.
@polyblocks_jit_jax(
compile_options={
"target": "nvgpu",
"debug": False,
"static_argnums": (2, 3),
}
)
def gradient_descent(
X, D, learning_rate=0.0001, num_iterations=1000
):
In the above example, arguments at indices 2 and 3 are treated as compile-time constants.
All inputs (except static arguments) to a JAX function being compiled via
PolyBlocks must be JAX arrays obtained using
jax.device_put()
.
D_JAX = jax.device_put(D, jax.devices("cpu")[0])
Compare with other backends: polymage_jit_* annotations#
The flexibility to execute via separate backends also exists. A different
annotation, polymage_jit_jax/tf/torch
must be used to enable execution via
different backends. For TensorFlow and JAX, there exist three different
backends that can be selected using the execution_backend
option:
execution_backend
can be one ofstd
for the standard JAX/TensorFlow runtime/executor,xla
for the TF/XLA JIT or JAX JIT, andpolyblocks
for the PolyBlocks backend. As an example:
@polymage_jit_tf(
[
tf.TensorSpec(shape=input_shape, dtype=tf.float16),
tf.TensorSpec(shape=filter_shape, dtype=tf.float16),
tf.TensorSpec(shape=bias_shape, dtype=tf.float16),
],
execution_backend="xla",
compile_options={"target": "nvgpu"},
)
def conv_bias_add_relu(img, filters, bias):
conv = tf.nn.conv2d(
img,
filters,
strides=(1, 1),
padding="VALID",
data_format="NHWC",
dilations=(1, 1),
)
conv_with_bias = tf.nn.bias_add(conv, bias)
output = tf.nn.relu(conv_with_bias)
return output
Tooling and developer ecosystem#
All well-known tools are able to run on PolyBlocks-optimized code and provide useful performance insights. This includes profiling, performance monitoring, and memory sanity-checking tools. For example, NVIDIA Nsight systems nsys
can be used to profile and understand performance, and the CUDA compute sanitizer can be used to check for memory errors.
As an example, the execution profile from nsys below shows that PolyBlocks was able to fuse all operators into a single GPU kernel.
nsys profile -s none --cpuctxsw=none --trace=cuda -o gpu_ python blur_x_blur_y.py -gpu \
-skip-tf-standard -skip-tf-xla -debug && nsys stats -q --report gpukernsum --format table gpu_.nsys-rep
** CUDA GPU Kernel Summary (gpukernsum):
+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+
| Time (%) | Total Time (ns) | Instances | Avg (ns) | Med (ns) | Min (ns) | Max (ns) | StdDev (ns) | GridXYZ | BlockXYZ | Name |
+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+
| 100.0 | 651,070 | 1 | 651,070.0 | 651,070.0 | 651,070 | 651,070 | 0.0 | 682 33 3 | 96 1 1 | __inference_BlurXY_15_Conv2D_1_kernel |
+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+