-
-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing a function to be called multiple times with different inputs #627
base: master
Are you sure you want to change the base?
Changes from 39 commits
ed280a6
1c0f0d0
63e0ddc
4e7b1b8
8a612dc
83e2475
13df657
b17f92b
e885f45
74a2749
d0df2a3
41a75f6
c5d9960
64b56de
55fa847
b7e3d7a
fb199e4
2572dbf
3e36fbe
d115eae
abb85a8
c7d3dc5
d9da546
18338d3
cee31db
ea1c3b0
be3abf1
308454c
09b6cf6
6e4206b
55d142a
a9b6b47
f815469
5889a1b
424a7ef
d581889
edcb1a7
b07ae13
7a1e0b5
7f527c7
530d50e
48c8b04
e4f1536
238b315
44f3a28
fc7d36c
550ab40
4dcf2a8
00f07fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; | |
this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input) | ||
this_eq_indvars = unique(vcat(values(this_eq_pair)...)) | ||
else | ||
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], | ||
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => filter(arg -> !isempty(find_thing_in_expr(integrand, | ||
arg)), | ||
dict_depvar_input[intvars]), | ||
integrating_depvars)) | ||
this_eq_indvars = transformation_vars isa Nothing ? | ||
unique(vcat(values(this_eq_pair)...)) : transformation_vars | ||
|
@@ -142,17 +144,10 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; | |
vcat_expr = Expr(:block, :($(eq_pair_expr...))) | ||
vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename | ||
|
||
if strategy isa QuadratureTraining | ||
indvars_ex = get_indvars_ex(bc_indvars) | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
else | ||
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
end | ||
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
|
||
if !(dict_transformation_vars isa Nothing) | ||
transformation_expr_ = Expr[] | ||
|
@@ -256,7 +251,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D | |
hcat(vec(map(points -> collect(points), | ||
Iterators.product(bc_data...)))...)) | ||
|
||
pde_train_sets = map(pde_args) do bt | ||
pde_train_sets = map(pde_vars) do bt | ||
span = map(b -> get(dict_var_span_, b, b), bt) | ||
_set = adapt(eltypeθ, | ||
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) | ||
|
@@ -292,7 +287,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, | |
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) | ||
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) | ||
|
||
pde_args = get_argument(eqs, dict_indvars, dict_depvars) | ||
pde_args = get_variables(eqs, dict_indvars, dict_depvars) | ||
|
||
pde_lower_bounds = map(pde_args) do pd | ||
span = map(p -> get(dict_lower_bound, p, p), pd) | ||
|
@@ -325,19 +320,33 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str | |
] for d in domains]) | ||
|
||
# pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] | ||
pde_args = get_argument(eqs, dict_indvars, dict_depvars) | ||
pde_bounds = map(pde_args) do pde_arg | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
pde_vars = get_variables(eqs, dict_indvars, dict_depvars) | ||
pde_bounds = map(pde_vars) do pde_var | ||
if !isempty(pde_var) | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_var) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
else | ||
[eltypeθ(0.0)], [eltypeθ(0.0)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what case is this handling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remember correctly, it was the case of something like |
||
end | ||
end | ||
|
||
bound_args = get_argument(bcs, dict_indvars, dict_depvars) | ||
bcs_bounds = map(bound_args) do bound_arg | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
dx_bcs = 1 / strategy.bcs_points | ||
dict_span_bcs = Dict([Symbol(d.variables) => [ | ||
infimum(d.domain) + dx_bcs, | ||
supremum(d.domain) - dx_bcs, | ||
] for d in domains]) | ||
bound_vars = get_variables(bcs, dict_indvars, dict_depvars) | ||
bcs_bounds = map(bound_vars) do bound_var | ||
if !isempty(bound_var) | ||
bds = mapreduce(s -> get(dict_span_bcs, s, fill(s, 2)), hcat, bound_var) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
else | ||
[eltypeθ(0.0)], [eltypeθ(0.0)] | ||
end | ||
end | ||
|
||
return pde_bounds, bcs_bounds | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,10 +19,25 @@ julia> _dot_(e) | |
dottable_(x) = Broadcast.dottable(x) | ||
dottable_(x::Function) = true | ||
|
||
_vcat(x::Number...) = vcat(x...) | ||
_vcat(x::AbstractArray{<:Number}...) = vcat(x...) | ||
# If the arguments are a mix of numbers and matrices/vectors/arrays, | ||
# the numbers need to be copied for the dimensions to match | ||
function _vcat(x::Union{Number, AbstractArray{<:Number}}...) | ||
example = first(Iterators.filter(e -> !(e isa Number), x)) | ||
dims = (1, size(example)[2:end]...) | ||
x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x) | ||
_vcat(x...) | ||
end | ||
_vcat(x...) = vcat(x...) | ||
dottable_(x::typeof(_vcat)) = false | ||
|
||
_dot_(x) = x | ||
function _dot_(x::Expr) | ||
dotargs = Base.mapany(_dot_, x.args) | ||
if x.head === :call && dottable_(x.args[1]) | ||
if x.head === :call && x.args[1] === :_vcat | ||
Expr(x.head, dotargs...) | ||
elseif x.head === :call && dottable_(x.args[1]) | ||
Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) | ||
elseif x.head === :comparison | ||
Expr(:comparison, | ||
|
@@ -128,14 +143,15 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa | |
if e in keys(dict_depvars) | ||
depvar = _args[1] | ||
num_depvar = dict_depvars[depvar] | ||
indvars = _args[2:end] | ||
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_), | ||
_args[2:end]) | ||
var_ = is_integral ? :(u) : :($(Expr(:$, :u))) | ||
ex.args = if !multioutput | ||
[var_, Symbol(:cord, num_depvar), :($θ), :phi] | ||
[var_, :((_vcat)($(indvars...))), :($θ), :phi] | ||
nicholaskl97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else | ||
[ | ||
var_, | ||
Symbol(:cord, num_depvar), | ||
:((_vcat)($(indvars...))), | ||
Symbol(:($θ), num_depvar), | ||
Symbol(:phi, num_depvar), | ||
] | ||
|
@@ -151,7 +167,8 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa | |
end | ||
depvar = _args[1] | ||
num_depvar = dict_depvars[depvar] | ||
indvars = _args[2:end] | ||
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_), | ||
_args[2:end]) | ||
dict_interior_indvars = Dict([indvar .=> j | ||
for (j, indvar) in enumerate(dict_depvar_input[depvar])]) | ||
dim_l = length(dict_interior_indvars) | ||
|
@@ -162,13 +179,13 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa | |
εs_dnv = [εs[d] for d in undv] | ||
|
||
ex.args = if !multioutput | ||
[var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)] | ||
[var_, :phi, :u, :((_vcat)($(indvars...))), εs_dnv, order, :($θ)] | ||
else | ||
[ | ||
var_, | ||
Symbol(:phi, num_depvar), | ||
:u, | ||
Symbol(:cord, num_depvar), | ||
:((_vcat)($(indvars...))), | ||
εs_dnv, | ||
order, | ||
Symbol(:($θ), num_depvar), | ||
|
@@ -336,7 +353,8 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) | |
expr = toexpr(eq) | ||
pair_ = map(depvars) do depvar | ||
if !isempty(find_thing_in_expr(expr, depvar)) | ||
dict_depvars[depvar] => dict_depvar_input[depvar] | ||
dict_depvars[depvar] => filter(arg -> !isempty(find_thing_in_expr(expr, arg)), | ||
dict_depvar_input[depvar]) | ||
end | ||
end | ||
Dict(filter(p -> p !== nothing, pair_)) | ||
|
@@ -419,6 +437,13 @@ function find_thing_in_expr(ex::Expr, thing; ans = []) | |
return collect(Set(ans)) | ||
end | ||
|
||
function find_thing_in_expr(ex::Symbol, thing::Symbol; ans = []) | ||
if thing == ex | ||
push!(ans, ex) | ||
end | ||
return ans | ||
end | ||
|
||
""" | ||
```julia | ||
get_argument(eqs,_indvars::Array,_depvars::Array) | ||
|
@@ -435,27 +460,48 @@ function get_argument(eqs, _indvars::Array, _depvars::Array) | |
get_argument(eqs, dict_indvars, dict_depvars) | ||
end | ||
function get_argument(eqs, dict_indvars, dict_depvars) | ||
"Equations, as expressions" | ||
nicholaskl97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
exprs = toexpr.(eqs) | ||
vars = map(exprs) do expr | ||
"Instances of each dependent variable that appears in the expression, by dependent variable, by equation" | ||
vars = map(exprs) do expr # For each equation,... | ||
"Arrays of instances of each dependent variable, by dependent variable" | ||
_vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars))) | ||
"Arrays of instances of each dependent variable that appears in the expression, by dependent variable" | ||
f_vars = filter(x -> !isempty(x), _vars) | ||
map(x -> first(x), f_vars) | ||
end | ||
# vars = [depvar for expr in vars for depvar in expr ] | ||
args_ = map(vars) do _vars | ||
ind_args_ = map(var -> var.args[2:end], _vars) | ||
"Arguments of all instances of dependent variable, by instance, by dependent variable" | ||
ind_args_ = map.(var -> var.args[2:end], _vars) | ||
syms = Set{Symbol}() | ||
filter(vcat(ind_args_...)) do ind_arg | ||
"All arguments in any instance of a dependent variable" | ||
nicholaskl97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all_ind_args = vcat((ind_args_...)...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Add any independent variables from expression dependent variable calls | ||
for ind_arg in all_ind_args | ||
if ind_arg isa Expr | ||
for ind_var in collect(keys(dict_indvars)) | ||
if !isempty(NeuralPDE.find_thing_in_expr(ind_arg, ind_var)) | ||
push!(all_ind_args, ind_var) | ||
end | ||
end | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh this looks like the key. |
||
|
||
filter(all_ind_args) do ind_arg # For each argument | ||
if ind_arg isa Symbol | ||
if ind_arg ∈ syms | ||
false | ||
false # remove symbols that have already occurred | ||
else | ||
push!(syms, ind_arg) | ||
true | ||
true # keep symbols that haven't occurred yet, but note their occurance | ||
end | ||
elseif ind_arg isa Expr # we've already taken what we wanted from the expressions | ||
false | ||
else | ||
true | ||
true # keep all non-symbols | ||
end | ||
end | ||
end | ||
return args_ # TODO for all arguments | ||
return args_ | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the docstring above. What does the code look like now?