Video Tutorial

Accelerating MMPreTrain models with JAX#

Accelerate your MMPreTrain models by converting them to JAX for faster inference.


Make sure you run this demo with GPU enabled!

[ ]:
!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"
!pip install -q ivy
!pip install -q dm-haiku

Let’s now import Ivy and the libraries we’ll use in this example:

import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
import time

import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDict

Sanity check to make sure checkpoint name is correct against mmpretrain’s model zoo

checkpoint_name = "convnext-tiny_32xb128-noema_in1k"

Now we can load the ConvNext model from OpenMMLab’s mmpretrain library

[ ]:
jax.config.update("jax_enable_x64", True)

model = get_model(checkpoint_name, pretrained=True, device='cuda')

We will also need a sample image to pass during tracing, so let’s use the appropriate transforms to get the corresponding torch tensors.

def get_scale(cfg):
    if type(cfg) == ConfigDict:
        if cfg.get('type', False) and cfg.get('scale', False):
            return cfg['scale']
            for k in cfg.keys():
                input_shape = get_scale(cfg[k])
                if input_shape:
                    return input_shape
    elif type(cfg) == list:
        for block in cfg:
            input_shape = get_scale(block)
            if input_shape:
                return input_shape
        return None
url = ""
image =, stream=True).raw)
input_shape = get_scale(model._config.train_pipeline)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_shape, input_shape)),
tensor_image = transform(image).unsqueeze(0).to("cuda")

And finally, let’s transpile the model to haiku!

[ ]:
transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))

After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile

[ ]:
# ref :
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
tensor_image = transform(image).unsqueeze(0).to("cuda")

def _f(args):
  return model(args)

comp_model = torch.compile(_f)
_ = comp_model(tensor_image)

Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])

import haiku as hk

def _forward(args):
  module = transpiled_graph()
  return module(args)

rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, args=jax_image)
apply = jax.jit(jax_mlp_forward.apply)
_ = apply(params, None, jax_image)

Now that we have both models optimized, let’s see how their runtime speeds compare to each other!

%timeit comp_model(tensor_image)
8.06 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply(params, None, jax_image).block_until_ready()
6.08 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

As expected, we have made the model significantly faster with just one line of code! Latency gets even better on a V100 GPU, where we can get up to a 2-3x increase in execution speed! 🚀

Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models

url = ""
image =, stream=True).raw)
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])

st = time.perf_counter()
out_torch = comp_model(tensor_image)
et = time.perf_counter()
print(f'Torch call took: {(et - st) * 1000:.2f}ms')

st = time.perf_counter()
out_jax = apply(params, None, jax_image)
et = time.perf_counter()
print(f'Jax call took: {(et - st) * 1000:.2f}ms')

print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4))
Torch call took: 6.66ms
Jax call took: 2.53ms

That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!