multi_head_attention#
- ivy.multi_head_attention(x, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, out=None)[source]#
Apply multi-head attention to inputs x.
- Parameters:
x (
Union
[Array
,NativeArray
]) – The array to determine the queries from [batch_shape,num_queries,query_dim].scale (
float
) – The value by which to scale the query-key similarity measure before softmax.num_heads (
int
) – The number of attention heads to use.context (
Optional
[Union
[Array
,NativeArray
]]) – The array to determine the keys and values from. Default isNone
. (default:None
) [batch_shape,num_keys,cont_feat_dim].mask (
Optional
[Union
[Array
,NativeArray
]]) – The mask to apply to the query-key values. Default isNone
. (default:None
) [batch_shape,num_queries,num_keys]to_q_fn (
Optional
[Callable
]) – The function to compute queries from input x, returning queries (default:None
) [batch_shape,num_queries,numheads×head_dim]. (Default value = None)to_kv_fn (
Optional
[Callable
]) – The function to compute keys and values from the context. (Default value = None) (default:None
)to_out_fn (
Optional
[Callable
]) – The function to compute the output from the scaled dot-product attention. (default:None
) (Default value = None)to_q_v (
Optional
[Union
[Array
,NativeArray
]]) – The variables for function to_q_fn. Default isNone
. (default:None
)to_kv_v (
Optional
[Union
[Array
,NativeArray
]]) – The variables for function to_kv_fn. Default isNone
. (default:None
)to_out_v (
Optional
[Union
[Array
,NativeArray
]]) – The variables for function to_out_fn. Default isNone
. (default:None
)out (
Optional
[Union
[Array
,NativeArray
]]) – optional output array, for writing the result to. It must have a shape that the (default:None
) inputs broadcast to.
- Return type:
Union
[Array
,NativeArray
]- Returns:
ret – The output following application of multi-head attention. [batch_shape,num_queries,out_feat_dim]
Both the description and the type hints above assumes an array input for simplicity,
but this function is nestable, and therefore also accepts
ivy.Container
instances in place of any of the arguments.
Examples
With
ivy.Array
input:>>> x = ivy.array([[[0.2, 1.], ... [2.2, 3.], ... [4.4, 5.6]]]) >>> context = ivy.array([[[0.2, 1., 1.1, 4.2], ... [2.2, 3., 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]) >>> result = ivy.multi_head_attention(x, 1, 2, context=context) >>> print(result) ivy.array([[[1.5678761 , 0.65441847], ... [2.18969631, 0.40131447], ... [2.19991851, 0.40000153]]])
With
ivy.NativeArray
input:>>> x = ivy.native_array([[[0.2, 1.], ... [2.2, 3.], ... [4.4, 5.6]]]) >>> context = ivy.native_array([[[0.2, 1., 1.1, 4.2], ... [2.2, 3., 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]) >>> result = ivy.multi_head_attention(x, 1, 2, context=context) >>> print(result) ivy.array([[[1.5678761 , 0.65441847], ... [2.18969631, 0.40131447], ... [2.19991851, 0.40000153]]])
With
ivy.Container
input:>>> x = ivy.Container(a=ivy.array([[[0.2, 1.1], [2.2, 3.4], [4.4, 5.6]]]), ... b=ivy.array([[[1.4, 0.3], [1.2, 3.9], [0.4, 3.7]]])) >>> context = ivy.Container(a=ivy.array([[[0.2, 1.8, 1.1, 4.2], ... [2.2, 3.3, 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]), ... b=ivy.array([[[1.4, 0.3, 4.4, 5.6], ... [1.2, 3.9, 4.2, 5.1], ... [0.4, 3.7, 4.3, 5.3]]])) >>> result = ivy.multi_head_attention(x, 1, 2, context=context) >>> print(result) { a: ivy.array([[[1.5678761, 0.68589532], [2.18969631, 0.40129396], [2.19991851, 0.40000817]]]), b: ivy.array([[[4.31219625, 5.25698996], [4.31022024, 5.16286421], [4.30296469, 5.16460133]]]) }
With a mix of
ivy.Container
andivy.Array
inputs:>>> x = ivy.Container(a=ivy.array([[[0.2, 1.1], [2.2, 3.4], [4.4, 5.6]]]), ... b=ivy.array([[[1.4, 0.3], [1.2, 3.9], [0.4, 3.7]]])) >>> context = ivy.array([[[0.2, 1., 1.1, 4.2], ... [2.2, 3., 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]) >>> result = ivy.multi_head_attention(x, 1, 2, context=context) >>> print(result) { a: ivy.array([[[1.5678761, 0.59497029], [2.18969631, 0.40046397], [2.19991851, 0.40000153]]]), b: ivy.array([[[2.14009905, 1.81691194], [2.10732293, 0.40012637], [1.73519301, 0.40021262]]]) }
With a mix of
ivy.Array
andivy.Container
inputs:>>> x = ivy.array([[[0.2, 1.], ... [2.2, 3.], ... [4.4, 5.6]]]) >>> context = ivy.Container(a=ivy.array([[[0.2, 1.8, 1.1, 4.2], ... [2.2, 3.3, 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]), ... b=ivy.array([[[1.4, 0.3, 4.4, 5.6], ... [1.2, 3.9, 4.2, 5.1], ... [0.4, 3.7, 4.3, 5.3]]])) >>> result = ivy.multi_head_attention(x, 1, 2, context=context) >>> print(result) { a: ivy.array([[[1.5678761, 0.7615059], [2.18969631, 0.40326414], [2.19991851, 0.40000817]]]), b: ivy.array([[[4.30141067, 5.19610119], [4.32028484, 5.1708746], [4.34100914, 5.14920235]]]) }
With
ivy.Array
inputs andivy.Array
mask:>>> x = ivy.array([[[0.2, 1.], ... [2.2, 3.], ... [4.4, 5.6]]]) >>> context = ivy.array([[[0.2, 1., 1.1, 4.2], ... [2.2, 3., 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]) >>> mask = ivy.array([[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) >>> result = ivy.multi_head_attention(x, 1, 2, context=context, mask=mask) >>> print(result) ivy.array([[[1.40000009, 2.73333335], ... [1.40000009, 2.73333335], ... [1.40000009, 2.73333335]]])
With
ivy.Array
inputs and lambda to_q_fn and to_kv_fn functions specified:>>> x = ivy.array([[[0.2, 1.], ... [2.2, 3.], ... [4.4, 5.6]]]) >>> context = ivy.array([[[0.2, 1., 1.1, 4.2], ... [2.2, 3., 0.9, 3.6], ... [4.4, 5.6, 2.2, 0.4]]]) >>> to_q_fn = lambda n, v: n >>> to_kv_fn = lambda n, v: ivy.split(n, num_or_size_splits=2, axis=-1) >>> result = layers.multi_head_attention(x, 1, 2, context=context, ... to_q_fn=to_q_fn, to_kv_fn=to_kv_fn) >>> print(result) ivy.array([[[1.5678761 , 0.65441847], ... [2.18969631, 0.40131447], ... [2.19991851, 0.40000153]]])
- Array.multi_head_attention(self, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, out=None)#
- Return type:
Array
- Container.multi_head_attention(self, scale, num_heads, /, *, context=None, mask=None, to_q_fn=None, to_kv_fn=None, to_out_fn=None, to_q_v=None, to_kv_v=None, to_out_v=None, key_chains=None, to_apply=True, prune_unapplied=False, map_sequences=False, out=None)#
- Return type:
Union
[Array
,NativeArray
,Container
]