multi_head_attention#

ivy.multi_head_attention(x, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, out=None)[source]#

Apply multi-head attention to inputs x.

Parameters:
  • x (Union[Array, NativeArray]) – The array to determine the queries from [batch_shape,num_queries,query_dim].

  • scale (float) – The value by which to scale the query-key similarity measure before softmax.

  • num_heads (int) – The number of attention heads to use.

  • context (Optional[Union[Array, NativeArray]]) – The array to determine the keys and values from. Default is None. (default: None) [batch_shape,num_keys,cont_feat_dim].

  • mask (Optional[Union[Array, NativeArray]]) – The mask to apply to the query-key values. Default is None. (default: None) [batch_shape,num_queries,num_keys]

  • to_q_fn (Optional[Callable]) – The function to compute queries from input x, returning queries (default: None) [batch_shape,num_queries,numheads×head_dim]. (Default value = None)

  • to_kv_fn (Optional[Callable]) – The function to compute keys and values from the context. (Default value = None) (default: None)

  • to_out_fn (Optional[Callable]) – The function to compute the output from the scaled dot-product attention. (default: None) (Default value = None)

  • to_q_v (Optional[Union[Array, NativeArray]]) – The variables for function to_q_fn. Default is None. (default: None)

  • to_kv_v (Optional[Union[Array, NativeArray]]) – The variables for function to_kv_fn. Default is None. (default: None)

  • to_out_v (Optional[Union[Array, NativeArray]]) – The variables for function to_out_fn. Default is None. (default: None)

  • out (Optional[Union[Array, NativeArray]]) – optional output array, for writing the result to. It must have a shape that the (default: None) inputs broadcast to.

Return type:

Union[Array, NativeArray]

Returns:

  • ret – The output following application of multi-head attention. [batch_shape,num_queries,out_feat_dim]

  • Both the description and the type hints above assumes an array input for simplicity,

  • but this function is nestable, and therefore also accepts ivy.Container

  • instances in place of any of the arguments.

Examples

With ivy.Array input:

>>> x = ivy.array([[[0.2, 1.],
...                 [2.2, 3.],
...                 [4.4, 5.6]]])
>>> context = ivy.array([[[0.2, 1., 1.1, 4.2],
...                       [2.2, 3., 0.9, 3.6],
...                       [4.4, 5.6, 2.2, 0.4]]])
>>> result = ivy.multi_head_attention(x, 1, 2, context=context)
>>> print(result)
ivy.array([[[1.5678761 , 0.65441847],
...         [2.18969631, 0.40131447],
...         [2.19991851, 0.40000153]]])

With ivy.NativeArray input:

>>> x = ivy.native_array([[[0.2, 1.],
...                        [2.2, 3.],
...                        [4.4, 5.6]]])
>>> context = ivy.native_array([[[0.2, 1., 1.1, 4.2],
...                              [2.2, 3., 0.9, 3.6],
...                              [4.4, 5.6, 2.2, 0.4]]])
>>> result = ivy.multi_head_attention(x, 1, 2, context=context)
>>> print(result)
ivy.array([[[1.5678761 , 0.65441847],
...         [2.18969631, 0.40131447],
...         [2.19991851, 0.40000153]]])

With ivy.Container input:

>>> x = ivy.Container(a=ivy.array([[[0.2, 1.1], [2.2, 3.4], [4.4, 5.6]]]),
...                   b=ivy.array([[[1.4, 0.3], [1.2, 3.9], [0.4, 3.7]]]))
>>> context = ivy.Container(a=ivy.array([[[0.2, 1.8, 1.1, 4.2],
...                                       [2.2, 3.3, 0.9, 3.6],
...                                       [4.4, 5.6, 2.2, 0.4]]]),
...                         b=ivy.array([[[1.4, 0.3, 4.4, 5.6],
...                                       [1.2, 3.9, 4.2, 5.1],
...                                       [0.4, 3.7, 4.3, 5.3]]]))
>>> result = ivy.multi_head_attention(x, 1, 2, context=context)
>>> print(result)
{
    a: ivy.array([[[1.5678761, 0.68589532],
                   [2.18969631, 0.40129396],
                   [2.19991851, 0.40000817]]]),
    b: ivy.array([[[4.31219625, 5.25698996],
                   [4.31022024, 5.16286421],
                   [4.30296469, 5.16460133]]])
}

