Table of Contents
Unify API

Unifying Framework Fragmentation with Ivy

Guillermo Sanchez-Brizuela & Ved Patwardhan
November 28, 2023
5 min read

The Framework Fragmentation Problem

Over the last decade, there has been substantial progress in machine learning. This progress has motivated developers to create numerous open-source tools, expediting the processes of researching, developing, and deploying machine learning applications. TensorFlow, PyTorch, and JAX have risen as leading choices for machine learning, with NumPy being widely recognized as the standard framework for numerical computation.

Users exhibit varied preferences depending on their specific use cases. For instance, numerous researchers favor PyTorch for its user-friendly interface, Pythonic design, and dynamic graph capabilities. On the other hand, some opt for TensorFlow, drawn to its efficiently compiled graphs, advanced deployment capabilities, and compatibility with edge and mobile devices. Meanwhile, others lean towards JAX, appreciating its fully functional form, exceptional runtime efficiency on TPUs, and unparalleled flexibility in gradient computation.

Every framework comes with its unique advantages and limitations, making it more suitable for specific use cases. Consequently, various sectors in both industry and academia employ different tools for different purposes, creating barriers to collaboration and hindering the democratization of knowledge. This, in turn, leads to costly re-implementations and suboptimal runtime efficiency during deployment.

In the graph below, we can see how pytorch popularity has grown significantly in the last years, while tf clearly declines, still, there is significant fragmentation, as other frameworks such as jax and mindspore have joined the arena.


What’s Ivy?

In this post, we will explore Ivy, a framework designed to address this issue by fostering interoperability among frameworks. Ivy facilitates the portability and reusability of code, unlocking access to framework-specific infrastructure and support for previously unavailable hardware. This advancement accelerates the development, training, and inference of machine learning models.

At a conceptual level, Ivy functions as a collection of APIs that empower the creation of framework-agnostic code. It acts as an intermediate representation bridging different frameworks, streamlining the conversion of code between them.

Its major components are the Ivy Framework and the Ivy Transpiler:

The Ivy Framework

The Ivy framework strives to offer an interface that ensures consistent behavior across the designated backends. It consists of three distinct APIs and backend APIs, each of which implements the Ivy API using one of the supported frameworks.

Functional API

At its core, the framework is centered on a unified functional API, positioned as the lowest level of abstraction and built upon the functional APIs of the respective backends. This functional API is compatible with multiple backends, including NumPy, JAX, TensorFlow, PyTorch, and PaddlePaddle. Each backend features a specialized implementation of the Ivy functional API. The functional API encompasses all functions outlined in the Array API Standard, encompassing widely used array-processing functions, along with additional functions not explicitly covered in the standard, such as activations, layers, and more.

Stateful API

Additionally, Ivy incorporates a stateful API that furnishes high-level modules and optimizers, facilitating swift prototyping. As this stateful API is constructed atop the functional API, it remains independent of specific backends.

Frontend API

Frontend APIs have been developed for PyTorch, TensorFlow, NumPy, JAX, and PaddlePaddle. The purpose of a frontend API is to emulate the functionality of each function in the respective framework's functional API, utilizing only functions from Ivy's functional API. Because Ivy's functional API is executable across any of the supported backends, these frontends enable the replication of the original behavior of a framework using any supported backend. In this manner, Ivy's functional API serves as an intermediary representation (IR) bridging different frameworks. (Ivy’s functional API acts as an intermediate representation (IR) between frameworks.)

Ivy’s functional API selects the relevant parts of its unified backend implementation

The Ivy Transpiler

In addition to the framework itself, Ivy incorporates a Transpiler designed to handle arbitrary code. Code transpilation involves converting code from one framework to another or adapting it to a different version of the source framework. Examples of this process include migrating PyTorch code to JAX for enhanced speed, translating JAX models to PyTorch for seamless integration into existing pipelines, or updating code versions when encountering breaking changes.

At the core of the transpiler is the graph tracer, which captures a computational graph of any function or trainable module based on array computing. Internally, Ivy's tracer logs functional-API level calls made during a computation, including essential information such as unique identifiers for inputs and outputs. Once this data is collected, the tracer recursively reconstructs a directed acyclic graph (DAG), starting from the output of the top-level system and connecting inputs and outputs of each function contributing mathematically to the computation.

By harnessing the capabilities of the graph tracer, combined with Ivy's frontends and backends APIs, Ivy facilitates code conversion in the transpiler, following the outlined process below.

The Transpiler for Research

As mentioned earlier, Ivy's transpiler serves various purposes. This section will delve into its primary value propositions.

Libraries and Tools' Interoperability

