Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff Decapode Gradient Fail Example #182

Open
jClugstor opened this issue Oct 1, 2024 · 2 comments
Open

ForwardDiff Decapode Gradient Fail Example #182

jClugstor opened this issue Oct 1, 2024 · 2 comments

Comments

@jClugstor
Copy link
Collaborator

jClugstor commented Oct 1, 2024

using Pkg
Pkg.activate(".")
Pkg.instantiate()

# AlgebraicJulia Dependencies
using Catlab
using Catlab.Graphics
using CombinatorialSpaces
using Decapodes
using ComponentArrays
using ForwardDiff
using Zygote

# External Dependencies
using MLStyle
using MultiScaleArrays
using LinearAlgebra
using OrdinaryDiffEq
using JLD2
using SparseArrays
using Statistics
#using GLMakie # Just for visualization
using GeometryBasics: Point2, Point3
Point2D = Point2{Float64};
Point3D = Point3{Float64};

using DiagrammaticEquations
using DiagrammaticEquations.Deca

@info("Packages Loaded")

halfar_eq2 = @decapode begin
    h::Form0
    Γ::Form1
    n::Constant== ∂ₜ(h)
    ḣ == (, d, )(Γ * d(h) * avg₀₁(mag((d(h)))^(n - 1)) * avg₀₁(h^(n + 2)))
end

glens_law = @decapode begin
    Γ::Form1
    (A, ρ, g, n)::Constant

    Γ == (2 / (n + 2)) * A ** g)^n
end

@info("Decapodes Defined")

ice_dynamics_composition_diagram = @relation () begin
    dynamics(Γ, n)
    stress(Γ, n)
end

ice_dynamics_cospan = oapply(ice_dynamics_composition_diagram,
    [Open(halfar_eq2, [, :n]),
        Open(glens_law, [, :n])])
ice_dynamics = apex(ice_dynamics_cospan)
ice_dynamics1D = expand_operators(ice_dynamics)
infer_types!(ice_dynamics1D, op1_inf_rules_1D, op2_inf_rules_1D)
resolve_overloads!(ice_dynamics1D, op1_res_rules_1D, op2_res_rules_1D)

s_prime = EmbeddedDeltaSet1D{Bool,Point2D}()
add_vertices!(s_prime, 100, point=Point2D.(range(-2, 2, length=100), 0))
add_edges!(s_prime, 1:nv(s_prime)-1, 2:nv(s_prime))
orient!(s_prime)
s = EmbeddedDeltaDualComplex1D{Bool,Float64,Point2D}(s_prime)
subdivide_duals!(s, Circumcenter())

@info("Spaces Defined")

function generate(sd, my_symbol; hodge=GeometricHodge())
    op = @match my_symbol begin
        :♯ => x -> begin
            # This is an implementation of the "sharp" operator from the exterior
            # calculus, which takes co-vector fields to vector fields.
            # This could be up-streamed to the CombinatorialSpaces.jl library. (i.e.
            # this operation is not bespoke to this simulation.)
            e_vecs = map(edges(sd)) do e
                point(sd, sd[e, :∂v0]) - point(sd, sd[e, :∂v1])
            end
            neighbors = map(vertices(sd)) do v
                union(incident(sd, v, :∂v0), incident(sd, v, :∂v1))
            end
            n_vecs = map(neighbors) do es
                [e_vecs[e] for e in es]
            end
            map(neighbors, n_vecs) do es, nvs
                sum([nv * norm(nv) * x[e] for (e, nv) in zip(es, nvs)]) / sum(norm.(nvs))
            end
        end
        :mag => x -> norm.(x)
        x => error("Unmatched operator $my_symbol")
    end
    return (args...) -> op(args...)
end

decapode_code = gensim(ice_dynamics1D, dimension=1, preallocate=true)
file = open("ice_sheet1D_alloc.jl", "w")
write(file, string("decapode_f = ", decapode_code))
close(file)
include("ice_sheet1D_alloc.jl")

fₘ = decapode_f(s, generate)


function f(constants_and_parameters)
    prob = ODEProblem{true,SciMLBase.FullSpecialize}(fₘ, u₀, (0, tₑ), constants_and_parameters)
    @info("Solving")
    soln = solve(prob, FBDF(autdodiff=false))
    @info("Done")

    # return soln(tₑ)
    sum(last(soln)) # last, not soln(tₑ) because to avoid interpolation fails when AD fails.
end

h₀ = map(x -> exp(-2*x[1]^2), point(s_prime))

flow_rate, ice_density, u_init_arr = 1e-3, 910.0, h₀
n = 3
ρ = ice_density
g = 9.8101
A = fill(flow_rate, ne(s))
tₑ = 5e9

u₀ = ComponentArray(dynamics_h=u_init_arr)

# Note that this must be a ComponentArray to differentiate
constants_and_parameters = ComponentArray(
    n=n,
    stress_ρ=ρ,
    stress_g=g,
    stress_A=A)

y = f(constants_and_parameters)
zygote_dy = Zygote.gradient(f,constants_and_parameters)

If you do FBDF() the forwarddiff fails. The forwarddiff also fails if the sensealg uses autodiff = true.

