Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample packing for map datasets with correct RoPE encoding and no cross-contamination #875

Merged
merged 24 commits into from May 15, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Apr 26, 2024

The Problem

Packing multiple samples within a single context window means the model may accidentally attend to other samples it should not attend to. If there are sequences that are completely orthogonal in topic/content, cross-contamination is adding a lot of noise for that sample. Additionally, if position ids are not correctly adjusted for each individual sequence with a packed sample, then tokens that are later in the pack get an unwarranted later position bias. So, we need to add two things:

  • Lower triangular causal block mask for attention to prevent cross contamination
  • Adjusted rope embeddings to handle positions ids for each sequence in a packed sample

This is a highly requested feature by the community, as seen in many discussions in TRL, Hugging Face, and llama recipes here and here and here and here

However, PyTorch Core's SDPA does not support flash attention with a non-causal mask (you can see several discussions on this here, here, and here). This means to do sample packing properly, we need to either turn to other implementations or use PyTorch's memory-efficient SDPA.

The Approach

Mem-eff SDPA is not as fast as flash v2 (2x slower), so we could turn to Dao's original implementation of flash attention that supports varied sample lengths. Since torchtune does not use fused QKV to support MQA and GQA, we should use the flash_attn_varlen_func as an alternative to SDPA. This allows our between sample masking via cumulative sequence lengths in the batch.

However, adding a third party dependency especially in such a core module should not be taken lightly and invites risk for breakages, lack of support, lack of control, etc. We need to evaluate whether the improved performance significantly outweighs this tradeoff.

Changelog

In this update, to unblock sample packing I decided to stick with mem eff SDPA + proper packing mask. More benchmarking of Dao's flash attention compared to mem eff attention needs to be done to understand the tradeoffs, which I will do as a follow-up. Here is a summary of the updates:

  • For proper masking and position encoding, I took advantage of the mask and input_pos kwargs in the transformer layers. This was enough to cover both without requiring any additional conversion or mask calculations, while still maintaining minimal changes to model forward signature.
  • Leaned on PackedDataset to create the lower triangular block mask and positions ids for each subsample within a pack, since it's much more efficient to create it while packing alongside the tokens and labels. Having to infer the mask and position ids batchwise is expensive and requires iterating through the whole batch again.
  • This means PackedDataset now returns tokens, labels, mask, and input_pos. To generalize this for all datasets, all dataset classes now return a dictionary. All of them will still just have "tokens" and "labels" keys, but PackedDataset will additionally return "mask" and "input_pos". This also gives flexibility for new datasets to also return mask and input_pos if needed
  • Because the dataset return signature changes, all recipes needed to be updated to key into the batch accordingly (just a few lines)
  • TransformerDecoder forward now can take in an optional mask. This cannot be used on inference when causal mask is used. The mask is propagated down to the attention module. Now packing mask can be used in SDPA.
  • Rope positional embeddings can use input_pos as a 2D tensor that has position ids for each sample in each pack in a batch, with a very tiny change.
  • Added a positional embedding test for the 2D input_pos case
  • Collating happens at the sample level for packed datasets with _padded_collate_packed, so we do not use a collator in the dataloader if dataset is packed

Test plan

tune run full_finetune_single_device --config llama3/8B_full_single_device epochs=1 dataset.packed=True dataset.max_seq_len=4096

Sample packing improves performance proportional to max sequence length - at 4096, tokens/sec were at about 1800, which is nearly 6x the unpacked version at a higher max sequence length of 8192. This effect may be exaggerated the more long-tailed the sequence length distribution of the dataset is. Longer sequence length (4096) also brought down run time to about 1 hour per epoch, compared to 1 hr 45 min at 512 when packed and nearly 5-6 hours at 512 when not packed. Packing for 52k samples for Alpaca only takes 40 seconds.

Loss curves across all packed and unpacked runs align.

image image

Evals are nearly identical for packed runs vs non-packed runs.

<style type="text/css"></style>

