forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributed_fused_lamb.py
1061 lines (952 loc) · 55.6 KB
/
distributed_fused_lamb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import math
import inspect
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
# Fallback to private fields if using older PyTorch version
try:
import torch.distributed.distributed_c10d.get_process_group_ranks
except ImportError:
def get_process_group_ranks(group):
return list(c10d._pg_group_ranks[group].keys())
_make_nccl_premul_sum = getattr(torch.distributed, "_make_nccl_premul_sum", None)
# Ref: https://github.com/pytorch/pytorch/pull/81272
if _make_nccl_premul_sum is None:
if hasattr(torch.distributed, "make_nccl_premul_sum"):
_make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum
class DistributedFusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
class AtomicCounter(object):
def __init__(self):
self.value = 0
self.order = []
import threading
self._lock = threading.Lock()
def add(self, idx):
with self._lock:
self.value += 1
self.order.append(idx)
def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False, param_order=None, nccl_allgather_channels=0):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._grad_averaging = grad_averaging
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._step_supports_amp_scaling = step_supports_amp_scaling
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather
self._fused_norm = fused_norm if not clip_after_ar else False
self._current_process_group = c10d._get_default_group()
self._available_ranks = get_process_group_ranks(self._current_process_group)
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
self._resume_from_checkpoint = False
self._step = torch.cuda.IntTensor([0])
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
# Check if collectives have no_copy option
self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
)
self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
)
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._full_ar: # full all reduce, only need AR and AG groups
# l2_grad_norm may be reduced within a node to limit from memory reads
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+j for j in range(self._group_size)]
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._ar_pg = []
# consider all the ranks
ranks = list(range(0, self._world_size))
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
if nccl_allgather_channels > 0:
os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
if self._num_ag_pg == 0:
self._ag_pg = self._ar_pg
self._ag_st = self._ar_st
self._num_ag_pg = self._num_ar_pg
else:
self._ag_pg = []
ranks = []
stride = torch.cuda.device_count()
for i in range(self._num_groups):
rs = list(range(i*stride, (i+1)*stride))
ranks.append(rs)
for rs in ranks:
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=rs)
if torch.distributed.get_rank() in rs:
if self._verbose:
print(f"creating AG group {i}: {rs}")
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
else: # reduce-scatter + all-reduce, need RS, AR, AG groups
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._verbose:
print(f"creating RS group : {ranks}")
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
if self._verbose:
print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.barrier(group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
self._one = torch.cuda.IntTensor([1])
self._first_step = True
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter()
p_offset = 0
p_i = 0
self._model_params = []
self._grad_accs = []
self._group_properties = []
for group in self.param_groups:
prev = None
beta1, beta2 = group['betas']
beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if not p.requires_grad:
continue
self._model_params.append(p)
self._group_properties.append((
weight_decay,
bias_correction,
beta1,
beta2,
beta3,
eps
))
p_grads_size = p.numel()
if self._set_flat_param_view:
if param_order:
# this is executed when param_order is specified by the user
self._param_order.add(param_order[p])
else:
self._param_order.add(p_i)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
if param_order:
self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(p)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
# note(crcrpar): the function below doesn't seem to be used at all.
# def _flat_split_no_shards(p):
# def __blockify(p):
# return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
# def __chunkify(p):
# return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
# list_of_blocks = __blockify(self._flat_grads)
# list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
# return list_of_blocks, list_of_list_of_chunks
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _split_assign(shards):
packed_block_size = self._num_chunks*self._shard_size
list_of_list_of_chunks=[]
for block_id in range(self._num_blocks):
list_of_chunks=[]
for chunk_id in range(self._num_chunks):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
list_of_list_of_chunks.append(list_of_chunks)
return list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
# for weight update
self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params_fp16 = []
self._packed_flat_to_model_params_fp32 = []
self._model_params_num = len(self._model_params)
self._contrib_tensor_list = []
self._contrib_min_param_i, self._contrib_max_param_i = -1, -1
self._contrib_update_frag_for_norm = []
self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = []
self._contrib_group_properties = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
pf = _get_flat_view(p)
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
self._model_param_is_contrib.append(param_i)
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
if p.dtype == torch.float16:
self._contrib_model_param_for_norm_fp16.append(p)
else:
self._contrib_model_param_for_norm_fp32.append(p)
self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)
if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
self._contrib_max_param_i = param_i
self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
self._contrib_update_weights_tensor_list = [u, p, p_copy]
math_type = self._fp32_u.dtype
decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
self._lazy_init_stage2_done = True
self.complete_reductions()
self._first_step = False
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks
if self._clip_after_ar:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
else:
glob_chunk_id = block_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
for i in range(self._num_chunks):
works[i]=works0
self._reductions_works[block_id] = works
def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
if self._reduce_scatter_no_copy:
works[chunk_id] = torch.distributed.reduce_scatter(
output=self._fp16_g_chunks[block_id][chunk_id],
input_list=self._flat_grads_shards[block_id][chunk_id],
group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op=True,
no_copy=True,
op=_make_nccl_premul_sum(scale),
)
else:
works[chunk_id] = torch.distributed.reduce_scatter_tensor(
output=self._fp16_g_chunks[block_id][chunk_id],
input=self._flat_grads_chunks[block_id][chunk_id],
group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op=True,
op=_make_nccl_premul_sum(scale),
)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
if self._reduce_scatter_no_copy:
works[chunk_id] = torch.distributed.reduce_scatter(
output=self._fp16_g_chunks[block_id][chunk_id],
input_list=self._flat_grads_shards[block_id][chunk_id],
group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op=True,
no_copy=True,
)
else:
works[chunk_id] = torch.distributed.reduce_scatter_tensor(
output = self._fp16_g_chunks[block_id][chunk_id],
input = self._flat_grads_chunks[block_id][chunk_id],
group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op = True,
)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size)
if self._full_ar:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
# Compute L2 grad norm
if block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
if self._full_ar:
# this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation
flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]
l2_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [flat_list], False)[0]**2
else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else:
# Copy model grads to flat grads buffer
self._flatten_grad_mt(1.0)
# Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st):
if not self._fused_norm:
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
# Apply clipping & pre-reduction scaling on grads
loss_scale = self.global_scale
max_grad_norm = loss_scale*self.defaults['max_grad_norm']
coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
coeff = (coeff>1) * self._one + (coeff<=1) * coeff
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
if not self._fuse_scale:
self._flat_grads.mul_(scale)
if self._full_ar:
if self._fuse_scale:
self._full_all_reduce_scale(block_id, scale)
else:
self._full_all_reduce(block_id)
else:
if self._fuse_scale:
self._reduce_scatter_and_all_reduce_scale(block_id, scale)
else:
self._reduce_scatter_and_all_reduce(block_id)
if block_id == 0:
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
elif self._contrib_model_param_for_norm_fp16 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
elif self._contrib_model_param_for_norm_fp32 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
return gnorm
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm
def _pipeline_step(self):
global_scale = self.global_scale
# if clip before ar, set max_grad_norm to 0
max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
self._completion_st.wait_stream(self._l2_grad_norm_st)
global_grad_norm = self.L2_grad_norm
# check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
if not self._clip_after_ar:
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow
self._step += is_finite
self._completion_st.wait_stream(torch.cuda.current_stream())
self._completion_st.wait_stream(self._l2_grad_norm_st)
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm()
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
self._contrib_beta1,
self._contrib_beta2,
self._contrib_beta3,
self._contrib_bias_correction,
self._step,
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
global_scale,
global_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm,
upd_norm,
self._offsets,
self._lr,
self._contrib_weight_decay,
global_grad_norm,
self._use_nvlamb)
if not self._skip_ag:
# allgather chunking is currently not supported for clip after allreduce
if not self._clip_after_ar:
for block in range(self._num_blocks):
for chunk in range(self._num_chunks):
if self._all_gather_no_copy:
torch.distributed.all_gather(
tensor_list = self._new_params2_shards[block][chunk],
tensor = self._fp16_p_chunks[block][chunk],
group = self._ag_pg[0],
no_copy = True,
)
else:
torch.distributed.all_gather_into_tensor(
output_tensor = self._new_params2_blocks[block],
input_tensor = self._fp16_p_chunks[block][chunk],
group = self._ag_pg[0],
)
else:
if self._all_gather_no_copy:
torch.distributed.all_gather(
tensor_list = self._new_params_mega_shards,
tensor = self._fp16_p,
group = self._ag_pg[0],
no_copy = True,
)
else:
torch.distributed.all_gather_into_tensor(
output_tensor = self._new_params,
input_tensor = self._fp16_p,
group = self._ag_pg[0],
)
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
self._overflow_buf.zero_()
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale, False)[0].float()
self._grads_fp16 = []
if len(self._grads_fp32) > 0:
self._overflow_buf.zero_()
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale, False)[0].float()
self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param):
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
self._grads_generated[param_i]=True
if not self._first_step and not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._first_step or self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
self._pipeline_step()
if grad_scaler is not None:
found_inf = self._overflow_buf.float()
optimizer_state = grad_scaler._per_optimizer_states[id(self)]
current_device = torch.device('cuda', torch.cuda.current_device())
optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream())