|
7 | 7 | import logging
|
8 | 8 | import os
|
9 | 9 | import math
|
| 10 | +import argparse |
10 | 11 |
|
11 |
| -root = 'masks/deit_tiny_lowrank' |
12 |
| -# root = 'masks/reorder/deit/reorder_att/deit_base' |
13 |
| -sparse = [0.95] |
| 12 | +def get_args_parser(): |
| 13 | + parser = argparse.ArgumentParser('ViTCoD attn similation script', add_help=False) |
| 14 | + parser.add_argument('--root', default='masks/deit_tiny_lowrank', type=str) |
| 15 | + parser.add_argument('--sparse', type=float, default=[0.95], nargs='+', help='the sparsity of the model') |
| 16 | + parser.add_argument('--feature_dim', default=64, type=int, help='the feature dimension of Q/K/V') |
| 17 | + parser.add_argument('--ratio', default=2/3, type=float, help='the compression ratio of encoder/decoder') |
| 18 | + parser.add_argument('--PE_width', default=64, type=int) |
| 19 | + parser.add_argument('--PE_height', default=8, type=int) |
| 20 | + return parser |
14 | 21 |
|
15 |
| -for p in sparse: |
| 22 | +parser = argparse.ArgumentParser('ViTCoD attn similation script', parents=[get_args_parser()]) |
| 23 | +args = parser.parse_args() |
| 24 | + |
| 25 | +# args.root = 'masks/deit_tiny_lowrank' |
| 26 | +# args.root = 'masks/reorder/deit/reorder_att/deit_base' |
| 27 | +# args.sparse = [0.95] |
| 28 | + |
| 29 | +for p in args.sparse: |
16 | 30 | # Logging
|
17 | 31 | log = logging.getLogger()
|
18 |
| - log_path = os.path.join(root, 'vitcod_atten_'+str(p)+'_wo.txt') |
| 32 | + log_path = os.path.join(args.root, 'vitcod_atten_'+str(p)+'_wo.txt') |
19 | 33 | handlers = [logging.FileHandler(log_path, mode='a+'),
|
20 | 34 | logging.StreamHandler()]
|
21 | 35 | logging.basicConfig(
|
|
24 | 38 | level=logging.INFO,
|
25 | 39 | handlers=handlers)
|
26 | 40 | # Initialize Q, K, V and attn maps
|
27 |
| - attn_map_mask = np.load(root+'/reodered_info_'+str(p)+'.npy') |
28 |
| - num_global_tokens = np.load(root+'/global_token_info_'+str(p)+'.npy') |
| 41 | + attn_map_mask = np.load(args.root+'/reodered_info_'+str(p)+'.npy') |
| 42 | + num_global_tokens = np.load(args.root+'/global_token_info_'+str(p)+'.npy') |
29 | 43 |
|
30 | 44 | # dim of (layer, head, token, features)
|
31 |
| - all_Q = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64)) |
32 |
| - all_K = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64)) |
33 |
| - all_V = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64)) |
| 45 | + all_Q = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim)) |
| 46 | + all_K = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim)) |
| 47 | + all_V = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim)) |
34 | 48 | log.info('Shape: {}'.format(all_V.shape))
|
35 | 49 | my_SRAM = SRAM()
|
36 | 50 | my_PE = PE_array()
|
|
43 | 57 |
|
44 | 58 | head = all_Q.shape[1]
|
45 | 59 | # the compression ratio of head via the encoder
|
46 |
| - ratio = 2/3 |
47 |
| - PE_width = 64 |
48 |
| - PE_height = 8 |
| 60 | + # args.ratio = 2/3 |
| 61 | + # PE_width = 64 |
| 62 | + # PE_height = 8 |
49 | 63 |
|
50 | 64 | total_sparse_ratio = 0
|
51 | 65 | for _layer in range(all_Q.shape[0]):
|
|
67 | 81 | sparse_ratio = 0
|
68 | 82 | else:
|
69 | 83 | sparse_ratio = len(sparser)/(mask[:, global_tokens:].shape[0]*mask[:, global_tokens:].shape[1])
|
| 84 | + # print("sparse_ratio:", sparse_ratio) |
| 85 | + # print("reload_ratio:", len(sparser)/(mask[:, global_tokens:].shape[1])) |
| 86 | + # break |
70 | 87 | total_sparse_ratio += sparse_ratio
|
71 | 88 | # log.info('number of non-zeros in the sparser region: {}'.format(len(sparser)))
|
72 | 89 |
|
|
78 | 95 | for _sta_k in range(global_tokens):
|
79 | 96 | # ############ k #########
|
80 | 97 | # ######### Load k and decoder weight
|
81 |
| - preload_cycles += my_SRAM.preload_K(nums=head*ratio*1* K.shape[1], bits=8, bandwidth_ratio=1) |
| 98 | + preload_cycles += my_SRAM.preload_K(nums=head*args.ratio*1* K.shape[1], bits=8, bandwidth_ratio=1) |
82 | 99 | if _sta_k == 0:
|
83 |
| - preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head) |
| 100 | + preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head) |
84 | 101 | # ######### Preprocessing
|
85 |
| - for k in range((math.ceil((head*ratio*1* K.shape[1])/int(PE_width*PE_height/head)))): |
| 102 | + for k in range((math.ceil((head*args.ratio*1* K.shape[1])/int(args.PE_width*args.PE_width/head)))): |
86 | 103 | PRE_cycles += 1
|
87 | 104 | for _sta_q in range(int(Q.shape[0])):
|
88 | 105 | if _sta_k == 0:
|
89 | 106 | # ############ q #########
|
90 | 107 | # ######### Load q and decoder weight
|
91 | 108 | # reload_ratio = (Q.shape[0]-(my_SRAM.max_Q/(8*Q.shape[1]*head)))/Q.shape[0]
|
92 | 109 | reload_ratio = 0
|
93 |
| - preload_cycles += my_SRAM.preload_Q(nums=head*ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*(1+reload_ratio) |
| 110 | + preload_cycles += my_SRAM.preload_Q(nums=head*args.ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*(1+reload_ratio) |
94 | 111 | if _sta_q == 0:
|
95 |
| - preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head) |
| 112 | + preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head) |
96 | 113 | # ######### Preprocessing
|
97 |
| - for q in range(math.ceil((head*ratio*1* Q.shape[1])/int(PE_width*PE_height/head))): |
| 114 | + for q in range(math.ceil((head*args.ratio*1* Q.shape[1])/int(args.PE_width*args.PE_width/head))): |
98 | 115 | PRE_cycles += 1*(1+reload_ratio)
|
99 | 116 |
|
100 | 117 | total_PRE_cycles += PRE_cycles
|
|
109 | 126 | # ############ k #########
|
110 | 127 | # ######### Load K and decoder weights
|
111 | 128 | for i in range(K.shape[0]-global_tokens):
|
112 |
| - preload_cycles += my_SRAM.preload_K(nums=head*ratio*1* K.shape[1], bits=8, bandwidth_ratio=1) |
| 129 | + preload_cycles += my_SRAM.preload_K(nums=head*args.ratio*1* K.shape[1], bits=8, bandwidth_ratio=1) |
113 | 130 | if i == 0:
|
114 |
| - preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head) |
| 131 | + preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head) |
115 | 132 | # ######### Preprocessing
|
116 |
| - for k in range(math.ceil((head*ratio*1* K.shape[1])/int(PE_width*PE_height/head))): |
| 133 | + for k in range(math.ceil((head*args.ratio*1* K.shape[1])/int(args.PE_width*args.PE_width/head))): |
117 | 134 | PRE_cycles += 1
|
118 | 135 |
|
119 | 136 | # ############ Q #########
|
|
123 | 140 | if global_tokens==0:
|
124 | 141 | reload_ratio = len(sparser)/mask[:, global_tokens:].shape[1]
|
125 | 142 | for i in range(Q.shape[0]):
|
126 |
| - preload_cycles += my_SRAM.preload_Q(nums=head*ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*reload_ratio |
| 143 | + preload_cycles += my_SRAM.preload_Q(nums=head*args.ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*reload_ratio |
127 | 144 | if i == 0:
|
128 |
| - preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head) |
| 145 | + preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head) |
129 | 146 | # ######### Preprocessing
|
130 |
| - for k in range(math.ceil((head*ratio*1* Q.shape[1])/int(PE_width*PE_height/head))): |
| 147 | + for k in range(math.ceil((head*args.ratio*1* Q.shape[1])/int(args.PE_width*args.PE_width/head))): |
131 | 148 | PRE_cycles += 1*reload_ratio
|
132 | 149 | total_PRE_cycles += PRE_cycles
|
133 | 150 | total_preload_cycles += preload_cycles
|
|
140 | 157 | # DATA_cycles = 0
|
141 | 158 | # TODO:
|
142 | 159 | dense_ratio = global_tokens*Q.shape[0]/(len(sparser) + global_tokens*Q.shape[0])
|
143 |
| - dense_PE_width = int(PE_width*dense_ratio) |
144 |
| - sparse_PE_width = PE_width - dense_PE_width |
| 160 | + dense_PE_width = int(args.PE_width*dense_ratio) |
| 161 | + sparse_PE_width = args.PE_width - dense_PE_width |
145 | 162 | # ############## dense pattern q*k ##############
|
146 | 163 | dense_SDDMM_PE_cycles = 0
|
147 | 164 | for _sta_k in range(global_tokens):
|
148 | 165 | for _sta_q in range(math.ceil(Q.shape[0]/dense_PE_width)):
|
149 |
| - for _tile_q in range(math.ceil(Q.shape[1] / (PE_height/head))): |
| 166 | + for _tile_q in range(math.ceil(Q.shape[1] / (args.PE_width/head))): |
150 | 167 | dense_SDDMM_PE_cycles += 1
|
151 | 168 | log.info('Dense SDMM PE caclulation | cycles: {}'.format(dense_SDDMM_PE_cycles))
|
152 | 169 | # ############## simoutalous sparse pattern q*k ##############
|
153 | 170 | sparse_SDDMM_PE_cycles = 0
|
154 |
| - # for _sta_k in range(math.ceil(len(sparser)*Q.shape[1]/int(sparse_PE_width*PE_height/head))): |
| 171 | + # for _sta_k in range(math.ceil(len(sparser)*Q.shape[1]/int(sparse_PE_width*args.PE_width/head))): |
155 | 172 | for _sta_k in range(math.ceil(len(sparser)/(sparse_PE_width))):
|
156 |
| - for _tile_q in range(math.ceil(Q.shape[1] / (PE_height/head))): |
| 173 | + for _tile_q in range(math.ceil(Q.shape[1] / (args.PE_width/head))): |
157 | 174 | sparse_SDDMM_PE_cycles += 1
|
158 | 175 | log.info('Sparse SDMM PE caclulation | cycles: {}'.format(sparse_SDDMM_PE_cycles))
|
159 | 176 | SDDMM_PE_cycles = max(dense_SDDMM_PE_cycles, sparse_SDDMM_PE_cycles)
|
|
170 | 187 | log.info('Dense SpMM dataloader | cycles: {}'.format(preload_cycles))
|
171 | 188 | # ############## dense pattern s*v ##############
|
172 | 189 | dense_SpMM_PE_cycles = 0
|
173 |
| - for _tile_attn in range(math.ceil((V.shape[0]*V.shape[1]*global_tokens) / int(dense_PE_width*PE_height/head))): |
| 190 | + for _tile_attn in range(math.ceil((V.shape[0]*V.shape[1]*global_tokens) / int(dense_PE_width*args.PE_width/head))): |
174 | 191 | dense_SpMM_PE_cycles += 1
|
175 | 192 | # total_SpMM_PE_cycles += SpMM_PE_cycles
|
176 | 193 | log.info('Dense SpMM PE caclulation | cycles: {}'.format(dense_SpMM_PE_cycles))
|
|
197 | 214 | sparse_SpMM_PE_cycles = 0
|
198 | 215 | preload_cycles = 0
|
199 | 216 | for _tile_k in range(attn_map.shape[0]-global_tokens):
|
200 |
| - preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8)*(1+0.5) |
| 217 | + # preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8)*(1+0.5) |
| 218 | + preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8) |
201 | 219 | total_preload_cycles += preload_cycles
|
202 | 220 | log.info('Sparse SpMM dataloader | cycles: {}'.format(preload_cycles))
|
203 | 221 | # ############## sparse pattern s*v ##############
|
204 | 222 | SpMM_PE_cycles = 0
|
205 | 223 | for row_num in num_list:
|
206 | 224 | sparse_SpMM_PE_cycles += row_num*V.shape[1]
|
207 |
| - sparse_SpMM_PE_cycles = math.ceil(sparse_SpMM_PE_cycles/int(sparse_PE_width*PE_height/head)) |
| 225 | + sparse_SpMM_PE_cycles = math.ceil(sparse_SpMM_PE_cycles/int(sparse_PE_width*args.PE_width/head)) |
208 | 226 | log.info('Sparse SpMM PE caclulation | cycles: {}'.format(sparse_SpMM_PE_cycles))
|
209 | 227 | SpMM_PE_cycles = max(sparse_SpMM_PE_cycles, dense_SpMM_PE_cycles)
|
210 | 228 | total_SpMM_PE_cycles += SpMM_PE_cycles
|
211 | 229 |
|
212 |
| - # for row_num in range(int(len(num_list)/PE_height)): |
213 |
| - # # for _tile_attn in range(int(row_num / PE_height)): |
214 |
| - # for _tile_v in range(int(V.shape[1] / PE_width)): # do not need to plus one if 64 / 64 == 0 |
| 230 | + # for row_num in range(int(len(num_list)/args.PE_width)): |
| 231 | + # # for _tile_attn in range(int(row_num / args.PE_width)): |
| 232 | + # for _tile_v in range(int(V.shape[1] / args.PE_width)): # do not need to plus one if 64 / 64 == 0 |
215 | 233 | # for _tile_k in range(num_list[row_num]):
|
216 | 234 | # SpMM_PE_cycles += 1
|
217 | 235 |
|
218 | 236 | # ########### linear transformation
|
219 | 237 | # Linear_PE_cycles = 0
|
220 |
| - # for _tile_attn in range(int(attn_map.shape[0] / PE_height)): |
221 |
| - # for _tile_v in range(int(V.shape[1] / PE_width)): |
| 238 | + # for _tile_attn in range(int(attn_map.shape[0] / args.PE_width)): |
| 239 | + # for _tile_v in range(int(V.shape[1] / args.PE_width)): |
222 | 240 | # for _tile_k in range(V.shape[0]):
|
223 | 241 | # Linear_PE_cycles += 1
|
224 | 242 | # print('Linear PE caclulation | cycles: {}'.format(Linear_PE_cycles))
|
|
231 | 249 | # K = all_K[_layer, _head]
|
232 | 250 | # V = all_V[_layer, _head]
|
233 | 251 | # for h in range(6):
|
234 |
| - # for q in range(int((Q.shape[0]*Q.shape[1])/(PE_width*PE_width))): |
| 252 | + # for q in range(int((Q.shape[0]*Q.shape[1])/(args.PE_width*args.PE_width))): |
235 | 253 | # for h in range(12):
|
236 | 254 | # PRE_cycles += 1
|
237 | 255 |
|
238 |
| - # for k in range(int((K.shape[0]*K.shape[1])/(PE_width*PE_width))): |
| 256 | + # for k in range(int((K.shape[0]*K.shape[1])/(args.PE_width*args.PE_width))): |
239 | 257 | # for h in range(12):
|
240 | 258 | # PRE_cycles += 1
|
241 | 259 |
|
|
0 commit comments