You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to scale up MultiHeadDotProductAttention and nn.LayerNorm as below.
I am not sure that the size of the query, key, and value is (32, 8, 4) but the sharding names are (None, 'model') which does not have the same shape.
Moreover, I'm confused with sharding behavior about the last index of $Q,,K,V$s, the 4, seems like the 'model'` axis for sharding. Is this right?
There is one more strange behavior. The result below changes when I use y = sharding_check(1, (1, 1)) in the 7th floating point.
Please help me 🙏 and let me know what is the right way to do.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Dear Flax community 😄
I want to scale up
MultiHeadDotProductAttention
andnn.LayerNorm
as below.I am not sure that the size of the query, key, and value is
(32, 8, 4)
but the sharding names are(None, 'model')
which does not have the same shape.Moreover, I'm confused with sharding behavior about the last index of $Q,,K,V$s, the
4
, seems like the 'model'` axis for sharding. Is this right?There is one more strange behavior. The result below changes when I use
y = sharding_check(1, (1, 1))
in the 7th floating point.Please help me 🙏 and let me know what is the right way to do.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions