Video Tutorial


Transpile any model#

Transpile a Keras model into a PyTorch module.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.

[1]:
!pip install ivy
Requirement already satisfied: ivy in /workspaces/ivy (0.0.4.0)
Requirement already satisfied: numpy in /opt/fw/mxnet (from ivy) (1.26.1)
Requirement already satisfied: einops in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (0.7.0)
Requirement already satisfied: psutil in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (5.9.6)
Requirement already satisfied: termcolor in /opt/fw/tensorflow (from ivy) (2.3.0)
Requirement already satisfied: colorama in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (0.4.6)
Requirement already satisfied: packaging in /opt/fw/tensorflow (from ivy) (23.2)
Requirement already satisfied: nvidia-ml-py in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (12.535.108)
Requirement already satisfied: diskcache in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (5.6.3)
Requirement already satisfied: google-auth in /opt/fw/tensorflow (from ivy) (2.23.3)
Requirement already satisfied: urllib3<2.0 in /root/.local/lib/python3.10/site-packages (from ivy) (1.26.18)
Requirement already satisfied: requests in /opt/fw/mxnet (from ivy) (2.31.0)
Requirement already satisfied: pyvis in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (0.3.2)
Requirement already satisfied: dill in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (0.3.7)
Requirement already satisfied: astunparse in /opt/fw/tensorflow (from ivy) (1.6.3)
Requirement already satisfied: ml-dtypes in /opt/fw/tensorflow (from ivy) (0.2.0)
Requirement already satisfied: cloudpickle in /opt/fw/tensorflow (from ivy) (3.0.0)
Requirement already satisfied: gast in /opt/fw/tensorflow (from ivy) (0.5.4)
Requirement already satisfied: tqdm in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ivy) (4.66.1)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/fw/tensorflow (from astunparse->ivy) (0.41.3)
Requirement already satisfied: six<2.0,>=1.6.1 in /opt/fw/tensorflow (from astunparse->ivy) (1.16.0)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/fw/tensorflow (from google-auth->ivy) (5.3.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/fw/tensorflow (from google-auth->ivy) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /opt/fw/tensorflow (from google-auth->ivy) (4.9)
Requirement already satisfied: ipython>=5.3.0 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from pyvis->ivy) (8.17.1)
Requirement already satisfied: jinja2>=2.9.6 in /opt/fw/torch (from pyvis->ivy) (3.1.2)
Requirement already satisfied: jsonpickle>=1.4.1 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from pyvis->ivy) (3.0.2)
Requirement already satisfied: networkx>=1.11 in /opt/fw/torch (from pyvis->ivy) (3.2.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/fw/mxnet (from requests->ivy) (3.3.1)
Requirement already satisfied: idna<4,>=2.5 in /opt/fw/mxnet (from requests->ivy) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/fw/mxnet (from requests->ivy) (2023.7.22)
Requirement already satisfied: decorator in /opt/fw/tensorflow (from ipython>=5.3.0->pyvis->ivy) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (0.19.1)
Requirement already satisfied: matplotlib-inline in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (0.1.6)
Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (3.0.39)
Requirement already satisfied: pygments>=2.4.0 in /opt/fw/jax (from ipython>=5.3.0->pyvis->ivy) (2.16.1)
Requirement already satisfied: stack-data in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (0.6.3)
Requirement already satisfied: traitlets>=5 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (5.13.0)
Requirement already satisfied: exceptiongroup in /opt/fw/paddle (from ipython>=5.3.0->pyvis->ivy) (1.1.3)
Requirement already satisfied: pexpect>4.3 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from ipython>=5.3.0->pyvis->ivy) (4.8.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/fw/tensorflow (from jinja2>=2.9.6->pyvis->ivy) (2.1.3)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /opt/fw/tensorflow (from pyasn1-modules>=0.2.1->google-auth->ivy) (0.5.0)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from jedi>=0.16->ipython>=5.3.0->pyvis->ivy) (0.8.3)
Requirement already satisfied: ptyprocess>=0.5 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from pexpect>4.3->ipython>=5.3.0->pyvis->ivy) (0.7.0)
Requirement already satisfied: wcwidth in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=5.3.0->pyvis->ivy) (0.2.9)
Requirement already satisfied: executing>=1.2.0 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from stack-data->ipython>=5.3.0->pyvis->ivy) (2.0.1)
Requirement already satisfied: asttokens>=2.1.0 in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from stack-data->ipython>=5.3.0->pyvis->ivy) (2.4.1)
Requirement already satisfied: pure-eval in /opt/miniconda/envs/multienv/lib/python3.10/site-packages (from stack-data->ipython>=5.3.0->pyvis->ivy) (0.2.2)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

