Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing a function to be called multiple times with different inputs #627

Draft
wants to merge 49 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ed280a6
Modified 1D approx test to show get_argument bug
nicholaskl97 Sep 2, 2022
1c0f0d0
Updated get_argument for eval with multiple inputs
nicholaskl97 Oct 27, 2022
63e0ddc
Forced get_argument when strategy != Quadrature
nicholaskl97 Oct 27, 2022
4e7b1b8
Test file for fixing get_argument
nicholaskl97 Oct 27, 2022
8a612dc
Test file for debugging symbolic_discretize
nicholaskl97 Oct 27, 2022
83e2475
transform_expression uses indvars now
nicholaskl97 Dec 31, 2022
13df657
Some test files
nicholaskl97 Dec 31, 2022
b17f92b
Merge branch 'master' into get_argument-fix
nicholaskl97 Dec 31, 2022
e885f45
Reverted get_argument to original state
nicholaskl97 Dec 31, 2022
74a2749
Removed temporary debug files
nicholaskl97 Dec 31, 2022
d0df2a3
Updated _vcat to accept multiple arguments
nicholaskl97 Jan 1, 2023
41a75f6
get_argument returns all args no just first per eq
nicholaskl97 Jan 12, 2023
c5d9960
Added implicit 1D and another 2D test case
nicholaskl97 Jan 12, 2023
64b56de
generate gridtrain trainsets based of pde vars
nicholaskl97 Jan 12, 2023
55fa847
added OptimJL and OptimOptimisers
nicholaskl97 Jan 12, 2023
b7e3d7a
get_bounds works with new transform_expression
nicholaskl97 Jan 12, 2023
fb199e4
Added test of ODE with hard constraint ic
nicholaskl97 Jan 12, 2023
2572dbf
_vcat now fills out scalar inputs to match batches
nicholaskl97 Jan 24, 2023
3e36fbe
cord now only has variables that show up in the eq
nicholaskl97 Jan 26, 2023
d115eae
GridTraining train_sets now work on the GPU
nicholaskl97 Feb 7, 2023
abb85a8
_vcat maintains Array types when filling
nicholaskl97 Feb 7, 2023
c7d3dc5
Formatting change
nicholaskl97 Feb 7, 2023
d9da546
StochasticTraining now actually uses bcs_points
nicholaskl97 Feb 17, 2023
18338d3
get_bounds uses bcs_points
nicholaskl97 Feb 17, 2023
cee31db
get_bounds uses get_variables
nicholaskl97 Feb 17, 2023
ea1c3b0
Merge branch 'master' into master
nicholaskl97 Feb 17, 2023
be3abf1
Increased test number of points
nicholaskl97 Feb 20, 2023
308454c
get_bounds is now okay with eqs with no variables
nicholaskl97 Feb 20, 2023
09b6cf6
symbolic_utilities doesn't need LinearAlgebra
nicholaskl97 Feb 20, 2023
6e4206b
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Feb 21, 2023
55d142a
Can now handle Ix(u(x,1)) and not just Ix(u(x,y))
nicholaskl97 Feb 21, 2023
a9b6b47
import ComponentArrays used in training_strategies
nicholaskl97 Feb 21, 2023
f815469
Added import ComponentArrays statements
nicholaskl97 Feb 22, 2023
5889a1b
Revert "Added import ComponentArrays statements"
nicholaskl97 Feb 22, 2023
424a7ef
Revert "import ComponentArrays used in training_strategies"
nicholaskl97 Feb 22, 2023
d581889
Revert "added OptimJL and OptimOptimisers"
nicholaskl97 Feb 22, 2023
edcb1a7
Replaced Lux.ComponentArray with using Co...Arrays
nicholaskl97 Feb 22, 2023
b07ae13
Formatted with JuliaFormtter
nicholaskl97 Feb 23, 2023
7a1e0b5
Docstrings were counting against code coverage
nicholaskl97 Mar 7, 2023
7f527c7
Improperly used docstrings changed to comments
nicholaskl97 Mar 8, 2023
530d50e
Added comments for _vcat
nicholaskl97 Mar 8, 2023
48c8b04
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Mar 8, 2023
e4f1536
Updated docstring for build_symbolic_loss_function
nicholaskl97 Mar 9, 2023
238b315
Reductions needed inits for cases like u(0)=0
nicholaskl97 Mar 10, 2023
44f3a28
Formatted with JuliaFormatter
nicholaskl97 Mar 10, 2023
fc7d36c
Added a new integral test
nicholaskl97 Apr 3, 2023
550ab40
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Apr 3, 2023
4dcf2a8
Merge remote-tracking branch 'origin/master'
nicholaskl97 May 29, 2023
00f07fc
Merge remote-tracking branch 'origin/master'
nicholaskl97 Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input)
this_eq_indvars = unique(vcat(values(this_eq_pair)...))
else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring above. What does the code look like now?

this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars],
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => filter(arg -> !isempty(find_thing_in_expr(integrand,
arg)),
dict_depvar_input[intvars]),
integrating_depvars))
this_eq_indvars = transformation_vars isa Nothing ?
unique(vcat(values(this_eq_pair)...)) : transformation_vars
Expand Down Expand Up @@ -142,17 +144,10 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
vcat_expr = Expr(:block, :($(eq_pair_expr...)))
vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename

if strategy isa QuadratureTraining
indvars_ex = get_indvars_ex(bc_indvars)
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))
else
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)]
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))
end
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)]
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))

if !(dict_transformation_vars isa Nothing)
transformation_expr_ = Expr[]
Expand Down Expand Up @@ -256,7 +251,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
hcat(vec(map(points -> collect(points),
Iterators.product(bc_data...)))...))

pde_train_sets = map(pde_args) do bt
pde_train_sets = map(pde_vars) do bt
span = map(b -> get(dict_var_span_, b, b), bt)
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
Expand Down Expand Up @@ -292,7 +287,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars,
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains])

pde_args = get_argument(eqs, dict_indvars, dict_depvars)
pde_args = get_variables(eqs, dict_indvars, dict_depvars)

pde_lower_bounds = map(pde_args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
Expand Down Expand Up @@ -325,19 +320,33 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str
] for d in domains])

# pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains]
pde_args = get_argument(eqs, dict_indvars, dict_depvars)
pde_bounds = map(pde_args) do pde_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
pde_bounds = map(pde_vars) do pde_var
if !isempty(pde_var)
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_var)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
else
[eltypeθ(0.0)], [eltypeθ(0.0)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what case is this handling?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, it was the case of something like $u(0)=0$. The parser will now just make that into something like u(_vcat(0)) .- 0 instead of u(cord1) .- 0. Since that doesn't have a variable in it, we don't bother making training data for that expression (we can evaluate it without a training set). However, if you passed empty arrays along, then it would error, so instead we're just giving it 0 as both the upper and lower bounds, which don't really have any meaning since there aren't any variables that range between the bounds.

end
end

bound_args = get_argument(bcs, dict_indvars, dict_depvars)
bcs_bounds = map(bound_args) do bound_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
dx_bcs = 1 / strategy.bcs_points
dict_span_bcs = Dict([Symbol(d.variables) => [
infimum(d.domain) + dx_bcs,
supremum(d.domain) - dx_bcs,
] for d in domains])
bound_vars = get_variables(bcs, dict_indvars, dict_depvars)
bcs_bounds = map(bound_vars) do bound_var
if !isempty(bound_var)
bds = mapreduce(s -> get(dict_span_bcs, s, fill(s, 2)), hcat, bound_var)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
else
[eltypeθ(0.0)], [eltypeθ(0.0)]
end
end

return pde_bounds, bcs_bounds
end

Expand Down
78 changes: 62 additions & 16 deletions src/symbolic_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,25 @@ julia> _dot_(e)
dottable_(x) = Broadcast.dottable(x)
dottable_(x::Function) = true

_vcat(x::Number...) = vcat(x...)
_vcat(x::AbstractArray{<:Number}...) = vcat(x...)
# If the arguments are a mix of numbers and matrices/vectors/arrays,
# the numbers need to be copied for the dimensions to match
function _vcat(x::Union{Number, AbstractArray{<:Number}}...)
example = first(Iterators.filter(e -> !(e isa Number), x))
dims = (1, size(example)[2:end]...)
x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x)
_vcat(x...)
end
_vcat(x...) = vcat(x...)
dottable_(x::typeof(_vcat)) = false

