-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
1946 lines (1696 loc) · 85.1 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain utilities."""
import dataclasses
from datetime import datetime
import functools
import gc
import logging
import math
import os
import sys
from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
from .theoretical_memory_usage import report_theoretical_memory
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import (
check_param_hashes_across_dp_replicas,
get_model_config,
StragglerDetector,
is_float8tensor,
)
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.training.checkpointing import checkpoint_exists
from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
HAVE_FSDP2 = True
except ImportError:
HAVE_FSDP2 = False
from megatron.core.distributed import finalize_model_grads
from megatron.core.enums import ModelType
from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig
from megatron.core.rerun_state_machine import (
get_rerun_state_machine,
destroy_rerun_state_machine,
RerunDataIterator,
RerunMode,
)
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options
from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.core.parallel_state import (
destroy_global_memory_buffer,
destroy_model_parallel,
)
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.num_microbatches_calculator import (
destroy_num_microbatches_calculator,
get_current_global_batch_size,
get_current_running_global_batch_size,
get_num_microbatches,
update_num_microbatches)
from .async_utils import maybe_finalize_async_save
from .utils import (
calc_params_l2_norm,
check_adlr_autoresume_termination,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model,
append_to_progress_log,
update_use_dist_ckpt,
)
from .global_vars import (
destroy_global_vars,
get_args,
get_signal_handler,
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
from . import one_logger_utils
from . import ft_integration
stimer = StragglerDetector()
def destroy_global_state():
destroy_global_vars()
destroy_num_microbatches_calculator()
destroy_global_memory_buffer()
destroy_model_parallel()
destroy_rerun_state_machine()
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0(f'[{string}] datetime: {time_str} ')
def num_floating_point_operations(args, batch_size):
# Attention projection size.
query_projection_size = args.kv_channels * args.num_attention_heads
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
# Group Query Attention.
if not args.group_query_attention:
args.num_query_groups = args.num_attention_heads
# MoE.
num_experts_routed_to = 1 if args.num_experts is None else args.moe_router_topk
gated_linear_multiplier = 3 / 2 if args.swiglu else 1
shared_expert_ffn_hidden_size = (
0
if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size
)
# The 12x term below comes from the following factors; for more details, see
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
# - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
# backward wgrad [weight gradient], backward dgrad [data gradient]).
# - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
# architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
# in MLP layer).
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
return (
expansion_factor
* batch_size
* args.seq_length
* args.num_layers
* args.hidden_size
* args.hidden_size
* (
# Attention.
(
(
1
+ (args.num_query_groups / args.num_attention_heads)
+ (args.seq_length / args.hidden_size)
) * query_projection_to_hidden_size_ratio
)
# MLP.
+ (
(args.ffn_hidden_size / args.hidden_size)
* num_experts_routed_to
* gated_linear_multiplier
)
# Shared Experts.
+ ((shared_expert_ffn_hidden_size / args.hidden_size) * gated_linear_multiplier)
# Logit.
+ (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size))
)
)
def get_start_time_from_progress_log():
"""
Gets start time of earliest job with same world size. Also returns the number
of floating-point operations completed in last saved checkpoint.
"""
args = get_args()
assert args.save is not None
progress_log_filename = os.path.join(args.save, "progress.txt")
# start_time is time when job with same world size started.
# start_num_floating_point_operations is the number of floating-point operations
# completed when this job started.
# latest_num_floating_point_operations is the number of floating-point operations
# completed in most recent saved checkpoint.
start_time = None
start_num_floating_point_operations = None
latest_num_floating_point_operations = 0
def _get_field(string, type):
return type(string.split(': ')[1])
with open(progress_log_filename, 'r') as f:
for line in f:
line = line.strip()
line_tokens = line.split('\t')
world_size_in_line = _get_field(line_tokens[2], int)
if line_tokens[3] == "Saved checkpoint":
latest_num_floating_point_operations = \
_get_field(line_tokens[7], float)
if world_size_in_line != args.world_size:
# Re-start search if we see a different world size.
start_time = None
start_num_floating_point_operations = None
continue
if line_tokens[3] == "Starting job":
if start_time is None:
start_time = line_tokens[0]
start_num_floating_point_operations = \
latest_num_floating_point_operations
assert start_time is not None and start_num_floating_point_operations is not None, \
"Should have seen at least one 'Starting job' entry with same world_size"
return datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S'), \
start_num_floating_point_operations
def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
preprocessed_common_state_dict['args'].pop('local_rank', None)
preprocessed_common_state_dict['args'].pop('rank', None)
return preprocessed_common_state_dict
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
get_embedding_ranks=None,
get_position_embedding_ranks=None,
non_loss_data_func=None,
):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the model using the forward_step_func.
Args:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
get_embedding_ranks (TODO):
get_position_embedding_ranks (TODO):
non_loss_data_func (callable): A custom function to call during evaluation.
It can run e.g. benchmarks.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(
extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks
)
args = get_args()
timers = get_timers()
if args.log_progress:
append_to_progress_log("Starting job")
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.tensor([_TRAIN_START_TIME],
dtype=torch.double,
device='cuda')
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()
# Track E2E metrics on pretrain start
one_logger_utils.on_pretrain_start()
# Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
checkpointing_context = {
'local_checkpoint_manager': BasicLocalCheckpointManager(
args.non_persistent_local_ckpt_dir
)
}
else:
checkpointing_context = {}
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
app_metrics['app_build_optimizer_start_time'] = one_logger_utils.get_timestamp_in_ms()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type, checkpointing_context=checkpointing_context)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
app_metrics['app_build_optimizer_finish_time'] = one_logger_utils.get_timestamp_in_ms()
config = get_model_config(model[0])
# Data stuff.
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
iterators = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()
# Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
one_logger_utils.track_config_flags(args.train_iters, args.skip_train, args.do_train,
args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().init_workload_monitoring()
ft_timeouts = ft_integration.get_rank_monitor_client().timeouts
print_rank_0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
one_logger = get_one_logger()
one_logger and one_logger.log_metrics(app_metrics)
if not args.skip_train:
print_rank_0('training ...')
if args.dataloader_type == 'cyclic' and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = train(
forward_step_func,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context,
non_loss_data_func)
print_datetime('after training is done')
if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
})
else:
print_rank_0('skipping training (--skip-train is on) ...')
iteration = args.iteration
if args.do_valid:
prefix = f'iteration {iteration} on validation set'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)
if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
maybe_finalize_async_save(blocking=True)
one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms()
})
one_logger_utils.finish()
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]) and consumed_samples <= args.train_samples:
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
if args.train_samples > consumed_samples:
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0(f'setting training iterations to {args.train_iters}')
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
rank = mpu.get_pipeline_model_parallel_rank()
first_decoder_rank = args.encoder_pipeline_model_parallel_size
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_inside_encoder(rank)
add_decoder = mpu.is_inside_decoder(rank)
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
# The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's
# Float8Tensor, which will write an unwanted value (amax calculated from the current fp8
# param) to its amax_history. The following logic will correct the amax_history back.
for model_module in model:
for param in model_module.parameters():
if is_float8tensor(param) and param._fp8_meta is not None:
fp8_meta = param._fp8_meta['scaling_fwd']
fp8_meta_index = param._fp8_meta_index
if hasattr(param, 'get_high_precision_init_val'):
fp8_meta.amax_history[0][fp8_meta_index].copy_(
param.get_high_precision_init_val().abs().max()
)
else:
fp8_meta.amax_history[0][fp8_meta_index] = 0
if wrap_with_ddp:
if getattr(args, "use_torch_fsdp2", False):
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
else:
DP = DDP
config = get_model_config(model[0])
kwargs = {}
for f in dataclasses.fields(DistributedDataParallelConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['average_in_collective'] = args.ddp_average_in_collective
ddp_config = DistributedDataParallelConfig(**kwargs)
overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False)
model = [DP(config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step)
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
wsd_decay_steps = None
if args.lr_wsd_decay_iters is not None:
wsd_decay_steps = args.lr_wsd_decay_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
wsd_decay_steps = args.lr_wsd_decay_samples
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
init_lr=args.lr_warmup_init,
max_lr=args.lr,
min_lr=args.min_lr,
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler,
wsd_decay_steps=wsd_decay_steps,
lr_wsd_decay_style=args.lr_wsd_decay_style)
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0,
checkpointing_context=None):
"""Setup model and optimizer."""
args = get_args()
timers = get_timers()
one_logger = get_one_logger()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model)
kwargs = {}
for f in dataclasses.fields(OptimizerConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
config = OptimizerConfig(**kwargs)
config.timers = timers
optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.moe_use_upcycling:
torch.distributed.barrier()
assert not checkpoint_exists(
args.save
), ("The upcycling destination directory already exists. "
"Please check if --moe-use-upcycling is mistakenly enabled. "
"Upcycling should only be set for the first run when converting the dense model. "
"All subsequent runs should remove this flag. ")
num_experts = args.num_experts
args.num_experts = None
expert_model_parallel_size = args.expert_model_parallel_size
args.expert_model_parallel_size = 1
dense_model_for_upcycling = get_model(model_provider_func, model_type)
args.num_experts = num_experts
args.expert_model_parallel_size = expert_model_parallel_size
_, args.num_floating_point_operations_so_far = upcycling_utils.load_and_upcycle_model(
load_checkpoint,
unwrapped_model,
dense_model_for_upcycling,
load_kwargs = {'model': dense_model_for_upcycling, 'optimizer': None, 'opt_param_scheduler': None}
)
args.iteration = 1
save_checkpoint(args.iteration, model, None, None, args.num_floating_point_operations_so_far)
torch.distributed.barrier()
del dense_model_for_upcycling
if (args.fp16 or args.bf16) and optimizer is not None:
optimizer.reload_model_params()
print_rank_0(f'Upcycled checkpoint saved to {args.save}')
if (args.load is not None or args.pretrained_checkpoint is not None) and not args.moe_use_upcycling:
one_logger and one_logger.log_metrics({
'load_checkpoint_start_time': one_logger_utils.get_timestamp_in_ms()
})
timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
model, optimizer, opt_param_scheduler,
ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and getattr(args, "use_torch_fsdp2", False))
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
one_logger and one_logger.log_metrics({
'load_checkpoint_finish_time': one_logger_utils.get_timestamp_in_ms(),
'load_checkpoint_time': timers('load-checkpoint').active_time()
})
else:
args.iteration = 0
args.num_floating_point_operations_so_far = 0
# get model without FP16 and/or DDP wrappers
if args.iteration == 0 and len(unwrapped_model) == 1 \
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
# Convert checkpoint format.
if args.ckpt_convert_format is not None:
load_ckpt_format = args.ckpt_format
args.ckpt_format = args.ckpt_convert_format
args.save = os.path.join(args.ckpt_convert_save, args.ckpt_convert_format)
update_use_dist_ckpt(args)
save_checkpoint(args.iteration, model, optimizer, opt_param_scheduler,
args.num_floating_point_operations_so_far,
preprocess_common_state_dict_fn=preprocess_common_state_dict)
print_rank_0("> converted checkpoint: %s -> %s." % (load_ckpt_format, args.ckpt_format))
torch.distributed.barrier()
exit()
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad, moe_relu_sparsity, moe_relu_l1_reg_coeff):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()
one_logger = get_one_logger()
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'all-grads-sync',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
# Track app tag & app tag ID
one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path , 'wb') as f:
dump(snapshot, f)
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.skipped_train_samples > 0:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
if wandb_writer:
wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration)
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'batch-size': batch_size}, iteration)
for key in loss_dict:
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({key: loss_dict[key]}, iteration)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'loss-scale': loss_scale}, iteration)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'world-size': args.world_size}, iteration)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'grad-norm': grad_norm}, iteration)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'params-norm': params_norm}, iteration)
if moe_relu_sparsity is not None:
writer.add_scalar('moe_relu_sparsity', moe_relu_sparsity, iteration)
writer.add_scalar('moe_relu_sparsity vs samples', moe_relu_sparsity,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'moe_relu_sparsity': moe_relu_sparsity}, iteration)
if moe_relu_l1_reg_coeff is not None:
writer.add_scalar('moe_relu_l1_reg_coeff', moe_relu_l1_reg_coeff, iteration)
writer.add_scalar('moe_relu_l1_reg_coeff vs samples', moe_relu_l1_reg_coeff,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'moe_relu_l1_reg_coeff': moe_relu_l1_reg_coeff}, iteration)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)