As we’ve already seen, ivy.transpile can convert functions and whole libraries from one framework to another. However, in machine learning and deep learning, much of the focus is on trainable modules. Fortunately, Ivy can manage the parameters of these modules and ensure that the transpiled module is fully compatible with the target framework. This allows you to take full advantage of the training utilities provided by any framework and to build complex models on top of the transpiled ones. Let’s see how this works!

Let’s start by importing the neccessary libraries:

[2]:
import ivy
import torch
import numpy as np
import tensorflow as tf
tf.keras.utils.set_random_seed(0)
torch.manual_seed(0)
2023-11-01 16:31:42.301436: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2]:
<torch._C.Generator at 0x7f252c392390>

There are examples which use more involved models in the Guides section, but to keep things simple, let’s define a basic convolutional network using Keras’ Sequential API to use it as the starting point!

[3]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 3)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation='softmax')
])

Now, we can use ivy.transpile to convert this Keras model to PyTorch. Since we are passing a framework-specific object to the transpile function, there is no need to specify the source keyword argument this time.

[4]:
input_array = tf.random.normal((1, 28, 28, 3))
torch_model = ivy.transpile(model, source = "tensorflow", to="torch", args=(input_array,))
WARNING:root:To preserve the tracer and transpiler caches across multiple machines, ensure that the relative path of your projects from the .ivy folder is consistent across all machines. You can do this by adding .ivy to your home folder and placing all projects in the same place relative to the home folder on all machines.
/workspaces/ivy/ivy/utils/exceptions.py:390: UserWarning: The current backend: 'tensorflow' does not support inplace updates natively. Ivy would quietly create new arrays when using inplace updates with this backend, leading to memory overhead (same applies for views). If you want to control your memory management, consider doing ivy.set_inplace_mode('strict') which should raise an error whenever an inplace update is attempted with this backend.
  warnings.warn(
/workspaces/ivy/ivy/compiler/compiler.py:159: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  return _transpile(
/workspaces/ivy/ivy/compiler/compiler.py:159: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  return _transpile(

Thanks to (eager) transpilation, we now have a fully-fledged torch.nn.module 🚀

[5]:
isinstance(torch_model, torch.nn.Module)
[5]:
True

This means that we can pass PyTorch inputs (keeping the channels-last format of Keras, as the new computational graph is identical to the original one!) and get PyTorch tensors as the output:

[6]:
input_array = torch.rand((1, 28, 28, 3))
output_array = torch_model(input_array)
print(output_array)
tensor([[0.0893, 0.1504, 0.1372, 0.0991, 0.0867, 0.0851, 0.0911, 0.0804, 0.0926,
         0.0881]], grad_fn=<SoftmaxBackward0>)

Furthermore, having a torch.nn.Module also enables you to train the model using PyTorch training code, and also to use the transpiled model to build more complex torch models, as shown in the Transpiling a haiku model to build on top guide!

Round up#

This is the last tutorial related to the tracer/transpiler in Learn the Basics. In the next tutorial, we’ll go over an introduction to building models directly using Ivy 👨‍💻. If you are interested in continuing to learn about transpilation, you can check out the more complex tutorials in the Guides section!