Ivy Frontends#

Introduction#

On top of the Ivy functional API and backend functional APIs, Ivy has another set of framework-specific frontend functional APIs, which play an important role in code transpilations, as explained here.

The Frontend Basics#

When using functions and methods of Ivy Frontends, in addition to importing ivy itself like import ivy please also import the corresponding Frontend module. For example, to use ivy’s tensorflow frontend:

import ivy.functional.frontends.tensorflow as tf_frontend


When testing the frontend functions, we can sometimes call the function directly from the root frontend namespace. For example, we call tensorflow.tan rather than tensorflow.math.tan(). In this particular case both are fine, and in fact are aliases.

However, sometimes an extra namespace path is necessary. Taking JAX as an example, the functions jax.numpy.abs() and jax.lax.abs() both exist, while jax.abs() does not exist. In our JAX frontend, if we add both of these to the root namespace, it would not be possible to call jax.abs() in our frontend.

This would result in jax.numpy.abs() or jax.lax.abs() overwriting the other one in an arbitrary manner. In fact, neither of these should be added to the root namespace, as it does not exist in the native jax framework.

If you accidentally test a function with fn_tree="<func_name>" instead of fn_tree="<lax|numpy>.<func_name>", you will see an error since the wrong frontend function is being tested.

Therefore, in order to avoid this potential conflict:

  • All frontend tests should use the full namespace path when calling the frontend function. In the case of TensorFlow, this would mean writing fn_tree="math.tan" instead of fn_tree="tan" in the frontend test.

  • The __init__.py file in all frontends should be carefully checked, and you should verify that you are not adding aliases into the frontend which should not exist, such as the case of jax.abs() explained above.

  • You should ensure that the tests are passing before merging any frontend PRs. The only exception to this rule is if the test is failing due to a bug in the Ivy functional API, which does not need to be solved as part of the frontend task.

There will be some implicit discussion of the locations of frontend functions in these examples, however an explicit explanation of how to place a frontend function can be found in a sub-section of the Frontend APIs open task.

NOTE: Type hints, docstrings, and examples are not required when working on frontend functions.

Frontend Arrays

The native arrays of each framework have their own attributes and instance methods which differ from the attributes and instance methods of ivy.Array. As such we have implemented framework-specific array classes: tf_frontend.Tensor, torch_frontend.Tensor, numpy_frontend.ndarray, and jax_frontend.DeviceArray. These classes simply wrap an ivy.Array, which is stored in the ivy_array attribute, and behave as closely as possible to the native framework array classes. This is explained further in the Classes and Instance Methods section.

As we aim to replicate the frontend frameworks as closely as possible, all functions accept their frontend array class (as well as ivy.Array and ivy.NativeArray) and return a frontend array. However, since most logic in each function is handled by Ivy, the ivy.Array must be extracted from any frontend array inputs. Therefore we add the wrapper @to_ivy_arrays_and_back to virtually all functions in the frontends.

There are more framework-specific classes we support in the frontends such as NumPy and Tensorflow Dtype classes, NumPy and Jax Scalars, NumPy Matrix, etc. All these increase the fidelity of our frontends.

Writing Frontend Functions#

Jax

JAX has two distinct groups of functions, those in the jax.lax namespace and those in the jax.numpy namespace. The former set of functions map very closely to the API for the Accelerated Linear Algebra (XLA) compiler, which is used under the hood to run high performance JAX code. The latter set of functions map very closely to NumPy’s well known API. In general, all functions in the jax.numpy namespace are themselves implemented as a composition of the lower-level functions in the jax.lax namespace.

When transpiling between frameworks, the first step is to trace a computation graph of low level python functions for the source framework using Ivy’s tracer, before then replacing these nodes with the associated functions in Ivy’s frontend API. Given that all jax code can be decomposed into jax.lax function calls, when transpiling JAX code it should always be possible to express the computation graph as a composition of only jax.lax functions. Therefore, arguably these are the only functions we should need to implement in the JAX frontend. However, in general we wish to be able to trace a graph in the backend framework with varying levels of dynamicism. A graph of only jax.lax functions chained together in general is more static and less dynamic than a graph which chains jax.numpy functions together. We wish to enable varying extents of dynamicism when creating a graph with our tracer, and therefore we also implement the functions in the jax.numpy namespace in our frontend API for JAX.