_dot_(x) = x
function _dot_(x::Expr)
dotargs = Base.mapany(_dot_, x.args)
if x.head === :call && dottable_(x.args[1])
if x.head === :call && x.args[1] === :_vcat
Expr(x.head, dotargs...)
elseif x.head === :call && dottable_(x.args[1])
Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...))
elseif x.head === :comparison
Expr(:comparison,
Expand Down Expand Up @@ -128,14 +143,15 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
if e in keys(dict_depvars)
depvar = _args[1]
num_depvar = dict_depvars[depvar]
indvars = _args[2:end]
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_),
_args[2:end])
var_ = is_integral ? :(u) : :($(Expr(:$, :u)))
ex.args = if !multioutput
[var_, Symbol(:cord, num_depvar), :($θ), :phi]
[var_, :((_vcat)($(indvars...))), :($θ), :phi]
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
else
[
var_,
Symbol(:cord, num_depvar),
:((_vcat)($(indvars...))),
Symbol(:($θ), num_depvar),
Symbol(:phi, num_depvar),
]
Expand All @@ -151,7 +167,8 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
end
depvar = _args[1]
num_depvar = dict_depvars[depvar]
indvars = _args[2:end]
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_),
_args[2:end])
dict_interior_indvars = Dict([indvar .=> j
for (j, indvar) in enumerate(dict_depvar_input[depvar])])
dim_l = length(dict_interior_indvars)
Expand All @@ -162,13 +179,13 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
εs_dnv = [εs[d] for d in undv]

ex.args = if !multioutput
[var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)]
[var_, :phi, :u, :((_vcat)($(indvars...))), εs_dnv, order, :($θ)]
else
[
var_,
Symbol(:phi, num_depvar),
:u,
Symbol(:cord, num_depvar),
:((_vcat)($(indvars...))),
εs_dnv,
order,
Symbol(:($θ), num_depvar),
Expand Down Expand Up @@ -336,7 +353,8 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input)
expr = toexpr(eq)
pair_ = map(depvars) do depvar
if !isempty(find_thing_in_expr(expr, depvar))
dict_depvars[depvar] => dict_depvar_input[depvar]
dict_depvars[depvar] => filter(arg -> !isempty(find_thing_in_expr(expr, arg)),
dict_depvar_input[depvar])
end
end
Dict(filter(p -> p !== nothing, pair_))
Expand Down Expand Up @@ -419,6 +437,13 @@ function find_thing_in_expr(ex::Expr, thing; ans = [])
return collect(Set(ans))
end

function find_thing_in_expr(ex::Symbol, thing::Symbol; ans = [])
if thing == ex
push!(ans, ex)
end
return ans
end

"""
```julia
get_argument(eqs,_indvars::Array,_depvars::Array)
Expand All @@ -435,27 +460,48 @@ function get_argument(eqs, _indvars::Array, _depvars::Array)
get_argument(eqs, dict_indvars, dict_depvars)
end
function get_argument(eqs, dict_indvars, dict_depvars)
"Equations, as expressions"
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
exprs = toexpr.(eqs)
vars = map(exprs) do expr
"Instances of each dependent variable that appears in the expression, by dependent variable, by equation"
vars = map(exprs) do expr # For each equation,...
"Arrays of instances of each dependent variable, by dependent variable"
_vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars)))
"Arrays of instances of each dependent variable that appears in the expression, by dependent variable"
f_vars = filter(x -> !isempty(x), _vars)
map(x -> first(x), f_vars)
end
# vars = [depvar for expr in vars for depvar in expr ]
args_ = map(vars) do _vars
ind_args_ = map(var -> var.args[2:end], _vars)
"Arguments of all instances of dependent variable, by instance, by dependent variable"
ind_args_ = map.(var -> var.args[2:end], _vars)
syms = Set{Symbol}()
filter(vcat(ind_args_...)) do ind_arg
"All arguments in any instance of a dependent variable"
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
all_ind_args = vcat((ind_args_...)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce(vcat,...)


# Add any independent variables from expression dependent variable calls
for ind_arg in all_ind_args
if ind_arg isa Expr
for ind_var in collect(keys(dict_indvars))
if !isempty(NeuralPDE.find_thing_in_expr(ind_arg, ind_var))
push!(all_ind_args, ind_var)
end
end
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh this looks like the key.


filter(all_ind_args) do ind_arg # For each argument
if ind_arg isa Symbol
if ind_arg ∈ syms
false
false # remove symbols that have already occurred
else
push!(syms, ind_arg)
true
true # keep symbols that haven't occurred yet, but note their occurance
end
elseif ind_arg isa Expr # we've already taken what we wanted from the expressions
false
else
true
true # keep all non-symbols
end
end
end
return args_ # TODO for all arguments
return args_
end
15 changes: 9 additions & 6 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
datafree_bc_loss_function)
@unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep
dx = strategy.dx
eltypeθ = eltype(pinnrep.flat_init_params)
eltypeθ = eltype(flat_init_params)

train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ,
dict_indvars, dict_depvars)

# the points in the domain and on the boundary
pde_train_sets, bcs_train_sets = train_sets

pde_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)),
pde_train_sets)
bcs_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)),
bcs_train_sets)

pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_pde_loss_function,
pde_train_sets)]
Expand Down Expand Up @@ -88,19 +90,20 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
strategy)
pde_bounds, bcs_bounds = bounds

pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy)
pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy,
strategy.points)
for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)]

bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy)
bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy,
strategy.bcs_points)
for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)]

pde_loss_functions, bc_loss_functions
end

function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining;
function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining,
points = strategy.points;
τ = nothing)
points = strategy.points

loss = (θ) -> begin
sets = generate_random_points(points, bound, eltypeθ)
sets_ = adapt(parameterless_type(ComponentArrays.getdata(θ)), sets)
Expand Down
34 changes: 34 additions & 0 deletions test/IDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,40 @@ u_real = [x^2 / cos(x) for x in xs]
# plot(xs,u_real)
# plot!(xs,u_predict)

## Simple Integral Test 2
println("Simple Integral Test 2")

@parameters x y
@variables u(..)
Ix = Integral(x in DomainSets.ClosedInterval(0, x))
# eq = Ix(u(x, y) * cos(x)) ~ y * (x^3) / 3 # This is the same, but we're testing the parsing of the version below
eqs = [Ix(u(x, 1) * cos(x)) ~ (x^3) / 3,
u(x, y) ~ y * u(x, 1.0)]

bcs = [u(0.0, y) ~ 0.0]
domains = [x ∈ Interval(0.0, 1.00),
y ∈ Interval(0.5, 2.00)]
# chain = Chain(Dense(1,15,Flux.σ),Dense(15,1))
chain = Lux.Chain(Lux.Dense(2, 15, Flux.σ), Lux.Dense(15, 1))
strategy_ = NeuralPDE.GridTraining(0.1)
discretization = NeuralPDE.PhysicsInformedNN(chain,
strategy_)
@named pde_system = PDESystem(eqs, bcs, domains, [x, y], [u(x, y)])
prob = NeuralPDE.discretize(pde_system, discretization)
sym_prob = NeuralPDE.symbolic_discretize(pde_system, discretization)
res = Optimization.solve(prob, OptimizationOptimJL.BFGS(); callback = callback,
maxiters = 500)
xys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
xs = Iterators.product(xys...)
phi = discretization.phi
u_predict = [first(phi([x...], res.minimizer)) for x in xs]
u_real = [x[1]^2 / cos(x[1]) * x[2] for x in xs]
@test Flux.mse(u_real, u_predict) < 0.001

# p1 = plot(xys..., u_real', linetype=:contourf, title="Analytic");
# p2 = plot(xys..., u_predict', linetype=:contourf, title="Predicted");
# plot(p1,p2)

#simple multidimensitonal integral test
println("simple multidimensitonal integral test")

Expand Down
Loading