-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better Transformer TokenBatchSampler
#192
Comments
I think it's related to bucketing in the minibatch creation. (If not, correct me please.) Let me explain a bit historical background: So, we are happy and open to receive new idea. The important criterium could be, no huge degradation in BLUE for standard benchmarks we currently have, i.e. iwslt, wmt, etc. Could you please confirm that your change doesn't worsen the result and improve the memory usage in at least one standard MT benchmark, before you raise a pull request? Note: we might not integrate to the official main repo more complicated acceleration or scaling solutions such as DeepSpeed, Fairscale. (I personally think bucketing is simple enough and nice to have, though.) Is the point clear? ... or I missed your point?? |
You got my point indeed.
Seems to me like the second, because the transformer decoder attends to both the decoder states and the encoder states - but does it do it all at once, or in separate phases? |
I thought We have auto-regressive decoding, which means we repeat the computation If you mean If we focus on "memory" complexity, maybe we should consider the sum: src-src attention + trg-trg attention + src-trg attention. If we focus on computational speed, the dominant one out of these three would be the bottle neck, as the size increases. |
Just an update: This is the output: output.txt Then, fit a linear regression model to it: (in bytes) With test set I'll continue, assuming that the free variable is the model size in memory, and only the rest scales with the batch, and that I can batch until I hit the GPU memory limit |
Joey implements a token based batch sampler, such that every batch has roughly a similar number of tokens.
https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L783-L820
The goals as I understand them are:
If I do understand them correctly, when using a$O(n^2)$ model there is a problem:$O(500^2c)$ = $O(250000c)$ memory, while 5 sequences of length 100 take $O(5*100^2c)$ = $O(50000c)$ , and so one batch takes 5 times the memory than another.
We can not exhaust all GPU memory, as input sequences can be in varying lengths.
A single input sequence of 500 tokens takes
This can be corrected for by changing this line to calculate
max_tokens**2
.https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L812
However, by doing so we ignore the first goal, which is to have roughly the same number of tokens in the loss, making it consistent.
My question is: which goal is more important? Is it really crucial to satisfy goal 1?
The text was updated successfully, but these errors were encountered: