PolyBlocks: A Compiler Engine#

PolyBlocks is a high-performance, modular, and fully automatic MLIR-based compiler engine for deep learning (DL) and non-DL computations from PolyMage Labs.

A simple incremental approach

The PolyBlocks engine can be used with existing Python-based programming models like PyTorch, TensorFlow, and JAX, through the addition of a single trivial annotation.

Transformations

The PolyBlocks engine performs long and complex sequences of transformations for parallelization, locality optimization, and memory minimization relying on MLIR and polyhedral techniques.

Run anywhere

The PolyBlocks’ code generation pipeline is modular and reusable. Simply change a compile option to run on a different target.

Using PolyBlocks
Playground
FAQ

Preview#

@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
  blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
  img_reshaped = tf.reshape(img, [3, height, width, 1])
  # The output is two short at each end.
  blurx = tf.nn.conv2d(img_reshaped,
                         blur_kernel, [1, 1, 1, 1], "VALID")
  blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
  blury = tf.nn.conv2d(blurx, blur_kernel,
                         [1, 1, 1, 1], "VALID")
  output = tf.reshape(blury, [3, height - 4, width - 4])
  return output
@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625],
                        device = img.device).float()
  blur_kernel = kernel.view(1, 1, 1, 5)
  img_reshaped = img.view(3, 1, height, width)
  # The output is two short at each end.
  blurx = torch.nn.functional.conv2d(img_reshaped,
                                     blur_kernel, stride=1, padding=0)
  blur_kernel = kernel.view(1, 1, 5, 1)
  blury = torch.nn.functional.conv2d(blurx, blur_kernel,
                                     stride=1, padding=0)
  output = blury.view(3, height - 4, width - 4)
  return output
class CNN(nn.Module):
@nn.compact
def __call__(self, img):
  # We are doing a regular image convolution in the X direction and then in
  # the Y direction.
  img_reshaped = jnp.reshape(img, [3, height, width, 1])
  blur_x_kernel = np.array(
      [[[[0.0625]], [[0.25]], [[0.375]], [[0.25]], [[0.0625]]]]
  )
  blurx = nn.Conv(
      features=1,
      kernel_size=[1, 5],
      kernel_init=jax.nn.initializers.constant(blur_x_kernel),
      padding="VALID",
      use_bias=False,
      dtype=jnp.dtype("float32"),
      param_dtype=jnp.dtype("float32"),
  )(img_reshaped)
  blur_y_kernel = jnp.reshape(blur_x_kernel, [5, 1, 1, 1])
  blury = nn.Conv(
      features=1,
      kernel_size=[5, 1],
      kernel_init=jax.nn.initializers.constant(blur_y_kernel),
      padding="VALID",
      use_bias=False,
      dtype=jnp.dtype("float32"),
      param_dtype=jnp.dtype("float32"),
  )(blurx)
  return jnp.reshape(blury, [3, height - 4, width - 4])

@polyblocks_jit_jax(compile_options={'target' : 'nvgpu'})
def BlurXBlurY(img):
    out = cnn.apply(variables, img)
    return out