Skip to content

Commit b2f7fb5

Browse files
author
shh
committed
add config to simulator
1 parent 7dbdc66 commit b2f7fb5

File tree

5 files changed

+977
-82
lines changed

5 files changed

+977
-82
lines changed

Hardware/Simulator/README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ which correspondingly generates a numpy file (e.g., ./masks/deit_tiny_lowrank/gl
2121

2222
To simulate the latency of attention computation, run
2323
````
24-
python ViTCoD.py
24+
python ViTCoD.py \
25+
--root 'masks/deit_tiny_lowrank' \
26+
--sparse 0.95 \
27+
--feature_dim 64 \
28+
--ratio 0.667 \
29+
--PE_width 64 \
30+
--PE_height 8
2531
````
2632
where we adopt a ***dynamic*** *PE allocation* between the ***denser*** and ***sparser engines*** to balance the workload of processing the denser and sparser patterns of different attention head, and leverage the on-chip ***decoder*** to reconstruct Q and K that are compressed by the on-chip ***encoder*** for saving data access costs.
2733

@@ -31,7 +37,14 @@ where we adopt a ***dynamic*** *PE allocation* between the ***denser*** and ***s
3137
To simulate the end-to-end latency,
3238
* first run
3339
````
34-
python ViT_FFN.py
40+
python ViT_FFN.py \
41+
--root 'masks/deit_tiny_lowrank' \
42+
--sparse 0.95 \
43+
--feature_dim 64 \
44+
--embedding 192 \
45+
--ratio 0.667 \
46+
--PE_width 64 \
47+
--PE_height 8
3548
````
3649
which simulates the latency consumed by the remaining ***linear projections*** and ***MLPs***,
3750
* then add the simulated latency with the previously simulated attention latency.

Hardware/Simulator/ViTCoD.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,29 @@
77
import logging
88
import os
99
import math
10+
import argparse
1011

11-
root = 'masks/deit_tiny_lowrank'
12-
# root = 'masks/reorder/deit/reorder_att/deit_base'
13-
sparse = [0.95]
12+
def get_args_parser():
13+
parser = argparse.ArgumentParser('ViTCoD attn similation script', add_help=False)
14+
parser.add_argument('--root', default='masks/deit_tiny_lowrank', type=str)
15+
parser.add_argument('--sparse', type=float, default=[0.95], nargs='+', help='the sparsity of the model')
16+
parser.add_argument('--feature_dim', default=64, type=int, help='the feature dimension of Q/K/V')
17+
parser.add_argument('--ratio', default=2/3, type=float, help='the compression ratio of encoder/decoder')
18+
parser.add_argument('--PE_width', default=64, type=int)
19+
parser.add_argument('--PE_height', default=8, type=int)
20+
return parser
1421

15-
for p in sparse:
22+
parser = argparse.ArgumentParser('ViTCoD attn similation script', parents=[get_args_parser()])
23+
args = parser.parse_args()
24+
25+
# args.root = 'masks/deit_tiny_lowrank'
26+
# args.root = 'masks/reorder/deit/reorder_att/deit_base'
27+
# args.sparse = [0.95]
28+
29+
for p in args.sparse:
1630
# Logging
1731
log = logging.getLogger()
18-
log_path = os.path.join(root, 'vitcod_atten_'+str(p)+'_wo.txt')
32+
log_path = os.path.join(args.root, 'vitcod_atten_'+str(p)+'_wo.txt')
1933
handlers = [logging.FileHandler(log_path, mode='a+'),
2034
logging.StreamHandler()]
2135
logging.basicConfig(
@@ -24,13 +38,13 @@
2438
level=logging.INFO,
2539
handlers=handlers)
2640
# Initialize Q, K, V and attn maps
27-
attn_map_mask = np.load(root+'/reodered_info_'+str(p)+'.npy')
28-
num_global_tokens = np.load(root+'/global_token_info_'+str(p)+'.npy')
41+
attn_map_mask = np.load(args.root+'/reodered_info_'+str(p)+'.npy')
42+
num_global_tokens = np.load(args.root+'/global_token_info_'+str(p)+'.npy')
2943

3044
# dim of (layer, head, token, features)
31-
all_Q = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64))
32-
all_K = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64))
33-
all_V = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], 64))
45+
all_Q = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim))
46+
all_K = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim))
47+
all_V = np.random.random((attn_map_mask.shape[0], attn_map_mask.shape[1], attn_map_mask.shape[2], args.feature_dim))
3448
log.info('Shape: {}'.format(all_V.shape))
3549
my_SRAM = SRAM()
3650
my_PE = PE_array()
@@ -43,9 +57,9 @@
4357

