Skip to content

Fix GaussAdjoint with callbacks #1060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 21, 2025
Merged

Conversation

ChrisRackauckas
Copy link
Member

No description provided.

@Vaibhavdixit02
Copy link
Member

Ah, this was what #1034 was trying to address lol

@ChrisRackauckas
Copy link
Member Author

That would've been good to know 😅

@ChrisRackauckas ChrisRackauckas force-pushed the gaussadjoint_callbacks branch from 07d6bdf to 33a57d2 Compare June 7, 2024 01:51
@ChrisRackauckas
Copy link
Member Author

Bump?

@frankschae frankschae force-pushed the gaussadjoint_callbacks branch from f248d52 to ed0f98f Compare June 29, 2024 05:34
@frankschae
Copy link
Member

hmm ... it's probably close but it currently fails for the callback where the correction should be 0:

g(sol) = sum(sol)
function dg!(out, u, p, t, i)
    (out .= 1)
end
@testset "callbacks with no effect" begin
    condition(u, t, integrator) = t == 5
    affect!(integrator) = integrator.u[1] += 0.0
    cb = DiscreteCallback(condition, affect!, save_positions = (false, false))
    tstops = [5.0]
    test_discrete_callback(cb, tstops, g, dg!)
end

I think it's the line:

if sensealg isa GaussAdjoint
            @assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
            @show integrator.f.f.integrating_cb.affect!.integrand_values.integrand dgrad
            integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad

in callback_tracking.jl, whch for the example above prints:

integrator.f.f.integrating_cb.affect!.integrand_values.integrand = [4.318834073956617, -23.50595577533703, 3.507883353046969, -26.70725798446173]
dgrad = [-0.0, -0.0, -0.0, -0.0]

@ChrisRackauckas
Copy link
Member Author

Is that with the VJPs as Enzyme or ReverseDiff? IIUC it always defaults to ReverseDiff right now?

@ChrisRackauckas
Copy link
Member Author

The incorrect values I think stem from missing make_zero!s, but it's currently waiting on an Enzyme tag from @wsmoses before #1067 finishes and then this can get retested.

@ChrisRackauckas
Copy link
Member Author

@jClugstor did your PR look at GaussAdjoint?

@jClugstor
Copy link
Member

No, but there is a '@test_broken' for 'GaussAdjoint' with a callback in that PR, so if callbacks get fixed that will need to change.

@ChrisRackauckas ChrisRackauckas force-pushed the gaussadjoint_callbacks branch 4 times, most recently from 11e61fb to 957fc6b Compare May 19, 2025 02:40
@ChrisRackauckas ChrisRackauckas force-pushed the gaussadjoint_callbacks branch 2 times, most recently from cfea3e5 to 8433226 Compare May 20, 2025 16:00
@ChrisRackauckas ChrisRackauckas force-pushed the gaussadjoint_callbacks branch from 4b84fa2 to 7904c77 Compare May 21, 2025 08:50
@ChrisRackauckas ChrisRackauckas merged commit 83e4da5 into master May 21, 2025
26 of 30 checks passed
@ChrisRackauckas ChrisRackauckas deleted the gaussadjoint_callbacks branch May 21, 2025 10:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants