Skip to content

Better Transformer TokenBatchSampler #192

Open
@AmitMY

Description

@AmitMY

Joey implements a token based batch sampler, such that every batch has roughly a similar number of tokens.
https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L783-L820

The goals as I understand them are:

  1. Create a consistent loss: on every update, the loss is roughly from the same number of tokens
  2. Highly parallelizing training, using as much GPU memory as possible.

If I do understand them correctly, when using a $O(n^2)$ model there is a problem:
We can not exhaust all GPU memory, as input sequences can be in varying lengths.
A single input sequence of 500 tokens takes $O(500^2c)$ = $O(250000c)$ memory, while 5 sequences of length 100 take $O(5*100^2c)$ = $O(50000c)$, and so one batch takes 5 times the memory than another.

This can be corrected for by changing this line to calculate max_tokens**2.
https://github.com/joeynmt/joeynmt/blob/main/joeynmt/datasets.py#L812

However, by doing so we ignore the first goal, which is to have roughly the same number of tokens in the loss, making it consistent.


My question is: which goal is more important? Is it really crucial to satisfy goal 1?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions