@@ -56,7 +56,7 @@ create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true
56
56
57
57
function get_loss_function (args)
58
58
if args[" model-type" ] == " VANILLA"
59
- function loss_function_closure_vanilla (x, y, model, ps, st)
59
+ function loss_function_closure_vanilla (x, y, model, ps, st, w_skip = args[ " w-skip " ] )
60
60
(ŷ, soln), st_ = model (x, ps, st)
61
61
celoss = logitcrossentropy (ŷ, y)
62
62
skiploss = FastDEQExperiments. mae (soln. u₀, soln. z_star)
@@ -65,11 +65,11 @@ function get_loss_function(args)
65
65
end
66
66
return loss_function_closure_vanilla
67
67
else
68
- function loss_function_closure_skip (x, y, model, ps, st)
68
+ function loss_function_closure_skip (x, y, model, ps, st, w_skip = args[ " w-skip " ] )
69
69
(ŷ, soln), st_ = model (x, ps, st)
70
70
celoss = logitcrossentropy (ŷ, y)
71
71
skiploss = FastDEQExperiments. mae (soln. u₀, soln. z_star)
72
- loss = celoss + args[ " w-skip " ] * skiploss
72
+ loss = celoss + w_skip * skiploss
73
73
return loss, st_, (ŷ, soln. nfe, celoss, skiploss, soln. residual)
74
74
end
75
75
return loss_function_closure_skip
@@ -185,7 +185,7 @@ function validate(val_loader, model, ps, st, loss_function, args)
185
185
end
186
186
187
187
# Training
188
- function train_one_epoch (train_loader, model, ps, st, optimiser_state, epoch, loss_function, args)
188
+ function train_one_epoch (train_loader, model, ps, st, optimiser_state, epoch, loss_function, w_skip, args)
189
189
batch_time = AverageMeter (" Batch Time" , " 6.3f" )
190
190
data_time = AverageMeter (" Data Time" , " 6.3f" )
191
191
forward_pass_time = AverageMeter (" Forward Pass Time" , " 6.3f" )
@@ -212,7 +212,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo
212
212
# Gradients and Update
213
213
_t = time ()
214
214
(loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote. pullback (
215
- p -> loss_function (x, y, model, p, st), ps
215
+ p -> loss_function (x, y, model, p, st, w_skip ), ps
216
216
)
217
217
forward_pass_time (time () - _t, B)
218
218
_t = time ()
@@ -353,10 +353,12 @@ function main(args)
353
353
354
354
st = hasproperty (expt_config, :pretrain_epochs ) && getproperty (expt_config, :pretrain_epochs ) > 0 ? Lux. update_state (st, :fixed_depth , Val (getproperty (expt_config, :num_layers ))) : st
355
355
356
+ wskip_sched = ParameterSchedulers. Exp (args[" w-skip" ], 0.92f0 )
357
+
356
358
for epoch in args[" start-epoch" ]: (expt_config. nepochs)
357
359
# Train for 1 epoch
358
360
ps, st, optimiser_state, train_stats = train_one_epoch (
359
- train_loader, model, ps, st, optimiser_state, epoch, loss_function, args
361
+ train_loader, model, ps, st, optimiser_state, epoch, loss_function, wskip_sched (epoch), args
360
362
)
361
363
train_stats = get_loggable_stats (train_stats)
362
364
0 commit comments