Thus, both lax and numpy modules are created in the JAX frontend API. We start with the function lax.add() as an example.

# in ivy/functional/frontends/jax/lax/operators.py
@to_ivy_arrays_and_back
def add(x, y):
    return ivy.add(x, y)

lax.add() is categorised under operators as shown in the jax.lax package directory. We organize the functions using the same categorizations as the original framework, and also mimic the importing behaviour regarding modules and namespaces etc.

For the function arguments, these must be identical to the original function in Jax. In this case, jax.lax.add has two arguments, and so we will also have the same two arguments in our Jax frontend lax.add(). In this case, the function will then simply return ivy.add(), which in turn will link to the backend-specific implementation ivy.add() according to the framework set in the backend.

# in ivy/functional/frontends/jax/lax/operators.py
@to_ivy_arrays_and_back
def tan(x):
    return ivy.tan(x)

Using lax.tan() as a second example, we can see that this is placed under operators, again in the jax.lax directory. By referring to the jax.lax.tan documentation, we can see that it has only one argument. In the same manner as our add() function, we simply link its return to ivy.tan(), and again the computation then depends on the backend framework.

NumPy

# in ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting
@from_zero_dim_arrays_to_scalar
def _add(
    x1,
    x2,
    /,
    out=None,
    *,
    where=True,
    casting="same_kind",
    order="k",
    dtype=None,
    subok=True,
):
    x1, x2 = promote_types_of_numpy_inputs(x1, x2)
    ret = ivy.add(x1, x2, out=out)
    if ivy.is_array(where):
        ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
    return ret

In NumPy, add() is categorised under mathematical_functions with a sub-category of arithmetic_operations as shown in the numpy mathematical functions directory. It is important to note that add() is a universal function (ufunc) in NumPy, thus the function is actually an object with instance methods like .at and .reduce, etc. We deal with this in the NumPy frontend by including a ufunc class and initialising it in the __init__ file:

# in ivy/functional/frontends/numpy/__init__.py
from ivy.functional.frontends.numpy.mathematical_functions.arithmetic_operations import _add
add = ufunc("_add")

As shown, we import the above function _add() and use it to initialise the ufunc object which corresponds to the NumPy add() function. Practically the add() object calls the _add() under the hood, but it has all the extra instance methods of the ufunc class. All other functions which are ufunc objects in NumPy are implemented in the same way. Of course if the ufunc object and its respective function have the same name, we would run into problems where one would overwrite the other, to prevent this we make the actual function private by adding an underscore in the front of its name. Since only the ufunc object should be accessible to the user, this approach is sufficient. When adding new NumPy functions which are ufuncs, it’s important to implement them in this way in order to properly replicate their functionality. Namely, a private function needs to be created in the respective sub-category, this function needs to be imported in the __init__ file, and a ufunc object needs to be created that shares the name of the function. For functions which are not ufuncs, they are named normally without the underscore and are implemented as any other function.

The function arguments for this function are slightly more complex due to the extra optional arguments. Additional handling code is added to recover the behaviour according to the numpy.add documentation. For example, @handle_numpy_out is added to functions with an out argument and it handles the inplace update of the ivy.Array specified by out, or the ivy.Array wrapped by a frontend ndarray. This wrapper was added because out can be either a positional or keyword argument in most functions, thus it required some additional logic for proper handling. Additionally, casting and dtype are handled in the @handle_numpy_casting wrapper, which casts the input arguments to the desired dtype as specified by dtype and the chosen casting rules. There’s an additional wrapper for the dtype argument @handle_numpy_dtype. This wrapper is included to handle the various formats of the dtype argument which NumPy accepts, such as type strings, numpy.Dtype objects, characters, etc. In NumPy, most functions which can return a scalar value return it as a NumPy Scalar. To replicate this we add the wrapper @from_zero_dim_arrays_to_scalar which converts outputs that would normally be 0-dim arrays from Ivy functions, to a NumPy scalar. Of course the returned scalar object is actually an Ivy frontend equivalent object which behaves very similarly to the frontend ndarray. Finally, order is handled in the @to_ivy_arrays_and_back decorator. The returned result is then obtained through ivy.add() just like the other examples.

