Skip to content

Commit

Permalink
test: fix SCCNonlinearProblem indexing test
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Feb 28, 2025
1 parent 3d952e4 commit d931d75
Showing 1 changed file with 3 additions and 60 deletions.
63 changes: 3 additions & 60 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ prob = SteadyStateProblem(osys, u0, ps)
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)

@testset "SCCNonlinearProblem" begin
# TODO: Rewrite this example when the MTK codegen is merged

function fullf!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
Expand All @@ -311,63 +309,10 @@ prob = SteadyStateProblem(osys, u0, ps)
@parameters p = 1.0
eqs = Any[0 for _ in 1:8]
fullf!(eqs, u, [p])
@named model = NonlinearSystem(0 .~ eqs, [u...], [p])
model = complete(model; split = false)

cache = zeros(4)
cache[1] = 1.0

function f1!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
end
explicitfun1(cache, sols) = nothing

f1!(eqs, u2[1:2], [p])
@named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p])
subsys1 = complete(subsys1; split = false)
prob1 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1),
zeros(2), copy(cache))

function f2!(du, u, p)
du[1] = 2u[2] + u[1] + p[1]
du[2] = u[3]^2 + u[2]
du[3] = u[1]^2 + u[3]
end
explicitfun2(cache, sols) = nothing

f2!(eqs, u2[3:5], [p])
@named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p])
subsys2 = complete(subsys2; split = false)
prob2 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2),
zeros(3), copy(cache))

function f3!(du, u, p)
du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3]
du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3]
du[3] = p[4] + +u[1] - u[2] - u[3]
end
function explicitfun3(cache, sols)
cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
6.0sols[2][3]
end

@parameters tmpvar[1:3]
f3!(eqs, u2[6:8], [p, tmpvar...])
@named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...])
subsys3 = complete(subsys3; split = false)
prob3 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3),
zeros(3), copy(cache))
@mtkbuild model = NonlinearSystem(0 .~ eqs, [u...], [p])

prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
copy(cache); sys = model)
sccprob = SCCNonlinearProblem(model, [])

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] sccprob[sym]
Expand All @@ -380,12 +325,10 @@ prob = SteadyStateProblem(osys, u0, ps)
for (i, sym) in enumerate([u[1], u[3], u[6]])
sccprob[sym] = 0.5i
@test sccprob[sym] 0.5i
@test sccprob.probs[i].u0[1] 0.5i
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] 2.5
@test sccprob.p[1] 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] 2.5
@test scc.ps[p] 2.5
end
end

0 comments on commit d931d75

Please sign in to comment.