Imax: Making Image Augmentations fast with JAX

sample transformations

Image augmentations make all the difference when working with neural networks. Everybody should know that by now. No matter what you're trying to train, if it involves images you should be using heavy and fancy augmentations! The only downside of these heavy augmentations is that they might slow down your training significantly if they are not implemented in a fast and efficient way. With Imax the goal was to solve that while getting better at Jax.
And you can try the results today!

The reason for the importance of augmentations is that as long as you don't have literally infinite data representing the entire sample space the model will see in production, it will always end up with samples that fall outside of the seen range. By applying lots of augmentations you can create many samples from every single image in your dataset. Thus augmentations help the model to get more robust to random permutations and less reliant on specific features.

With Imax I first focused on fast and efficient linear and perspective transforms in 2D and 3D (yes it even accepts depth maps for the transformations!). After that was done I added some PIL color transforms also used in the randaugment paper.

Finally, I implemented a fully jittable version of randaugment (trust me that's not that easy 😅) that can run on CPU or GPU or be vmapped over a batch of images, and after the first iteration is super-fast.

For those interested in the code check it out on my github.

You can try it out on colab or like this:

pip install imax
from jax import random
import jax.numpy as jnp
from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from imax import transforms, color_transforms, randaugment

# Setup
random_key = random.PRNGKey(42)
random_key, split_key = random.split(random_key)
image = jnp.asarray('test.png').convert('RGBA')).astype('uint8')

# Geometric transforms:
transform = transforms.rotate(rad=0.7)  # create transformation matrix
transformed_image = transforms.apply_transform(image,    # apply transformation
                                               mask_value=jnp.array([0, 0, 0, 255]))

# multiple transformations can be combined through matrix multiplication
# this makes multiple sequential transforms much faster
multi_transform = transform @ transform @ transform
multi_transformed_image = transforms.apply_transform(image,

# Color transforms:
adjusted_image = color_transforms.posterize(image, bits=2)

# Randaugment:
randomized_image = randaugment.distort_image_with_randaugment(
    num_layers=3,   # number of random augmentations in sequence
    magnitude=15,   # magnitude of random augmentations

# Show results:
results = [transformed_image, multi_transformed_image, adjusted_image, randomized_image]
fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(2, 2),

for ax, im in zip(grid, results):