Skip to content

Commit 56b1e3d

Browse files
committed
Add regularization term
1 parent c5e5c32 commit 56b1e3d

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

configs/bck_snell/att_1.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ seeds:
2525
prior: 345
2626
post: 456
2727
closure:
28-
name: "attention"
28+
name: "attention_reg"
2929
type: attentioncnn
3030
radii: [2, 2, 2, 2, 2]
3131
channels: [24, 24, 24, 24, 2]
@@ -42,19 +42,21 @@ closure:
4242
rng: "Xoshiro(seeds.θ_start)"
4343
priori:
4444
dotrain: true
45-
nepoch: 50000
45+
nepoch: 5000
4646
batchsize: 64
4747
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
4848
do_plot: false
4949
plot_train: false
50+
lambda: 0.001
5051
posteriori:
5152
dotrain: true
5253
projectorders: "(ProjectOrder.Last, )"
5354
nepoch: 1500
54-
opt: "OptimizationCMAEvolutionStrategy.CMAEvolutionStrategyOpt()"
55+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
5556
nunroll: 5
5657
nunroll_valid: 10
5758
dt: 0.0001
5859
nsamples: 1
5960
do_plot: false
6061
plot_train: false
62+
lambda: 0.001

configs/bck_snell/cnn_noproj.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ seeds:
2525
prior: 345
2626
post: 456
2727
closure:
28-
name: "cnn_noproj"
28+
name: "NoProjection"
2929
type: cnn
3030
radii: [2, 2, 2, 2, 2]
3131
channels: [24, 24, 24, 24, 2]

extra_model_workflow.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ let
299299
plot_train = conf["priori"]["plot_train"],
300300
nepoch,
301301
dataproj = conf["dataproj"],
302+
λ = conf["priori"]["lambda"],
302303
)
303304
end
304305
end
@@ -411,6 +412,7 @@ let
411412
sensealg = sensealg,
412413
sciml_solver = sciml_solver,
413414
dataproj = conf["dataproj"],
415+
λ = conf["posteriori"]["lambda"],
414416
)
415417
end
416418
end

src/train.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ function trainprior(;
110110
do_plot = false,
111111
plot_train = false,
112112
nepoch,
113-
dataproj
113+
dataproj,
114+
λ = nothing,
114115
)
115116
device(x) = adapt(params.backend, x)
116117
itotal = 0
@@ -159,8 +160,9 @@ function trainprior(;
159160
)
160161
train_data_priori = dataloader_prior()
161162

162-
loss_priori_lux(closure, θ, st, train_data_priori)
163-
loss = loss_priori_lux
163+
# Trigger the loss once and wrap it for the expected Lux interface
164+
loss_priori_lux(closure, θ, st, train_data_priori, λ)
165+
loss(model, param, state, data) = loss_priori_lux(model, param, state, data, λ)
164166

165167
if loadcheckpoint && isfile(checkfile)
166168
callbackstate, trainstate, epochs_trained =
@@ -272,6 +274,7 @@ function trainpost(;
272274
sensealg = nothing,
273275
sciml_solver = nothing,
274276
dataproj,
277+
λ = nothing,
275278
)
276279
device(x) = adapt(params.backend, x)
277280
itotal = 0
@@ -325,6 +328,7 @@ function trainpost(;
325328
griddims,
326329
inside,
327330
dt;
331+
λ = λ,
328332
ensemble = nsamples > 1,
329333
sciml_solver = sciml_solver,
330334
sensealg = sensealg,

0 commit comments

Comments
 (0)