# Imax: Making Image Augmentations fast with JAX

Image augmentations make all the difference when working with neural networks. Everybody should know that by now.
No matter what you're trying to train, if it involves images you should be using heavy and fancy augmentations!
The only downside of these heavy augmentations is that they might slow down your training significantly if they are not implemented in a fast and efficient way.
With *Imax* the goal was to solve that while getting better at Jax.

And you can try the results today!

The reason for the importance of augmentations is that as long as you don't have literally infinite data representing the entire sample space the model will see in production, it will always end up with samples that fall outside of the seen range. By applying lots of augmentations you can create many samples from every single image in your dataset. Thus augmentations help the model to get more robust to random permutations and less reliant on specific features.

With Imax I first focused on fast and efficient linear and perspective transforms in 2D and 3D (yes it even accepts depth maps for the transformations!). After that was done I added some PIL color transforms also used in the randaugment paper.

Finally, I implemented a fully jittable version of randaugment (trust me that's not that easy ðŸ˜…) that can run on CPU or GPU or be vmapped over a batch of images, and after the first iteration is super-fast.

For those interested in the code check it out on my github.

You can try it out on colab or like this:

pip install imax

from jax import random import jax.numpy as jnp from PIL import Image from matplotlib import pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid from imax import transforms, color_transforms, randaugment # Setup random_key = random.PRNGKey(42) random_key, split_key = random.split(random_key) image = jnp.asarray(Image.open('test.png').convert('RGBA')).astype('uint8') # Geometric transforms: transform = transforms.rotate(rad=0.7) # create transformation matrix transformed_image = transforms.apply_transform(image, # apply transformation transform, mask_value=jnp.array([0, 0, 0, 255])) # multiple transformations can be combined through matrix multiplication # this makes multiple sequential transforms much faster multi_transform = transform @ transform @ transform multi_transformed_image = transforms.apply_transform(image, multi_transform, mask_value=-1) # Color transforms: adjusted_image = color_transforms.posterize(image, bits=2) # Randaugment: randomized_image = randaugment.distort_image_with_randaugment( image, num_layers=3, # number of random augmentations in sequence magnitude=15, # magnitude of random augmentations random_key=split_key ) # Show results: results = [transformed_image, multi_transformed_image, adjusted_image, randomized_image] fig = plt.figure(figsize=(10., 10.)) grid = ImageGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0.1) for ax, im in zip(grid, results): ax.axis('off') ax.imshow(im) plt.show()