Skip to content

Commit 1e2cab7

Browse files
committed
trial of lm for ROBER
1 parent 7d914c8 commit 1e2cab7

File tree

6 files changed

+259
-0
lines changed

6 files changed

+259
-0
lines changed

robertson/checkpoint/mymodel.bson

-2.12 MB
Binary file not shown.

robertson/figs/i_exp_1.png

-45.7 KB
Binary file not shown.

robertson/figs/i_exp_10.png

-37.1 KB
Binary file not shown.

robertson/figs/i_exp_25.png

-45.4 KB
Binary file not shown.

robertson/figs/loss_grad.png

3.55 KB
Loading

robertson/rober_crnn_lm.jl

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots
2+
using DifferentialEquations
3+
using DiffEqSensitivity
4+
using Zygote
5+
using ForwardDiff
6+
using LinearAlgebra
7+
using Random
8+
using Statistics
9+
using ProgressBars, Printf
10+
using Flux.Optimise: update!
11+
using Flux.Losses: mae
12+
using BSON: @save, @load
13+
using LatinHypercubeSampling
14+
using LsqFit
15+
16+
is_restart = false;
17+
n_epoch = 2000;
18+
n_plot = 10;
19+
20+
opt = ADAMW(0.005, (0.9, 0.999), 1.f-6);
21+
datasize = 40;
22+
batchsize = 16;
23+
n_exp_train = 20;
24+
n_exp_val = 5;
25+
n_exp = n_exp_train + n_exp_val;
26+
noise = 1.f-4;
27+
ns = 3;
28+
nr = 6;
29+
30+
grad_max = 10 ^ (0.5);
31+
maxiters = 10000;
32+
33+
# alg = AutoTsit5(Rosenbrock23(autodiff=false));
34+
alg = Rosenbrock23(autodiff=false);
35+
atol = [1e-6, 1e-8, 1e-6];
36+
rtol = [1e-3, 1e-3, 1e-3];
37+
lb = 1e-8;
38+
ub = 1.f1;
39+
40+
np = nr * (2 * ns + 1) + 1;
41+
p = (rand(Float32, np) .- 0.5) * 2 * sqrt(6 / (ns + nr));
42+
p[end] = 1.e-1;
43+
44+
# Generate datasets
45+
u0_list = rand(Float32, (n_exp, ns)) .* 2 .+ 0.5;
46+
u0_list[:, 2:2] .= 0 .+ lb;
47+
u0_list[:, [1, 3]] .= randomLHC(n_exp, 2) ./ n_exp .+ 0.5
48+
49+
tsteps = 10 .^ range(0, 5, length=datasize);
50+
tspan = Float32[0, tsteps[end]];
51+
t_end = tsteps[end]
52+
53+
k = [4.f-2, 3.f7, 1.f4];
54+
ode_data_list = zeros(Float32, (n_exp, ns, datasize));
55+
yscale_list = [];
56+
57+
function trueODEfunc(dydt, y, k, t)
58+
r1 = k[1] * y[1]
59+
r2 = k[2] * y[2] * y[2]
60+
r3 = k[3] * y[2] * y[3]
61+
dydt[1] = -r1 + r3
62+
dydt[2] = r1 - r2 - r3
63+
dydt[3] = r2
64+
end
65+
66+
u0 = u0_list[1, :];
67+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k);
68+
69+
function max_min(ode_data)
70+
return maximum(ode_data, dims=2) .- minimum(ode_data, dims=2)
71+
end
72+
73+
for i = 1:n_exp
74+
u0 = u0_list[i, :]
75+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k)
76+
ode_data = Array(solve(prob_trueode, alg, saveat=tsteps, atol=atol, rtol=rtol))
77+
ode_data += randn(size(ode_data)) .* ode_data .* noise
78+
ode_data_list[i, :, :] = ode_data
79+
push!(yscale_list, max_min(ode_data))
80+
end
81+
82+
yscale = maximum(hcat(yscale_list...), dims=2);
83+
dydt_scale = yscale[:, 1] ./ t_end
84+
show(stdout, "text/plain", round.(yscale', digits=8))
85+
86+
function p2vec(p)
87+
slope = abs(p[end])
88+
w_b = @view(p[1:nr]) .* (10 * slope)
89+
90+
w_in = reshape(@view(p[nr * (ns + 1) + 1:nr * (2 * ns + 1)]), ns, nr)
91+
92+
w_out = reshape(@view(p[nr + 1:nr * (ns + 1)]), ns, nr)
93+
w_out = @. -w_in * (10 ^ w_out)
94+
95+
w_in = clamp.(w_in, 0, 2.5)
96+
return w_in, w_b, w_out
97+
end
98+
99+
function display_p(p)
100+
w_in, w_b, w_out = p2vec(p)
101+
println("species (column) reaction (row)")
102+
println("w_in | w_b | w_out")
103+
display(hcat(w_in', w_b, w_out'))
104+
105+
println("w_out_scale")
106+
w_out_ = (w_out .* dydt_scale)' .* exp.(w_b)
107+
# display(w_out_)
108+
# display(maximum(abs.(w_out_), dims=2)')
109+
display(w_out_ ./ maximum(abs.(w_out_), dims=2))
110+
println("slope = $(p[end])")
111+
end
112+
display_p(p)
113+
114+
function crnn(du, u, p, t)
115+
w_in_x = w_in' * @. log(clamp(u, lb, Inf))
116+
du .= w_out * (@. exp(w_in_x + w_b)) .* dydt_scale
117+
end
118+
119+
u0 = u0_list[1, :]
120+
prob = ODEProblem(crnn, u0, tspan, saveat=tsteps, atol=atol, rtol=rtol)
121+
122+
sense = BacksolveAdjoint(checkpointing=true; autojacvec=false);
123+
function predict_neuralode(u0, p; sample = datasize)
124+
global w_in, w_b, w_out = p2vec(p)
125+
_prob = remake(prob, tspan=[0, tsteps[sample]])
126+
sol = solve(prob, alg, u0=u0, p=p, saveat=tsteps[1:sample],
127+
sensalg=sense, verbose=false, maxiters=maxiters)
128+
pred = Array(sol)
129+
130+
if sol.retcode == :Success
131+
nothing
132+
else
133+
println("ode solver failed")
134+
end
135+
return pred
136+
end
137+
pred = predict_neuralode(u0, p);
138+
139+
function loss_neuralode(p, i_exp; sample = datasize)
140+
pred = predict_neuralode(u0_list[i_exp, :], p; sample)
141+
ode_data = ode_data_list[i_exp, :, 1:size(pred)[2]]
142+
loss = mae(ode_data ./ yscale, pred ./ yscale)
143+
return loss
144+
end
145+
loss_neuralode(p, 1)
146+
147+
cbi = function (p, i_exp)
148+
ode_data = ode_data_list[i_exp, :, :]
149+
pred = predict_neuralode(u0_list[i_exp, :], p)
150+
l_plt = []
151+
for i = 1:ns
152+
plt = scatter(tsteps, ode_data[i, :], xscale=:log10,
153+
markercolor=:transparent, label=string("data"))
154+
plot!(plt, tsteps[1:size(pred)[2]], pred[i, :], xscale=:log10, label=string("pred"))
155+
ylabel!(plt, "y$i")
156+
if i == ns
157+
xlabel!(plt, "Time [s]")
158+
plot!(plt, legend=:topleft)
159+
else
160+
plot!(plt, legend=false)
161+
end
162+
push!(l_plt, plt)
163+
end
164+
plt_all = plot(l_plt..., framestyle=:box, layouts = (ns, 1))
165+
png(plt_all, string("figs/i_exp_", i_exp))
166+
167+
return false
168+
end
169+
170+
l_loss_train = []
171+
l_loss_val = []
172+
l_grad = []
173+
iter = 1
174+
cb = function (p, loss_train, loss_val, g_norm)
175+
global l_loss_train, l_loss_val, l_grad, iter
176+
push!(l_loss_train, loss_train)
177+
push!(l_loss_val, loss_val)
178+
push!(l_grad, g_norm)
179+
180+
if iter % n_plot == 0
181+
display_p(p)
182+
183+
l_exp = randperm(n_exp)[1:1]
184+
println("update plot for ", l_exp)
185+
for i_exp in l_exp
186+
cbi(p, i_exp)
187+
end
188+
189+
plt_loss = plot(l_loss_train, xscale=:identity, yscale=:log10, label="train")
190+
plot!(plt_loss, l_loss_val, xscale=:identity, yscale=:log10, label="val")
191+
plt_grad = plot(l_grad, xscale=:identity, yscale=:log10, label="grad_norm")
192+
xlabel!(plt_loss, "Epoch")
193+
xlabel!(plt_grad, "Epoch")
194+
ylabel!(plt_loss, "Loss")
195+
ylabel!(plt_grad, "Grad Norm")
196+
ylims!(plt_loss, (-Inf, 1))
197+
plt_all = plot([plt_loss, plt_grad]..., legend=:top)
198+
png(plt_all, "figs/loss_grad")
199+
200+
@save "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val l_grad iter
201+
end
202+
iter += 1
203+
end
204+
205+
if is_restart
206+
@load "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val l_grad iter
207+
iter += 1
208+
# opt = ADAMW(0.001, (0.9, 0.999), 1.f-6)
209+
end
210+
211+
function loss_lm(p)
212+
[loss_neuralode(p, i) for i in 1:n_exp_train]
213+
end
214+
loss_lm(p);
215+
216+
g = function (p)
217+
return ForwardDiff.jacobian(x -> loss_lm(x), p)
218+
end
219+
g(p);
220+
221+
epochs = ProgressBar(iter:n_epoch);
222+
loss_epoch = zeros(Float32, n_exp);
223+
grad_norm = zeros(Float32, n_exp_train);
224+
for epoch in epochs
225+
global p
226+
for i_exp in randperm(n_exp_train)
227+
sample = rand(batchsize:datasize)
228+
grad = ForwardDiff.gradient(x -> loss_neuralode(x, i_exp; sample), p)
229+
grad_norm[i_exp] = norm(grad, 2)
230+
if grad_norm[i_exp] > grad_max
231+
grad = grad ./ grad_norm[i_exp] .* grad_max
232+
end
233+
update!(opt, p, grad)
234+
end
235+
for i_exp in 1:n_exp
236+
loss_epoch[i_exp] = loss_neuralode(p, i_exp)
237+
end
238+
loss_train = mean(loss_epoch[1:n_exp_train]);
239+
loss_val = mean(loss_epoch[n_exp_train + 1:end]);
240+
g_norm = mean(grad_norm)
241+
set_description(epochs, string(@sprintf("Loss train %.4e val %.4e gnorm %.4e", loss_train, loss_val, g_norm)))
242+
cb(p, loss_train, loss_val, g_norm);
243+
244+
if loss_train < 0.1
245+
break
246+
end
247+
end
248+
249+
fit = LsqFit.lmfit(loss_lm, g, p, Float64[]; show_trace=true, maxIter=2000, x_tol=1e-8)
250+
p_fit = fit.param;
251+
display_p(p_fit)
252+
cbi(p_fit, 1)
253+
cbi(p_fit, n_exp)
254+
255+
# @printf("min loss train %.4e val %.4e\n", minimum(l_loss_train), minimum(l_loss_val))
256+
257+
# for i_exp in 1:n_exp
258+
# cbi(p, i_exp)
259+
# end

0 commit comments

Comments
 (0)