4458
head = all_Q.shape[1]
4559
# the compression ratio of head via the encoder
46-
ratio = 2/3
47-
PE_width = 64
48-
PE_height = 8
60+
# args.ratio = 2/3
61+
# PE_width = 64
62+
# PE_height = 8
4963

5064
total_sparse_ratio = 0
5165
for _layer in range(all_Q.shape[0]):
@@ -67,6 +81,9 @@
6781
sparse_ratio = 0
6882
else:
6983
sparse_ratio = len(sparser)/(mask[:, global_tokens:].shape[0]*mask[:, global_tokens:].shape[1])
84+
# print("sparse_ratio:", sparse_ratio)
85+
# print("reload_ratio:", len(sparser)/(mask[:, global_tokens:].shape[1]))
86+
# break
7087
total_sparse_ratio += sparse_ratio
7188
# log.info('number of non-zeros in the sparser region: {}'.format(len(sparser)))
7289

@@ -78,23 +95,23 @@
7895
for _sta_k in range(global_tokens):
7996
# ############ k #########
8097
# ######### Load k and decoder weight
81-
preload_cycles += my_SRAM.preload_K(nums=head*ratio*1* K.shape[1], bits=8, bandwidth_ratio=1)
98+
preload_cycles += my_SRAM.preload_K(nums=head*args.ratio*1* K.shape[1], bits=8, bandwidth_ratio=1)
8299
if _sta_k == 0:
83-
preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head)
100+
preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head)
84101
# ######### Preprocessing
85-
for k in range((math.ceil((head*ratio*1* K.shape[1])/int(PE_width*PE_height/head)))):
102+
for k in range((math.ceil((head*args.ratio*1* K.shape[1])/int(args.PE_width*args.PE_width/head)))):
86103
PRE_cycles += 1
87104
for _sta_q in range(int(Q.shape[0])):
88105
if _sta_k == 0:
89106
# ############ q #########
90107
# ######### Load q and decoder weight
91108
# reload_ratio = (Q.shape[0]-(my_SRAM.max_Q/(8*Q.shape[1]*head)))/Q.shape[0]
92109
reload_ratio = 0
93-
preload_cycles += my_SRAM.preload_Q(nums=head*ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*(1+reload_ratio)
110+
preload_cycles += my_SRAM.preload_Q(nums=head*args.ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*(1+reload_ratio)
94111
if _sta_q == 0:
95-
preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head)
112+
preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head)
96113
# ######### Preprocessing
97-
for q in range(math.ceil((head*ratio*1* Q.shape[1])/int(PE_width*PE_height/head))):
114+
for q in range(math.ceil((head*args.ratio*1* Q.shape[1])/int(args.PE_width*args.PE_width/head))):
98115
PRE_cycles += 1*(1+reload_ratio)
99116

100117
total_PRE_cycles += PRE_cycles
@@ -109,11 +126,11 @@
109126
# ############ k #########
110127
# ######### Load K and decoder weights
111128
for i in range(K.shape[0]-global_tokens):
112-
preload_cycles += my_SRAM.preload_K(nums=head*ratio*1* K.shape[1], bits=8, bandwidth_ratio=1)
129+
preload_cycles += my_SRAM.preload_K(nums=head*args.ratio*1* K.shape[1], bits=8, bandwidth_ratio=1)
113130
if i == 0:
114-
preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head)
131+
preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head)
115132
# ######### Preprocessing
116-
for k in range(math.ceil((head*ratio*1* K.shape[1])/int(PE_width*PE_height/head))):
133+
for k in range(math.ceil((head*args.ratio*1* K.shape[1])/int(args.PE_width*args.PE_width/head))):
117134
PRE_cycles += 1
118135

119136
# ############ Q #########
@@ -123,11 +140,11 @@
123140
if global_tokens==0:
124141
reload_ratio = len(sparser)/mask[:, global_tokens:].shape[1]
125142
for i in range(Q.shape[0]):
126-
preload_cycles += my_SRAM.preload_Q(nums=head*ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*reload_ratio
143+
preload_cycles += my_SRAM.preload_Q(nums=head*args.ratio*1* Q.shape[1], bits=8, bandwidth_ratio=1)*reload_ratio
127144
if i == 0:
128-
preload_cycles += my_SRAM.preload_decoder(nums=head*ratio*1, bits=8, bandwidth_ratio=1/head)
145+
preload_cycles += my_SRAM.preload_decoder(nums=head*args.ratio*1, bits=8, bandwidth_ratio=1/head)
129146
# ######### Preprocessing
130-
for k in range(math.ceil((head*ratio*1* Q.shape[1])/int(PE_width*PE_height/head))):
147+
for k in range(math.ceil((head*args.ratio*1* Q.shape[1])/int(args.PE_width*args.PE_width/head))):
131148
PRE_cycles += 1*reload_ratio
132149
total_PRE_cycles += PRE_cycles
133150
total_preload_cycles += preload_cycles
@@ -140,20 +157,20 @@
140157
# DATA_cycles = 0
141158
# TODO:
142159
dense_ratio = global_tokens*Q.shape[0]/(len(sparser) + global_tokens*Q.shape[0])
143-
dense_PE_width = int(PE_width*dense_ratio)
144-
sparse_PE_width = PE_width - dense_PE_width
160+
dense_PE_width = int(args.PE_width*dense_ratio)
161+
sparse_PE_width = args.PE_width - dense_PE_width
145162
# ############## dense pattern q*k ##############
146163
dense_SDDMM_PE_cycles = 0
147164
for _sta_k in range(global_tokens):
148165
for _sta_q in range(math.ceil(Q.shape[0]/dense_PE_width)):
149-
for _tile_q in range(math.ceil(Q.shape[1] / (PE_height/head))):
166+
for _tile_q in range(math.ceil(Q.shape[1] / (args.PE_width/head))):
150167
dense_SDDMM_PE_cycles += 1
151168
log.info('Dense SDMM PE caclulation | cycles: {}'.format(dense_SDDMM_PE_cycles))
152169
# ############## simoutalous sparse pattern q*k ##############
153170
sparse_SDDMM_PE_cycles = 0
154-
# for _sta_k in range(math.ceil(len(sparser)*Q.shape[1]/int(sparse_PE_width*PE_height/head))):
171+
# for _sta_k in range(math.ceil(len(sparser)*Q.shape[1]/int(sparse_PE_width*args.PE_width/head))):
155172
for _sta_k in range(math.ceil(len(sparser)/(sparse_PE_width))):
156-
for _tile_q in range(math.ceil(Q.shape[1] / (PE_height/head))):
173+
for _tile_q in range(math.ceil(Q.shape[1] / (args.PE_width/head))):
157174
sparse_SDDMM_PE_cycles += 1
158175
log.info('Sparse SDMM PE caclulation | cycles: {}'.format(sparse_SDDMM_PE_cycles))
159176
SDDMM_PE_cycles = max(dense_SDDMM_PE_cycles, sparse_SDDMM_PE_cycles)
@@ -170,7 +187,7 @@
170187
log.info('Dense SpMM dataloader | cycles: {}'.format(preload_cycles))
171188
# ############## dense pattern s*v ##############
172189
dense_SpMM_PE_cycles = 0
173-
for _tile_attn in range(math.ceil((V.shape[0]*V.shape[1]*global_tokens) / int(dense_PE_width*PE_height/head))):
190+
for _tile_attn in range(math.ceil((V.shape[0]*V.shape[1]*global_tokens) / int(dense_PE_width*args.PE_width/head))):
174191
dense_SpMM_PE_cycles += 1
175192
# total_SpMM_PE_cycles += SpMM_PE_cycles
176193
log.info('Dense SpMM PE caclulation | cycles: {}'.format(dense_SpMM_PE_cycles))
@@ -197,28 +214,29 @@
197214
sparse_SpMM_PE_cycles = 0
198215
preload_cycles = 0
199216
for _tile_k in range(attn_map.shape[0]-global_tokens):
200-
preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8)*(1+0.5)
217+
# preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8)*(1+0.5)
218+
preload_cycles += my_SRAM.preload_V(nums=head*1* V.shape[1], bits=8)
201219
total_preload_cycles += preload_cycles
202220
log.info('Sparse SpMM dataloader | cycles: {}'.format(preload_cycles))
203221
# ############## sparse pattern s*v ##############
204222
SpMM_PE_cycles = 0
205223
for row_num in num_list:
206224
sparse_SpMM_PE_cycles += row_num*V.shape[1]
207-
sparse_SpMM_PE_cycles = math.ceil(sparse_SpMM_PE_cycles/int(sparse_PE_width*PE_height/head))
225+
sparse_SpMM_PE_cycles = math.ceil(sparse_SpMM_PE_cycles/int(sparse_PE_width*args.PE_width/head))
208226
log.info('Sparse SpMM PE caclulation | cycles: {}'.format(sparse_SpMM_PE_cycles))
209227
SpMM_PE_cycles = max(sparse_SpMM_PE_cycles, dense_SpMM_PE_cycles)
210228
total_SpMM_PE_cycles += SpMM_PE_cycles
211229

212-
# for row_num in range(int(len(num_list)/PE_height)):
213-
# # for _tile_attn in range(int(row_num / PE_height)):
214-
# for _tile_v in range(int(V.shape[1] / PE_width)): # do not need to plus one if 64 / 64 == 0
230+
# for row_num in range(int(len(num_list)/args.PE_width)):
231+
# # for _tile_attn in range(int(row_num / args.PE_width)):
232+
# for _tile_v in range(int(V.shape[1] / args.PE_width)): # do not need to plus one if 64 / 64 == 0
215233
# for _tile_k in range(num_list[row_num]):
216234
# SpMM_PE_cycles += 1
217235

218236
# ########### linear transformation
219237
# Linear_PE_cycles = 0
220-
# for _tile_attn in range(int(attn_map.shape[0] / PE_height)):
221-
# for _tile_v in range(int(V.shape[1] / PE_width)):
238+
# for _tile_attn in range(int(attn_map.shape[0] / args.PE_width)):
239+
# for _tile_v in range(int(V.shape[1] / args.PE_width)):
222240
# for _tile_k in range(V.shape[0]):
223241
# Linear_PE_cycles += 1
224242
# print('Linear PE caclulation | cycles: {}'.format(Linear_PE_cycles))
@@ -231,11 +249,11 @@
231249
# K = all_K[_layer, _head]
232250
# V = all_V[_layer, _head]
233251
# for h in range(6):
234-
# for q in range(int((Q.shape[0]*Q.shape[1])/(PE_width*PE_width))):
252+
# for q in range(int((Q.shape[0]*Q.shape[1])/(args.PE_width*args.PE_width))):
235253
# for h in range(12):
236254
# PRE_cycles += 1
237255

238-
# for k in range(int((K.shape[0]*K.shape[1])/(PE_width*PE_width))):
256+
# for k in range(int((K.shape[0]*K.shape[1])/(args.PE_width*args.PE_width))):
239257
# for h in range(12):
240258
# PRE_cycles += 1
241259

0 commit comments

Comments
 (0)