However, the argument subok is completely unhandled here because it controls whether or not subclasses of the numpy.ndarray should be permitted as inputs to the function. All ivy functions by default do enable subclasses of the ivy.Array to be passed, and the frontend function will be operating with ivy.Array instances rather than numpy.ndarray instances, and so we omit this argument. Again, it has no bearing on input-output behaviour and so this is not a problem when transpiling between frameworks.

See the section “Unused Arguments” below for more details.

# in ivy/functional/frontends/numpy/mathematical_functions/trigonometric_functions.py
@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting
@from_zero_dim_arrays_to_scalar
def _tan(
    x,
    /,
    out=None,
    *,
    where=True,
    casting="same_kind",
    order="K",
    dtype=None,
    subok=True,
):
    ret = ivy.tan(x, out=out)
    if ivy.is_array(where):
        ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
    return ret

For the second example, tan() has a sub-category of trigonometric_functions according to the numpy mathematical functions directory. By referring to the numpy.tan documentation, we can see it has the same additional arguments as the add() function and it’s also a ufunc. In the same manner as add(), we handle the argument out, where, dtype, casting, and order but we omit support for subok.

TensorFlow

# in ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def add(x, y, name=None):
    x, y = check_tensorflow_casting(x, y)
    return ivy.add(x, y)

The add() function is categorised under the math folder in the TensorFlow frontend. There are three arguments according to the tf.math.add documentation, which are written accordingly as shown above. Just like the previous examples, the implementation wraps ivy.add(), which itself defers to backend-specific functions depending on which framework is set in Ivy’s backend.

The arguments x and y are both used in the implementation, but the argument name is not used. Similar to the omitted argument in the NumPy example above, the name argument does not change the input-output behaviour of the function. Rather, this argument is added purely for the purpose of operation logging and retrieval, and also graph visualization in TensorFlow. Ivy does not support the unique naming of individual operations, and so we omit support for this particular argument.

Additionally, TensorFlow only allows explicit casting, therefore there are no promotion rules in the TensorFlow frontend, except in the case of array like or scalar inputs, which get casted to the dtype of the other argument if it’s a Tensor, or the default dtype if both arguments are array like or scalar. The function check_tensorflow_casting() is added to functions with multiple arguments such as add(), and it ensures the second argument is the same type as the first, just as TensorFlow does.

# in ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def tan(x, name=None):
    return ivy.tan(x)

Likewise, tan is also placed under math. By referring to the tf.math.tan documentation, we add the same arguments, and simply wrap ivy.tan() in this case. Again, we do not support the name argument for the reasons outlined above.

NOTE

Many of the functions in the tf.raw_ops module have identical behaviour to functions in the general TensorFlow namespace e.g tf.argmax(). However, these functions are specified to have key-word only arguments and in some cases they have different argument names. In order to tackle these variations in behaviour, the map_raw_ops_alias decorator was designed to wrap the functions that exist in the TensorFlow namespace, thus reducing unnecessary re-implementations.

# in ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def argmax(input, axis, output_type=None, name=None):
    if output_type in ["uint16", "int16", "int32", "int64"]:
        return ivy.astype(ivy.argmax(input, axis=axis), output_type)
    else:
        return ivy.astype(ivy.argmax(input, axis=axis), "int64")

This function argmax() is implemented in the tf.math module of the TensorFlow framework, there exists an identical function in the tf.raw_ops module implemented as ArgMax(). Both the functions have identical behaviour except for the fact that all arguments are passed as key-word only for tf.raw_ops.ArgMax(). In some corner cases, arguments are renamed such as tf.math.argmax(), the dimension argument replaces the axis argument. Let’s see how the map_raw_ops_alias decorator can be used to tackle these variations.

