!git clone https://github.com/unifyai/ivy.git
!cd ivy && git checkout 04a9f61f875a892574993a9b4e032aac02ab21f2 && python3 -m pip install --user -e .
!pip install transformers
!pip install dm-haiku
Accelerating PyTorch models with JAX
Accelerate your Pytorch models by converting them to JAX for faster inference.
⚠️ If you are running this notebook in Colab, you will have to install Ivy
and some dependencies manually. You can do so by running the cell below ⬇️
If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Setting Up section of the docs.
For the installed packages to be available you will have to restart your kernel. In Colab, you can do this by clicking on “Runtime > Restart Runtime”. Once the runtime has been restarted you should skip the previous cell 😄
To use the compiler and the transpiler now you will need an API Key. If you already have one, you should replace the string in the next cell.
= "PASTE_YOUR_KEY_HERE" API_KEY
!mkdir -p .ivy
!echo -n $API_KEY > .ivy/key.pem
Let’s now import Ivy and the libraries we’ll use in this example:
import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor
Now we can load a vision model and its corresponding feature extractor from Hugging Face transformers library
"jax_enable_x64", True)
jax.config.update(
= "Vision Transformer (ViT)"
arch_name = "google/vit-base-patch16-224"
checkpoint_name
= AutoFeatureExtractor.from_pretrained(checkpoint_name)
feature_extractor = AutoModel.from_pretrained(checkpoint_name) model
We will also need a sample image to pass during tracing, so let’s use the feature extractor to get the corresponding torch tensors.
= "http://images.cocodataset.org/val2017/000000039769.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = feature_extractor(
inputs =image, return_tensors="pt"
images )
And finally, let’s transpile the model to haiku!
= ivy.transpile(model, to="haiku", kwargs=inputs) transpiled_graph
After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile
= feature_extractor(
inputs =image, return_tensors="pt"
images"cuda")
).to(
"cuda")
model.to(
def _f(**kwargs):
return model(**kwargs)
= torch.compile(_f)
comp_model = comp_model(**inputs) _
Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:
= feature_extractor(
inputs_jax =image, return_tensors="jax"
images
)
import haiku as hk
def _forward(**kwargs):
= transpiled_graph()
module return module(**kwargs).last_hidden_state
= jax.jit(_forward)
_forward = jax.random.PRNGKey(42)
rng_key = hk.transform(_forward)
jax_mlp_forward = jax_mlp_forward.init(rng=rng_key, **inputs_jax) params
Now that we have both models optimized, let’s see how their runtime speeds compare to each other!
%%timeit
= comp_model(**inputs) _
%%timeit
= jax_mlp_forward.apply(params, None, **inputs_jax) out
As expected, we have made the model significantly faster with just one line of code, getting a ~3x increase in its execution speed! 🚀
Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models
= "http://images.cocodataset.org/train2017/000000283921.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image = feature_extractor(
inputs =image, return_tensors="pt"
images"cuda")
).to(= feature_extractor(
inputs_jax =image, return_tensors="jax"
images
)= comp_model(**inputs)
out_torch = jax_mlp_forward.apply(params, None, **inputs_jax)
out_jax
=1e-4) np.allclose(out_torch.last_hidden_state.detach().cpu().numpy(), out_jax, atol
That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!