run truthfulqa stderr hellaswag stderr
llama3_alpaca_unpacked 0.4589353557 0.01488350256 0.5846444931 0.004917761182
llama3_alpaca_packed 0.490953071 0.01462141724 0.5734913364 0.00493558773
llama3_alpaca_packed_no_mask 0.4783502665 0.01494382987 0.5749850627 0.004933349622
llama3_alpaca_packed_no_pos 0.4909587427 0.01462954265 0.5723959371 0.00493719976
llama3_alpaca_packed_no_mask_no_pos 0.4798960754 0.014649153 0.5780720972 0.004928578106

Huge shoutout to @calmitchell617 for initiating this in #827, I've incorporated some of your logic here for a basic sample packing feature.

What is sample packing?

Sample packing maximizes the sequence length / context window by jamming as many samples as can fit so we don't waste compute on padding tokens. This leads to much faster training because the model can process more data with fewer forward passes. However, there's usually a slower start up if you perform packing on your dataset prior to training.

There are ways to do this offline, with a bin packing algorithm, or on-the-fly, but for now I do the naive approach of greedily packing samples as a preprocessing step when initializing the dataset. This will need to evolve further once we support IterableDatasets and streaming.

How is shuffling handled?

The DistributedSampler will still handle shuffling in the recipe layer, so for a packed dataset only the packs are shuffled and not within a pack. We may need to add within-pack shuffling later, although it will be simpler with IterableDatasets.

How can I configure packing?

Simply set packed=True in the config. This gets routed to the builder, so all dataset builders are updated with this flag. This is preferable to having a recipe parameter, because that would require updating all configs.

dataset:
  _component_: torchtune.dataset.alpaca_dataset
  packed: True

Copy link

pytorch-bot bot commented Apr 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/875

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 97e69f4 with merge base f3611e5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 26, 2024
Copy link
Contributor

@xingyaoww xingyaoww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super happy to see this project :) I wanted to find a codebase that's simple, usable, with reasonable performance.
I think torchtune pretty much satisfies most of my needs except sentence packing (which you already did in this PR!) & integrating more models (e.g., Mixtral). Would be excited to see this PR go through!

