Just-in-time Compilation#

The @polyblocks_jit_* decorators#

The @polyblocks_jit_tf decorator can be attached to a function to specify that the latter should be compiled and executed using the PolyBlocks JIT compiler. Additionally, compile options can be specified as a trailing argument. The target option is an important one and can be ‘cpu’, ‘nvgpu’, or ‘amdgpu’. It defaults to ‘cpu’ when not provided. TensorFlow, PyTorch (2.0), and JAX are supported.

Using PolyBlocks with TensorFlow#

PolyBlocks works with tf.functions written using the various Tensorflow Python API operators available.

@polyblocks_jit_tf(input_signature: list, compile_options: dict)
  • input_signature is optional and can be provided by the user; its syntax and semantics remain the same as in the case of a tf.function. It ensures that the code is specialized for the supplied shape and element type instead of being inferred from the inputs. Even when a signature is not supplied, the compilation is cached, and repeated calls with the same input shapes and elemental types do not trigger recompilation.

  • compile_options is a dictionary to be used to supply any options/flags to be sent to the PolyBlocks compiler. This is especially useful to enable/disable optimizations and features or specify a target. A list of compile options can be found here.

Here is a simple example that leads to the blur_x_blur_y function being JIT compiled via PolyBlocks to an NVIDIA GPU on the system.

@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
    # We are doing a regular image convolution in the X direction and then in
    # the Y direction.
    kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
    blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
    img_reshaped = tf.reshape(img, [3, height, width, 1])
    # The output is two short at each end.
    blurx = tf.nn.conv2d(img_reshaped,
                         blur_kernel, [1, 1, 1, 1], "VALID")
    blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
    blury = tf.nn.conv2d(blurx, blur_kernel,
                         [1, 1, 1, 1], "VALID")
    output = tf.reshape(blurry, [3, height - 4, width - 4])
    return output

The annotated function can now be called just like before, passing it either numpy arrays or tf.tensor. The call to it triggers JIT compilation with PolyBlocks and execution. When making multiple calls, the JIT compilation happens only the first time unless the input shapes or compile options change on subsequent invocations.

 img = np.random.rand(3, height, width).astype(np.float32)
 out = blur_x_blur_y(img)

Another way is to decorate the function at the callsite. In this case, there is no need to annotate the blur_x_blur_y definition. The following example demonstrates it:

 img = np.random.rand(3, height, width).astype(np.float32)
 blur_x_blur_y_polyblocks_compilable = polymage_jit_tf(
                                       runtime_args={'target': 'nvgpu'}
                                       )(blur_x_blur_y)
 out = blur_x_blur_y_polyblocks_compilable(img)

When executing on a system with NVIDIA GPUs, the CUDA_VISIBLE_DEVICES environment variable can be used to control which one(s) the function can execute.

PolyBlocks-powered PyTorch#

The PolyBlocks compiler supports the PyTorch framework. PyTorch/PolyBlocks uses TorchDynamo with the PolyBlocks engine as a backend to generate fast optimized code by compiling the captured Torch FX graphs.

@polyblocks_jit_torch annotation does not require an input specification.

@polyblocks_jit_torch(compile_options: dict)

An example with PyTorch.

@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
    # We are doing a regular image convolution in the X direction and then in
    # the Y direction.
    kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625], device = img.device).float()
    blur_kernel = kernel.view(1, 1, 1, 5)
    img_reshaped = img.view(3, 1, height, width)
    # The output is two short at each end.
    blurx = torch.nn.functional.conv2d(img_reshaped,
                                       blur_kernel, stride=1, padding=0)
    blur_kernel = kernel.view(1, 1, 5, 1)
    blury = torch.nn.functional.conv2d(blurx, blur_kernel,
                                       stride=1, padding=0)
    output = blury.view(3, height - 4, width - 4)
    return output

The function can be then called the usual way:

img = torch.rand(3, height, width, device=device).float()
out = blur_x_blur_y(img)

PolyBlocks compiled PyTorch functions support inputs of type torch.Tensor irrespective of where the inputs reside. The only restriction is that all of the inputs should be on the host, or all of them should be on a GPU.

PolyBlocks-powered JAX#

Functions written in JAX can be compiled via PolyBlocks using the @polyblocks_jit_jax annotation. @polyblocks_jit_jax has the following syntax. (The annotation does not require an input specification.)

@polyblocks_jit_jax(compile_options: dict)

An example with JAX.

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(
            features=out_channels,
            kernel_size=filter_shape,
            padding="VALID",
            use_bias=True,
            dtype=jnp.dtype("float16"),
            param_dtype=jnp.dtype("float16"),
        )(x)
        x = nn.relu(x)
        return x

# Care must be taken when passing the model parameters. They should always be
# passed as function arguments to the function that is being JITted. If we
# do not pass them in that manner and obtain them from the enclosing scope, all those parameters
# will be captured as constant in the resultant IR, and that will negatively impact
# compilation times as well as performance. When they are explicitly passed,
# they are captured as function parameters, which is better for compilation time
# and performance.
@polyblocks_jit_jax(compile_options={'target': 'nvgpu'})
def ConvBiasRelu(variables, img):
    out = cnn.apply(variables, img)
    return out

The JAX JIT annotation also supports passing static arguments (arguments to be treated as compile-time constants) via the static_argnums compile option.

@polyblocks_jit_jax(
    compile_options={
        "target": "nvgpu",
        "debug": False,
        "static_argnums": (2, 3),
    }
)
def gradient_descent(
    X, D, learning_rate=0.0001, num_iterations=1000
):

In the above example, arguments at indices 2 and 3 are treated as compile-time constants.

All inputs (except static arguments) to a JAX function being compiled via PolyBlocks must be JAX arrays obtained using jax.device_put().

D_JAX = jax.device_put(D, jax.devices("cpu")[0])

Compare with other backends: polymage_jit_* annotations#

The flexibility to execute via separate backends also exists. A different annotation, polymage_jit_jax/tf/torch must be used to enable execution via different backends. For TensorFlow and JAX, there exist three different backends that can be selected using the execution_backend option:

  • execution_backend can be one of std for the standard JAX/TensorFlow runtime/executor, xla for the TF/XLA JIT or JAX JIT, and polyblocks for the PolyBlocks backend. As an example:

@polymage_jit_tf(
    [
        tf.TensorSpec(shape=input_shape, dtype=tf.float16),
        tf.TensorSpec(shape=filter_shape, dtype=tf.float16),
        tf.TensorSpec(shape=bias_shape, dtype=tf.float16),
    ],
    execution_backend="xla",
    compile_options={"target": "nvgpu"},
)
def conv_bias_add_relu(img, filters, bias):
    conv = tf.nn.conv2d(
        img,
        filters,
        strides=(1, 1),
        padding="VALID",
        data_format="NHWC",
        dilations=(1, 1),
    )
    conv_with_bias = tf.nn.bias_add(conv, bias)
    output = tf.nn.relu(conv_with_bias)
    return output

Tooling and developer ecosystem#

All well-known tools are able to run on PolyBlocks-optimized code and provide useful performance insights. This includes profiling, performance monitoring, and memory sanity-checking tools. For example, NVIDIA Nsight systems nsys can be used to profile and understand performance, and the CUDA compute sanitizer can be used to check for memory errors.

As an example, the execution profile from nsys below shows that PolyBlocks was able to fuse all operators into a single GPU kernel.

nsys profile -s none --cpuctxsw=none --trace=cuda -o gpu_ python blur_x_blur_y.py -gpu \
    -skip-tf-standard -skip-tf-xla -debug && nsys stats -q --report gpukernsum --format table gpu_.nsys-rep

  ** CUDA GPU Kernel Summary (gpukernsum):

+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+
| Time (%) | Total Time (ns) | Instances | Avg (ns)  | Med (ns)  | Min (ns) | Max (ns) | StdDev (ns) |    GridXYZ     |    BlockXYZ    |                 Name                  |
+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+
|    100.0 |         651,070 |         1 | 651,070.0 | 651,070.0 |  651,070 |  651,070 |         0.0 |  682   33    3 |   96    1    1 | __inference_BlurXY_15_Conv2D_1_kernel |
+----------+-----------------+-----------+-----------+-----------+----------+----------+-------------+----------------+----------------+---------------------------------------+