Video Tutorial

Transpile code#

Convert a torch function to jax with just one line of code.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.

[ ]:
!pip install ivy

Using what we learnt in the previous two notebooks for Unify and Trace, the workflow for converting directly from torch to jax would be as follows, first unifying to ivy code, and then tracing to the jax backend:

import ivy
import torch

def normalize(x):
    mean = torch.mean(x)
    std = torch.std(x)
    return torch.div(torch.sub(x, mean), std)

# convert the function to Ivy code
ivy_normalize = ivy.unify(normalize)

# trace the Ivy code into jax functions
jax_normalize = ivy.trace_graph(ivy_normalize)

normalize is now traced to jax, ready to be integrated into your wider jax project.

This workflow is common, and so in order to avoid repeated calls to ivy.unify followed by ivy.trace_graph, there is another convenience function ivy.transpile, which basically acts as a shorthand for this pair of function calls:

jax_normalize = ivy.transpile(normalize, source="torch", to="jax")

Again, normalize is now a jax function, ready to be integrated into your jax project.

import jax

key = jax.random.PRNGKey(42)
jax.config.update('jax_enable_x64', True)
x = jax.random.uniform(key, shape=(10,))

[-0.93968587  0.26075466 -0.22723222 -1.06276492 -0.47426987  1.72835908
  1.71737559 -0.50411096 -0.65419174  0.15576624]

Round Up#

That’s it, you can now transpile code from one framework to another with one line of code! However, there are still other important topics to master before you’re ready to unify ML code like a pro 🥷. In the next notebooks we’ll be learning about the various different ways that ivy.unify, ivy.trace_graph and ivy.transpile can be called, and what implications each of these have!