From 3e96119c949396dc0a5135e55c2bbefa8590981f Mon Sep 17 00:00:00 2001 From: Steve Bako Date: Tue, 16 Jul 2024 19:12:40 -0700 Subject: [PATCH] enable torch compile with ddp --- train_gpt2.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 403f213..f5f08f7 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -339,9 +339,8 @@ def get_most_likely_row(tokens, mask, logits): model = GPT(GPTConfig(vocab_size=50304)) # model = GPT.from_pretrained("gpt2") # or init from OpenAI GPT-2 model.to(device) -use_compile = False # torch.compile interferes with HellaSwag eval and Generation. TODO fix -if use_compile: - model = torch.compile(model) +model = torch.compile(model) +torch._dynamo.config.optimize_ddp = False if ddp: model = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module if ddp else model # always contains the "raw" unwrapped model @@ -411,7 +410,7 @@ def get_lr(it): torch.save(checkpoint, checkpoint_path) # once in a while evaluate hellaswag - if (step % 250 == 0 or last_step) and (not use_compile): + if (step % 250 == 0 or last_step): num_correct_norm = 0 num_total = 0 for i, example in enumerate(iterate_examples("val")): @@ -444,7 +443,7 @@ def get_lr(it): f.write(f"{step} hella {acc_norm:.4f}\n") # once in a while generate from the model (except step 0, which is noise) - if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile): + if ((step > 0 and step % 250 == 0) or last_step): model.eval() num_return_sequences = 4 max_length = 32