From 458854a3eca285b63bb209b8142b33acda32c5d9 Mon Sep 17 00:00:00 2001 From: Jadon Clugston <34165782+jClugstor@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:27:38 -0400 Subject: [PATCH] Rudimentary progress tracking for simulate (#175) --- src/SimulationService.jl | 2 +- src/operations.jl | 30 ++++++++++++++++-------------- test/runtests.jl | 4 +++- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/SimulationService.jl b/src/SimulationService.jl index 17a44bf..6f6e84d 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -431,7 +431,7 @@ end function publish_to_rabbitmq(content) if !RABBITMQ_ENABLED[] # stop printing content for now, getting to be too much - @warn "RabbitMQ disabled - `publish_to_rabbitmq`" # with content $(JSON3.write(content))" + @warn "RabbitMQ disabled - `publish_to_rabbitmq`" #with content $(JSON3.write(content))" return content end json = Vector{UInt8}(codeunits(JSON3.write(content))) diff --git a/src/operations.jl b/src/operations.jl index 354857c..fba0f16 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -183,19 +183,6 @@ mutable struct IntermediateResults new(0, every, id, 0) end end - -function (o::IntermediateResults)(integrator) - (; iter, f, t, u, p) = integrator - if o.last_callback + o.every == iter - o.last_callback = iter - state_dict = Dict(unknowns(f.sys) .=> u) - param_dict = Dict(parameters(f.sys) .=> p) - publish_to_rabbitmq(; iter=iter, state=state_dict, params = param_dict, id=o.id, - retcode=SciMLBase.check_error(integrator)) - end - EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false) -end - # Intermediate results functor for calibrate function (o::IntermediateResults)(state,loss_val, ode_sol, ts) if o.last_callback + o.every == o.iter @@ -234,10 +221,25 @@ function get_callback(o::OperationRequest, ::Type{Simulate}) DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = 10)) end +function (o::IntermediateResults)(integrator) + (; iter, f, t, u, p, sol) = integrator + t_end = sol.prob.tspan[2] + percent = round((t/t_end)*100.0, digits = 2) + if o.last_callback + o.every == iter + o.last_callback = iter + #state_dict = Dict(states(f.sys) .=> u) + #param_dict = Dict(parameters(f.sys) .=> p) + publish_to_rabbitmq(;id=o.id, + retcode=SciMLBase.check_error(integrator), percent = percent) + end + EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false) +end + + # callback for Simulate requests function solve(op::Simulate; callback) prob = ODEProblem(op.sys, [], op.timespan) - sol = solve(prob; progress = true, progress_steps = 1, saveat=1, callback = nothing) + sol = solve(prob; saveat=1, callback = callback) @info "Timesteps returned are: $(sol.t)" dataframe_with_observables(sol) end diff --git a/test/runtests.jl b/test/runtests.jl index 1eb5a2e..ab8d403 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -206,7 +206,9 @@ end obj = SimulationService.get_json(json_url).configuration sys = SimulationService.amr_get(obj, ODESystem) op = Simulate(sys, (0.0, 99.0)) - df = solve(op; callback = nothing) + call_op = OperationRequest() # to test callback + call_op.id = "1" + df = solve(op; callback = SimulationService.get_callback(call_op,SimulationService.Simulate)) @test df isa DataFrame @test extrema(df.timestamp) == (0.0, 99.0) end