-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathViT.yaml
More file actions
127 lines (108 loc) · 2.06 KB
/
ViT.yaml
File metadata and controls
127 lines (108 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
base: &base
# Model config
embed_dim: 384
depth: 12
dropout: 0.0
patch_size: 8
num_heads: 8
# Training config
img_size: [360, 720]
dt: 1
global_batch_size: 16 # number of samples per training batch
num_iters: 30000
amp_mode: none
enable_jit: false
expdir: '/logs'
lr_schedule: 'cosine'
lr: 5E-4
warmup: 0
optimizer: 'Adam'
# Data
num_data_workers: 0 # number of dataloader worker threads per proc
n_in_channels: 20
n_out_channels: 20
train_data_path: '/data/train'
valid_data_path: '/data/valid'
inf_data_path: '/data/test'
time_means_path: '/data/stats/time_means.npy'
global_means_path: '/data/stats/global_means.npy'
global_stds_path: '/data/stats/global_stds.npy'
limit_nsamples: None
limit_nsamples_val: None
# Comms
wireup_info: env
wireup_store: tcp
# limit the number of samples
short: &short_ls
<<: *base
limit_nsamples: 512
limit_nsamples_val: 128
num_iters: 128
# add optimization flags
short_opt:
<<: *short_ls
num_data_workers: 8
amp_mode: fp16
enable_jit: true
# no samples limits
opt: &opt
<<: *base
num_data_workers: 8
amp_mode: fp16
num_iters: 30000
enable_jit: true
# ----- Data parallel scaling configs
bs16_opt:
<<: *opt
global_batch_size: 16
lr: 5e-4
bs32_opt:
<<: *opt
global_batch_size: 32
lr: 7.07e-4
bs64_opt:
<<: *opt
global_batch_size: 64
lr: 1e-3
bs128_opt:
<<: *opt
global_batch_size: 128
lr: 1.41e-3
bs256_opt:
<<: *opt
global_batch_size: 256
lr: 2e-3
bs512_opt:
<<: *opt
global_batch_size: 512
lr: 2.83e-3
bs1024_opt:
<<: *opt
global_batch_size: 1024
lr: 4e-3
bs2048_opt:
<<: *opt
global_batch_size: 2048
lr: 5.66e-3
# Model parallel configs
mp: &mp
<<: *base
num_iters: 30000
global_batch_size: 64
lr: 1e-3
num_data_workers: 8
embed_dim: 1536 # change to bigger model
amp_mode: fp16
enable_jit: true
mp_bs16:
<<: *mp
global_batch_size: 16
lr: 5e-4
mp_bs32:
<<: *mp
global_batch_size: 32
lr: 7.07e-4
# larger seq length (use local bs = 1 here)
mp_patch4:
<<: *mp
patch_size: 4