-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable more flexible inputs to the TransformerDecoder
#968
Comments
Thanks for opening this issue! This is a good suggestion. We've been discussing redesigning the transformer decoder task to output more information (eg: hidden states from intermediate layers). I think making the embedding layer more generic can be part of this change. I'll put up an RFC some time next week and share here. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently, the
forward
method of theTransformerDecoder
class requires a tokens tensor of the shape[b, s]
to be passed as an argument, which is then passed toself.tok_embeddings
.But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.
Perhaps it would be worth relaxing the
TransformerDecoder
implementation to allow easier use of them in such cases?Specifically, to allow the input data to be of any shape
[b, s, ...]
, and to change the type of thetok_embeddings
fromnn.Embedding
to any model that inherits fromnn.Module
and returns a tensor of the shape[b, s, d]
.Alternatively do it like huggingface library, which allows
inputs_embeds
to be passed directly instead ofinputs_ids
.The text was updated successfully, but these errors were encountered: