PolyBlocks is a compiler engine. The compiler driver built using the PolyBlocks engine is referred to as the PolyBlocks compiler, which is a JIT (just-in-time compiler) as well as an AOT (ahead-of-time) compiler. The PolyBlocks compiler is aimed at being fully automatic, analytical model-driven, and fully code generating, i.e., it does not rely on any vendor/HPC libraries. PolyBlocks is built using the MLIR infrastructure.


There have been multiple objectives behind building PolyBlocks. First, we would like to exploit the full power of today’s ML/AI hardware automatically, i.e., without resorting to hand-written and hand-tuned libraries while delivering high performance with a unified compiler infrastructure. Second, we would like to free users from the burden of writing in low-level languages like CUDA while getting performance competitive with the best hand-optimized code – in some cases better and, in other cases, within an acceptable range. Several domains are not well-served by libraries; even those that are leave significant performance on the table, especially on emerging patterns of computation or models. As one example, PolyBlocks can transform code in a way that achieves a complex non-trivial interleaving of the execution of multiple operators in a compute graph specification. Such transformations are hard to manually realize or even determine.

In addition, a compiler engine like PolyBlocks is easier to continuously adapt to and drive compilation for new and evolving hardware. In contrast, developing and optimizing libraries or specific operators manually or semi-automatically for every new generation of hardware leads to repeated/duplicated effort.

Programming Model#

The PolyBlocks compiler works with existing Python-based programming models for ML/AI. It currently supports TensorFlow, PyTorch, and JAX. To JIT compile functions of interest, a user is required only to add a single line of annotation to ask PolyBlocks to compile and execute that function instead of having that function run through the standard runtimes of TensorFlow, PyTorch, or JAX.

The programming model to unlock compilation and acceleration with PolyBlocks is thus incremental, requiring minimal changes to an existing model. This is similar in spirit to the OpenMP programming model, well-known to users from the high-performance computing domain.


@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
  blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
  img_reshaped = tf.reshape(img, [3, height, width, 1])
  # The output is two short at each end.
  blurx = tf.nn.conv2d(img_reshaped,
                       blur_kernel, [1, 1, 1, 1], "VALID")
  blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
  blury = tf.nn.conv2d(blurx, blur_kernel,
                         [1, 1, 1, 1], "VALID")
  output = tf.reshape(blury, [3, height - 4, width - 4])
  return output
@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625],
                        device = img.device).float()
  blur_kernel = kernel.view(1, 1, 1, 5)
  img_reshaped = img.view(3, 1, height, width)
  # The output is two short at each end.
  blurx = torch.nn.functional.conv2d(img_reshaped,
                                     blur_kernel, stride=1, padding=0)
  blur_kernel = kernel.view(1, 1, 5, 1)
  blury = torch.nn.functional.conv2d(blurx, blur_kernel,
                                     stride=1, padding=0)
  output = blury.View(3, height - 4, width - 4)
  return output
class CNN(nn.Module):
def __call__(self, img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  img_reshaped = jnp.reshape(img, [3, height, width, 1])
  blur_x_kernel = np.array(
      [[[[0.0625]], [[0.25]], [[0.375]], [[0.25]], [[0.0625]]]]
  blurx = nn.Conv(
      kernel_size=[1, 5],
  blur_y_kernel = jnp.reshape(blur_x_kernel, [5, 1, 1, 1])
  blury = nn.Conv(
      kernel_size=[5, 1],
  return jnp.reshape(blury, [3, height - 4, width - 4])

@polyblocks_jit_jax(compile_options={'target' : 'nvgpu'})
def BlurXBlurY(img):
    out = cnn.apply(variables, img)
    return out

DL or non-DL Workloads#

The PolyBlocks compiler is not restricted to deep learning (DL) or AI workloads. Any specification written using the operators of TensorFlow, PyTorch, and JAX (see limitations on coverage in State of Support) can be compiled and optimized on the supported hardware. This alleviates users from writing code in low-level languages such CUDA or OpenCL while obtaining high performance. Our experimentation shows that PolyBlocks delivers significantly higher speedups when applied to non-DL workloads since the latter are not served by highly-tuned vendor libraries and primitives like cuDNN, cuBLAS, CUTLASS, MKL, OneDNN, ZenDNN, etc. On non-DL workloads, we have often noticed PolyBlocks to be 5x-15x faster than the standard runtimes of JAX, PyTorch, and Tensorflow and as much faster than their respective compiler backends (XLA or torch.compile).


The PolyBlocks compiler currently supports the CPUs, NVIDIA GPUs, and AMD GPUs. It is able to exploit the specialized tensor/matmul execution units on the GPUs. One of the key strengths of the PolyBlocks compiler is its ability to perform complex sequences of polyhedral transformations to exploit locality, parallelism, and minimize memory usage in conjunction with well-known transformations relying on the SSA form.

For modern NVIDIA and AMD GPU GPUs, PolyBlocks is expected to generate code that is several times faster on average for a diverse range of large DL as well non-DL (traditional dense matrix/image processing) workloads than state-of-the-art library-based approaches, including the standard runtimes of TensorFlow, PyTorch, JAX, as well as their standard JITs (TF/XLA), PyTorch compile, and JAX JIT. That said, PolyBlocks is a work in progress and is being released early. From the standpoint of coverage and performance, mileage is expected to improve. More information can be found in the FAQ and Known Issues.