# in ivy/functional/frontends/tensorflow/raw_ops.py
ArgMax = to_ivy_arrays_and_back(
    map_raw_ops_alias(
        tf_frontend.math.argmax,
        kwargs_to_update={"dimension": "axis"},
    )
)

The decorator map_raw_ops_alias here, takes the existing behaviour of tf_frontend.math.argmax() as its first parameter, and changes all its arguments to key-word only. The argument kwargs_to_update is a dictionary indicating all updates in arguments names to be made, in the case of tf.raw_ops.ArgMax(), dimension is replacing axis. The wrapper mentioned above is implemented here map_raw_ops_alias in the ivy codebase.

PyTorch

# in ivy/functional/frontends/torch/pointwise_ops.py
@to_ivy_arrays_and_back
def add(input, other, *, alpha=None, out=None):
    return ivy.add(input, other, alpha=alpha, out=out)

For PyTorch, add() is categorised under pointwise_ops as is the case in the torch framework.

In this case, the native torch.add has both positional and keyword arguments, and we therefore use the same for our PyTorch frontend add(). We wrap ivy.add() as usual.

# in ivy/functional/frontends/torch/pointwise_ops.py
@to_ivy_arrays_and_back
def tan(input, *, out=None):
    return ivy.tan(input, out=out)

tan() is also placed under pointwise_ops as is the case in the torch framework. Looking at the torch.tan documentation, we can mimic the same arguments, and again simply wrap ivy.tan(), also making use of the out argument in this case.

Short Frontend Implementations#

Ideally, all frontend functions should call the equivalent Ivy function and only be one line long. This is mainly because compositional implementations are bound to be slower than direct backend implementation calls.

In case a frontend function is complex and there is no equivalent Ivy function to use, it is strongly advised to add that function to our Experimental API. To do so, you are invited to open a Missing Function Suggestion issue as described in the Open Tasks section. A member of our team will then review your issue, and if the proposed addition is deemed to be timely and sensible, we will add the function to the “Extend Ivy Functional API” ToDo list issue.

If you would rather not wait around for a member of our team to review your suggestion, you can instead go straight ahead and add the frontend function as a heavy composition of the existing Ivy functions, with a #ToDo comment included, explaining that this frontend implementation will be simplified when ivy.func_name() is added.

Examples

The native TensorFlow function tf.reduce_logsumexp() does not have an equivalent function in Ivy, therefore it can be composed of multiple Ivy functions instead.

TensorFlow Frontend

# ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name="reduce_logsumexp"):
    # stable logsumexp trick
    max_input_tensor = ivy.max(input_tensor, axis=axis, keepdims=True)
    return (
        ivy.log(
            ivy.sum(
                ivy.exp(input_tensor - max_input_tensor),
                axis=axis,
                keepdims=keepdims,
            )
        )
        + max_input_tensor
    ).astype(input_tensor.dtype)

Through compositions, we can easily meet the required input-output behaviour for the TensorFlow frontend function.

The entire workflow for extending the Ivy Frontends as an external contributor is explained in more detail in the Open Tasks section.

Unused Arguments#

As can be seen from the examples above, there are often cases where we do not add support for particular arguments in the frontend function. Generally, we can omit support for a particular argument only if: the argument does not fundamentally affect the input-output behaviour of the function in a mathematical sense. The only two exceptions to this rule are arguments related to either the data type or the device on which the returned array(s) should reside. Examples of arguments which can be omitted, on account that they do not change the mathematics of the function are arguments which relate to:

  • the algorithm or approximations used under the hood, such as precision and preferred_element_type in jax.lax.conv_general_dilated.

  • the specific array class in the original framework, such as subok in numpy.add.

  • the labelling of functions for organizational purposes, such as name in tf.math.add.

There are likely to be many other examples of arguments which do not fundamentally affect the input-output behaviour of the function in a mathematical sense, and so can also be omitted from Ivy’s frontend implementation.

