PolyBlocks: A Compiler Engine#
PolyBlocks is a high-performance, modular, and fully automatic MLIR-based compiler engine for deep learning (DL) and non-DL computations from PolyMage Labs.
A Simple Incremental Programming Model
The PolyBlocks engine can be used with existing Python-based programming models like PyTorch, TensorFlow, and JAX, by adding simple annotations in an incremental way.
Transformations
The PolyBlocks engine performs long and complex sequences of transformations for parallelization, locality optimization, and memory minimization relying on MLIR and polyhedral techniques.
Run Anywhere
The PolyBlocks’ code generation pipeline is modular and reusable. Simply change a compile option to run on a different target.
Preview#
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
img_reshaped = tf.reshape(img, [3, height, width, 1])
# The output is two short at each end.
blurx = tf.nn.conv2d(img_reshaped,
blur_kernel, [1, 1, 1, 1], "VALID")
blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
blury = tf.nn.conv2d(blurx, blur_kernel,
[1, 1, 1, 1], "VALID")
output = tf.reshape(blury, [3, height - 4, width - 4])
return output
@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y_pytorch(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625], device = img.device).float()
blur_kernel = kernel.view(1, 1, 1, 5)
img_reshaped = img.view(3, 1, height, width)
# The output is two short at each end.
blurx = torch.nn.functional.conv2d(img_reshaped,
blur_kernel, stride=1, padding=0)
blur_kernel = kernel.view(1, 1, 5, 1)
blury = torch.nn.functional.conv2d(blurx, blur_kernel,
stride=1, padding=0)
output = blury.view(3, height - 4, width - 4)
return output
class CNN(nn.Module):
@nn.compact
def __call__(self, img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
img_reshaped = jnp.reshape(img, [3, height, width, 1])
blur_x_kernel = np.array(
[[[[0.0625]], [[0.25]], [[0.375]], [[0.25]], [[0.0625]]]]
)
blurx = nn.Conv(
features=1,
kernel_size=[1, 5],
kernel_init=jax.nn.initializers.constant(blur_x_kernel),
padding="VALID",
use_bias=False,
dtype=jnp.dtype("float32"),
param_dtype=jnp.dtype("float32"),
)(img_reshaped)
blur_y_kernel = jnp.reshape(blur_x_kernel, [5, 1, 1, 1])
blury = nn.Conv(
features=1,
kernel_size=[5, 1],
kernel_init=jax.nn.initializers.constant(blur_y_kernel),
padding="VALID",
use_bias=False,
dtype=jnp.dtype("float32"),
param_dtype=jnp.dtype("float32"),
)(blurx)
return jnp.reshape(blury, [3, height - 4, width - 4])
@polyblocks_jit_jax(compile_options={'target' : 'nvgpu'})
def BlurXBlurY(img):
out = cnn.apply(variables, img)
return out