Skip to content

New loss #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
29e123e
Too Easy.
AstitvaAggarwal Apr 13, 2025
6bf8c49
remove Integrals from test deps
AstitvaAggarwal Apr 13, 2025
45c1e39
tests. (new is better in more noise)
AstitvaAggarwal Apr 13, 2025
83a642d
minor change
AstitvaAggarwal Apr 13, 2025
5cfc6e7
.
AstitvaAggarwal Apr 13, 2025
d8f602b
.
AstitvaAggarwal Apr 13, 2025
30fa615
Likelihood probabilites are not driven to 0.
AstitvaAggarwal Apr 14, 2025
85b350c
.
AstitvaAggarwal Apr 14, 2025
3c59fde
more samples
AstitvaAggarwal Apr 14, 2025
f2aafdb
.
AstitvaAggarwal Apr 14, 2025
70e956d
fixed tests
AstitvaAggarwal Apr 14, 2025
41691eb
tests.
AstitvaAggarwal Apr 16, 2025
f21969c
.
AstitvaAggarwal Apr 16, 2025
a968cc8
std for new loss is parametric
AstitvaAggarwal Apr 26, 2025
c66772a
Changes to API
AstitvaAggarwal Apr 26, 2025
2fcf75e
tests.
AstitvaAggarwal Apr 27, 2025
cb85411
tests-2
AstitvaAggarwal Apr 27, 2025
2423224
tests-3
AstitvaAggarwal Apr 27, 2025
cc1348f
Update BPINN_tests.jl
AstitvaAggarwal Apr 27, 2025
4a1ca0e
BPINN_PDE loss corrected
AstitvaAggarwal Apr 27, 2025
107165e
NNODE improvements & L2Data!=additional_loss
AstitvaAggarwal May 3, 2025
a4d1fb7
spelling check
AstitvaAggarwal May 3, 2025
d47c19c
tests
AstitvaAggarwal May 3, 2025
0bc0ec1
tests-1
AstitvaAggarwal May 3, 2025
65c4b08
tests-3
AstitvaAggarwal May 5, 2025
8ed4e18
format
AstitvaAggarwal May 6, 2025
c242419
Update src/BPINN_ode.jl
AstitvaAggarwal May 6, 2025
95140dd
cubature over L2 instead of L1
AstitvaAggarwal May 6, 2025
6cb24c5
Merge branch 'sdepinn' of https://github.com/AstitvaAggarwal/NeuralPD…
AstitvaAggarwal May 6, 2025
f24df29
bpinn remains in non squared within logpdf?
AstitvaAggarwal May 6, 2025
e8dfd9a
changes from reviews.
AstitvaAggarwal May 16, 2025
0b7123a
docstrings, support preexisting tutorials
AstitvaAggarwal May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Distributions = "0.25.107"
DocStringExtensions = "0.9.3"
DomainSets = "0.7"
ExplicitImports = "1.10.1"
FastGaussQuadrature = "1.0.2"
Flux = "0.14.22"
ForwardDiff = "0.10.36"
Functors = "0.4.12, 0.5"
Expand Down Expand Up @@ -116,6 +117,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -132,4 +134,4 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
6 changes: 3 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phystd = [0.05], phynewstd = (ode_params)->[0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Expand Down Expand Up @@ -86,7 +86,7 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
phynewstd
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs <: NamedTuple
Expand All @@ -103,7 +103,7 @@ end

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phynewstd = (ode_params) -> [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian,
using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf
using LogDensityProblems: LogDensityProblems
using MCMCChains: MCMCChains, Chains, sample
using MonteCarloMeasurements: Particles, pmean
using MonteCarloMeasurements: Particles

import LuxCore: initialparameters, initialstates, parameterlength

Expand Down
103 changes: 59 additions & 44 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
prob <: SciMLBase.ODEProblem
smodel <: StatefulLuxLayer
strategy <: AbstractTrainingStrategy
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}}
priors <: Vector{<:Distribution}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
phynewstd::Function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure how we can specialize functions...
(Im keeping a function for std in BPINNs as selecting the right std can be tricky and usually depends on the problem)

l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand Down Expand Up @@ -74,32 +74,37 @@ suggested extra loss function for ODE solver case
"""
@views function L2loss2(ltd::LogTargetDensity, θ)
ltd.extraparams ≤ 0 && return false # XXX: type-stability?

u0 = ltd.prob.u0
f = ltd.prob.f
t = ltd.dataset[end]
u1 = ltd.dataset[2]
= ltd.dataset[1]
t = ltd.dataset[end - 1]
= ltd.dataset[1:(end - 2)]
quadrature_weights = ltd.dataset[end]

nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff)

ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] :
θ[((length(θ) - ltd.extraparams) + 1):length(θ)]
phynewstd = ltd.phynewstd(ode_params)

physsol = if length(ltd.prob.u0) == 1
[f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
physsol = if length(u0) == 1
[f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
else
[f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
[f([û[j][i] for j in eachindex(u0)], ode_params, tᵢ)
for (i, tᵢ) in enumerate(t)]
end
# form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)
T = promote_type(eltype(deri_physsol), eltype(nnsol))

physlogprob = T(0)
for i in 1:length(ltd.prob.u0)
# for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero.
# Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t))
# dims of phynewstd is same as u0 due to BNNODE being an out-of-place ODE solver.
for i in eachindex(u0)
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
nnsol[i, :]
MvNormal((nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights,
Diagonal(abs2.(T(phynewstd[i]) .* ones(T, length(t))))),
zeros(length(t))
)
end
return physlogprob
Expand All @@ -109,10 +114,10 @@ end
L2 loss loglikelihood(needed for ODE parameter estimation).
"""
@views function L2LossData(ltd::LogTargetDensity, θ)
(ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0
(isempty(ltd.dataset) || ltd.extraparams == 0) && return 0

# matrix(each row corresponds to vector u's rows)
nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)])
nn = ltd(ltd.dataset[end - 1], θ[1:(length(θ) - ltd.extraparams)])
T = eltype(nn)

L2logprob = zero(T)
Expand Down Expand Up @@ -150,24 +155,26 @@ end
function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool,
tspan, ode_params, θ)
ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2])
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity,
f, autodiff::Bool, tspan, ode_params, θ)
T = promote_type(eltype(tspan[1]), eltype(tspan[2]))
samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1]
t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end])
t = isempty(ltd.dataset) ? samples : vcat(samples, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool,
tspan, ode_params, θ)
# integrand is shape of NN output
integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params)
intprob = IntegralProblem(
integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0))
sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol)
# sum over losses for all NN outputs
return sum(sol.u)
end

Expand All @@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f,
append!(ts, temp_data)
end

t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

Expand All @@ -202,23 +209,21 @@ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN

# this is a vector{vector{dx,dy}}(handle case single u(float passed))
if length(out[:, 1]) == 1
physsol = [f(out[:, i][1], ode_params, t[i]) for i in 1:length(out[1, :])]
physsol = [f(out[:, i][1], ode_params, t[i]) for i in eachindex(t)]
else
physsol = [f(out[:, i], ode_params, t[i]) for i in 1:length(out[1, :])]
physsol = [f(out[:, i], ode_params, t[i]) for i in eachindex(t)]
end
physsol = reduce(hcat, physsol)

nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], autodiff)

vals = nnsol .- physsol
T = eltype(vals)
T = eltype(nnsol)

# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector
# of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(vals[i, :]))))),
zeros(T, length(vals[i, :]))
MvNormal((nnsol[i, :] .- physsol[i, :]),
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(t))))),
zeros(T, length(t))
) for i in 1:length(ltd.prob.u0)]
end

Expand Down Expand Up @@ -262,9 +267,9 @@ function kernelchoice(Kernel, MCMCkwargs)
end

"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)->[0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down Expand Up @@ -294,7 +299,7 @@ time = sol.t[1:100]

### dataset and BPINN create
x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u)))
dataset = [x̂, time]
dataset = [x̂, time, 0.05 .* ones(length(time))]

chain1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1)

Expand All @@ -318,26 +323,32 @@ fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chain1,

## NOTES

Dataset is required for accurate Parameter estimation + solving equations
Incase you are only solving the Equations for solution, do not provide dataset
Dataset is required for accurate Parameter estimation in Inverse Problems.
Incase you are only solving Non parametric ODE Equations for a solution, do not provide a dataset.

## Positional Arguments

* `prob`: DEProblem(out of place and the function signature should be f(u,p,t).
* `prob`: ODEProblem(out of place and the function signature should be f(u,p,t).
* `chain`: Lux Neural Netork which would be made the Bayesian PINN.

## Keyword Arguments

* `strategy`: The training strategy used to choose the points for the evaluations. By
default GridTraining is used with given physdt discretization.
* `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain.
The dataset is used to compute the L2 loss against the data and also for the new loss function.
For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`.
Is empty by default assuming a forward problem is being solved.
* `init_params`: initial parameter values for BPINN (ideally for multiple chains different
initializations preferred)
* `nchains`: number of chains you want to sample
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `phynewstd`: A function that gives the standard deviation of the new loss function at each iteration.
It takes the ODE parameters as input and returns a vector of standard deviations.
Is (ode_params) -> [0.05] by default.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
Expand All @@ -357,6 +368,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
* `max_depth`: Maximum doubling tree depth (NUTS)
* `Δ_max`: Maximum divergence during doubling tree (NUTS)
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation.
* `progress`: controls whether to show the progress meter or not.
* `verbose`: controls the verbosity. (Sample call args in AHMC)

Expand All @@ -366,9 +378,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
releases.
"""
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
phystd = [0.05], phynewstd = (ode_params) -> [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand All @@ -380,15 +393,15 @@ function ahmc_bayesian_pinn_ode(

strategy = strategy == GridTraining ? strategy(physdt) : strategy

if dataset != [nothing] &&
(length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")
if !isempty(dataset) &&
(length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. The dataset would be a timeseries (x̂,t,W) with type: Vector{Vector{AbstractFloat}}")
end

if dataset != [nothing] && param == []
println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.")
elseif dataset == [nothing] && param != []
error("Dataset Required for Parameter Estimation.")
if !isempty(dataset) && isempty(param)
println("Dataset is only needed for Inverse problems performing Parameter Estimation, not in only Forward Problem case.")
elseif isempty(dataset) && !isempty(param)
error("Dataset Required for Inverse problems performing Parameter Estimation.")
end

initial_nnθ, chain, st = generate_ltd(chain, init_params)
Expand Down Expand Up @@ -461,7 +474,8 @@ function ahmc_bayesian_pinn_ode(

MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
samples,
stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
progress = progress, verbose = verbose)

samplesc[i] = samples
Expand All @@ -479,7 +493,8 @@ function ahmc_bayesian_pinn_ode(

MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
samples,
stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

if verbose
Expand Down
Loading
Loading