Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Transformer TokenBatchSampler #192

Open
AmitMY opened this issue Sep 1, 2022 · 4 comments
Open

Better Transformer TokenBatchSampler #192

AmitMY opened this issue Sep 1, 2022 · 4 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@AmitMY
Copy link
Contributor

AmitMY commented Sep 1, 2022

Joey implements a token based batch sampler, such that every batch has roughly a similar number of tokens.
https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L783-L820

The goals as I understand them are:

  1. Create a consistent loss: on every update, the loss is roughly from the same number of tokens
  2. Highly parallelizing training, using as much GPU memory as possible.

If I do understand them correctly, when using a $O(n^2)$ model there is a problem:
We can not exhaust all GPU memory, as input sequences can be in varying lengths.
A single input sequence of 500 tokens takes $O(500^2c)$ = $O(250000c)$ memory, while 5 sequences of length 100 take $O(5*100^2c)$ = $O(50000c)$, and so one batch takes 5 times the memory than another.

This can be corrected for by changing this line to calculate max_tokens**2.
https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L812

However, by doing so we ignore the first goal, which is to have roughly the same number of tokens in the loss, making it consistent.


My question is: which goal is more important? Is it really crucial to satisfy goal 1?

@may-
Copy link
Collaborator

may- commented Sep 2, 2022

I think it's related to bucketing in the minibatch creation. (If not, correct me please.)

Let me explain a bit historical background:
We used to depend on "Bucket" sampler implemented in torchtext, but we decided to get rid of torch text. When we migrate from torchtext, we discussed whether we keep "Bucketing" in our new sampler or not. At that time, I observed that several MT experiments without bucketing reached even better BLEU than the experiment with bucketing, like 3-4 BLEU improved. We hypothesized that the bucketing could break the iid condition, and therefore could worsen the performance. (maybe related to the loss consistency issue you mentioned above, as a consequent of non-iid examples distribution, I guess.) That's why we just omitted to implement bucketing.
But I'm aware that the bucketing is important especially for TPU or distributed data parallelism. I was considering to mimic fairseq's implementation: https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/bucket_pad_length_dataset.py
We just took an "easy-first" path at that time, and postponed the further consideration.

So, we are happy and open to receive new idea. The important criterium could be, no huge degradation in BLUE for standard benchmarks we currently have, i.e. iwslt, wmt, etc. Could you please confirm that your change doesn't worsen the result and improve the memory usage in at least one standard MT benchmark, before you raise a pull request?

Note: we might not integrate to the official main repo more complicated acceleration or scaling solutions such as DeepSpeed, Fairscale. (I personally think bucketing is simple enough and nice to have, though.)

Is the point clear? ... or I missed your point??

@may- may- added enhancement New feature or request help wanted Extra attention is needed labels Sep 2, 2022
@AmitMY
Copy link
Contributor Author

AmitMY commented Sep 2, 2022

You got my point indeed.
Before I try this, a consultation if I may -
How does the memory of a transformer encoder-decoder model behave?

  1. $O(max(|src|, |trg|)^2)$
  2. $O(max(|src|^2, |trg|*(|trg| + |src|))$

Seems to me like the second, because the transformer decoder attends to both the decoder states and the encoder states - but does it do it all at once, or in separate phases?

@may-
Copy link
Collaborator

may- commented Sep 2, 2022

I thought $\mathcal{O}(\max(|src|^2, |trg|^2, |src|*|trg|))$, which corresponds to src-src attention, trg-trg attention, and src-trg attention, respectively. So, essentially the same as the first one: $\mathcal{O}(\max(|src|, |trg|)^2)$. (Ignored the factor of $d$: hidden dimension size in KEY-VALUE-QUERY computation of attention)

We have auto-regressive decoding, which means we repeat the computation $|trg|$-times, it doesn't mean the attention layer has $\mathcal{O}(|trg| * (|trg|+|src|))$ computation, does it? $\mathcal{O}(|trg|+|src|)$ (tensor concatenation of src and trg) happens nowhere, I thought.

If you mean $\mathcal{O}(|trg| * (|trg|+|src|)) = \mathcal{O}(|trg|^2+|src|*|trg|)$, implies trg-trg attention + src-trg attention, then maybe you are right. I'm not sure, though....

If we focus on "memory" complexity, maybe we should consider the sum: src-src attention + trg-trg attention + src-trg attention. If we focus on computational speed, the dominant one out of these three would be the bottle neck, as the size increases.

@AmitMY
Copy link
Contributor Author

AmitMY commented Sep 5, 2022

Just an update:
I wanted to try and take a more data-driven approach:
Run every datum through the model, and profile the memory used by the GPU.
This means that it could generalize to any model

This is the output: output.txt

Then, fit a linear regression model to it: (in bytes)
Memory = 218125183.50 + 268733.14 * |S| + 292696.09 * |T| + 488.47 * |S|^2 + 504.68 * |T|^2 + 481.32 * |S|*|T|

With test set Coefficient of determination=1, meaning perfect held-out prediction

I'll continue, assuming that the free variable is the model size in memory, and only the rest scales with the batch, and that I can batch until I hit the GPU memory limit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants