diff --git a/Project.toml b/Project.toml index 07a51c0a7..cd4efa061 100644 --- a/Project.toml +++ b/Project.toml @@ -67,6 +67,7 @@ Distributions = "0.25.107" DocStringExtensions = "0.9.3" DomainSets = "0.7" ExplicitImports = "1.10.1" +FastGaussQuadrature = "1.0.2" Flux = "0.14.22" ForwardDiff = "0.10.36" Functors = "0.4.12, 0.5" @@ -116,6 +117,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -132,4 +134,4 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"] +test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"] \ No newline at end of file diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 489cd5f6d..25db6701b 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -3,12 +3,11 @@ """ BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000, priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05], - phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0, + phystd = [0.05], phynewstd = (ode_params)->[0.05], dataset = [], physdt = 1 / 20.0, MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing, Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric), - Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, - progress = false, verbose = false) + Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, estim_collocate = false, progress = false, verbose = false) Algorithm for solving ordinary differential equations using a Bayesian neural network. This is a specialization of the physics-informed neural network which is used as a solver for a @@ -43,7 +42,7 @@ sol = solve(prob, Tsit5(); saveat = 0.05) u = sol.u[1:100] time = sol.t[1:100] x̂ = u .+ (u .* 0.2) .* randn(size(u)) -dataset = [x̂, time] +dataset = [x̂, time, 0.05 .* ones(length(time))] chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1)) @@ -86,8 +85,8 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El param <: Union{Nothing, Vector{<:Distribution}} l2std::Vector{Float64} phystd::Vector{Float64} - phynewstd::Vector{Float64} - dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} + phynewstd + dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}} physdt::Float64 MCMCkwargs <: NamedTuple nchains::Int @@ -103,7 +102,7 @@ end function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000, priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05], - phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0, + phynewstd = (ode_params) -> [0.05], dataset = [], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing, Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 54c1ac091..b5735ec71 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -49,7 +49,7 @@ using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian, using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf using LogDensityProblems: LogDensityProblems using MCMCChains: MCMCChains, Chains, sample -using MonteCarloMeasurements: Particles, pmean +using MonteCarloMeasurements: Particles import LuxCore: initialparameters, initialstates, parameterlength diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index f7f18e09b..b1a27df88 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -3,10 +3,10 @@ prob <: SciMLBase.ODEProblem smodel <: StatefulLuxLayer strategy <: AbstractTrainingStrategy - dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} + dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}} priors <: Vector{<:Distribution} phystd::Vector{Float64} - phynewstd::Vector{Float64} + phynewstd::Function l2std::Vector{Float64} autodiff::Bool physdt::Float64 @@ -74,32 +74,37 @@ suggested extra loss function for ODE solver case """ @views function L2loss2(ltd::LogTargetDensity, θ) ltd.extraparams ≤ 0 && return false # XXX: type-stability? - + u0 = ltd.prob.u0 f = ltd.prob.f - t = ltd.dataset[end] - u1 = ltd.dataset[2] - û = ltd.dataset[1] + t = ltd.dataset[end - 1] + û = ltd.dataset[1:(end - 2)] + quadrature_weights = ltd.dataset[end] nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff) ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] : θ[((length(θ) - ltd.extraparams) + 1):length(θ)] + phynewstd = ltd.phynewstd(ode_params) - physsol = if length(ltd.prob.u0) == 1 - [f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] + physsol = if length(u0) == 1 + [f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] else - [f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] + [f([û[j][i] for j in eachindex(u0)], ode_params, tᵢ) + for (i, tᵢ) in enumerate(t)] end # form of NN output matrix output dim x n deri_physsol = reduce(hcat, physsol) T = promote_type(eltype(deri_physsol), eltype(nnsol)) physlogprob = T(0) - for i in 1:length(ltd.prob.u0) + # for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero. + # Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t)) + # dims of phynewstd is same as u0 due to BNNODE being an out-of-place ODE solver. + for i in eachindex(u0) physlogprob += logpdf( - MvNormal(deri_physsol[i, :], - Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))), - nnsol[i, :] + MvNormal((nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights, + Diagonal(abs2.(T(phynewstd[i]) .* ones(T, length(t))))), + zeros(length(t)) ) end return physlogprob @@ -109,10 +114,10 @@ end L2 loss loglikelihood(needed for ODE parameter estimation). """ @views function L2LossData(ltd::LogTargetDensity, θ) - (ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0 + (isempty(ltd.dataset) || ltd.extraparams == 0) && return 0 # matrix(each row corresponds to vector u's rows) - nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)]) + nn = ltd(ltd.dataset[end - 1], θ[1:(length(θ) - ltd.extraparams)]) T = eltype(nn) L2logprob = zero(T) @@ -150,7 +155,7 @@ end function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool, tspan, ode_params, θ) ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2]) - t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) + t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1]) return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end @@ -158,16 +163,18 @@ function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity, f, autodiff::Bool, tspan, ode_params, θ) T = promote_type(eltype(tspan[1]), eltype(tspan[2])) samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1] - t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end]) + t = isempty(ltd.dataset) ? samples : vcat(samples, ltd.dataset[end - 1]) return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool, tspan, ode_params, θ) + # integrand is shape of NN output integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params) intprob = IntegralProblem( integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0)) sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol) + # sum over losses for all NN outputs return sum(sol.u) end @@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f, append!(ts, temp_data) end - t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) + t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1]) return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) end @@ -202,23 +209,21 @@ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN # this is a vector{vector{dx,dy}}(handle case single u(float passed)) if length(out[:, 1]) == 1 - physsol = [f(out[:, i][1], ode_params, t[i]) for i in 1:length(out[1, :])] + physsol = [f(out[:, i][1], ode_params, t[i]) for i in eachindex(t)] else - physsol = [f(out[:, i], ode_params, t[i]) for i in 1:length(out[1, :])] + physsol = [f(out[:, i], ode_params, t[i]) for i in eachindex(t)] end physsol = reduce(hcat, physsol) nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], autodiff) - - vals = nnsol .- physsol - T = eltype(vals) + T = eltype(nnsol) # N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector # of dependant variables) return [logpdf( - MvNormal(vals[i, :], - Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(vals[i, :]))))), - zeros(T, length(vals[i, :])) + MvNormal((nnsol[i, :] .- physsol[i, :]), + Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(t))))), + zeros(T, length(t)) ) for i in 1:length(ltd.prob.u0)] end @@ -262,9 +267,9 @@ function kernelchoice(Kernel, MCMCkwargs) end """ - ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing], + ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [], init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0, - l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), + l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)->[0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false, Kernel = HMC, Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), @@ -294,7 +299,7 @@ time = sol.t[1:100] ### dataset and BPINN create x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u))) -dataset = [x̂, time] +dataset = [x̂, time, 0.05 .* ones(length(time))] chain1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1) @@ -318,18 +323,22 @@ fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chain1, ## NOTES -Dataset is required for accurate Parameter estimation + solving equations -Incase you are only solving the Equations for solution, do not provide dataset +Dataset is required for accurate Parameter estimation in Inverse Problems. +Incase you are only solving Non parametric ODE Equations for a solution, do not provide a dataset. ## Positional Arguments -* `prob`: DEProblem(out of place and the function signature should be f(u,p,t). +* `prob`: ODEProblem(out of place and the function signature should be f(u,p,t). * `chain`: Lux Neural Netork which would be made the Bayesian PINN. ## Keyword Arguments * `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization. +* `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain. + The dataset is used to compute the L2 loss against the data and also for the new loss function. + For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`. + Is empty by default assuming a forward problem is being solved. * `init_params`: initial parameter values for BPINN (ideally for multiple chains different initializations preferred) * `nchains`: number of chains you want to sample @@ -337,7 +346,9 @@ Incase you are only solving the Equations for solution, do not provide dataset ~2/3 of draw samples) * `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset * `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System -* `phynewstd`: standard deviation of new loss func term +* `phynewstd`: A function that gives the standard deviation of the new loss function at each iteration. + It takes the ODE parameters as input and returns a vector of standard deviations. + Is (ode_params) -> [0.05] by default. * `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default. * `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems. @@ -357,6 +368,7 @@ Incase you are only solving the Equations for solution, do not provide dataset * `max_depth`: Maximum doubling tree depth (NUTS) * `Δ_max`: Maximum divergence during doubling tree (NUTS) Refer: https://turinglang.org/AdvancedHMC.jl/stable/ +* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation. * `progress`: controls whether to show the progress meter or not. * `verbose`: controls the verbosity. (Sample call args in AHMC) @@ -366,9 +378,10 @@ Incase you are only solving the Equations for solution, do not provide dataset releases. """ function ahmc_bayesian_pinn_ode( - prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing], + prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [], init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05], - phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, + phystd = [0.05], phynewstd = (ode_params) -> [0.05], + priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false, Kernel = HMC, Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), @@ -380,15 +393,15 @@ function ahmc_bayesian_pinn_ode( strategy = strategy == GridTraining ? strategy(physdt) : strategy - if dataset != [nothing] && - (length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) - error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}") + if !isempty(dataset) && + (length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) + error("Invalid dataset. The dataset would be a timeseries (x̂,t,W) with type: Vector{Vector{AbstractFloat}}") end - if dataset != [nothing] && param == [] - println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.") - elseif dataset == [nothing] && param != [] - error("Dataset Required for Parameter Estimation.") + if !isempty(dataset) && isempty(param) + println("Dataset is only needed for Inverse problems performing Parameter Estimation, not in only Forward Problem case.") + elseif isempty(dataset) && !isempty(param) + error("Dataset Required for Inverse problems performing Parameter Estimation.") end initial_nnθ, chain, st = generate_ltd(chain, init_params) @@ -461,7 +474,8 @@ function ahmc_bayesian_pinn_ode( MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) - samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; + samples, + stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; progress = progress, verbose = verbose) samplesc[i] = samples @@ -479,7 +493,8 @@ function ahmc_bayesian_pinn_ode( MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) - samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, + samples, + stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; progress = progress, verbose = verbose) if verbose diff --git a/src/ode_solve.jl b/src/ode_solve.jl index f0189c7f2..89cac8b33 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -1,8 +1,9 @@ abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end """ - NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, - additional_loss = nothing, kwargs...) + NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, + batch = true, param_estim = false, additional_loss = nothing, + dataset = [], estim_collocate = false, kwargs...) Algorithm for solving ordinary differential equations using a neural network. This is a specialization of the physics-informed neural network which is used as a solver for a @@ -28,6 +29,10 @@ standard `ODEProblem`. * `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, θ are the weights of the neural network(s). +* `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain. + The dataset is used to compute a L2 loss against the data and also for the new loss function. + For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`. + Is empty by default assuming a forward problem is being solved. * `autodiff`: The switch between automatic and numerical differentiation for the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative @@ -44,6 +49,7 @@ standard `ODEProblem`. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no `dt` is given, and `GridTraining` is used with `dt` if given. +* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation. * `kwargs`: Extra keyword arguments are splatted to the Optimization.jl `solve` call. ## Examples @@ -91,14 +97,17 @@ Networks 9, no. 5 (1998): 987-1000. strategy <: Union{Nothing, AbstractTrainingStrategy} param_estim additional_loss <: Union{Nothing, Function} + dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}} + estim_collocate::Bool kwargs end function NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, - batch = true, param_estim = false, additional_loss = nothing, kwargs...) + batch = true, param_estim = false, additional_loss = nothing, + dataset = [], estim_collocate = false, kwargs...) chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) return NNODE(chain, opt, init_params, autodiff, batch, - strategy, param_estim, additional_loss, kwargs) + strategy, param_estim, additional_loss, dataset, estim_collocate, kwargs) end """ @@ -263,6 +272,44 @@ function generate_loss(::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) spaces only. Use StochasticTraining instead.") end +""" +L2 loss (needed for ODE parameter estimation). +""" +function generate_L2lossData(dataset, phi, n_output) + isempty(dataset) && return 0 + return (θ, _) -> sum(sum(abs2, phi(dataset[end - 1], θ)[i, :] .- dataset[i]) + for i in 1:n_output) +end + +""" +new loss +""" +function generate_L2loss2(f, autodiff, dataset, phi, n_output) + isempty(dataset) && return 0 + t = dataset[end - 1] + û = dataset[1:(end - 2)] + quadrature_weights = dataset[end] + + function L2loss2(θ, _) + nnsol = ode_dfdx(phi, t, θ, autodiff) + ode_params = θ.p + + physsol = if n_output == 1 + [f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] + else + [f([û[j][i] for j in 1:(length(dataset) - 2)], ode_params, tᵢ) + for (i, tᵢ) in enumerate(t)] + end + # form of NN output matrix output dim x n + deri_physsol = reduce(hcat, physsol) + + # Quadrature is applied on timewise losses + # Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t)) + return sum(sum(abs2.(nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights) + for i in 1:n_output) + end +end + @concrete struct NNODEInterpolation phi <: ODEPhi θ @@ -307,7 +354,9 @@ function SciMLBase.__solve( ) (; u0, tspan, f, p) = prob t0 = tspan[1] - (; param_estim, chain, opt, autodiff, init_params, batch, additional_loss) = alg + # add estim_collocate, dataset (or nothing) in NNODE + (; param_estim, estim_collocate, dataset, chain, opt, autodiff, + init_params, batch, additional_loss, estim_collocate) = alg phi, init_params = generate_phi_θ(chain, t0, u0, init_params) @@ -336,12 +385,30 @@ function SciMLBase.__solve( inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) - (param_estim && additional_loss === nothing) && - throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true).")) + if !isempty(dataset) && + (length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) + error("Invalid dataset. The dataset would be a timeseries (x̂,t,W) with type: Vector{Vector{AbstractFloat}") + end + + if isempty(dataset) && param_estim && isnothing(additional_loss) + error("Dataset or an additional loss is required for Inverse problems performing Parameter Estimation.") + elseif isempty(dataset) && estim_collocate + error("Dataset is required for Inverse problems performing Parameter Estimation using the new loss.") + end + + n_output = length(u0) + L2lossData = generate_L2lossData(dataset, phi, n_output) + L2loss2 = generate_L2loss2(f, autodiff, dataset, phi, n_output) # Creates OptimizationFunction Object from total_loss function total_loss(θ, _) L2_loss = inner_f(θ, phi) + + if param_estim && estim_collocate + L2_loss = L2_loss + L2lossData(θ, phi) + L2loss2(θ, phi) + elseif param_estim + L2_loss = L2_loss + L2lossData(θ, phi) + end if additional_loss !== nothing L2_loss = L2_loss + additional_loss(phi, θ) end diff --git a/test/BPINN_PDE_tests.jl b/test/BPINN_PDE_tests.jl index 4f7caea7e..7cba05259 100644 --- a/test/BPINN_PDE_tests.jl +++ b/test/BPINN_PDE_tests.jl @@ -379,8 +379,9 @@ end end end - @parameters x, t, α - @variables u(..) + @parameters α + @variables x, t + @syms u(x, t) Dt = Differential(t) Dx = Differential(x) Dx2 = Differential(x)^2 diff --git a/test/BPINN_tests.jl b/test/BPINN_tests.jl index 37dc4e54d..c975a23f5 100644 --- a/test/BPINN_tests.jl +++ b/test/BPINN_tests.jl @@ -29,7 +29,8 @@ chainlux = Chain(Dense(1, 7, tanh), Dense(7, 1)) θinit, st = Lux.setup(Random.default_rng(), chainlux) - fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode( + fh_mcmc_chain, fhsamples, + fhstats = ahmc_bayesian_pinn_ode( prob, chainlux, draw_samples = 2500) alg = BNNODE(chainlux, draw_samples = 2500) @@ -45,7 +46,7 @@ # --------------------- ahmc_bayesian_pinn_ode() call @test mean(abs.(x̂ .- meanscurve)) < 0.05 - @test mean(abs.(physsol1 .- meanscurve)) < 0.005 + @test mean(abs.(physsol1 .- meanscurve)) < 0.006 #--------------------- solve() call @test mean(abs.(x̂1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025 @@ -76,7 +77,7 @@ end u = [linear_analytic(u0, p, ti) for ti in ta] x̂ = collect(Float64, Array(u) + 0.2 * randn(size(u))) time = vec(collect(Float64, ta)) - dataset = [x̂, time] + dataset = [x̂, time, ones(length(time))] physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)] # testing points for solve call(saveat=1/50.0 ∴ at t = collect(eltype(saveat), prob.tspan[1]:saveat:prob.tspan[2] internally estimates) @@ -89,7 +90,8 @@ end chainlux1 = Chain(Dense(1, 7, tanh), Dense(7, 1)) θinit, st = Lux.setup(Random.default_rng(), chainlux1) - fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode( + fh_mcmc_chain, fhsamples, + fhstats = ahmc_bayesian_pinn_ode( prob, chainlux1, dataset = dataset, draw_samples = 2500, physdt = 1 / 50.0, priorsNNw = (0.0, 3.0), param = [LogNormal(9, 0.5)]) @@ -137,8 +139,10 @@ end sol = solve(prob, Tsit5(); saveat = 0.1) u = sol.u time = sol.t + + # Note this is signal scaled gaussian noise, therefore the noise is biased and L2 penalizes high std points implicitly. x̂ = u .+ (u .* 0.1) .* randn(size(u)) - dataset = [x̂, time] + dataset = [x̂, time, ones(length(time))] physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)] # separate set of points for testing the solve() call (it uses saveat 1/50 hence here length 501) @@ -149,10 +153,12 @@ end θinit, st = Lux.setup(Random.default_rng(), chainlux12) # this a forward solve - fh_mcmc_chainlux12, fhsampleslux12, fhstatslux12 = ahmc_bayesian_pinn_ode( + fh_mcmc_chainlux12, fhsampleslux12, + fhstatslux12 = ahmc_bayesian_pinn_ode( prob, chainlux12, draw_samples = 500, phystd = [0.01], priorsNNw = (0.0, 10.0)) - fh_mcmc_chainlux22, fhsampleslux22, fhstatslux22 = ahmc_bayesian_pinn_ode( + fh_mcmc_chainlux22, fhsampleslux22, + fhstatslux22 = ahmc_bayesian_pinn_ode( prob, chainlux12, dataset = dataset, draw_samples = 500, l2std = [0.02], phystd = [0.05], priorsNNw = (0.0, 10.0), param = [Normal(-7, 4)]) @@ -215,56 +221,78 @@ end time1 = vec(collect(Float64, ta0)) physsol0_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)] chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 - fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode( + fh_mcmc_chain, fhsamples, + fhstats = ahmc_bayesian_pinn_ode( prob, chainflux, draw_samples = 2500) alg = BNNODE(chainflux, draw_samples = 2500) @test alg.chain isa AbstractLuxLayer end -@testitem "BPINN ODE III: with the new objective" tags=[:odebpinn] begin +@testitem "BPINN ODE III: Inverse solve Improvement" tags=[:odebpinn] begin using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux, AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements import Flux - + using FastGaussQuadrature Random.seed!(100) - linear = (u, p, t) -> u / p + exp(t / p) * cos(t) + # (original Improvement tests can be run with 100 training points, check solve call tests.) + # new model is always better (especially less points, more noise etc), given the correct std & enough samples. + # std for the equation is limited ~ (var propagated via data points through chosen equation var/phystd) + # for inverse problems ratio of datapoints and unsolved datapoints is important. + + N = 20 # choose number of nodes, enough to approximate 2n-2 degree polynomials (gauss-lobatto case) + # x, w = gausslegendre(N) # does not include endpoints + x, w = gausslobatto(N) + # x, w = clenshaw_curtis(N) + tspan = (0.0, 10.0) + a = tspan[1] + b = tspan[2] + # transform the roots and weights + # x = map((x) -> (2 * (t - a) / (b - a)) - 1, x) + t = map((x) -> (x * (b - a) + (b + a)) / 2, x) + W = map((x) -> x * (b - a) / 2, w) + + linear = (u, p, t) -> u / p + exp(t / p) * cos(t) u0 = 0.0 p = -5.0 prob = ODEProblem(linear, u0, tspan, p) linear_analytic = (u0, p, t) -> exp(t / p) * (u0 + sin(t)) # SOLUTION AND CREATE DATASET - sol = solve(prob, Tsit5(); saveat = 0.1) - u = sol.u - time = sol.t - x̂ = u .+ (0.1 .* randn(size(u))) - dataset = [x̂, time] - physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)] + sol = solve(prob, Tsit5(); saveat = t) + u = sol.u # use these points for collocation + ts = sol.t + # old model finds less noisy signal easier to learn. (i think its overfitting) + x̂ = u .+ (0.1 .* randn(size(u))) + dataset = [x̂, ts, W] + physsol1 = [linear_analytic(prob.u0, p, ts[i]) for i in eachindex(ts)] chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1)) θinit, st = Lux.setup(Random.default_rng(), chainlux12) - fh_mcmc_chainlux22, fhsampleslux22, fhstatslux22 = ahmc_bayesian_pinn_ode( + # you could always directly fit model to all data, but it ignores equation, overfits data. + fh_mcmc_chainlux22, fhsampleslux22, + fhstatslux22 = ahmc_bayesian_pinn_ode( prob, chainlux12, dataset = dataset, - draw_samples = 500, + draw_samples = 2500, l2std = [0.1], - phystd = [0.01], - phynewstd = [0.01], + phystd = [0.1], + phynewstd = (p) -> [0.1 / p], priorsNNw = (0.0, 1.0), param = [ Normal(-7, 3) ], estim_collocate = true) - fh_mcmc_chainlux12, fhsampleslux12, fhstatslux12 = ahmc_bayesian_pinn_ode( + fh_mcmc_chainlux12, fhsampleslux12, + fhstatslux12 = ahmc_bayesian_pinn_ode( prob, chainlux12, dataset = dataset, - draw_samples = 500, + draw_samples = 2500, l2std = [0.1], - phystd = [0.01], + phystd = [0.1], priorsNNw = (0.0, 1.0), param = [ @@ -276,32 +304,31 @@ end #------------------------------ ahmc_bayesian_pinn_ode() call # Mean of last 100 sampled parameter's curves(lux chains)[Ensemble predictions] θ = [vector_to_parameters(fhsampleslux12[i][1:(end - 1)], θinit) - for i in 400:length(fhsampleslux12)] + for i in 2400:length(fhsampleslux12)] luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)] luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit) - for i in 400:length(fhsampleslux22)] + for i in 2400:length(fhsampleslux22)] luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)] luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean - @test mean(abs.(sol.u .- meanscurve2_2)) < 1e-2 - @test mean(abs.(physsol1 .- meanscurve2_2)) < 1e-2 + @test mean(abs.(sol.u .- meanscurve2_2)) < 5e-2 + @test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2 @test mean(abs.(sol.u .- meanscurve2_1)) > mean(abs.(sol.u .- meanscurve2_2)) @test mean(abs.(physsol1 .- meanscurve2_1)) > mean(abs.(physsol1 .- meanscurve2_2)) - # estimated parameters(lux chain) - param2 = mean(i[62] for i in fhsampleslux22[400:length(fhsampleslux22)]) - @test abs(param2 - p) < abs(0.05 * p) + param2 = mean(i[62] for i in fhsampleslux22[2400:length(fhsampleslux22)]) + @test abs(param2 - p) < abs(0.2 * p) - param1 = mean(i[62] for i in fhsampleslux12[400:length(fhsampleslux12)]) + param1 = mean(i[62] for i in fhsampleslux12[2400:length(fhsampleslux12)]) @test abs(param1 - p) > abs(0.5 * p) @test abs(param2 - p) < abs(param1 - p) end -@testitem "BPINN ODE III: new objective solve call" tags=[:odebpinn] begin +@testitem "BPINN ODE III: Inverse solve Improvement solve call" tags=[:odebpinn] begin using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux, AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements import Flux @@ -320,7 +347,8 @@ end u = sol.u time = sol.t x̂ = u .+ (0.1 .* randn(size(u))) - dataset = [x̂, time] + # dx=0.1 Gridtraining for newloss + dataset = [x̂, time, ones(length(time))] # set of points for testing the solve() call (it uses saveat 1/50 hence here length 501) time1 = vec(collect(Float64, range(tspan[1], tspan[2], length = 501))) @@ -334,7 +362,7 @@ end draw_samples = 1000, l2std = [0.1], phystd = [0.01], - phynewstd = [0.01], + phynewstd = (p) -> [0.01], priorsNNw = (0.0, 1.0), param = [ @@ -352,12 +380,13 @@ end @test abs(param3 - p) < abs(0.05 * p) end -@testitem "BPINN ODE IV: Improvement" tags=[:odebpinn] begin +@testitem "BPINN ODE IV: Inverse solve Improvement" tags=[:odebpinn] begin using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux, AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements import Flux - + using FastGaussQuadrature Random.seed!(100) + using NeuralPDE, Test function lotka_volterra(u, p, t) # Model parameters. @@ -375,18 +404,25 @@ end # initial-value problem. u0 = [1.0, 1.0] p = [1.5, 3.0] - tspan = (0.0, 7.0) + tspan = (0.0, 4.0) prob = ODEProblem(lotka_volterra, u0, tspan, p) - # OrdinaryDiffEq.jl solve - dt = 0.1 - solution = solve(prob, Tsit5(); saveat = dt) - + N = 20 + # x, w = gausslegendre(N) # does not include endpoints + x, w = gausslobatto(N) + # x, w = clenshaw_curtis(N) + a = tspan[1] + b = tspan[2] + # transform the roots and weights + # x = map((x) -> (2 * (t - a) / (b - a)) - 1, x) + t = map((x) -> (x * (b - a) + (b + a)) / 2, x) + W = map((x) -> x * (b - a) / 2, w) + solution = solve(prob, Tsit5(); saveat = t) times = solution.t u = hcat(solution.u...) x = u[1, :] + (0.5 .* randn(length(u[1, :]))) y = u[2, :] + (0.5 .* randn(length(u[2, :]))) - dataset = [x, y, times] + dataset = [x, y, times, W] chain = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 7, tanh), Lux.Dense(7, 2)) @@ -398,32 +434,40 @@ end phystd = [0.5, 0.5], priorsNNw = (0.0, 1.0), param = [ - Normal(2, 2), - Normal(2, 2)]) + Normal(-7, 2), + Normal(-7, 2)]) alg2 = BNNODE(chain; dataset = dataset, draw_samples = 1000, l2std = [0.5, 0.5], phystd = [0.5, 0.5], - phynewstd = [1.0, 1.0], + phynewstd = (p) -> [0.5, 0.5], priorsNNw = (0.0, 1.0), param = [ - Normal(2, 2), - Normal(2, 2)], estim_collocate = true) + Normal(-7, 2), + Normal(-7, 2)], estim_collocate = true) + dt = 0.05 @time sol_pestim1 = solve(prob, alg1; saveat = dt) @time sol_pestim2 = solve(prob, alg2; saveat = dt) + # OrdinaryDiffEq.jl solve at sol.timepoints. + solution = solve(prob, Tsit5(); saveat = dt) + u = hcat(solution.u...) unsafe_comparisons(true) bitvec = abs.(p .- sol_pestim1.estimated_de_params) .> abs.(p .- sol_pestim2.estimated_de_params) @test bitvec == ones(size(bitvec)) - Loss_1 = mean(abs, u[1, :] .- pmean(sol_pestim1.ensemblesol[1])) + - mean(abs, u[2, :] .- pmean(sol_pestim1.ensemblesol[2])) - Loss_2 = mean(abs, u[1, :] .- pmean(sol_pestim2.ensemblesol[1])) + - mean(abs, u[2, :] .- pmean(sol_pestim2.ensemblesol[2])) + @test mean(abs, u[1, :] .- pmean(sol_pestim1.ensemblesol[1])) > + mean(abs, u[1, :] .- pmean(sol_pestim2.ensemblesol[1])) + @test mean(abs, u[2, :] .- pmean(sol_pestim1.ensemblesol[2])) > + mean(abs, u[2, :] .- pmean(sol_pestim2.ensemblesol[2])) + + @test mean(abs2, u[1, :] .- pmean(sol_pestim2.ensemblesol[1])) < 5e-2 + @test mean(abs2, u[2, :] .- pmean(sol_pestim2.ensemblesol[2])) < 2e-2 - @test Loss_1 > Loss_2 + @test abs(sol_pestim2.estimated_de_params[1] - p[1]) < 0.05p[1] + @test abs(sol_pestim2.estimated_de_params[2] - p[2]) < 0.1p[2] end diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 8dea6fde0..4b63a8992 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -171,7 +171,6 @@ end @testitem "ODE Parameter Estimation" tags=[:nnode] begin using OrdinaryDiffEq, Random, Lux, OptimizationOptimJL, LineSearches - Random.seed!(100) function lorenz(u, p, t) @@ -179,24 +178,90 @@ end u[1] * (p[2] - u[3]) - u[2], u[1] * u[2] - p[3] * u[3]] end - prob = ODEProblem(lorenz, [1.0, 0.0, 0.0], (0.0, 1.0), [1.0, 1.0, 1.0]) + tspan = (0.0, 1.0) + prob = ODEProblem(lorenz, [1.0, 0.0, 0.0], tspan, [1.0, 1.0, 1.0]) true_p = [2.0, 3.0, 2.0] prob2 = remake(prob, p = true_p) + n = 8 + luxchain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, 3)) sol = solve(prob2, Tsit5(); saveat = 0.01) t_ = sol.t - u_ = reduce(hcat, sol.u) - function additional_loss(phi, θ) - return sum(abs2, phi(t_, θ) .- u_) / 100 + u_ = sol.u + sol_points = hcat(u_...) + u1_ = [u_[i][1] for i in eachindex(t_)] + u2_ = [u_[i][2] for i in eachindex(t_)] + u3_ = [u_[i][3] for i in eachindex(t_)] + dataset = [u1_, u2_, u3_, t_, ones(length(t_))] + + alg = NNODE(luxchain, BFGS(linesearch = BackTracking()); + strategy = GridTraining(0.01), dataset = dataset, + param_estim = true) + sol = solve(prob, alg; verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) + + @test sol.k.u.p≈true_p atol=1e-2 + @test reduce(hcat, sol.u)≈sol_points atol=1e-2 +end + +@testitem "ODE Parameter Estimation Improvement" tags=[:nnode] begin + using OrdinaryDiffEq, Random, Lux, OptimizationOptimJL, LineSearches + using FastGaussQuadrature + Random.seed!(100) + + function lorenz(u, p, t) + return [p[1] * (u[2] - u[1]), + u[1] * (p[2] - u[3]) - u[2], + u[1] * u[2] - p[3] * u[3]] end + tspan = (0.0, 5.0) + prob = ODEProblem(lorenz, [1.0, 0.0, 0.0], tspan, [-10.0, -10.0, -10.0]) + true_p = [2.0, 3.0, 2.0] + prob2 = remake(prob, p = true_p) n = 8 luxchain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, 3)) - alg = NNODE(luxchain, BFGS(linesearch = BackTracking()); strategy = GridTraining(0.01), - param_estim = true, additional_loss) + # this example is especially easy for new loss. + # even with ~2 observed data points, we can exactly calculate the physics parameters (even before solve call). + N = 7 + # x, w = gausslegendre(N) # does not include endpoints + x, w = gausslobatto(N) + # x, w = clenshaw_curtis(N) + a = tspan[1] + b = tspan[2] + + # transform the roots and weights + # x = map((x) -> (2 * (t - a) / (b - a)) - 1, x) + t = map((x) -> (x * (b - a) + (b + a)) / 2, x) + W = map((x) -> x * (b - a) / 2, w) + sol = solve(prob2, Tsit5(); saveat = t) + t_ = sol.t + u_ = sol.u + u1_ = [u_[i][1] for i in eachindex(t_)] + u2_ = [u_[i][2] for i in eachindex(t_)] + u3_ = [u_[i][3] for i in eachindex(t_)] + dataset = [u1_, u2_, u3_, t_, W] + + alg_old = NNODE(luxchain, BFGS(linesearch = BackTracking()); + strategy = GridTraining(0.01), dataset = dataset, + param_estim = true) + sol_old = solve( + prob, alg_old; verbose = false, abstol = 1e-12, maxiters = 2000, saveat = 0.01) + + alg_new = NNODE( + luxchain, BFGS(linesearch = BackTracking()); strategy = GridTraining(0.01), + param_estim = true, dataset = dataset, estim_collocate = true) + sol_new = solve( + prob, alg_new; verbose = false, abstol = 1e-12, maxiters = 2000, saveat = 0.01) - sol = solve(prob, alg; verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) - @test sol.k.u.p≈true_p atol=1e-2 - @test reduce(hcat, sol.u)≈u_ atol=1e-2 + sol = solve(prob2, Tsit5(); saveat = 0.01) + sol_points = hcat(sol.u...) + sol_old_points = hcat(sol_old.u...) + sol_new_points = hcat(sol_new.u...) + + @test !isapprox(sol_old.k.u.p, true_p; atol = 10) + @test !isapprox(sol_old_points, sol_points; atol = 10) + + @test sol_new.k.u.p≈true_p atol=1e-2 + @test sol_new_points≈sol_points atol=3e-2 end @testitem "ODE Complex Numbers" tags=[:nnode] begin diff --git a/test/direct_function_tests.jl b/test/direct_function_tests.jl index 8aea367b8..723276bc5 100644 --- a/test/direct_function_tests.jl +++ b/test/direct_function_tests.jl @@ -152,4 +152,4 @@ end prob = discretize(pde_system, discretization) @test_throws MethodError solve(prob, Adam(0.05), maxiters = 10) end -end \ No newline at end of file +end