-
Notifications
You must be signed in to change notification settings - Fork 1
/
FLOPS.py
23 lines (22 loc) · 1.16 KB
/
FLOPS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import nets, utils
import torch
from thop import profile,clever_format
#%%
model = nets.VisionTransformer(
img_size = utils.input_image_size,
patch_size = utils.patch_size,
in_chans = utils.in_chanel,
n_classes = utils.num_classes,
embed_dim = utils.embed_dim,
depth = utils.depth,
n_heads = utils.n_heads,
mlp_ratio = 4.0,
qkv_bias = True,
p = utils.p,
attn_p = utils.attention_p
).to(utils.device)
#%%
input = torch.randn(1, utils.in_chanel, utils.input_image_size, utils.input_image_size).to(utils.device)
macs, params = profile(model, inputs=(input, ))
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)