You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, MultiHeadDotProductAttention layer's call method signature is MultiHeadDotProductAttention.__call__(inputs_q, inputs_kv, mask=None, deterministic=None). As discussed in #1737, there are some cases where passing in separate values for the key and values is desired, which isn't possible with the current API. The PR #3379 adds two more arguments, inputs_k and inputs_v to the call method signature and sets the method signature to the following: MultiHeadDotProductAttention.__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None). Note that the inputs_kv, mask and deterministic args are now keyword arguments.
if inputs_k and inputs_v are None, then they will both copy the value of inputs_q (i.e. self attention)
if inputs_v is None, it will copy the value of inputs_k (same behavior as the previous API, i.e. module.apply(inputs_q=query, inputs_k=key_value, ...) is equivalent to module.apply(inputs_q=query, inputs_kv=key_value, ...))
if inputs_kv is not None, both inputs_k and inputs_v will copy the value of inputs_kv
Users can still use inputs_kv but a DeprecationWarning will be raised and inputs_kv will be removed in the future.
Since self attention can be done using this new API, the SelfAttention layer will also raise a DeprecationWarning and will be removed in the future.
Some examples of porting over your code to the new method signature:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Currently,
MultiHeadDotProductAttention
layer's call method signature isMultiHeadDotProductAttention.__call__(inputs_q, inputs_kv, mask=None, deterministic=None)
. As discussed in #1737, there are some cases where passing in separate values for the key and values is desired, which isn't possible with the current API. The PR #3379 adds two more arguments,inputs_k
andinputs_v
to the call method signature and sets the method signature to the following:MultiHeadDotProductAttention.__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None)
. Note that theinputs_kv
,mask
anddeterministic
args are now keyword arguments.inputs_k
andinputs_v
areNone
, then they will both copy the value ofinputs_q
(i.e. self attention)inputs_v
isNone
, it will copy the value ofinputs_k
(same behavior as the previous API, i.e.module.apply(inputs_q=query, inputs_k=key_value, ...)
is equivalent tomodule.apply(inputs_q=query, inputs_kv=key_value, ...)
)inputs_kv
is not None, bothinputs_k
andinputs_v
will copy the value ofinputs_kv
Users can still use
inputs_kv
but aDeprecationWarning
will be raised andinputs_kv
will be removed in the future.Since self attention can be done using this new API, the
SelfAttention
layer will also raise aDeprecationWarning
and will be removed in the future.Some examples of porting over your code to the new method signature:
module.apply(query, key_value, mask, deterministic)
module.apply(query, key_value, mask=mask, deterministic=deterministic)
module.apply(inputs_q=query, inputs_kv=key_value, mask=mask, deterministic=deterministic)
module.apply(inputs_q=query, inputs_k=key_value, mask=mask, deterministic=deterministic)
sa_module.apply(query, mask, deterministic)
module.apply(query, mask=mask, deterministic=deterministic)
sa_module.apply(inputs_q=query, mask=mask, deterministic=deterministic)
module.apply(inputs_q=query, mask=mask, deterministic=deterministic)
For additional context, check out the PR #3379 and the discussion thread #1737.
Beta Was this translation helpful? Give feedback.
All reactions