The reason we omit these arguments in Ivy is because Ivy is not designed to provide low-level control to functions that extend beyond the pure mathematics of the function. This is a requirement because Ivy abstracts the backend framework, and therefore also abstracts everything below the backend framework’s functional API, including the backend array class, the low-level language compiled to, the device etc. Most ML frameworks do not offer per-array control of the memory layout, and control for the finer details of the algorithmic approximations under the hood, and so we cannot in general offer this level of control at the Ivy API level, nor the frontend API level as a direct result. As explained above, this is not a problem, as the memory layout has no bearing at all on the input-output behaviour of the function. In contrast, the algorithmic approximation may have a marginal bearing on the final results in some cases, but Ivy is only designed to unify to within a reasonable numeric approximation in any case, and so omitting these arguments also very much fits within Ivy’s design.

Supported Data Types and Devices#

Sometimes, the corresponding function in the original framework might only support a subset of data types. For example, tf.math.logical_and() only supports inputs of type tf.bool. However, Ivy’s implementation is as follows, with direct wrapping around ivy.logical_and():

@to_ivy_arrays_and_back
def logical_and(x, y, name="LogicalAnd"):
    return ivy.logical_and(x, y)

ivy.logical_and() supports all data types, and so ivy.functional.frontends.tensorflow.math.logical_and() can also easily support all data types. However, the primary purpose of these frontend functions is for code transpilations, and in such cases it would never be useful to support extra data types beyond tf.bool, as the tensorflow code being transpiled would not support this. Additionally, the unit tests for all frontend functions use the original framework function as the ground truth, and so we can only test ivy.functional.frontends.tensorflow.math.logical_and() with boolean inputs anyway.

For these reasons, all frontend functions which correspond to functions with limited data type support in the native framework (in other words, which have even more restrictions than the data type limitations of the framework itself) should be flagged as such in a manner like the following:

@with_unsupported_dtypes({"2.13.0 and below": ("float16", "bfloat16")}, "tensorflow")

The same logic applies to unsupported devices. Even if the wrapped Ivy function supports more devices, we should still flag the frontend function supported devices to be the same as those supported by the function in the native framework. Again, this is only needed if the limitations go beyond those of the framework itself. For example, it is not necessary to uniquely flag every single NumPy function as supporting only CPU, as this is a limitation of the entire framework, and this limitation is already globally flagged.

It could also be the case that a frontend function supports a data type, but one or more of the backend frameworks does not, and therefore the frontend function may not support the data type due to backend limitation. For example, the frontend function jax.lax.cumprod does support all data types, but PyTorch does not support bfloat16 for the function cumprod(), even though the framework generally supports handling bfloat16 data type. In that case, we should flag that the backend function does not support bfloat16 as this is done here.

Classes and Instance Methods#

Most frameworks include instance methods and special methods on their array class for common array processing functions, such as reshape(), expand_dims() and add(). This simple design choice comes with many advantages, some of which are explained in our Ivy Array section.

Important Note Before implementing the instance method or special method, make sure that the regular function in the specific frontend is already implemented.

In order to implement Ivy’s frontend APIs to the extent that is required for arbitrary code transpilations, it’s necessary for us to also implement these instance methods and special methods of the framework-specific array classes (tf.Tensor, torch.Tensor, numpy.ndarray, jax.DeviceArray etc).

Instance Method

numpy.ndarray

For an example of how these are implemented, we first show the instance method for np.ndarray.argsort(), which is implemented in the frontend ndarray class:

# ivy/functional/frontends/numpy/ndarray/ndarray.py
def argsort(self, *, axis=-1, kind=None, order=None):
    return np_frontend.argsort(self._ivy_array, axis=axis, kind=kind, order=order)

Under the hood, this simply calls the frontend np_frontend.argsort() function, which itself is implemented as follows:

# ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
@to_ivy_arrays_and_back
def argsort(
    x,
    /,
    *,
    axis=-1,
    kind=None,
    order=None,
):
    return ivy.argsort(x, axis=axis)

Special Method

Some examples referring to the special methods would make things more clear. For example let’s take a look at how tf_frontend.tensor.__add__() is implemented and how it’s reverse tf_frontend.tensor.__radd__() is implemented.

