Skip to content

Commit

Permalink
Merge pull request #645 from AayushSabharwal/as/discrete-save
Browse files Browse the repository at this point in the history
feat: add discrete saving feature to ODESolution
  • Loading branch information
ChrisRackauckas authored Jul 23, 2024
2 parents 71a578d + fe68aaa commit d7edbe7
Show file tree
Hide file tree
Showing 10 changed files with 763 additions and 102 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -62,6 +63,7 @@ DataFrames = "1.6"
Distributed = "1.10"
DocStringExtensions = "0.9"
EnumX = "1"
Expronicon = "0.8"
ForwardDiff = "0.10.36"
FunctionWrappersWrappers = "0.1.3"
IteratorInterfaceExtensions = "^1"
Expand All @@ -78,7 +80,7 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.14.0"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.22.0"
RecursiveArrayTools = "3.26.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
Expand All @@ -87,7 +89,7 @@ StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.20"
SymbolicIndexingInterface = "0.3.26"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
20 changes: 11 additions & 9 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import SciMLStructures
N = length((size(dprob.u0)..., length(du)))
end
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
Expand All @@ -60,7 +60,7 @@ end
T = eltype(eltype(VA.u))
N = ndims(VA)
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
Expand Down Expand Up @@ -117,9 +117,11 @@ end
elseif i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
VA = recursivecopy(VA)
recursivefill!(VA, zero(eltype(VA)))
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
copyto!(v, Δ)
(VA, nothing)
end
end
VA[sym], ODESolution_getindex_pullback
Expand Down Expand Up @@ -172,15 +174,15 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14
}(u,
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
, T14, T15}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12, T13, T14}
T9, T10, T11, T12, T13, T14, T15}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14}(u, args...),
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...),
ODESolutionAdjoint
end

Expand Down
4 changes: 4 additions & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import Accessors: @set, @reset
using Expronicon.ADT: @match

using Reexport
using SciMLOperators
Expand Down Expand Up @@ -717,6 +718,7 @@ include("problems/problem_traits.jl")
include("problems/problem_interface.jl")
include("problems/optimization_problems.jl")

include("clock.jl")
include("solutions/basic_solutions.jl")
include("solutions/nonlinear_solutions.jl")
include("solutions/ode_solutions.jl")
Expand Down Expand Up @@ -835,4 +837,6 @@ export step!, deleteat!, addat!, get_tmp_cache,

export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback

export Clocks, TimeDomain, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous

end
92 changes: 92 additions & 0 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
module Clocks

export TimeDomain

using Expronicon.ADT: @adt, @match

@adt TimeDomain begin
Continuous
struct PeriodicClock
dt::Union{Nothing, Float64, Rational{Int}}
phase::Float64 = 0.0
end
SolverStepClock
end

Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d)

end

using .Clocks

"""
Clock(dt)
Clock()
The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will
be inferred (if possible).
"""
Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase)
Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase)
Clock(; phase = 0.0) = PeriodicClock(nothing, phase)

@doc """
SolverStepClock
A clock that ticks at each solver step (sometimes referred to as "continuous sample time").
This clock **does generally not have equidistant tick intervals**, instead, the tick
interval depends on the adaptive step-size selection of the continuous solver, as well as
any continuous event handling. If adaptivity of the solver is turned off and there are no
continuous events, the tick interval will be given by the fixed solver time step `dt`.
Due to possibly non-equidistant tick intervals, this clock should typically not be used with
discrete-time systems that assume a fixed sample time, such as PID controllers and digital
filters.
""" SolverStepClock

isclock(c) = @match c begin
PeriodicClock(_...) => true
_ => false
end

issolverstepclock(c) = @match c begin
&SolverStepClock => true
_ => false
end

iscontinuous(c) = @match c begin
&Continuous => true
_ => false
end

is_discrete_time_domain(c) = !iscontinuous(c)

function first_clock_tick_time(c, t0)
@match c begin
PeriodicClock(dt, _...) => ceil(t0 / dt) * dt
&SolverStepClock => t0
&Continuous => error("Continuous is not a discrete clock")
end
end

struct IndexedClock{I}
clock::TimeDomain
idx::I
end

Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx)

function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution)
c = ic.clock

return @match c begin
PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
&SolverStepClock => begin
ssc_idx = findfirst(eachindex(sol.discretes)) do i
!isa(sol.discretes[i].t, AbstractRange)
end
sol.discretes[ssc_idx].t[ic.idx]
end
&Continuous => sol.t[ic.idx]
end
end
4 changes: 4 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ end
anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

function _updated_u0_p_internal(
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
return state_values(prob), parameter_values(prob)
end
function _updated_u0_p_internal(
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
u0 = state_values(prob)
Expand Down
Loading

0 comments on commit d7edbe7

Please sign in to comment.