Transpile code

Convert a torch function to jax with just one line of code.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Setting Up section of the docs.

!git clone https://github.com/unifyai/ivy.git
!cd ivy && git checkout d6bc18c64a47a135fe18404d9f83f98d9f3b63cf && python3 -m pip install --user -e .

For the installed packages to be available you will have to restart your kernel. In Colab, you can do this by clicking on “Runtime > Restart Runtime”. Once the runtime has been restarted you should skip the previous cell 😄

To use the compiler and the transpiler now you will need an API Key. If you already have one, you should replace the string in the next cell.

API_KEY = "PASTE_YOUR_KEY_HERE"
!mkdir -p .ivy
!echo -n $API_KEY > .ivy/key.pem

Using what we learnt in the previous two notebooks for Unify and Compile, the workflow for converting directly from torch to jax would be as follows, first unifying to ivy code, and then compiling to the jax backend:

import ivy
import torch
ivy.set_backend("jax")

def normalize(x):
    mean = torch.mean(x)
    std = torch.std(x)
    return torch.div(torch.sub(x, mean), std)

# convert the function to Ivy code
ivy_normalize = ivy.unify(normalize)

# compile the Ivy code into jax functions
jax_normalize = ivy.compile(ivy_normalize)

normalize is now compiled to jax, ready to be integrated into your wider jax project.

This workflow is common, and so in order to avoid repeated calls to ivy.unify followed by ivy.compile, there is another convenience function ivy.transpile, which basically acts as a shorthand for this pair of function calls:

jax_normalize = ivy.transpile(normalize, source="torch", to="jax")

Again, normalize is now a jax function, ready to be integrated into your jax project.

import jax

key = jax.random.PRNGKey(42)
jax.config.update('jax_enable_x64', True)
x = jax.random.uniform(key, shape=(10,))

print(jax_normalize(x))
[-0.93968587  0.26075466 -0.22723222 -1.06276492 -0.47426987  1.72835908
  1.71737559 -0.50411096 -0.65419174  0.15576624]

Round Up

That’s it, you can now transpile code from one framework to another with one line of code! However, there are still other important topics to master before you’re ready to unify ML code like a pro 🥷. In the next notebooks we’ll be learning about the various different ways that ivy.unify, ivy.compile and ivy.transpile can be called, and what implications each of these have!