# ivy/functional/frontends/tensorflow/tensor.py
def __radd__(self, x, name="radd"):
    return tf_frontend.math.add(x, self._ivy_array, name=name)

def __add__(self, y, name="add"):
    return self.__radd__(y)

Here also, both of them simply call the frontend tf_frontend.math.add() under the hood. The functions with reverse operators should call the same frontend function as shown in the examples above. The implementation for the tf_frontend.math.add() is shown as follows:

# ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def add(x, y, name=None):
    return ivy.add(x, y)

numpy.matrix

To support special classes and their instance methods, the equivalent classes are created in their respective frontend so that the useful instance methods are supported for transpilation.

For instance, the numpy.matrix class is supported in the Ivy NumPy frontend. Part of the code is shown below as an example:

# ivy/functional/frontends/numpy/matrix/methods.py
class matrix:
    def __init__(self, data, dtype=None, copy=True):
        self._init_data(data, dtype)

    def _init_data(self, data, dtype):
        if isinstance(data, str):
            self._process_str_data(data, dtype)
        elif isinstance(data, (list, ndarray)) or ivy.is_array(data):
            if isinstance(data, ndarray):
                data = data.ivy_array
            if ivy.is_array(data) and dtype is None:
                dtype = data.dtype
            data = ivy.array(data, dtype=dtype)
            self._data = data
        else:
            raise ivy.exceptions.IvyException("data must be an array, list, or str")
        ivy.assertions.check_equal(
            len(ivy.shape(self._data)), 2, message="data must be 2D"
        )
        self._dtype = self._data.dtype
        self._shape = ivy.shape(self._data)

With this class available, the supported instance methods can now be included in the class. For example, numpy.matrix has an instance method of any():

# ivy/functional/frontends/numpy/matrix/methods.py
from ivy.functional.frontends.numpy import any
...
def any(self, axis=None, out=None):
    if ivy.exists(axis):
        return any(self.A, axis=axis, keepdims=True, out=out)
    return any(self.A, axis=axis, out=out)

We need to create these frontend array classes and all of their instance methods and also their special methods such that we are able to transpile code which makes use of these methods. As explained in Ivy as a Transpiler, when transpiling code we first extract the computation graph in the source framework. In the case of instance methods, we then replace each of the original instance methods in the extracted computation graph with these new instance methods defined in the Ivy frontend class.

Frontend Data Type Promotion Rules#

Each frontend framework has its own rules governing the common result type for two array operands during an arithmetic operation.

In order to ensure that each frontend framework implemented in Ivy has the same data type promotion behaviors as the native framework does, we have implemented data type promotion rules according to framework-specific data type promotion tables for these we are currently supporting as frontends. The function can be accessed through calling promote_types_of_<frontend>_inputs() and pass in both array operands.

# ivy/functional/frontends/torch/pointwise_ops.py
@to_ivy_arrays_and_back
def add(input, other, *, alpha=1, out=None):
    input, other = torch_frontend.promote_types_of_torch_inputs(input, other)
    return ivy.add(input, other, alpha=alpha, out=out)

Although in most cases, array operands being passed into an arithmetic operation function should be the same data type, using the data type promotion rules can add a layer of sanity check to prevent data precision losses or exceptions from further arithmetic operations.

TensorFlow is a framework where casting is completely explicit, except for array likes and scalars. As such there are no promotion rules we replicate for the TensorFlow frontend, instead we check if the two arguments of the function are the same type using check_tensorflow_casting().

# ivy/functional/frontends/tensorflow/math.py
@to_ivy_arrays_and_back
def add(x, y, name=None):
    x, y = check_tensorflow_casting(x, y)
    return ivy.add(x, y)

NumPy Special Argument - Casting#

NumPy supports an additional, special argument - casting, which allows user to determine the kind of dtype casting that fits their objectives. The casting rules are explained in the numpy.can_cast documentation. While handling this argument, the dtype argument is used to state the desired return dtype.

To handle this, a decorator - handle_numpy_casting is used to simplify the handling logic and reduce code redundancy. It is located in the ivy/functional/frontends/numpy/func_wrapper.py.

This decorator is then added to the numpy frontend functions with the casting argument. An example of the add() function is shown below.

# ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting
@from_zero_dim_arrays_to_scalar
def _add(
    x1,
    x2,
    /,
    out=None,
    *,
    where=True,
    casting="same_kind",
    order="k",
    dtype=None,
    subok=True,
):
    x1, x2 = promote_types_of_numpy_inputs(x1, x2)
    ret = ivy.add(x1, x2, out=out)
    if ivy.is_array(where):
        ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
    return ret

There is a special case for the casting argument, where the allowed dtype must be bool, therefore a handle_numpy_casting_special is included to handle this.

# ivy/functional/frontends/numpy/func_wrapper.py
def handle_numpy_casting_special(fn: Callable) -> Callable:
    @functools.wraps(fn)
    def new_fn(*args, casting="same_kind", dtype=None, **kwargs):
        ivy.assertions.check_elem_in_list(
            casting,
            ["no", "equiv", "safe", "same_kind", "unsafe"],
            message="casting must be one of [no, equiv, safe, same_kind, unsafe]",
        )
        if ivy.exists(dtype):
            ivy.assertions.check_equal(
                ivy.as_ivy_dtype(dtype),
                "bool",
                message="output is compatible with bool only",
            )
        return fn(*args, **kwargs)
    new_fn.handle_numpy_casting_special = True
    return new_fn

An example function using this is the numpy.isfinite() function.

# ivy/functional/frontends/numpy/logic/array_type_testing.py
@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting_special
@from_zero_dim_arrays_to_scalar
def _isfinite(
    x,
    /,
    out=None,
    *,
    where=True,
    casting="same_kind",
    order="K",
    dtype=None,
    subok=True,
):
    ret = ivy.isfinite(x, out=out)
    if ivy.is_array(where):
        ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
    return ret

Frontends Duplicate Policy#

Some frontend functions appear in multiple namespaces within the original framework that the frontend is replicating. For example the np.asarray() function appears in Array manipulation routines and also in Array creation routines. This section outlines a policy that should serve as a guide for handling duplicate functions. The following sub-headings outline the policy:

Listing duplicate frontend functions on the ToDo lists

Essentially, there are two types of duplicate functions;

  1. Functions that are listed in multiple namespaces but are callable from the same path, for example asarray() is listed in manipulation routines and creation routines however this function is called from the same path as np.asarray().

  2. Functions that are listed in multiple namespaces but are callable from different paths, for example the function tf.math.tan() and tf.raw_ops.Tan().

When listing frontend functions, extra care should be taken to keep note of these two types of duplicate functions.

  • For duplicate functions of the first type, we should list the function once in any namespace where it exists and leave it out of all other namespaces.

  • For duplicates of the second type, we should list the function in each namespace where it exists but there should be a note to highlight that the function(s) on the list are duplicates and should therefore be implemented as aliases. For example, most of the functions in tf.raw_ops are aliases and this point is made clear when listing the functions on the ToDo list here.

Contributing duplicate frontend functions

Before working on a frontend function, contributors should check if the function is designated as an alias on the ToDo list. If the function is an alias, you should check if there is an implementation that can be aliased.

  • If an implementation exists then simply create an alias of the implementation, for example many functions in ivy/functional/frontends/tensorflow/raw_ops are implemented as aliases here.

  • If there is no implementation to be aliased then feel free to contribute the implementation first, then go ahead to create the alias.

Testing duplicate functions

Unit tests should be written for all aliases. This is arguably a duplication, but having a unique test for each alias helps us to keep the testing code organised and aligned with the groupings in the frontend API.

Round Up

This should hopefully have given you a better grasp on what the Ivy Frontend APIs are for, how they should be implemented, and the things to watch out for! We also have a short YouTube tutorial series on this as well if you prefer a video explanation!

If you have any questions, please feel free to reach out on discord in the ivy frontends thread!

Video