Skip to content

Commit 7e42c90

Browse files
committed
Decay the weight of skip
1 parent a9af2e3 commit 7e42c90

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

examples/cifar10/main.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true
5656

5757
function get_loss_function(args)
5858
if args["model-type"] == "VANILLA"
59-
function loss_function_closure_vanilla(x, y, model, ps, st)
59+
function loss_function_closure_vanilla(x, y, model, ps, st, w_skip=args["w-skip"])
6060
(ŷ, soln), st_ = model(x, ps, st)
6161
celoss = logitcrossentropy(ŷ, y)
6262
skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star)
@@ -65,11 +65,11 @@ function get_loss_function(args)
6565
end
6666
return loss_function_closure_vanilla
6767
else
68-
function loss_function_closure_skip(x, y, model, ps, st)
68+
function loss_function_closure_skip(x, y, model, ps, st, w_skip=args["w-skip"])
6969
(ŷ, soln), st_ = model(x, ps, st)
7070
celoss = logitcrossentropy(ŷ, y)
7171
skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star)
72-
loss = celoss + args["w-skip"] * skiploss
72+
loss = celoss + w_skip * skiploss
7373
return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual)
7474
end
7575
return loss_function_closure_skip
@@ -185,7 +185,7 @@ function validate(val_loader, model, ps, st, loss_function, args)
185185
end
186186

187187
# Training
188-
function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, args)
188+
function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, w_skip, args)
189189
batch_time = AverageMeter("Batch Time", "6.3f")
190190
data_time = AverageMeter("Data Time", "6.3f")
191191
forward_pass_time = AverageMeter("Forward Pass Time", "6.3f")
@@ -212,7 +212,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo
212212
# Gradients and Update
213213
_t = time()
214214
(loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote.pullback(
215-
p -> loss_function(x, y, model, p, st), ps
215+
p -> loss_function(x, y, model, p, st, w_skip), ps
216216
)
217217
forward_pass_time(time() - _t, B)
218218
_t = time()
@@ -353,10 +353,12 @@ function main(args)
353353

354354
st = hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) > 0 ? Lux.update_state(st, :fixed_depth, Val(getproperty(expt_config, :num_layers))) : st
355355

356+
wskip_sched = ParameterSchedulers.Exp(args["w-skip"], 0.92f0)
357+
356358
for epoch in args["start-epoch"]:(expt_config.nepochs)
357359
# Train for 1 epoch
358360
ps, st, optimiser_state, train_stats = train_one_epoch(
359-
train_loader, model, ps, st, optimiser_state, epoch, loss_function, args
361+
train_loader, model, ps, st, optimiser_state, epoch, loss_function, wskip_sched(epoch), args
360362
)
361363
train_stats = get_loggable_stats(train_stats)
362364

0 commit comments

Comments
 (0)