Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning crash at iteration 592. #32

Open
lippman1125 opened this issue Dec 6, 2023 · 6 comments
Open

Pruning crash at iteration 592. #32

lippman1125 opened this issue Dec 6, 2023 · 6 comments

Comments

@lippman1125
Copy link

lippman1125 commented Dec 6, 2023

@xiamengzhou
[batch=592/3200]
Train time/batch: 591
Train time/sample: 18912
Train time/batch_in_epoch: 591
Train time/sample_in_epoch: 18912
Train time/token: 77463552
Train time/token_in_epoch: 77463552
Train metrics/train/cc_weight: 0.2292
Train metrics/train/github_weight: 0.0121
Train metrics/train/book_weight: 0.0220
Train metrics/train/stackexchange_weight: 0.0059
Train metrics/train/wiki_weight: 0.5933
Train metrics/train/arxiv_weight: 0.0038
Train metrics/train/c4-rp_weight: 0.1336
Train memory/current_allocated_mem: 14.6140
Train memory/current_active_mem: 14.6140
Train memory/current_inactive_mem: 1.9265
Train memory/current_reserved_mem: 43.4220
Train memory/peak_allocated_mem: 28.0710
Train memory/peak_active_mem: 28.0710
Train memory/peak_inactive_mem: 11.7290
Train memory/peak_reserved_mem: 43.4220
Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.3583
Train metrics/train/target_head_sparsity: 0.3463
Train metrics/train/expected_intermediate_sparsity: 0.3196
Train metrics/train/target_intermediate_sparsity: 0.3436
Train metrics/train/expected_layer_sparsity: 0.0039
Train metrics/train/target_layer_sparsity: 0.0000
Train metrics/train/expected_hidden_sparsity: 0.4266
Train metrics/train/target_hidden_sparsity: 0.3463
Train metrics/train/expected_sparsity: 0.6188
Train metrics/train/target_sparsity: 0.5616
Train trainer/device_train_microbatch_size: 4
Train loss/train/total: 3.5578
Train loss/train/ce_loss: 2.8953
Train loss/train/lag_loss: 0.6625
Train metrics/train/LanguageCrossEntropy: 2.8953
Train metrics/train/Perplexity: 18.0886
Train metrics/train/cc_LanguageCrossEntropy: 3.0387
Train metrics/train/cc_count: 9884
Train metrics/train/github_LanguageCrossEntropy: nan
Train metrics/train/github_count: 652
Train metrics/train/book_LanguageCrossEntropy: nan
Train metrics/train/book_count: 712
Train metrics/train/stackexchange_LanguageCrossEntropy: nan
Train metrics/train/stackexchange_count: 236
Train metrics/train/wiki_LanguageCrossEntropy: 2.7964
Train metrics/train/wiki_count: 4011
Train metrics/train/arxiv_LanguageCrossEntropy: nan
Train metrics/train/arxiv_count: 267
Train metrics/train/c4-rp_LanguageCrossEntropy: 3.1243
Train metrics/train/c4-rp_count: 3182
Train throughput/batches_per_sec: 0.1329
Train throughput/samples_per_sec: 4.2523
Train throughput/device/batches_per_sec: 0.0166
Train throughput/device/samples_per_sec: 0.5315
Train throughput/tokens_per_sec: 17417.3748
Train throughput/device/tokens_per_sec: 2177.1719
Train throughput/flops_per_sec: 816440956730026.0000
Train throughput/device/flops_per_sec: 102055119591253.2500
Train time/train: 1.2715
Train time/val: 0.6538
Train time/total: 1.9253
Traceback (most recent call last):
File "/llm-shearing//llmshearing/train.py", line 317, in
main(cfg)
File "/llm-shearing//llmshearing/train.py", line 301, in main
trainer.fit()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 18
76, in fit
self._train_loop()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 20
18, in _train_loop
for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 30
24, in _iter_dataloader
batch = next(dataloader_iter)
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line
630, in next
data = self._next_data()
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line
674, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", li
ne 32, in fetch
data.append(next(self.dataset_iter))
File "/llm-shearing/llmshearing/datasets/streaming_dataset.py", line 392,
in iter
domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id]
IndexError: index 552 is out of bounds for axis 0 with size 552

Reproduce as follow:
prepare data for prunnig as paper said.
then execuate following command
/bin/bash llmshearing/scripts/prunning.sh

@lippman1125
Copy link
Author

Line 392 of llmshearing/datasets/streaming_dataset.py

        sample_ids_per_stream = self._get_work(world, epoch, used_sample_ids)
        # Currently only supports dynamically loading data from each domain for once. 
        # Issues could occur if one domain of data is used up. 
        while True:
            proportion = self.proportion
            stream_id = np.random.choice(range(self.num_streams), 1, p=proportion)[0].item()
            domain_sample_id = sample_ids_per_stream[stream_id]
            domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]
            self.used_num_samples_per_stream[stream_id] += 1
            yield self[domain_sample_id]

I think

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]

should change to

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % sample_ids_per_stream[stream_id]]

because sample_ids_per_stream[stream_id] size is smaller than self.samples_per_stream[stream_id]

@lippman1125
Copy link
Author

lippman1125 commented Dec 6, 2023

this code
domain_sample_id = sample_ids_per_stream[stream_id]
is always return same sample id list

sample_ids_per_stream[stream_id] = self.samples_per_stream[stream_id] / gpu_num / worker_num ?

I also found that always rank0 takes data.
debug info as follow:

rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=200
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=642
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=201
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=1012
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=4
self.used_num_samples_per_stream[4]=16
self.samples_per_stream[4]=4394
1 domain_sample_id size=1104
@ domain_sample_id =[76176 76205 76178 ...    -1    -1    -1]
2 domain_sample_id=76208
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=202
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=717
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=203
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=987
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=204
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=729

@lippman1125
Copy link
Author

@xiamengzhou Can you help me look at this issue?

@xiamengzhou
Copy link
Contributor

Hi -- thanks for bringing this to our attention. I think you are correct! However,

  • sample_ids_per_stream[stream_id] returns a numpy array, and I don't think it should be used as a denominator?
    • The current implementation only supports single worker logic, because shared memory for used ids, and proportion update across workers are not implemented.
  • The current implementation is a little wrong in that we should use the size of the assigned data of the current rank as the denominator. But it should not be problematic if we don't exhaust data of any streams.

Let me know if it helps!

@PengWenChen
Copy link

Hi @xiamengzhou,
I also encounter the same issue on the same index:
IndexError: index 552 is out of bounds for axis 0 with size 552.

As the comment said, Issues could occur if one domain of data is used up.
However, I use the same amount of data (0.4B) and same initial proportion as papers.
Why data exhausting happens? How to avoid exhausting data of any streams?
Thank you.

@PengWenChen
Copy link

PengWenChen commented Jan 9, 2024

Thanks to @lippman1125 's advice. I agree with the modification of domain_sample_id.
I change the domain_sample_id from
domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % self.samples_per_stream[stream_id]]
to
domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % len(sample_ids_per_stream[stream_id])]

And it works now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants