diff --git a/train_gpt2.py b/train_gpt2.py index 403f213..8f46bb3 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -245,7 +245,7 @@ def next_batch(self): # advance the position in the tensor self.current_position += B * T * self.num_processes # if loading the next batch would be out of bounds, advance to next shard - if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): + if self.current_position + (B * T + 1) > len(self.tokens): self.current_shard = (self.current_shard + 1) % len(self.shards) self.tokens = load_tokens(self.shards[self.current_shard]) self.current_position = B * T * self.process_rank