ERROR: ArgumentError: cannot reinterpret an `ForwardDiff.Dual{nothing, Float64, 11}` array to `ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UJacobianWrapper{true, ODEFunction{true, SciMLBase.FullSpecialize, var"#f#36"{PreallocationTools.FixedSizeDiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 12}}}, PreallocationTools.FixedSizeDiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 11}}}, PreallocationTools.FixedSizeDiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 11}}}, PreallocationTools.FixedSizeDiffCache{Vector{Float64}, Vector{ForwardDiff.Dual{nothing, Float64, 11}}}, SparseMatrixCSC{Float64, Int64}, var"#48#57"{var"#47#56"}, var"#48#57"{var"#40#49"{EmbeddedDeltaDualComplex1D{Bool, Float64, GeometryBasics.Point{2, Float64}}}}, Decapodes.var"#37#38"{SparseMatrixCSC{Float64, Int32}}, SparseMatrixCSC{Int8, Int32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(n = 1, stress_ρ = 2, stress_g = 3, stress_A = 4:102)}}}}, Float64}, Float64, 12}` whose first dimension has size `99`.
The resulting array would have non-integral first dimension.

Stacktrace:
  [1] (::Base.var"#thrownonint#336")(S::Type, T::Type, dim::Int64)
    @ Base ./reinterpretarray.jl:53
  [2] reinterpret
    @ ./reinterpretarray.jl:71 [inlined]
  [3] get_tmp(dc::PreallocationTools.FixedSizeDiffCache{…}, u::ComponentVector{…})
    @ PreallocationTools ~/.julia/packages/PreallocationTools/7dIFh/src/PreallocationTools.jl:59
  [4] (::var"#f#36"{})(du::ComponentVector{…}, u::ComponentVector{…}, p::ComponentVector{…}, t::Float64)
    @ Main ~/Documents/Work/dev/DecapodeCalibrateDemos/GlacialFlow/ice_sheet1D_alloc.jl:48
  [5] ODEFunction
    @ ~/.julia/packages/SciMLBase/EiBzT/src/scimlfunctions.jl:2335 [inlined]
  [6] UJacobianWrapper
    @ ~/.julia/packages/SciMLBase/EiBzT/src/function_wrappers.jl:32 [inlined]
  [7] chunk_mode_jacobian!(result::Matrix{…}, f!::SciMLBase.UJacobianWrapper{…}, y::ComponentVector{…}, x::ComponentVector{…}, cfg::ForwardDiff.JacobianConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/jacobian.jl:183
  [8] jacobian!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/jacobian.jl:80 [inlined]
  [9] jacobian!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/jacobian.jl:76 [inlined]
 [10] jacobian!
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:155 [inlined]
 [11] _vecjacobian!(dλ::SubArray{…}, y::ComponentVector{…}, λ::SubArray{…}, p::ComponentVector{…}, t::Float64, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…}, isautojacvec::Bool, dgrad::SubArray{…}, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:270
 [12] #vecjacobian!#18
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:232 [inlined]
 [13] vecjacobian!
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/derivative_wrappers.jl:229 [inlined]
 [14] (::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…})(du::Vector{…}, u::Vector{…}, p::ComponentVector{…}, t::Float64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/interpolating_adjoint.jl:138
 [15] ODEFunction
    @ ~/.julia/packages/SciMLBase/EiBzT/src/scimlfunctions.jl:2335 [inlined]
 [16] initialize!
    @ ~/.julia/packages/OrdinaryDiffEqBDF/J0IGS/src/bdf_perform_step.jl:1220 [inlined]
 [17] __init(prob::ODEProblem{…}, alg::FBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Rational{…}, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:525
 [18] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:11 [inlined]
 [19] __solve(::ODEProblem{…}, ::FBDF{…}; kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:6
 [20] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:1 [inlined]
 [21] solve_call(_prob::ODEProblem{…}, args::FBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:612
 [22] solve_call
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:569 [inlined]
 [23] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1080 [inlined]
 [24] solve_up
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1066 [inlined]
 [25] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [26] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::FBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:448
 [27] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:405 [inlined]
 [28] #adjoint_sensitivities#63
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:401 [inlined]
 [29] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#313"{})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:627
 [30] ZBack
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:212 [inlined]
 [31] (::Zygote.var"#294#295"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206
 [32] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [33] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [34] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [35] #294
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
 [36] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [37] solve
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
 [38] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [39] f
    @ ~/Documents/Work/dev/DecapodeCalibrateDemos/GlacialFlow/glacialflow1D_calibrate_alloc.jl:119 [inlined]
 [40] (::Zygote.Pullback{Tuple{typeof(f), ComponentVector{Float64, Vector{Float64}, Tuple{Axis{}}}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [41] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{typeof(f), ComponentVector{}}, Any}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91
 [42] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148
 [43] top-level scope
    @ ~/Documents/Work/dev/DecapodeCalibrateDemos/GlacialFlow/glacialflow1D_calibrate_alloc.jl:146
Some type information was truncated. Use `show(err)` to see complete types.
@jClugstor
Copy link
Collaborator Author

So Decapodes has two different code generation options, one that uses PreallocationTools and the DiffCaches, and one that doesn't.
The code that doesn't use PreallocationTools, we're able to do ForwardDiff through.

This is an example of code that does use PreallocationTools, and ForwardDiff doesn't work.

@jClugstor
Copy link
Collaborator Author

@ChrisRackauckas I think the next steps for Decapode calibration stuff is to make sure that ForwardDiff works in situations like this, and then making sure that sensealg = EnzymeVJP() works.

For context, I have a heat equation decapode example that uses PreallocationTools, ForwardDiff works, and sensealg = EnzymeVJP() works, but that example only has one parameter that I differentiate with respect to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant