Video Tutorial


Accelerating PyTorch models with JAX#

Accelerate your Pytorch models by converting them to JAX for faster inference.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.

Make sure you run this demo with GPU enabled!

[1]:
!pip install -q ivy
!pip install -q transformers
!pip install -q dm-haiku
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Let’s now import Ivy and the libraries we’ll use in this example:

[2]:
import jax
jax.devices()
import ivy
ivy.set_default_device("gpu:0")
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor

Now we can load a ResNet model and its corresponding feature extractor from Hugging Face transformers library

[3]:
jax.config.update("jax_enable_x64", False)

arch_name = "ResNet"
checkpoint_name = "microsoft/resnet-50"

feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_name)
model = AutoModel.from_pretrained(checkpoint_name).to('cuda')
2023-11-02 19:23:15.980130: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-02 19:23:15.980177: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-02 19:23:15.980207: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-02 19:23:17.351203: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

We will also need a sample image to pass during tracing, so let’s use the feature extractor to get the corresponding torch tensors.

[4]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to('cuda')

And finally, let’s transpile the model to haiku!

[5]:
transpiled_graph = ivy.transpile(model, to="haiku", kwargs=inputs)
WARNING:root:To preserve the tracer and transpiler caches across multiple machines, ensure that the relative path of your projects from the .ivy folder is consistent across all machines. You can do this by adding .ivy to your home folder and placing all projects in the same place relative to the home folder on all machines.
WARNING:root:Native Numpy does not support GPU placement, consider using Jax instead
/workspaces/ivy/ivy/utils/exceptions.py:390: UserWarning: The current backend: 'jax' does not support inplace updates natively. Ivy would quietly create new arrays when using inplace updates with this backend, leading to memory overhead (same applies for views). If you want to control your memory management, consider doing ivy.set_inplace_mode('strict') which should raise an error whenever an inplace update is attempted with this backend.
  warnings.warn(

After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile

[ ]:
# ref : https://github.com/pytorch/pytorch/issues/107960
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
[6]:
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")

def _f(**kwargs):
  return model(**kwargs)

comp_model = torch.compile(_f)
_ = comp_model(**inputs)

Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

[7]:
inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)

import haiku as hk

def _forward(**kwargs):
  module = transpiled_graph()
  return module(**kwargs).last_hidden_state

rng_key = jax.random.PRNGKey(42)
jax_forward = hk.transform(_forward)
params = jax_forward.init(rng=rng_key, **inputs_jax)
jit_apply = jax.jit(jax_forward.apply)

Now that we have both models optimized, let’s see how their runtime speeds compare to each other!

[8]:
%%timeit
_ = comp_model(**inputs)
6.63 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[9]:
%%timeit
out = jit_apply(params, None, **inputs_jax)
1.18 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

As expected, we have made the model significantly faster with just one line of code, getting a ~2x increase in its execution speed! 🚀

Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models

[10]:
url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")
inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)
out_torch = comp_model(**inputs)
out_jax = jit_apply(params, None, **inputs_jax)

np.allclose(out_torch.last_hidden_state.detach().cpu().numpy(), out_jax, atol=1e-4)
[10]:
True

That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!