-
Notifications
You must be signed in to change notification settings - Fork 87
/
ai8x.py
2337 lines (1924 loc) · 77.3 KB
/
ai8x.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
###################################################################################################
#
# Copyright (C) 2020-2024 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
# pyright: reportOptionalMemberAccess=false, reportPrivateImportUsage=false
# pyright: reportOptionalCall=false, reportOptionalOperand=false
"""
Contains the limits of the MAX78000/MAX78002 implementations and custom PyTorch modules that take
the limits into account.
"""
import numpy as np
import torch
from torch import nn
from torch.autograd import Function
from torch.fx import symbolic_trace
from tqdm import tqdm
import devices
dev = None
class normalize:
"""
Normalize input to either [-128/128, +127/128] or [-128, +127]
"""
def __init__(self, args):
self.args = args
def __call__(self, img):
if self.args.act_mode_8bit:
return img.sub(0.5).mul(256.).round().clamp(min=-128, max=127)
return img.sub(0.5).mul(256.).round().clamp(min=-128, max=127).div(128.)
class fold:
"""
Fold data to increase the number of channels. An interlaced approach used in this folding
as explained in [1].
[1] https://arxiv.org/pdf/2203.16528.pdf
"""
def __init__(self, fold_ratio):
self.fold_ratio = fold_ratio
def __call__(self, img):
if self.fold_ratio == 1:
return img
img_folded = None
for i in range(self.fold_ratio):
for j in range(self.fold_ratio):
img_subsample = img[:, i::self.fold_ratio, j::self.fold_ratio]
if img_folded is not None:
img_folded = torch.cat((img_folded, img_subsample), dim=0)
else:
img_folded = img_subsample
return img_folded
def unfold_batch(img_batch, fold_ratio):
"""
Unfold data to reduce the number of channels. An interlaced approach used in this folding
as explained in [1]. This operation is the reverse of the transformation implemented
at ai8x.fold class.
[1] https://arxiv.org/pdf/2203.16528.pdf
"""
if fold_ratio == 1:
return img_batch
num_out_channels = img_batch.shape[1] // (fold_ratio*fold_ratio)
img_batch_uf = torch.zeros((img_batch.shape[0], num_out_channels,
img_batch.shape[2]*fold_ratio, img_batch.shape[3]*fold_ratio),
dtype=img_batch.dtype, device=img_batch.device, requires_grad=False)
for i in range(fold_ratio):
for j in range(fold_ratio):
ch_index_start = num_out_channels*(i*fold_ratio + j)
ch_index_end = num_out_channels * (i*fold_ratio + j + 1)
img_batch_uf[:, :, i::fold_ratio, j::fold_ratio] = \
img_batch[:, ch_index_start:ch_index_end, :, :]
return img_batch_uf
class QuantizationFunction(Function):
"""
Custom autograd function
The forward pass divides by 2**(bits-1) (typically, 128) and rounds the result to the
nearest integer.
The backward pass is straight through.
"""
# pylint: disable=abstract-method
@staticmethod
def forward(_, x, bits=8, extra_bit_shift=0): # pylint: disable=arguments-differ
"""Forward prop"""
if dev.simulate:
if bits > 1:
return x.div(2**(bits+extra_bit_shift-1)).add(.5).floor()
if bits < 1:
return x.mul(2**(1-bits-extra_bit_shift)).add(.5).floor()
return x.add(.5).floor()
factor1 = 2**(bits-extra_bit_shift-1)
factor2 = 2**(bits-1)
return x.mul(factor1).add(.5).floor().div(factor2)
@staticmethod
def backward(_, x): # pylint: disable=arguments-differ
"""Backprop"""
# Straight through - return as many input gradients as there were arguments;
# gradients of non-Tensor arguments to forward must be None.
return x, None, None
class Quantize(nn.Module):
"""
Post-activation integer quantization module
Apply the custom autograd function
"""
def __init__(self, num_bits=8, num_extra_bit_shift=0):
super().__init__()
self.num_bits = num_bits
self.num_extra_bit_shift = num_extra_bit_shift
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return QuantizationFunction.apply(x, self.num_bits, self.num_extra_bit_shift)
class FloorFunction(Function):
"""
Custom MAX78000/MAX78002 autograd function
The forward pass returns the integer floor.
The backward pass is straight through.
"""
# pylint: disable=abstract-method
@staticmethod
def forward(_, x): # pylint: disable=arguments-differ
"""Forward prop"""
return x.floor()
@staticmethod
def backward(_, x): # pylint: disable=arguments-differ
"""Backprop"""
# Straight through - return as many input gradients as there were arguments;
# gradients of non-Tensor arguments to forward must be None.
return x
class AvgPoolFloorFunction(Function):
"""
Custom MAX78000/MAX78002 autograd function
The forward pass returns the integer floor for positive numbers and integer
ceil for negative numbers.
The backward pass is straight through.
"""
# pylint: disable=abstract-method
@staticmethod
def forward(_, x): # pylint: disable=arguments-differ
"""Forward prop"""
return torch.where(x > 0, torch.floor(x), torch.ceil(x))
@staticmethod
def backward(_, x): # pylint: disable=arguments-differ
"""Backprop"""
# Straight through - return as many input gradients as there were arguments;
# gradients of non-Tensor arguments to forward must be None.
return x
class Floor(nn.Module):
"""
Post-pooling integer quantization module
Apply the custom autograd function
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return FloorFunction.apply(x)
class AvgPoolFloor(nn.Module):
"""
Post-pooling integer quantization module
Apply the custom autograd function
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return AvgPoolFloorFunction.apply(x)
class FloorONNX(nn.Module):
"""
Post-pooling integer quantization module
Apply the custom autograd function
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return x.floor()
class RoundFunction(Function):
"""
Custom MAX78000/MAX78002 autograd function
The forward pass returns the integer rounded.
The backward pass is straight through.
"""
# pylint: disable=abstract-method
@staticmethod
def forward(_, x): # pylint: disable=arguments-differ
"""Forward prop"""
return x.round()
@staticmethod
def backward(_, x): # pylint: disable=arguments-differ
"""Backprop"""
# Straight through - return as many input gradients as there were arguments;
# gradients of non-Tensor arguments to forward must be None.
return x
class Round(nn.Module):
"""
Post-pooling integer quantization module
Apply the custom autograd function
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return RoundFunction.apply(x)
class Clamp(nn.Module):
"""
Post-Activation Clamping Module
Clamp the output to the given range (typically, [-128, +127])
"""
def __init__(self, min_val=None, max_val=None):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
x = x.clamp(min=self.min_val)
return x.clamp(max=self.max_val)
class Scaler(nn.Module):
"""
Scaler module that considers integer quantization
Apply the custom autograd function
"""
def forward(self, x, s): # pylint: disable=arguments-differ
"""Forward prop"""
if dev.simulate:
return FloorFunction.apply(x.mul(s))
return x.mul(s)
class ScalerONNX(nn.Module):
"""
Scaler module that considers integer quantization
Apply the custom autograd function
"""
def forward(self, x, s): # pylint: disable=arguments-differ
"""Forward prop"""
if dev.simulate:
return x.mul(s).floor()
return x.mul(s)
class ID3(nn.Module):
"""
ID forward function with 3 arguments
"""
def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
return x
class RoundQat(nn.Module):
"""
Round function for AvgPool in QAT mode
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
factor = 2**(dev.ACTIVATION_BITS - 1)
return RoundFunction.apply(x.mul(factor)).div(factor)
class RoundQatONNX(nn.Module):
"""
Round function for AvgPool in QAT mode
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
factor = 2**(dev.ACTIVATION_BITS - 1)
return x.mul(factor).round().div(factor)
class FloorQat(nn.Module):
"""
Floor function for AvgPool in QAT mode
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
factor = 2**(dev.ACTIVATION_BITS - 1)
return AvgPoolFloorFunction.apply(x.mul(factor)).div(factor)
class FloorQatONNX(nn.Module):
"""
Floor function for AvgPool in QAT mode
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
factor = 2**(dev.ACTIVATION_BITS - 1)
return x.mul(factor).floor().div(factor)
def quantize_clamp(wide, quantize_activation=False, clamp_activation=False, weight_bits=8):
"""
Return new Quantization and Clamp objects.
"""
if dev.simulate:
if not wide:
quantize = Quantize(num_bits=dev.DATA_BITS)
clamp = Clamp(
min_val=-(2**(dev.ACTIVATION_BITS-1)),
max_val=2**(dev.ACTIVATION_BITS-1)-1,
)
else:
quantize = Quantize(num_bits=dev.DATA_BITS - weight_bits + 1)
clamp = Clamp(
min_val=-(2**(dev.FULL_ACC_BITS-1)),
max_val=2**(dev.FULL_ACC_BITS-1)-1,
)
else:
if quantize_activation:
if not wide:
quantize = Quantize(num_bits=dev.ACTIVATION_BITS)
else:
quantize = Quantize(num_bits=dev.WIDE_LAYER_RESOLUTION_BITS)
else:
quantize = Empty()
if clamp_activation:
if not wide:
clamp = Clamp( # Do not combine with ReLU
min_val=-1.,
max_val=(2.**(dev.ACTIVATION_BITS-1)-1)/(2.**(dev.ACTIVATION_BITS-1)),
)
else:
clamp = Clamp(
min_val=-(2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1)),
max_val=2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1),
)
else:
clamp = Empty()
return quantize, clamp
def quantize_clamp_pool(pooling, quantize_activation=False, clamp_activation=False):
"""
Return new Quantization and Clamp objects for pooling.
"""
if dev.simulate:
if pooling == 'Avg':
quantize = Round() if dev.round_avg else AvgPoolFloor()
clamp = Clamp(
min_val=-(2**(dev.DATA_BITS-1)),
max_val=2**(dev.DATA_BITS-1)-1,
)
else: # Max, None
quantize = Empty()
clamp = Empty()
else:
quantize = Empty()
if pooling == 'Avg':
if quantize_activation:
quantize = RoundQat() if dev.round_avg else FloorQat()
if clamp_activation:
clamp = Clamp(min_val=-1., max_val=127./128.)
else:
clamp = Empty()
else: # Max, None
clamp = Empty()
return quantize, clamp
def quantize_clamp_parameters(weight_bits, bias_bits):
"""
Return new Quantization and Clamp objects for weight and bias parameters
"""
if dev.simulate:
quantize_weight = Quantize(num_bits=weight_bits-dev.DATA_BITS+1)
quantize_bias = Quantize(num_bits=2*(weight_bits-dev.DATA_BITS)+1)
clamp_weight = Empty()
clamp_bias = Empty()
else:
if weight_bits == 0 and bias_bits == 0:
quantize_weight = Empty()
quantize_bias = Empty()
clamp_weight = Empty()
clamp_bias = Empty()
else:
quantize_weight = Quantize(num_bits=weight_bits)
quantize_bias = Quantize(num_bits=bias_bits)
clamp_weight = Clamp(min_val=-1.,
max_val=(2.**(weight_bits-1)-1)/(2.**(weight_bits-1)))
clamp_bias = Clamp(min_val=-1., max_val=(2.**(bias_bits-1)-1)/(2.**(bias_bits-1)))
return quantize_weight, quantize_bias, clamp_weight, clamp_bias
class OutputShiftPassthrough(nn.Module):
"""
Return output_shift when not using quantization-aware training.
"""
def forward(self, _, x): # pylint: disable=arguments-differ
"""Forward prop"""
return x
class OutputShiftLimit(nn.Module):
"""
Calculate the clamped output shift when adjusting during quantization-aware training.
"""
def __init__(self, shift_quantile=1.0):
super().__init__()
self.shift_quantile = shift_quantile
def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
limit = torch.quantile(x.abs(), self.shift_quantile)
return -(1./limit).log2().floor().clamp(min=-15., max=15.)
class OutputShiftONNX(nn.Module):
"""
Calculate the clamped output shift when adjusting during quantization-aware training.
"""
def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
return -(1./x.abs().max()).log2().floor().clamp(min=-15., max=15.)
class One(nn.Module):
"""
Return 1.
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return torch.ones(1).to(x.device)
class WeightScale(nn.Module):
"""
Calculate the weight scale (reciprocal of 2 to the power of the output shift)
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return torch.exp2(-x)
class WeightScaleONNX(nn.Module):
"""
Calculate the weight scale (reciprocal of 2 to the power of the output shift)
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return 2.**(-x)
class OutputScale(nn.Module):
"""
Calculate the output scale (2 to the power of the output shift)
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return torch.exp2(x)
class OutputScaleONNX(nn.Module):
"""
Calculate the output scale (2 to the power of the output shift)
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return 2.**x
class Abs(nn.Module):
"""
Return abs(x)
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return torch.abs_(x) # abs_() is the in-place version
class Empty(nn.Module):
"""
Do nothing
"""
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
return x
def get_activation(activation=None):
"""
Return the selected `activation` class ('ReLU', 'Abs', None)
"""
if activation == 'ReLU':
return nn.ReLU(inplace=True)
if activation == 'Abs':
assert dev.device != 84
return Abs()
return Empty()
def histogram(inp, bins):
"""
CUDA compatible histogram calculation
"""
minimum, maximum = inp.min(), inp.max()
counts = torch.histc(inp, bins, min=minimum, max=maximum).cpu()
boundaries = torch.linspace(minimum, maximum, bins + 1)
return counts, boundaries
def calc_q_error(module, threshold, bits, eps=1e-9):
"""
Activation quantization error calculation
"""
quantized_hist = module.hist[1].clone()
quantized_hist = torch.round((quantized_hist / (threshold + eps)) * 2**(bits-1))
quantized_hist = torch.clamp(quantized_hist, -2**(bits-1), 2**(bits-1)-1)
quantized_hist = (quantized_hist * (threshold + eps) / 2**(bits-1))
err = torch.sum(((quantized_hist - module.hist[1])**2)*module.hist[0]) \
/ torch.sum(module.hist[0])
return err
def _merge_hist(module):
"""
Merge histograms of activations
"""
bins_to_stack = []
for hist in module.hist:
bins_to_stack.append(hist[1])
stacked_bins = torch.stack(bins_to_stack)
min_edge = stacked_bins.min()
max_edge = stacked_bins.max()
# 2048 is the number of bins and 2049 is the number of edges
merged_bins = torch.linspace(min_edge.item(), max_edge.item(), 2049)
merged_counts = None
for hist in module.hist:
if merged_counts is None:
merged_counts = _interpolate_hist(hist[0], hist[1], merged_bins)
else:
merged_counts += _interpolate_hist(hist[0], hist[1], merged_bins)
module.hist = (merged_counts, merged_bins)
def _interpolate_hist(counts, bins, new_bins):
"""
Helper function for interpolating histograms to new bins
"""
cumulative_hist = torch.cumsum(counts, dim=0).to(device=bins.device)
cumulative_hist = torch.cat((torch.tensor([0]), cumulative_hist))
cumulative_interp_hist = torch.from_numpy(np.interp(new_bins.numpy(), bins.numpy(),
cumulative_hist.numpy()))
interp_counts = torch.diff(cumulative_interp_hist, prepend=torch.tensor([0]))
return interp_counts
# pylint: disable=unused-argument
def _hist_hook(module, inp, output):
"""
Hook to collect histogram of activations
"""
if not hasattr(module, 'hist'):
module.hist = []
# dynamic histogram collection
hist = histogram(output.clone().detach().flatten(), bins=2048)
module.hist.append(hist)
def register_hist_hooks(module):
"""
Register hooks for histogram collection
"""
module.handle = module.register_forward_hook(_hist_hook, always_call=True)
def release_hist_hooks(module):
"""
Release hooks after histogram collection
"""
module.handle.remove()
def _remove_outliers(module, outlier_removal_z_score=8.0):
"""
Remove outliers from histogram
"""
# Get mean and std of histogram
hist_count = module.hist[0]
hist_bins = module.hist[1]
hist_bins_middle = []
for i in range(len(hist_bins) - 1):
hist_bins_middle.append((hist_bins[i] + hist_bins[i+1])/2)
hist_bins_middle = torch.tensor(hist_bins_middle)
mean = torch.sum(hist_count[1:] * hist_bins_middle) / torch.sum(hist_count[1:])
std = torch.sqrt(torch.sum(hist_count[1:] * (hist_bins_middle - mean)**2)
/ torch.sum(hist_count[1:]))
# When activations are very small, std ends up being 0 due to rounding.
# In this case, we set std to a very small value to prevent zero element histogram.
if std == 0:
std = 1e-9
# Calculate bounds according to z-score
upper_bound = mean + outlier_removal_z_score * std
lower_bound = mean - outlier_removal_z_score * std
hist_bins_middle = torch.cat((torch.tensor([0]), hist_bins_middle))
# Remove outliers according to bounds
hist_count[hist_bins_middle > upper_bound] = 0
hist_count[hist_bins_middle < lower_bound] = 0
non_zero_bins = hist_count != 0
hist_count = hist_count[non_zero_bins]
hist_bins = hist_bins[non_zero_bins]
module.hist = (hist_count, hist_bins)
def init_threshold_module(module, outlier_removal_z_score):
"""
Initialize activation threshold
"""
_merge_hist(module)
_remove_outliers(module, outlier_removal_z_score)
module.activation_threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(),
requires_grad=False)
def calc_threshold(module, iterations=5, bits=8):
"""
Iteratively calculate threshold for activation quantization
"""
e_min = torch.inf
t_nc = module.activation_threshold
t = None
for i in range(iterations):
t_i = t_nc / (2**i)
e_i = calc_q_error(module, t_i, bits)
if e_i < e_min:
e_min = e_i
t = t_i
module.activation_threshold = nn.Parameter(torch.log2(t), requires_grad=False)
class QuantizationAwareModule(nn.Module):
"""
Common code for Quantization-Aware Training
"""
def __init__(
self,
pooling=None,
activation=None,
wide=False,
weight_bits=None,
bias_bits=None,
quantize_activation=False,
pool=None,
op=None,
bn=None,
shift_quantile=1.0,
clamp_activation=False,
):
super().__init__()
assert weight_bits in [None, 1, 2, 4, 8], f'Weight bits cannot be {weight_bits}'
assert bias_bits in [None, 1, 2, 4, 8], f'Bias bits cannot be {bias_bits}'
self.quantize = None
self.clamp = None
self.quantize_bias = None
self.clamp_bias = None
self.calc_out_shift = None
self.scale = None
self.calc_weight_scale = None
self.calc_out_scale = None
self.quantize_weight = None
self.clamp_weight = None
self.quantize_pool = None
self.clamp_pool = None
self.activate = get_activation(activation)
self.wide = wide
self.pool = pool
self.op = op
if op is not None and not hasattr(self, '_conv_forward'):
self._conv_forward = op._conv_forward # pylint: disable=protected-access
self.bn = bn
self.pooling = pooling
self.output_shift = nn.Parameter(torch.tensor([0.]), requires_grad=False)
# Activation threshold determined during QAT, used in quantization
# It determines the range of quantization
self.activation_threshold = nn.Parameter(torch.tensor(0.), requires_grad=False)
self.final_scale = nn.Parameter(torch.tensor(0.), requires_grad=False)
self.init_module(weight_bits, bias_bits, quantize_activation,
clamp_activation, shift_quantile)
def init_module(
self,
weight_bits,
bias_bits,
quantize_activation,
clamp_activation,
shift_quantile,
export=False,
):
"""Initialize model parameters"""
if weight_bits is None and bias_bits is None and not quantize_activation:
if not export:
self.weight_bits = nn.Parameter(torch.tensor([0]), requires_grad=False)
self.bias_bits = nn.Parameter(torch.tensor([0]), requires_grad=False)
self.quantize_activation = nn.Parameter(torch.tensor([False]), requires_grad=False)
self.clamp_activation = nn.Parameter(torch.tensor([clamp_activation]),
requires_grad=False)
self.adjust_output_shift = nn.Parameter(torch.tensor([False]), requires_grad=False)
elif weight_bits in [1, 2, 4, 8] and bias_bits in [1, 2, 4, 8] and quantize_activation:
self.weight_bits = nn.Parameter(torch.tensor([weight_bits]), requires_grad=False)
if not export:
self.bias_bits = nn.Parameter(torch.tensor([bias_bits]), requires_grad=False)
self.quantize_activation = nn.Parameter(torch.tensor([True]), requires_grad=False)
self.clamp_activation = nn.Parameter(torch.tensor([True]), requires_grad=False)
self.adjust_output_shift = nn.Parameter(torch.tensor([not dev.simulate]),
requires_grad=False)
else:
assert False, f'Undefined mode with weight_bits: {weight_bits}, ' \
f'bias_bits: {bias_bits}, ' \
f'quantize_activation: {quantize_activation}'
if not export:
self.shift_quantile = nn.Parameter(torch.tensor([shift_quantile]), requires_grad=False)
self.set_functions()
def set_functions(self):
"""Set functions to be used wrt the model parameters"""
if self.adjust_output_shift.detach():
self.calc_out_shift = OutputShiftLimit(self.shift_quantile.detach().item())
self.calc_weight_scale = WeightScale()
else:
self.calc_out_shift = OutputShiftPassthrough()
self.calc_weight_scale = One()
self.scale = Scaler()
self.calc_out_scale = OutputScale()
self.quantize_weight, self.quantize_bias, self.clamp_weight, self.clamp_bias = \
quantize_clamp_parameters(self.weight_bits.detach().item(),
self.bias_bits.detach().item())
self.quantize, self.clamp = \
quantize_clamp(self.wide, bool(self.quantize_activation.detach().item()),
bool(self.clamp_activation.detach().item()),
int(self.weight_bits.detach().item()))
self.quantize_pool, self.clamp_pool = \
quantize_clamp_pool(self.pooling, bool(self.quantize_activation.detach().item()),
bool(self.clamp_activation.detach().item()))
def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
if self.pool is not None:
x = self.clamp_pool(self.quantize_pool(self.pool(x)))
if self.op is not None:
if self.op.bias is not None:
bias_r = torch.flatten(self.op.bias.detach())
weight_r = torch.flatten(self.op.weight.detach())
params_r = torch.cat((weight_r, bias_r))
else:
params_r = torch.flatten(self.op.weight.detach())
out_shift = self.calc_out_shift(params_r, self.output_shift.detach())
weight_scale = self.calc_weight_scale(out_shift)
# Quantized checkpoint will have subtracted threshold from output shift
# Therefore, it shouldn't be done again in simulate mode
if not dev.simulate:
out_shift = (out_shift - self.activation_threshold).clamp(min=-15., max=15.)
out_scale = self.calc_out_scale(out_shift)
x = self._conv_forward( # pylint: disable=protected-access
x,
self.clamp_weight(self.quantize_weight(self.op.weight.mul(weight_scale))),
None if self.op.bias is None
else self.clamp_bias(self.quantize_bias(self.op.bias.mul(weight_scale))),
)
if self.bn is not None:
x = self.bn(x)
if not self.wide:
# The device does not apply output shift in wide mode
x = self.scale(x, out_scale)
x = self.clamp(self.quantize(self.activate(x)))
# This is the final scale for the output, in the device it will be realized in SW
x = x.mul(2.**(self.final_scale))
return x
class Conv2d(QuantizationAwareModule):
"""
2D pooling ('Avg', 'Max' or None) optionally followed by
2D convolution/transposed 2D convolution and activation ('ReLU', 'Abs', None)
"""
def __init__( # pylint: disable=too-many-arguments
self,
in_channels,
out_channels,
kernel_size,
op='Conv2d',
pooling=None,
pool_size=2,
pool_stride=2,
pool_dilation=1,
stride=1,
padding=0,
dilation=1,
bias=True,
activation=None,
wide=False,
batchnorm=None,
weight_bits=None,
bias_bits=None,
quantize_activation=False,
groups=1,
eps=1e-05,
momentum=0.05,
):
assert not wide or activation is None
if pooling is not None:
if pool_stride is None:
pool_stride = pool_size
if isinstance(pool_size, int):
assert dev.device != 84 or pool_size & 1 == 0
assert pool_size <= 16 \
and (dev.device != 84 or pool_size <= 4 or pooling == 'Max')
elif isinstance(pool_size, tuple):
assert len(pool_size) == 2
assert dev.device != 84 or pool_size[0] & 1 == 0
assert pool_size[0] <= 16 \
and (dev.device != 84 or pool_size[0] <= 4 or pooling == 'Max')
assert dev.device != 84 or pool_size[1] & 1 == 0
assert pool_size[1] <= 16 \
and (dev.device != 84 or pool_size[1] <= 4 or pooling == 'Max')
else:
raise ValueError('pool_size must be int or tuple')
if isinstance(pool_stride, int):
assert pool_stride > 0
assert pool_stride <= 16 \
and (dev.device != 84 or pool_stride <= 4 or pooling == 'Max')
elif isinstance(pool_stride, tuple):
assert len(pool_stride) == 2
assert dev.device != 84 or pool_stride[0] == pool_stride[1]
assert 0 < pool_stride[0] <= 16 \
and (dev.device != 84 or pool_stride[0] <= 4 or pooling == 'Max')
assert 0 < pool_stride[1] <= 16 \
and (dev.device != 84 or pool_stride[1] <= 4 or pooling == 'Max')
assert pool_stride[0] == pool_stride[1]
else:
raise ValueError('pool_stride must be int or tuple')
if isinstance(pool_dilation, int):
assert pool_dilation > 0
assert pool_dilation <= 1 \
or dev.device == 87 and pool_dilation <= 16 and pooling == 'Max'
elif isinstance(pool_dilation, tuple):
assert len(pool_dilation) == 2
assert pool_dilation[0] > 0
assert pool_dilation[0] <= 1 \
or dev.device == 87 and pool_dilation[0] <= 16 and pooling == 'Max'
assert pool_dilation[1] > 0
assert pool_dilation[1] <= 1 \
or dev.device == 87 and pool_dilation[1] <= 16 and pooling == 'Max'
else:
raise ValueError('pool_dilation must be int or tuple')
if op == 'ConvTranspose2d':
assert stride == 2
else:
assert stride == 1
else:
if op == 'ConvTranspose2d':
assert stride == 2
else:
assert 0 < stride <= 3
assert 0 <= padding <= 2
assert dilation == 1
if pooling == 'Max':
pool = nn.MaxPool2d(kernel_size=pool_size, stride=pool_stride,
dilation=pool_dilation, padding=0)
elif pooling == 'Avg':
pool = nn.AvgPool2d(kernel_size=pool_size, stride=pool_stride, padding=0)
else:
pool = None
if batchnorm == 'Affine':
bn = nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum, affine=True)
assert bias, '`bias` must be set (enable --use-bias for models where bias is optional)'
elif batchnorm == 'NoAffine':
bn = nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum, affine=False)
assert bias, '`bias` must be set (enable --use-bias for models where bias is optional)'
else:
bn = None
if kernel_size is not None:
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
kernel_size = kernel_size[0]
assert kernel_size == 3 or dev.device != 84 and kernel_size == 1
assert groups == 1 or dev.device == 87, 'Set device to MAX78002 for depthwise support'
if op == 'Conv2d':
opn = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
elif op == 'ConvTranspose2d':
assert dev.device != 84
opn = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
output_padding=1, padding=padding,
dilation=dilation, bias=bias, groups=groups)
else:
raise ValueError('Unsupported operation')
else:
opn = None
super().__init__(
pooling,
activation,
wide,
weight_bits,
bias_bits,
quantize_activation,
pool,
opn,
bn,
)
class FusedMaxPoolConv2d(Conv2d):
"""
Fused 2D Max Pool, 2D Convolution and Activation ('ReLU', 'Abs', None)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, pooling='Max', **kwargs)
class FusedMaxPoolConv2dBN(FusedMaxPoolConv2d):
"""
Fused 2D Max Pool, 2D Convolution, BatchNorm and Activation ('ReLU', 'Abs', None)
"""
def __init__(self, *args, **kwargs):
if 'batchnorm' not in kwargs:
kwargs['batchnorm'] = 'Affine'
super().__init__(*args, **kwargs)