|
| 1 | +using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots |
| 2 | +using DifferentialEquations |
| 3 | +using DiffEqSensitivity |
| 4 | +using Zygote |
| 5 | +using ForwardDiff |
| 6 | +using LinearAlgebra |
| 7 | +using Random |
| 8 | +using Statistics |
| 9 | +using ProgressBars, Printf |
| 10 | +using Flux.Optimise: update! |
| 11 | +using Flux.Losses: mae |
| 12 | +using BSON: @save, @load |
| 13 | +using LatinHypercubeSampling |
| 14 | +using LsqFit |
| 15 | + |
| 16 | +is_restart = false; |
| 17 | +n_epoch = 2000; |
| 18 | +n_plot = 10; |
| 19 | + |
| 20 | +opt = ADAMW(0.005, (0.9, 0.999), 1.f-6); |
| 21 | +datasize = 40; |
| 22 | +batchsize = 16; |
| 23 | +n_exp_train = 20; |
| 24 | +n_exp_val = 5; |
| 25 | +n_exp = n_exp_train + n_exp_val; |
| 26 | +noise = 1.f-4; |
| 27 | +ns = 3; |
| 28 | +nr = 6; |
| 29 | + |
| 30 | +grad_max = 10 ^ (0.5); |
| 31 | +maxiters = 10000; |
| 32 | + |
| 33 | +# alg = AutoTsit5(Rosenbrock23(autodiff=false)); |
| 34 | +alg = Rosenbrock23(autodiff=false); |
| 35 | +atol = [1e-6, 1e-8, 1e-6]; |
| 36 | +rtol = [1e-3, 1e-3, 1e-3]; |
| 37 | +lb = 1e-8; |
| 38 | +ub = 1.f1; |
| 39 | + |
| 40 | +np = nr * (2 * ns + 1) + 1; |
| 41 | +p = (rand(Float32, np) .- 0.5) * 2 * sqrt(6 / (ns + nr)); |
| 42 | +p[end] = 1.e-1; |
| 43 | + |
| 44 | +# Generate datasets |
| 45 | +u0_list = rand(Float32, (n_exp, ns)) .* 2 .+ 0.5; |
| 46 | +u0_list[:, 2:2] .= 0 .+ lb; |
| 47 | +u0_list[:, [1, 3]] .= randomLHC(n_exp, 2) ./ n_exp .+ 0.5 |
| 48 | + |
| 49 | +tsteps = 10 .^ range(0, 5, length=datasize); |
| 50 | +tspan = Float32[0, tsteps[end]]; |
| 51 | +t_end = tsteps[end] |
| 52 | + |
| 53 | +k = [4.f-2, 3.f7, 1.f4]; |
| 54 | +ode_data_list = zeros(Float32, (n_exp, ns, datasize)); |
| 55 | +yscale_list = []; |
| 56 | + |
| 57 | +function trueODEfunc(dydt, y, k, t) |
| 58 | + r1 = k[1] * y[1] |
| 59 | + r2 = k[2] * y[2] * y[2] |
| 60 | + r3 = k[3] * y[2] * y[3] |
| 61 | + dydt[1] = -r1 + r3 |
| 62 | + dydt[2] = r1 - r2 - r3 |
| 63 | + dydt[3] = r2 |
| 64 | +end |
| 65 | + |
| 66 | +u0 = u0_list[1, :]; |
| 67 | +prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k); |
| 68 | + |
| 69 | +function max_min(ode_data) |
| 70 | + return maximum(ode_data, dims=2) .- minimum(ode_data, dims=2) |
| 71 | +end |
| 72 | + |
| 73 | +for i = 1:n_exp |
| 74 | + u0 = u0_list[i, :] |
| 75 | + prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k) |
| 76 | + ode_data = Array(solve(prob_trueode, alg, saveat=tsteps, atol=atol, rtol=rtol)) |
| 77 | + ode_data += randn(size(ode_data)) .* ode_data .* noise |
| 78 | + ode_data_list[i, :, :] = ode_data |
| 79 | + push!(yscale_list, max_min(ode_data)) |
| 80 | +end |
| 81 | + |
| 82 | +yscale = maximum(hcat(yscale_list...), dims=2); |
| 83 | +dydt_scale = yscale[:, 1] ./ t_end |
| 84 | +show(stdout, "text/plain", round.(yscale', digits=8)) |
| 85 | + |
| 86 | +function p2vec(p) |
| 87 | + slope = abs(p[end]) |
| 88 | + w_b = @view(p[1:nr]) .* (10 * slope) |
| 89 | + |
| 90 | + w_in = reshape(@view(p[nr * (ns + 1) + 1:nr * (2 * ns + 1)]), ns, nr) |
| 91 | + |
| 92 | + w_out = reshape(@view(p[nr + 1:nr * (ns + 1)]), ns, nr) |
| 93 | + w_out = @. -w_in * (10 ^ w_out) |
| 94 | + |
| 95 | + w_in = clamp.(w_in, 0, 2.5) |
| 96 | + return w_in, w_b, w_out |
| 97 | +end |
| 98 | + |
| 99 | +function display_p(p) |
| 100 | + w_in, w_b, w_out = p2vec(p) |
| 101 | + println("species (column) reaction (row)") |
| 102 | + println("w_in | w_b | w_out") |
| 103 | + display(hcat(w_in', w_b, w_out')) |
| 104 | + |
| 105 | + println("w_out_scale") |
| 106 | + w_out_ = (w_out .* dydt_scale)' .* exp.(w_b) |
| 107 | + # display(w_out_) |
| 108 | + # display(maximum(abs.(w_out_), dims=2)') |
| 109 | + display(w_out_ ./ maximum(abs.(w_out_), dims=2)) |
| 110 | + println("slope = $(p[end])") |
| 111 | +end |
| 112 | +display_p(p) |
| 113 | + |
| 114 | +function crnn(du, u, p, t) |
| 115 | + w_in_x = w_in' * @. log(clamp(u, lb, Inf)) |
| 116 | + du .= w_out * (@. exp(w_in_x + w_b)) .* dydt_scale |
| 117 | +end |
| 118 | + |
| 119 | +u0 = u0_list[1, :] |
| 120 | +prob = ODEProblem(crnn, u0, tspan, saveat=tsteps, atol=atol, rtol=rtol) |
| 121 | + |
| 122 | +sense = BacksolveAdjoint(checkpointing=true; autojacvec=false); |
| 123 | +function predict_neuralode(u0, p; sample = datasize) |
| 124 | + global w_in, w_b, w_out = p2vec(p) |
| 125 | + _prob = remake(prob, tspan=[0, tsteps[sample]]) |
| 126 | + sol = solve(prob, alg, u0=u0, p=p, saveat=tsteps[1:sample], |
| 127 | + sensalg=sense, verbose=false, maxiters=maxiters) |
| 128 | + pred = Array(sol) |
| 129 | + |
| 130 | + if sol.retcode == :Success |
| 131 | + nothing |
| 132 | + else |
| 133 | + println("ode solver failed") |
| 134 | + end |
| 135 | + return pred |
| 136 | +end |
| 137 | +pred = predict_neuralode(u0, p); |
| 138 | + |
| 139 | +function loss_neuralode(p, i_exp; sample = datasize) |
| 140 | + pred = predict_neuralode(u0_list[i_exp, :], p; sample) |
| 141 | + ode_data = ode_data_list[i_exp, :, 1:size(pred)[2]] |
| 142 | + loss = mae(ode_data ./ yscale, pred ./ yscale) |
| 143 | + return loss |
| 144 | +end |
| 145 | +loss_neuralode(p, 1) |
| 146 | + |
| 147 | +cbi = function (p, i_exp) |
| 148 | + ode_data = ode_data_list[i_exp, :, :] |
| 149 | + pred = predict_neuralode(u0_list[i_exp, :], p) |
| 150 | + l_plt = [] |
| 151 | + for i = 1:ns |
| 152 | + plt = scatter(tsteps, ode_data[i, :], xscale=:log10, |
| 153 | + markercolor=:transparent, label=string("data")) |
| 154 | + plot!(plt, tsteps[1:size(pred)[2]], pred[i, :], xscale=:log10, label=string("pred")) |
| 155 | + ylabel!(plt, "y$i") |
| 156 | + if i == ns |
| 157 | + xlabel!(plt, "Time [s]") |
| 158 | + plot!(plt, legend=:topleft) |
| 159 | + else |
| 160 | + plot!(plt, legend=false) |
| 161 | + end |
| 162 | + push!(l_plt, plt) |
| 163 | + end |
| 164 | + plt_all = plot(l_plt..., framestyle=:box, layouts = (ns, 1)) |
| 165 | + png(plt_all, string("figs/i_exp_", i_exp)) |
| 166 | + |
| 167 | + return false |
| 168 | +end |
| 169 | + |
| 170 | +l_loss_train = [] |
| 171 | +l_loss_val = [] |
| 172 | +l_grad = [] |
| 173 | +iter = 1 |
| 174 | +cb = function (p, loss_train, loss_val, g_norm) |
| 175 | + global l_loss_train, l_loss_val, l_grad, iter |
| 176 | + push!(l_loss_train, loss_train) |
| 177 | + push!(l_loss_val, loss_val) |
| 178 | + push!(l_grad, g_norm) |
| 179 | + |
| 180 | + if iter % n_plot == 0 |
| 181 | + display_p(p) |
| 182 | + |
| 183 | + l_exp = randperm(n_exp)[1:1] |
| 184 | + println("update plot for ", l_exp) |
| 185 | + for i_exp in l_exp |
| 186 | + cbi(p, i_exp) |
| 187 | + end |
| 188 | + |
| 189 | + plt_loss = plot(l_loss_train, xscale=:identity, yscale=:log10, label="train") |
| 190 | + plot!(plt_loss, l_loss_val, xscale=:identity, yscale=:log10, label="val") |
| 191 | + plt_grad = plot(l_grad, xscale=:identity, yscale=:log10, label="grad_norm") |
| 192 | + xlabel!(plt_loss, "Epoch") |
| 193 | + xlabel!(plt_grad, "Epoch") |
| 194 | + ylabel!(plt_loss, "Loss") |
| 195 | + ylabel!(plt_grad, "Grad Norm") |
| 196 | + ylims!(plt_loss, (-Inf, 1)) |
| 197 | + plt_all = plot([plt_loss, plt_grad]..., legend=:top) |
| 198 | + png(plt_all, "figs/loss_grad") |
| 199 | + |
| 200 | + @save "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val l_grad iter |
| 201 | + end |
| 202 | + iter += 1 |
| 203 | +end |
| 204 | + |
| 205 | +if is_restart |
| 206 | + @load "./checkpoint/mymodel.bson" p opt l_loss_train l_loss_val l_grad iter |
| 207 | + iter += 1 |
| 208 | + # opt = ADAMW(0.001, (0.9, 0.999), 1.f-6) |
| 209 | +end |
| 210 | + |
| 211 | +function loss_lm(p) |
| 212 | + [loss_neuralode(p, i) for i in 1:n_exp_train] |
| 213 | +end |
| 214 | +loss_lm(p); |
| 215 | + |
| 216 | +g = function (p) |
| 217 | + return ForwardDiff.jacobian(x -> loss_lm(x), p) |
| 218 | +end |
| 219 | +g(p); |
| 220 | + |
| 221 | +epochs = ProgressBar(iter:n_epoch); |
| 222 | +loss_epoch = zeros(Float32, n_exp); |
| 223 | +grad_norm = zeros(Float32, n_exp_train); |
| 224 | +for epoch in epochs |
| 225 | + global p |
| 226 | + for i_exp in randperm(n_exp_train) |
| 227 | + sample = rand(batchsize:datasize) |
| 228 | + grad = ForwardDiff.gradient(x -> loss_neuralode(x, i_exp; sample), p) |
| 229 | + grad_norm[i_exp] = norm(grad, 2) |
| 230 | + if grad_norm[i_exp] > grad_max |
| 231 | + grad = grad ./ grad_norm[i_exp] .* grad_max |
| 232 | + end |
| 233 | + update!(opt, p, grad) |
| 234 | + end |
| 235 | + for i_exp in 1:n_exp |
| 236 | + loss_epoch[i_exp] = loss_neuralode(p, i_exp) |
| 237 | + end |
| 238 | + loss_train = mean(loss_epoch[1:n_exp_train]); |
| 239 | + loss_val = mean(loss_epoch[n_exp_train + 1:end]); |
| 240 | + g_norm = mean(grad_norm) |
| 241 | + set_description(epochs, string(@sprintf("Loss train %.4e val %.4e gnorm %.4e", loss_train, loss_val, g_norm))) |
| 242 | + cb(p, loss_train, loss_val, g_norm); |
| 243 | + |
| 244 | + if loss_train < 0.1 |
| 245 | + break |
| 246 | + end |
| 247 | +end |
| 248 | + |
| 249 | +fit = LsqFit.lmfit(loss_lm, g, p, Float64[]; show_trace=true, maxIter=2000, x_tol=1e-8) |
| 250 | +p_fit = fit.param; |
| 251 | +display_p(p_fit) |
| 252 | +cbi(p_fit, 1) |
| 253 | +cbi(p_fit, n_exp) |
| 254 | + |
| 255 | +# @printf("min loss train %.4e val %.4e\n", minimum(l_loss_train), minimum(l_loss_val)) |
| 256 | + |
| 257 | +# for i_exp in 1:n_exp |
| 258 | +# cbi(p, i_exp) |
| 259 | +# end |
0 commit comments