Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't use masks in multi-head-attention layer #2336

Open
alerem18 opened this issue Sep 15, 2023 · 6 comments
Open

can't use masks in multi-head-attention layer #2336

alerem18 opened this issue Sep 15, 2023 · 6 comments

Comments

@alerem18
Copy link

alerem18 commented Sep 15, 2023

Motivation and description

let's say we have an array of shape (embedding_size, seq_len, batch_size), our padding mask will have a shape of (seq_len, batch_size) which can't be used in multi-head-attension mask layer, we can only use casual masking which has the shape (seq_len, seq_len)

Possible Implementation

No response

@CarloLucibello
Copy link
Member

The layer's documentation for the forward pass says:

     (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask])
...
mask: Input array broadcastable to size (kv_len, q_len, nheads, batch_size). 
      The mask is applied to the attention scores just before the softmax. 
      See NNlib.make_causal_mask for creating causal masks. Default nothing.

so I think you should reshape as. reshape(mask, (seq_len, 1, 1, batch_size)) or reshape(mask, (1, seq_len, 1, batch_size)). I'm not sure which one of the two is correct.

@alerem18
Copy link
Author

The layer's documentation for the forward pass says:

     (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask])
...
mask: Input array broadcastable to size (kv_len, q_len, nheads, batch_size). 
      The mask is applied to the attention scores just before the softmax. 
      See NNlib.make_causal_mask for creating causal masks. Default nothing.

so I think you should reshape as. reshape(mask, (seq_len, 1, 1, batch_size)) or reshape(mask, (1, seq_len, 1, batch_size)). I'm not sure which one of the two is correct.

thanks now it's working

@CarloLucibello
Copy link
Member

@alerem18 which of the two reshaping is correct in your case?

@alerem18
Copy link
Author

@alerem18 which of the two reshaping is correct in your case?

reshape(mask, (seq_len, 1, 1, batch_size))

@alerem18
Copy link
Author

alerem18 commented Sep 23, 2023

@alerem18 which of the two reshaping is correct in your case?

reshape(mask, (seq_len, 1, 1, batch_size))

however masking is wrong
it should be in the shape (seq_len, seq_len, 1, batch_size)
but for the (1, seq_len, 1, batch_size) it'll return NaN so pad masking is not currently supported by the layer, i've tried that already

l = reduce(hcat, [[5, 2, 3, 1, 1], [4, 5, 6, 1, 1]])
mask = fill(true, 5, 5, 1, 2)
mask[4:5, :, :, :] .= 0
mask[:, 4:5, :, :] .= 0

emb_layer = Embedding(10, 128)
emb = emb_layer(l)
attn = MultiHeadAttention(128, nheads=2)
attn(emb, mask=mask)[2]

result
`5×5×2×2 Array{Float32, 4}:
[:, :, 1, 1] =
0.326395 0.362849 0.343025 NaN NaN
0.0660359 0.402627 0.0637925 NaN NaN
0.60757 0.234524 0.593183 NaN NaN
0.0 0.0 0.0 NaN NaN
0.0 0.0 0.0 NaN NaN

[:, :, 2, 1] =
0.486156 0.144888 0.532702 NaN NaN
0.2133 0.422068 0.0270071 NaN NaN
0.300544 0.433044 0.440291 NaN NaN
0.0 0.0 0.0 NaN NaN
0.0 0.0 0.0 NaN NaN

[:, :, 1, 2] =
0.0449472 0.396037 0.347837 NaN NaN
0.198215 0.455466 0.0415825 NaN NaN
0.756838 0.148497 0.610581 NaN NaN
0.0 0.0 0.0 NaN NaN
0.0 0.0 0.0 NaN NaN

[:, :, 2, 2] =
0.778366 0.164352 0.220597 NaN NaN
0.0780623 0.445108 0.702782 NaN NaN
0.143571 0.39054 0.0766214 NaN NaN
0.0 0.0 0.0 NaN NaN
0.0 0.0 0.0 NaN NaN`

@alerem18 alerem18 reopened this Sep 23, 2023
@alerem18
Copy link
Author

masking with shape (seq_len, 1, 1, batch_size) is ok but with shape (1, seq_len, 1, batch_size) return NaN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants