torch function to
jax with just one line of code.
⚠️ If you are running this notebook in Colab, you will have to install
Ivy and some dependencies manually. You can do so by running the cell below ⬇️
If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.
!pip install ivy
Using what we learnt in the previous two notebooks for Unify and Compile, the workflow for converting directly from
jax would be as follows, first unifying to
ivy code, and then compiling to the
import ivy import torch ivy.set_backend("jax") def normalize(x): mean = torch.mean(x) std = torch.std(x) return torch.div(torch.sub(x, mean), std) # convert the function to Ivy code ivy_normalize = ivy.unify(normalize) # compile the Ivy code into jax functions jax_normalize = ivy.compile(ivy_normalize)
normalize is now compiled to
jax, ready to be integrated into your wider
This workflow is common, and so in order to avoid repeated calls to
ivy.unify followed by
ivy.compile, there is another convenience function
ivy.transpile, which basically acts as a shorthand for this pair of function calls:
jax_normalize = ivy.transpile(normalize, source="torch", to="jax")
normalize is now a
jax function, ready to be integrated into your
import jax key = jax.random.PRNGKey(42) jax.config.update('jax_enable_x64', True) x = jax.random.uniform(key, shape=(10,)) print(jax_normalize(x))
[-0.93968587 0.26075466 -0.22723222 -1.06276492 -0.47426987 1.72835908 1.71737559 -0.50411096 -0.65419174 0.15576624]
That’s it, you can now transpile code from one framework to another with one line of code! However, there are still other important topics to master before you’re ready to unify ML code like a pro 🥷. In the next notebooks we’ll be learning about the various different ways that
ivy.transpile can be called, and what implications each of these have!