Due to the simultaneous existence of various frameworks, distinct ecosystems of libraries and tools have emerged around them. This implies that developers often find themselves confined to utilizing tools available within their specific ecosystem. Ivy offers assistance in two distinct ways in this scenario. Firstly, it can transpile an entire array-based library to a different framework, converting all functions within the library to the target framework for regular use. Secondly, Ivy enables users to harness training tools, pipelines, and overall infrastructure designed for models in a particular framework by seamlessly converting the model from the user's framework to the one required by the library.

Below is a code snippet exemplifying the typical workflow for library transpilation, where we transpile the Kornia library from PyTorch to JAX.

import ivy
import cv2
import kornia
import tensorflow as tf

# transpile module lazily
kornia = ivy.transpile(kornia, to="tensorflow")

# load image into tensorflow
img = tf.constant(cv2.imread("image.png")) / 255
img = tf.expand_dims(tf.transpose(img, (2, 0, 1)), 0)

# custom composite
def dilate_edges(img): edges = kornia.filters.canny(img)[1]
return kornia.morphology.dilation(edges, tf.ones((7, 7))

# dilate edges
new_img = dilate_edges(img) # slow, transpiling
new_img = dilate_edges(img) # fast, transpiled

Integration of Trainable Modules

Trainable modules, such as specific layers or models, are typically released for a specific framework. This often leads to expensive reimplementations where nuances may be overlooked, hindering the reproducibility of results. In such instances, Ivy can automatically convert these trainable modules, expediting research and making the latest models accessible to every developer. Given that the result of Ivy's transpilation process comprises functions from the target framework's functional API, this implies that the output graph (in this case, encapsulated within a trainable module in the target framework) is entirely differentiable. Consequently, it can be trained or fine-tuned as if it were originally developed in the target framework.

Consider an example where a Haiku module undergoes transpilation to PyTorch and is subsequently employed as the core component for a classifier. This illustrates how users can seamlessly transfer any module or a portion thereof to their preferred framework and then extend upon it, resulting in a trainable module that replicates the computational behaviour of the original system.

import ivy
import jax
import torch

# Get a pretrained haiku model
from deepmind_perceiver_io import key, perceiver_backbone
# Transpile it into a torch.nn.Module with the corresponding parameters
dummy_input = jax.random.uniform(key, shape=(1, 3, 224, 224))
params = perceiver_backbone.init(rng=key, images=dummy_input)
backbone = ivy.transpile(
perceiver_backbone, to="torch", params_v=params, kwargs={"images": dummy_input)

# Build a clasifier using the transpiled backbone
class PerceiverIOClassifier(torch.nn.Module):
def __init__(self, num_classes=20):
super(PerceiverIOClassifier, self).__init__()
self.backbone = backbone
self.max_pool = torch.nn.MaxPool2d((512, 1))
self.flatten = torch.nn.Flatten()
self.fc = torch.nn.Linear(1024, num_classes)

def forward(self, x):
x = self.backbone(images=x)
x = self.flatten(self.max_pool(x))
return self.fc(x)

# Initialize a trainable, customizable, torch.nn.Module
classifier = PerceiverIOClassifier()
ret = classifier(torch.rand((1, 3, 224, 224)))

The Transpiler for Deployment

We have seen how the Transpiler can speed up research and development by enabling the use of existing code and tools originally designed for other frameworks. However, transpiling a trainable module on its own can also be beneficial in terms of computational resources. 

We can tweak a model to take advantage of optimizations that only work in certain compilers or to use hardware that might not be well-supported in the original framework.

We can see an example of this in the table below. The data in the table represents the relative increment in latency we can obtain when transpiling a model to JAX compared to running the original PyTorch model using torch.compile and its different modes. For the benchmarking, we selected a subset of popular models from the Hugging Face Transformers library and averaged the latency measurements over 100 runs using a NVIDIA V100 GPU.

Looking at the results, we can observe that even after the latest improvements in the PyTorch compiling tools, the conversion of PyTorch models to JAX still bring a mean improvement in the execution time of the model 1.27x, which we can mainly attribute to the different optimizations that happen at the XLA level.

About the Author
Guillermo Sanchez-Brizuela & Ved Patwardhan
Unify | ML Engineers

Guillermo Sanchez-Brizuela

Led predictive analytics and AI research projects, and earned a Master’s with a focus on Deep Learning, Big Data, and Machine Learning from UVa. His work bridges Deep Learning research and AI deployment.

Ved Patwardhan

Focusing on areas of computer vision and natural language understanding, graduated from Pune Institute of Computer Technology.

More Reads

left button chevronright button chevron

Wish Your LLM Deployment Was
Faster, Cheaper and Simpler?

Use the Unify API to send your prompts to the best LLM endpoints and get your LLM applications flying