Navigating the Code#
Categorization#
Ivy uses the following categories taken from the Array API Standard:
In addition to these, we also add the following categories, used for additional functions in Ivy that are not in the Array API Standard:
Some functions that you’re considering adding might overlap several of these categorizations, and in such cases you should look at the other functions included in each file, and use your best judgement for which categorization is most suitable.
We can always suggest a more suitable location when reviewing your pull request if needed 🙂
Submodule Design#
Ivy is designed so that all methods are called directly from the ivy
namespace, such as ivy.matmul()
, and not ivy.some_namespace.matmul()
.
Therefore, inside any of the folders ivy.functional.ivy
, ivy.functional.backends.some_backend
, ivy.functional.backends.another_backend
the functions can be moved to different files or folders without breaking anything at all.
This makes it very simple to refactor and re-organize parts of the code structure in an ongoing manner.
The __init__.py
inside each of the subfolders are very similar, importing each function via from .file_name import *
and also importing each file as a submodule via from . import file_name
.
For example, an extract from ivy/ivy/functional/ivy/__init__.py is given below:
from . import elementwise
from .elementwise import *
from . import general
from .general import *
# etc.
Ivy API#
All function signatures for the Ivy API are defined in the ivy.functional.ivy
submodule.
Functions written here look something like the following, (explained in much more detail in the following sections):
def my_func(
x: Union[ivy.Array, ivy.NativeArray],
/,
axes: Union[int, Sequence[int]],
*,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
device: Optional[Union[ivy.Device, ivy.NativeDevice]] = None,
out: Optional[ivy.Array] = None
) -> ivy.Array:
"""
My function does something cool.
.. note::
This is an important note.
**Special Cases**
For this particular case,
- If ``x`` is ``NaN``, do something
- If ``y`` is ``-0``, do something else
- etc.
Parameters
----------
x
input array. Should have a numeric data type.
axes
the axes along which to perform the op.
dtype
array data type.
device
the device on which to place the new array.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Returns
-------
ret
a cooler array.
Examples
--------
Some cool examples go here
"""
return ivy.current_backend(x).my_func(x, axes, dtype=dtype, device=device, out=out)
We follow the Array API Standard convention about positional and keyword arguments.
Positional parameters must be positional-only parameters. Positional-only parameters have no externally-usable name. When a method accepting positional-only parameters is called, positional arguments are mapped to these parameters based solely on their order.
Optional parameters must be keyword-only arguments.
This convention makes it easier for us to modify functions in the future. Keyword-only parameters will mandate the use of argument names when calling functions, and this will increase our flexibility for extending function behaviour in future releases without breaking forward compatibility. Similar arguments can be kept together in the argument list, rather than us needing to add these at the very end to ensure positional argument behaviour remains the same.
The dtype
, device
and out
arguments are always keyword-only.
Arrays always have type hint Union[ivy.Array, ivy.NativeArray]
in the input and ivy.Array
in the output.
All functions which produce a single array include the out
argument.
The reasons for each of these features are explained in the following sections.
Backend API#
Code in the backend submodules such as ivy.functional.backends.torch
should then look something like:
def my_func(
x: torch.Tensor,
/,
axes: Union[int, Sequence[int]],
*,
dtype: torch.dtype,
device: torch.device,
out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.something_cool(x, axes, dtype, device, out)
The dtype
, device
and out
arguments are again all keyword-only, but dtype
and device
are now required arguments, rather than optional as they were in the Ivy API.
All arrays also now have the same type hint torch.Tensor
, rather than Union[ivy.Array, ivy.NativeArray]
in the input and ivy.Array
in the output.
The backend methods also should not add a docstring.
Again, the reasons for these features are explained in the following sections.
Submodule Helper Functions#
At times, helper functions specific to submodule is required to:
keep the code clean and readable
be imported in their respective backend implementations
To have a better idea on this, let’s look at an example!
Helper in Ivy
# in ivy/functional/ivy/creation.py
def _assert_fill_value_and_dtype_are_compatible(dtype, fill_value):
assert (
(ivy.is_int_dtype(dtype) or ivy.is_uint_dtype(dtype))
and isinstance(fill_value, int)
) or (
ivy.is_float_dtype(dtype)
and isinstance(fill_value, float)
or (isinstance(fill_value, bool))
), "the fill_value and data type are not compatible"
In the full_like()
function in creation.py
, the types of fill_value
and dtype
has to be verified to avoid errors.
This check has to be applied to all backends, which means the related code is common and identical.
In this case, we can extract the code to be a helper function on its own, placed in its related submodule (creation.py
here).
In this example, the helper function is named as _assert_fill_value_and_dtype_are_compatible()
.
Then, we import this submodule-specific helper function to the respective backends, where examples for each backend is shown below.
Jax
# in ivy/functional/backends/jax/creation.py
from ivy.functional.ivy.creation import _assert_fill_value_and_dtype_are_compatible
def full_like(
x: JaxArray,
/,
fill_value: Union[int, float],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None
) -> JaxArray:
_assert_fill_value_and_dtype_are_compatible(dtype, fill_value)
return _to_device(
jnp.full_like(x, fill_value, dtype=dtype),
device=device,
)
NumPy
# in ivy/functional/backends/numpy/creation.py
from ivy.functional.ivy.creation import _assert_fill_value_and_dtype_are_compatible
def full_like(
x: np.ndarray,
/,
fill_value: Union[int, float],
*,
dtype: np.dtype,
device: str,
out: Optional[np.ndarray] = None
) -> np.ndarray:
_assert_fill_value_and_dtype_are_compatible(dtype, fill_value)
return _to_device(np.full_like(x, fill_value, dtype=dtype), device=device)
TensorFlow
# in ivy/functional/backends/tensorflow/creation.py
from ivy.functional.ivy.creation import _assert_fill_value_and_dtype_are_compatible
def full_like(
x: Union[tf.Tensor, tf.Variable],
/,
fill_value: Union[int, float],
*,
dtype: tf.DType,
device: str,
out: Union[tf.Tensor, tf.Variable] = None
) -> Union[tf.Tensor, tf.Variable]:
_assert_fill_value_and_dtype_are_compatible(dtype, fill_value)
with tf.device(device):
return tf.experimental.numpy.full_like(x, fill_value, dtype=dtype)
Note
We shouldn’t be enabling numpy behaviour in tensorflow as it leads to issues with the bfloat16 datatype in tensorflow implementations
Torch
# in ivy/functional/backends/torch/creation.py
from ivy.functional.ivy.creation import _assert_fill_value_and_dtype_are_compatible
def full_like(
x: torch.Tensor,
/,
fill_value: Union[int, float],
*,
dtype: torch.dtype,
device: torch.device,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_assert_fill_value_and_dtype_are_compatible(dtype, fill_value)
return torch.full_like(x, fill_value, dtype=dtype, device=device)
Version Pinning#
At any point in time, Ivy’s development will be predominantly focused around a particular version (and all prior versions) for each of the backend frameworks. These are the pinned versions shown in the optional.txt file.
At the time of pinning, these will be the most up-to-date versions for each framework, but new releases of the backend frameworks will then of course be made and there will sometimes be a short period of time in which we are working towards the next Ivy release, and we opt to keep the repo pinned to the older version until the next release is out. This helps to prevent our work growing in an unbounded manner, as we work towards getting all tests passing and everything in good shape before making the release. If we always pulled the latest version of every framework into master, we might end up constantly battling new subtle bugs, without knowing whether the bugs come from the change in version or our own incremental changes to the code. Therefore, when working towards an Ivy release, keeping the backends temporarily pinned essentially ensures that our development target remains fixed for this period of time.
As an example, at the time of writing the latest version of PyTorch is 1.12.1
, whereas Ivy is pinned to version 2.0.1
.
Therefore, all frontend functions (see Ivy Frontends section) added to ivy should not include any arguments or behaviours which are exclusive to PyTorch version 1.12.1
.
Round Up
This should have hopefully given you a good feel for how to navigate the Ivy codebase.
If you have any questions, please feel free to reach out on discord in the navigating the code channel or in the navigating the code forum !
Video