Accelerating PyTorch models with JAX

Accelerate your Pytorch models by converting them to JAX for faster inference.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Setting Up section of the docs.

!git clone https://github.com/unifyai/ivy.git
!cd ivy && git checkout 04a9f61f875a892574993a9b4e032aac02ab21f2 && python3 -m pip install --user -e .
!pip install transformers
!pip install dm-haiku

For the installed packages to be available you will have to restart your kernel. In Colab, you can do this by clicking on “Runtime > Restart Runtime”. Once the runtime has been restarted you should skip the previous cell 😄

To use the compiler and the transpiler now you will need an API Key. If you already have one, you should replace the string in the next cell.

API_KEY = "PASTE_YOUR_KEY_HERE"
!mkdir -p .ivy
!echo -n $API_KEY > .ivy/key.pem

Let’s now import Ivy and the libraries we’ll use in this example:

import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor

Now we can load a vision model and its corresponding feature extractor from Hugging Face transformers library

jax.config.update("jax_enable_x64", True)

arch_name = "Vision Transformer (ViT)"
checkpoint_name = "google/vit-base-patch16-224"

feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_name)
model = AutoModel.from_pretrained(checkpoint_name)

We will also need a sample image to pass during tracing, so let’s use the feature extractor to get the corresponding torch tensors.

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
)

And finally, let’s transpile the model to haiku!

transpiled_graph = ivy.transpile(model, to="haiku", kwargs=inputs)

After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile

inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")

model.to("cuda")

def _f(**kwargs):
  return model(**kwargs)

comp_model = torch.compile(_f)
_ = comp_model(**inputs)

Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)

import haiku as hk

def _forward(**kwargs):
  module = transpiled_graph()
  return module(**kwargs).last_hidden_state

_forward = jax.jit(_forward)
rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, **inputs_jax)

Now that we have both models optimized, let’s see how their runtime speeds compare to each other!

%%timeit
_ = comp_model(**inputs)
%%timeit
out = jax_mlp_forward.apply(params, None, **inputs_jax)

As expected, we have made the model significantly faster with just one line of code, getting a ~3x increase in its execution speed! 🚀

Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models

url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")
inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)
out_torch = comp_model(**inputs)
out_jax = jax_mlp_forward.apply(params, None, **inputs_jax)

np.allclose(out_torch.last_hidden_state.detach().cpu().numpy(), out_jax, atol=1e-4)

That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!