With a mix of ivy.Container and ivy.Array inputs:

>>> x = ivy.Container(a=ivy.array([[[0.2, 1.1], [2.2, 3.4], [4.4, 5.6]]]),
...                   b=ivy.array([[[1.4, 0.3], [1.2, 3.9], [0.4, 3.7]]]))
>>> context = ivy.array([[[0.2, 1., 1.1, 4.2],
...                       [2.2, 3., 0.9, 3.6],
...                       [4.4, 5.6, 2.2, 0.4]]])
>>> result = ivy.multi_head_attention(x, 1, 2, context=context)
>>> print(result)
{
    a: ivy.array([[[1.5678761, 0.59497029],
                   [2.18969631, 0.40046397],
                   [2.19991851, 0.40000153]]]),
    b: ivy.array([[[2.14009905, 1.81691194],
                   [2.10732293, 0.40012637],
                   [1.73519301, 0.40021262]]])
}

With a mix of ivy.Array and ivy.Container inputs:

>>> x = ivy.array([[[0.2, 1.],
...                 [2.2, 3.],
...                 [4.4, 5.6]]])
>>> context = ivy.Container(a=ivy.array([[[0.2, 1.8, 1.1, 4.2],
...                                       [2.2, 3.3, 0.9, 3.6],
...                                       [4.4, 5.6, 2.2, 0.4]]]),
...                         b=ivy.array([[[1.4, 0.3, 4.4, 5.6],
...                                       [1.2, 3.9, 4.2, 5.1],
...                                       [0.4, 3.7, 4.3, 5.3]]]))
>>> result = ivy.multi_head_attention(x, 1, 2, context=context)
>>> print(result)
{
    a: ivy.array([[[1.5678761, 0.7615059],
                   [2.18969631, 0.40326414],
                   [2.19991851, 0.40000817]]]),
    b: ivy.array([[[4.30141067, 5.19610119],
                   [4.32028484, 5.1708746],
                   [4.34100914, 5.14920235]]])
}

With ivy.Array inputs and ivy.Array mask:

>>> x = ivy.array([[[0.2, 1.],
...                 [2.2, 3.],
...                 [4.4, 5.6]]])
>>> context = ivy.array([[[0.2, 1., 1.1, 4.2],
...                       [2.2, 3., 0.9, 3.6],
...                       [4.4, 5.6, 2.2, 0.4]]])
>>> mask = ivy.array([[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
>>> result = ivy.multi_head_attention(x, 1, 2, context=context, mask=mask)
>>> print(result)
ivy.array([[[1.40000009, 2.73333335],
...         [1.40000009, 2.73333335],
...         [1.40000009, 2.73333335]]])

With ivy.Array inputs and lambda to_q_fn and to_kv_fn functions specified:

>>> x = ivy.array([[[0.2, 1.],
...                 [2.2, 3.],
...                 [4.4, 5.6]]])
>>> context = ivy.array([[[0.2, 1., 1.1, 4.2],
...                       [2.2, 3., 0.9, 3.6],
...                       [4.4, 5.6, 2.2, 0.4]]])
>>> to_q_fn = lambda n, v: n
>>> to_kv_fn = lambda n, v: ivy.split(n, num_or_size_splits=2, axis=-1)
>>> result = layers.multi_head_attention(x, 1, 2, context=context,
...                                      to_q_fn=to_q_fn, to_kv_fn=to_kv_fn)
>>> print(result)
ivy.array([[[1.5678761 , 0.65441847],
...         [2.18969631, 0.40131447],
...         [2.19991851, 0.40000153]]])
Array.multi_head_attention(self, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, out=None)#
Return type:

Array

Container.multi_head_attention(self, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, key_chains=None, to_apply=True, prune_unapplied=False, map_sequences=False, out=None)#
Return type:

Union[Array, NativeArray, Container]