forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sparse_masklib.py
187 lines (162 loc) · 7.26 KB
/
sparse_masklib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import sys
import torch
import numpy as np
import collections
from itertools import permutations
""" compute density (helper fn to compute % NNZs in a tensor) """
def fill(x):
return float(x.nonzero().size(0))/torch.numel(x)
""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """
def reshape_1d(matrix, m):
# If not a nice multiple of m, fill with zeroes.
if matrix.shape[1] % m > 0:
mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)
mat[:, :matrix.shape[1]] = matrix
shape = mat.shape
return mat.view(-1,m),shape
else:
return matrix.view(-1,m), matrix.shape
""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_1d_patterns
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns
""" m:n 1d structured best """
def mn_1d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_1d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
mat,shape = reshape_1d(matrix,m)
pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)
"""
Below 2d-masking related code is targeted more for training (from scratch).
2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop
phase of training algorithm. Acceleration comes from using SpMMA instructions in
Tensor Cores of NVIDIA Ampere GPU Architecture
(note: this code does not do the acceleration, GPU kernels are required for this).
1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern
along the horizontal (logical) direction.
During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask
weight tensor such that their transposed versions are also 2:4 sparse along the
horizontal (logical) direction. Thus, with 2d pruning, weight tensors are
2:4 sparse along row and column directions.
"""
""" m:n 2d structured pruning: greedy method to select mask """
def mn_2d_greedy(matrix, m, n):
# Convert to numpy
mat = matrix.cpu().detach().numpy()
mask = np.ones(mat.shape, dtype=int)
rowCount = int(mat.shape[0]/m) * m
colCount = int(mat.shape[1]/m) * m
for rowStartIdx in range(0, rowCount, m):
rowEndIdx = rowStartIdx + m
for colStartIdx in range(0, colCount, m):
colEndIdx = colStartIdx + m
matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))
maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])
maskSub.fill(0.0)
matrixVecView = matrixSub.reshape(-1)
maskVecView = maskSub.reshape(-1)
linearIdx = np.argsort(matrixVecView)
matrixIdx = [(int(x/m), x % m) for x in linearIdx]
rowCounter = collections.Counter()
colCounter = collections.Counter()
for currIdx in range(len(linearIdx) - 1, -1, -1):
currMatrixEntry = matrixIdx[currIdx]
if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):
continue
#end if
maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0
rowCounter[currMatrixEntry[0]] += 1
colCounter[currMatrixEntry[1]] += 1
return torch.tensor(mask.cuda())
def m4n2_2d_greedy(mat, density):
return mn_2d_greedy(mat, 4, 2)
""" return all possible m:n patterns in a mxn block. """
valid_m4n2_2d_patterns = None
def compute_valid_2d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_2d_patterns
if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
patterns = list(set(permutations(patterns.tolist())))
patterns = patterns + patterns
patterns = torch.Tensor(list(set(permutations(patterns,m))))
valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)
valid_patterns = torch.Tensor(valid.shape[0],m,m)
valid_patterns[:] = patterns[valid[:]]
if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns
return valid_patterns
""" m:n 2d structured pruning: exhaustive method to select best mask """
def mn_2d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_2d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1)
mat = reshape_2d(matrix,m,m).abs()
pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2)
# Copy best m:n patterns into mask.
mat = mat.view(mat.shape[0]*mat.shape[1],-1)
pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1])
patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2])
mat = torch.gather(patterns,0,pmax)
mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m))
mask.copy_(mat.type(mask.type()))
return mask
def m4n2_2d_best(mat, density):
return mn_2d_best(mat, 4, 2)
""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):
# Reshape tensor and mask.
shape = tensor.shape
ttype = tensor.type()
t = tensor.float().contiguous()
# 1d-tensor
if len(shape) == 1:
t = t.view(1, shape[0])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 2d-tensor (K, C)
elif len(shape) == 2:
# linear
t = t.view(shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 3d-tensor (K, C, R)
elif len(shape) == 3:
# 1d convs
t = t.permute(0,2,1).contiguous().view(shape[0]*shape[2], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[0], shape[2], shape[1]).permute(0,2,1).contiguous()
return mask.view(shape).type(ttype)
# 4d-tensor (K, C, R, S)
elif len(shape) == 4:
"""
# transformers (bmm)
t = t.view(shape[0]*shape[1]*shape[2], shape[3])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
"""
# 2d convs
t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()
return mask.view(shape).type(ttype)