PolyBlocks: A Compiler Engine#
PolyBlocks is a high-performance, modular, and fully automatic MLIR-based compiler engine from PolyMage Labs for AI programming frameworks.
A simple incremental approach
The PolyBlocks engine can be used with existing Python-based programming models like PyTorch, TensorFlow, and JAX, through the addition of a single trivial annotation.
Transformations
The PolyBlocks engine performs long and complex sequences of transformations for parallelization, locality optimization, and memory minimization relying on MLIR and polyhedral techniques.
Run anywhere
The PolyBlocks’ code generation pipeline is modular and reusable. Simply change a compile option to run on a different target.
Preview#
@polyblocks_jit_torch(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = torch.tensor([0.0625, 0.25, 0.375, 0.25, 0.0625],
device = img.device).float()
blur_kernel = kernel.view(1, 1, 1, 5)
img_reshaped = img.view(3, 1, height, width)
# The output is two short at each end.
blurx = torch.nn.functional.conv2d(img_reshaped,
blur_kernel, stride=1, padding=0)
blur_kernel = kernel.view(1, 1, 5, 1)
blury = torch.nn.functional.conv2d(blurx, blur_kernel,
stride=1, padding=0)
output = blury.view(3, height - 4, width - 4)
return output
@polyblocks_jit_tf(compile_options={'target': 'nvgpu'})
def blur_x_blur_y(img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
kernel = tf.constant([0.0625, 0.25, 0.375, 0.25, 0.0625], dtype=tf.float32)
blur_kernel = tf.reshape(kernel, [1, 5, 1, 1])
img_reshaped = tf.reshape(img, [3, height, width, 1])
# The output is two short at each end.
blurx = tf.nn.conv2d(img_reshaped,
blur_kernel, [1, 1, 1, 1], "VALID")
blur_kernel = tf.reshape(kernel, [5, 1, 1, 1])
blury = tf.nn.conv2d(blurx, blur_kernel,
[1, 1, 1, 1], "VALID")
output = tf.reshape(blury, [3, height - 4, width - 4])
return output
class CNN(nn.Module):
@nn.compact
def __call__(self, img):
# We are doing a regular image convolution in the X direction and then in
# the Y direction.
img_reshaped = jnp.reshape(img, [3, height, width, 1])
blur_x_kernel = np.array(
[[[[0.0625]], [[0.25]], [[0.375]], [[0.25]], [[0.0625]]]]
)
blurx = nn.Conv(
features=1,
kernel_size=[1, 5],
kernel_init=jax.nn.initializers.constant(blur_x_kernel),
padding="VALID",
use_bias=False,
dtype=jnp.dtype("float32"),
param_dtype=jnp.dtype("float32"),
)(img_reshaped)
blur_y_kernel = jnp.reshape(blur_x_kernel, [5, 1, 1, 1])
blury = nn.Conv(
features=1,
kernel_size=[5, 1],
kernel_init=jax.nn.initializers.constant(blur_y_kernel),
padding="VALID",
use_bias=False,
dtype=jnp.dtype("float32"),
param_dtype=jnp.dtype("float32"),
)(blurx)
return jnp.reshape(blury, [3, height - 4, width - 4])
@polyblocks_jit_jax(compile_options={'target' : 'nvgpu'})
def BlurXBlurY(img):
out = cnn.apply(variables, img)
return out