self.ds, desc="Packing dataset", dynamic_ncols=True
):
buffer["input_ids"].extend(input_ids)
buffer["labels"].extend(labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we want to optionally allow the user to add a separator (e.g., <eod>) for packing (e.g., via a argument?)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EOS tokens should serve as inherent separators, no? And your other comment about creating a sentence mask is a preferable approach imo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sentence_mask is the ultimate solution i agree! with that we probably don't need this separator at all

# If buffer has reached max_seq_len, append packed sample
while len(buffer["input_ids"]) > self.max_seq_len:
self.samples.append(
{k: v[: self.max_seq_len] for k, v in buffer.items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it will be good to support only packing examples together only when all of them can fit into ONE window.
If adding an example will exceed the context window, maybe we should NOT add that example to the current pack, set the rest of the current pack with <PAD>, and move that example to the next pack? This will be very helpful for people doing SFT (fewer examples, but hope to get the benefit of packing without the need to truncate any sentences).

I did something like I described above here - feel free to re-use some part of the logic if interested: https://github.com/xingyaoww/Megatron-LLM/blob/main/tools/preprocess_instruct_data.py#L148-L194

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point. I agree, I think we should do this by default for finetune. Although, for pretraining datasets / unstructured text data maybe we don't need to perform padding and it's ok to split samples?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! For pre-training, we can just go with the current implementation!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a flag then to control this!

torchtune/datasets/_packed.py Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved
buffer["labels"].extend(labels)

# If buffer has reached max_seq_len, append packed sample
while len(buffer["input_ids"]) > self.max_seq_len:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A relatively unimportant point, but I believe this > could be >=, which may end up creating an extra sample.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm let me double check this

@calmitchell617
Copy link

@RdoubleA this looks great. Your implementation of the packing logic is much cleaner than mine :-)

One thing you might want to add is a test that validates how individual tokens are packed. This was easy for me to do because my code had complete control over initialization. It might be more complex now because PackedDataset is integrated with other dataset classes, but should still be manageable.

This kind of test might be overkill now, but further down the line when someone else is working on this class (or other dataset classes that integrate with it) it might be a helpful check.

Here is how I did it:

class TestInstructDataset:
    expected_tokenized_prompts = [
        [0, 4, 2, 1, 7, 4, -1, 0, 1, 9],
        [5, 2, 6, 4, 3, 8, -1, 0, 4, 3],
    ]

    def get_samples(self):
        samples_list = [
            "This is a packing test",
            "A fantastic test. It should pack two samples.",
            "This one will not be fully packed.",
        ]

        samples_dict = {"content": samples_list}

        return Dataset.from_dict(samples_dict)

    @mock.patch("torchtune.datasets._concat.load_dataset")
    def test_get_item(self, mock_load_dataset):
        mock_load_dataset.return_value = self.get_samples()
        dataset = ConcatDataset(
            tokenizer=DummyTokenizer(),
            source="were/going/jellyfishing",
            text_column="content",
            max_seq_len=10,
        )
        assert len(dataset) == 2
        mock_load_dataset.assert_called_once()

        for i in range(len(dataset)):
            prompt, label = dataset[i]
            assert prompt == self.expected_tokenized_prompts[i]
            assert label == self.expected_tokenized_prompts[i]

@RdoubleA
Copy link
Contributor Author

RdoubleA commented Apr 26, 2024

@calmitchell617 great suggestion, I was actually planning to borrow some of your testing logic so I'll try to include this. Also I'm glad someone else caught my hidden SpongeBob reference 👌

Overall, does this unblock what you wanted to achieve with your earlier PR? Or you need streaming first?

@calmitchell617
Copy link

I do not need streaming, this unblocks my use case.

@RdoubleA RdoubleA changed the title [WIP] Sample packing for map datasets Sample packing for map datasets Apr 27, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Left a handful of comments but no huge concerns from my side

torchtune/datasets/_alpaca.py Outdated Show resolved Hide resolved
torchtune/datasets/_slimorca.py Show resolved Hide resolved
if split_samples:
# If we split samples, we'll know how many samples by taking the
# full length and dividing by sample size
last_index, remainder = divmod(max_rows * max_seq_len, sample_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not even know divmod was a thing. Nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually could also use math.ceil(max_rows * max_seq_len / sample_size) here I think? Maybe that's clearer tbh

Comment on lines 75 to 78
raise ValueError(
f"Dataset sample is too long ({len(input_ids)} > {self.max_seq_len}). "
"Please set `split_samples=True` or increase `max_seq_len`."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about whether this is the right thing to do here. Without sample packing we wouldn't error out here, right? We would just truncate. I wonder if it makes sense to do the same here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm yeah that makes sense to me

current_pack["input_ids"].extend(input_ids)
current_pack["labels"].extend(labels)

if len(current_pack["input_ids"]) > self.max_seq_len:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So where did we land on the > vs >= thing here? I still don't understand why this isn't >= personally

Copy link
Contributor

@kartikayk kartikayk Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is just me, but I find the logic a bit round about. Why not just see if the current length of the pack + length of incoming sample is > max length or not? If it is, you write the current pack out, if not just add it. Maybe I'm missing some complexity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nuance is when you split the sample, if you check current length + length of incoming then you either write the current pack out or split the incoming sample upto max seq len and write the pack out. IMO that logic and the logic here are almost identical

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's figure out masking correctly before landing this PR.

inputs and labels.
max_seq_len (int): Maximum number of tokens to pack
max_rows (Optional[int]): maximum number of samples to pack. Default is None, which will pack as many samples as possible.
split_samples (bool): if the last sample in a pack does not fit in ``max_seq_len``,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like split_across_instances or split_across_boundary would be more intuitive. If I just found this flag in a config, I'd interpret it as splitting all samples across instances i.e. no packing. Ignore this comment if this is an accepted convention in the community. If not, I'd suggest thinking a bit about this flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it to split_across_pack

torchtune/datasets/_packed.py Show resolved Hide resolved
current_pack["input_ids"].extend(input_ids)
current_pack["labels"].extend(labels)

if len(current_pack["input_ids"]) > self.max_seq_len:
Copy link
Contributor

@kartikayk kartikayk Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is just me, but I find the logic a bit round about. Why not just see if the current length of the pack + length of incoming sample is > max length or not? If it is, you write the current pack out, if not just add it. Maybe I'm missing some complexity?

)

previous_sample_boundary = len(current_pack["input_ids"])
if self.max_rows is not None and len(self.samples) >= self.max_rows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So max_rows is only respected if we have < max_seq_len tokens right? Then why have that option at all? User can just reduce the seq len to get fewer samples? Or why would they want fewer samples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly - max rows lets users limit how many total packs in the dataset is returned. I imagine this will be a lot more relevant for streamed / iterable datasets. max_seq_len controls the size of each individual pack. reducing max_seq_len but keeping dataset size the same will actually result in more samples.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with @RdoubleA. Max rows is how many packed rows you end with. Most users will choose None for datasets that fit in RAM, but will choose some lesser value for OOM datasets. A very nice helper arg might be to calculate max_rows for you. Not needed for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh hmm, then I'm not sure this is clear in the doc string

max_rows (Optional[int]): maximum number of samples to pack. Default is None, which will pack as many samples as possible.

This leads me to believe this param controls the number of samples read in not the number f packs returned? Or am I misunderstanding?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote that comment.

The intent is to signal that this arg controls the total amount of packed samples you will end up with at the end of the process. However, if you don't understand it, then others wont, either. So, the phrasing should probably be revised for clarity.

@RdoubleA RdoubleA changed the title Sample packing for map datasets [WIP] Sample packing for map datasets May 1, 2024
@RdoubleA RdoubleA changed the title [WIP] Sample packing for map datasets Sample packing for map datasets, correct RoPE encoding and attention masking May 6, 2024
@RdoubleA
Copy link
Contributor Author

RdoubleA commented May 7, 2024

Do we handle serialization of packed dataset? running the packing only on rank0?

These are great suggestions for further optimizations that I will leave as follow ups :)

shouldn't we compute the block triu on the fly so we don't store this?

This is a good point, but where would you suggest creating this mask? This would require keeping track of sequence lengths if we create it outside PackedDataset.

@tcapelle
Copy link
Contributor

tcapelle commented May 7, 2024

There is a tradeoff, pre-computing them is faster but if we serialize it would be very memory expensive. We can compute them with the position ids, maybe storing them as ids like 1, 1, 1, 2, 2, 2, 3] and then constructing the block triu:
image
We are already touching the TransformerDecoderLayer, maybe there? Or when creating the batch with a DataCollator?

I may be missing something, maybe @winglian can suggest otherwise.

Also, maybe renaming positon_ids instead of input_pos ❤️ ?

  • Another thing is that we need a way of running this on rank 0 only, and then propagating to the other ranks. Sometimes packing a big dataset can take quite a lot of time, we are also tokenizing and ideally not waste that time for subsequent runs
  • Can we still train with packing and without the positional masking? I mean attending to other samples? I am still not convincend that this is bad, even if it makes sense theorically.

@RdoubleA
Copy link
Contributor Author

RdoubleA commented May 7, 2024

Or when creating the batch with a DataCollator?

I debated this exact thing when trying to figure out where to create the mask and position ids, I was mainly focused on simplicity and speed. If we have to recreate the mask in TransformerDecoderLayer or the collator we need to loop through the entire batch and the seqlens for each pack which is quite inefficient, though it might be worth benchmarking what the time difference is. But you're right that this would be much cheaper memory wise.

Also, maybe renaming positon_ids instead of input_pos ❤️ ?

This was the existing name, so I'd rather not add a refactor on top of all the changes here :)

Can we still train with packing and without the positional masking? I mean attending to other samples?

My opinion is that this is incorrect and should not be supported. Samples in an SFT dataset are typically highly correlated and should not cross-attend. But others are welcome to add their thoughts on this.

recipes/full_finetune_distributed.py Show resolved Hide resolved
torchtune/datasets/_chat.py Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved
torchtune/modules/position_embeddings.py Show resolved Hide resolved
@tcapelle
Copy link
Contributor

tcapelle commented May 8, 2024

This is more regarding the naming of the padding functions

Collating happens at the sample level for packed datasets with _padded_collate_packed, so we do not use a collator in the dataloader if dataset is packed

When I read this, I got the idea that the PackedDataset is already returning batches of data that don't need collation. I feel that these functions should be named packed_padding or something like, as the collation for me at least is constructing a batch, cat([sample1, sample2]).

I am probably not seeing something, so take this comment with a grain of salt

@RdoubleA
Copy link
Contributor Author

RdoubleA commented May 8, 2024

When I read this, I got the idea that the PackedDataset is already returning batches of data that don't need collation

Hm I see what you mean, technically this is not a collator, _padded_collate_packed is just handling padding. Maybe this should be moved to a utils file

@winglian
Copy link
Collaborator

winglian commented May 8, 2024

I debated this exact thing when trying to figure out where to create the mask and position ids, I was mainly focused on simplicity and speed. If we have to recreate the mask in TransformerDecoderLayer or the collator we need to loop through the entire batch and the seqlens for each pack which is quite inefficient, though it might be worth benchmarking what the time difference is. But you're right that this would be much cheaper memory wise.

Another option is building it in the model forward rather than each decoder layer. This way it only needs to be generated once.

recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
tests/torchtune/datasets/test_packed_dataset.py Outdated Show resolved Hide resolved
tests/torchtune/modules/test_position_embeddings.py Outdated Show resolved Hide resolved
torchtune/datasets/_chat.py Show resolved Hide resolved
torchtune/utils/collate.py Outdated Show resolved Hide resolved
mask = torch.block_diag(mask, mask_pad)
# For position ids, continue to increment for pad tokens
input_pos_pad = torch.arange(
input_pos[-1] + 1, max_seq_len - len(input_pos) + input_pos[-1] + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a bit confusing, maybe just define next_pos=input_pos[-1] + 1 in a separate line to make the arange clearer

input_pos[-1] + 1, max_seq_len - len(input_pos) + input_pos[-1] + 1
)
# Do not go beyond max_seq_len - 1
input_pos_pad = input_pos_pad.clamp(max=max_seq_len - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully follow this. Does this only happen when split_across_pack=True or something? Also in that case where are we truncating the tokens and labels? Is it in the base dataset? If so it's a bit confusing to have it spread out across three places (though admittedly I don't really see a better way since base dataset obviously shouldn't care about input_pos)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are not truncating - the creation of the packs are "truncating" to a set max_seq_len. If we split across packs there is an edge case for the last sample if it bleeds over into the next pack where the pos ids continue to the end of that sample (say sample had length 5 and max_seq_len is 8, and [..., 0, 1, 2] ended up in the previous pack, and the sample got split into the last pack [3, 4]. We still have 6 slots left, how do we pad position ids here without going beyond max_seq_len?). So I just enforced this, these position ids should be ignored anyway by the loss. But the better solution is a padding mask

Comment on lines 205 to 207
# shape: [b, 1, s, s]
if mask is not None:
mask = mask[:, None, :, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on doing this in attention.py? Cause (a) that's where we actually need it to be this shape, and (b) then we do not have separate contracts on mask shape in TransformerDecoder and TransformerDecoderLayer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that makes sense

torchtune/modules/position_embeddings.py Show resolved Hide resolved
@RdoubleA RdoubleA changed the title Sample packing for map datasets, correct RoPE encoding and attention masking Sample packing for map datasets with correct RoPE encoding and no cross-contamination May 9, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 21.50943% with 208 lines in your changes are missing coverage. Please review.

Project coverage is 26.85%. Comparing base (a978956) to head (3be7369).
Report is 32 commits behind head on main.

Files Patch % Lines
tests/torchtune/datasets/test_packed_dataset.py 28.75% 57 Missing ⚠️
torchtune/datasets/_packed.py 21.53% 51 Missing ⚠️
...ests/torchtune/modules/test_position_embeddings.py 25.00% 12 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 11 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 11 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 10 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 10 Missing ⚠️
torchtune/modules/transformer.py 16.66% 5 Missing ⚠️
tests/torchtune/modules/test_attention.py 0.00% 4 Missing ⚠️
torchtune/modules/attention.py 0.00% 4 Missing ⚠️
... and 19 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #875       +/-   ##
===========================================
- Coverage   66.39%   26.85%   -39.54%     
===========================================
  Files         155      176       +21     
  Lines        6484     7532     +1048     
===========================================
- Hits         4305     2023     -2282     
- Misses       2179     5509     +3330     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

from torchtune.modules.tokenizers import Tokenizer


def alpaca_dataset(
tokenizer: Tokenizer,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: be consistent here. In other things you put the kw delineator as the first arg, and in others you put it after tokenizer. Just pick one, either works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm for instruct and chat dataset builders there are many more required positional arguments which is why I put the asterisk in the beginning. I can just put it in the beginning for all of them

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit regarding consistency where we place the kw-only delineator in the dataset builders, but otherwise this looks amazing!

@RdoubleA is the hero we don't deserve 🫡

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, this is probably the best PR I've seen in torchtune!

tests/torchtune/datasets/test_packed_dataset.py Outdated Show resolved Hide resolved
torchtune/datasets/_alpaca.py Outdated Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved
torchtune/datasets/_packed.py Outdated Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved
torchtune/datasets/_packed.py Show resolved Hide resolved

# If the current pack is long enough, add it to self.packs and retain
# any truncated samples for next pack, if splitting samples
if len(current_pack["tokens"]) > self.max_seq_len:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why not >=?

ds (Dataset): dataset to sample pack. This should return a dictionary with field
"tokens" and "labels" containing the tokenized and label samples.
max_seq_len (int): Maximum number of tokens to pack
max_packs (Optional[int]): maximum number of packs. Default is None, which will create as many
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the appropriate error checking here and do we really need to expose this? What if I specify max_packs as something silly like 1 or 2, but the dataset packed into 2 samples isn't possible because it exceeds max_seq_len?

Copy link
Contributor Author

@RdoubleA RdoubleA May 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it won't pack the entire dataset into 2 packs, it will just create 2 packs with whatever samples possible and then drop the rest.

max_packs may become more relevant for iterable datasets, but maybe it's not as useful right now. Will revisit this. This was originally added by @calmitchell617 so it may have been needed for his use case?


# Add the last pack with remaining samples that did not fit in previous
if len(current_pack["tokens"]) > 0 and (
self.max_packs is None or len(self.packs) < self.max_packs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't fully understand why we need max_packs and if we don't need it, this code might become simpler.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General concern about data parallelism -

  1. When running sample pack, will each rank read different data, to ensure the generated packs are still correctly sharded across ranks and ranks don't see duplicated data (and no data is dropped)?
  2. Assuming each rank is indeed packing different data as expected. There seems that there could be a potential for self.packs to become different size across ranks in some situations? For example, if rank 0 encounters 2 max_seq_len samples, it might pack this into 2, while rank 1 encounters a bunch of smaller samples and packs this into 1. In general I don't see any guarantees that the # of packs are the same across ranks.

if rank == 0:
pbar = tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True)

for batch in self.ds:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this iteration rank / distributed aware? i.e. will each data parallel rank read different data to pack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, you're right that it may be worth using DistributedSampler here so that packing is partitioned across ranks (cc @tcapelle who also suggested this)

@RdoubleA
Copy link
Contributor Author

There seems that there could be a potential for self.packs to become different size across ranks in some situations? For example, if rank 0 encounters 2 max_seq_len samples, it might pack this into 2, while rank 1 encounters a bunch of smaller samples and packs this into 1. In general I don't see any guarantees that the # of packs are the same across ranks.

This is a great point and something we'll need to consider anyway when we move to iterable datasets. I need to think about this more, @gokulavasan had some suggestions on this so let me connect with him offline

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this through!

torchtune/datasets/_packed.py Show resolved Hide resolved
@RdoubleA RdoubleA merged commit 1e8d081 into pytorch:main May 15, 2024
29 checks passed
@RdoubleA RdoubleA deleted the packing branch May 15, 2024 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet