From 13e87151ee08c157caa8729b43346b676d1670ad Mon Sep 17 00:00:00 2001 From: Josh Day Date: Wed, 19 Jul 2023 14:22:55 -0400 Subject: [PATCH] WIP Ensemble and some other fixes (#103) Co-authored-by: Christopher Rackauckas --- Project.toml | 2 +- src/SimulationService.jl | 97 ++++++++++++++++++----------- src/operations.jl | 131 ++++++++++++++++++++++++++++++++------- test/runtests.jl | 36 ++++++++++- 4 files changed, 202 insertions(+), 64 deletions(-) diff --git a/Project.toml b/Project.toml index 516f5cd..f918b51 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimulationService" uuid = "e66378d9-a322-4933-8764-0ce0bcab4993" authors = ["Five Grant <5@fivegrant.com>"] -version = "0.14.5" +version = "0.15.0" [deps] AMQPClient = "79c8b4cd-a41a-55fa-907c-fab5288e1383" diff --git a/src/SimulationService.jl b/src/SimulationService.jl index ff58bc6..717a075 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -59,9 +59,8 @@ function __init__() if Threads.nthreads() == 1 @warn "SimulationService.jl expects `Threads.nthreads() > 1`. Use e.g. `julia --threads=auto`." end - simulation_api_spec_main = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main" - openapi_spec[] = YAML.load_file(download("$simulation_api_spec_main/openapi.yaml")) - simulation_schema[] = get_json("$simulation_api_spec_main/schemas/simulation.json") + openapi_spec[] = YAML.load_file(download("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main/openapi.yaml")) + simulation_schema[] = get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main/schemas/simulation.json") petrinet_schema[] = get_json("https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/petrinet_schema.json") HOST[] = get(ENV, "SIMSERVICE_HOST", "0.0.0.0") @@ -97,11 +96,16 @@ function start!(; host=HOST[], port=PORT[], kw...) JobSchedulers.scheduler_start() JobSchedulers.set_scheduler(max_cpu=JobSchedulers.SCHEDULER_MAX_CPU, max_mem=0.5, update_second=0.05, max_job=5000) Oxygen.resetstate() - Oxygen.@get "/" health - Oxygen.@get "/status/{id}" job_status - Oxygen.@post "/{operation_name}" operation + + Oxygen.@get "/" health + Oxygen.@get "/status/{id}" job_status Oxygen.@post "/kill/{id}" job_kill + Oxygen.@post "/simulate" req -> operation(req, "simulate") + Oxygen.@post "/calibrate" req -> operation(req, "calibrate") + Oxygen.@post "/ensemble-simulate" req -> operation(req, "ensemble-simulate") + Oxygen.@post "/ensemble-calibrate" req -> operation(req, "ensemble-calibrate") + # For /docs Oxygen.mergeschema(openapi_spec[]) @@ -126,17 +130,6 @@ get_json(url::String) = JSON3.read(HTTP.get(url, json_header).body) timestamp() = Dates.format(now(), "yyyy-mm-ddTHH:MM:SS") -# Dump all the info we can get about a simulation `id` -function debug_data(id::String) - @assert ENABLE_TDS[] - data_service_model = DataServiceModel(id::String) - request_json = data_service_model.execution_payload - amr = get_model(data_service_model.execution_payload.model_config_id) - route = String(Dict(v => k for (k,v) in operation_to_dsm_type)[data_service_model.type]) - operation_request = OperationRequest(HTTP.Request("POST", "", [], JSON3.write(request_json)), route) - job_status = get_job_status(get_job(id)) - return (; request_json, amr, data_service_model, operation_request, job_status) -end #-----------------------------------------------------------------------------# job endpoints get_job(id::String) = JobSchedulers.job_query(jobhash(id)) @@ -170,6 +163,7 @@ function job_kill(::HTTP.Request, id::String) if isnothing(job) return NO_JOB else + # TODO: update simulation model in TDS with status="cancelled" JobSchedulers.cancel!(job) return HTTP.Response(200) end @@ -187,7 +181,7 @@ health(::HTTP.Request) = ( Base.@kwdef mutable struct OperationRequest obj::JSON3.Object = JSON3.Object() # untouched JSON from request sent by HMI id::String = "sciml-$(UUIDs.uuid4())" # matches DataServiceModel :id - operation::Symbol = :unknown # :simulate, :calibrate, etc. + route::String = "unknown" # :simulate, :calibrate, etc. model::Union{Nothing, JSON3.Object} = nothing # ASKEM Model Representation (AMR) models::Union{Nothing, Vector{JSON3.Object}} = nothing # Multiple models (in AMR) timespan::Union{Nothing, NTuple{2, Float64}} = nothing # (start, end) @@ -196,16 +190,16 @@ Base.@kwdef mutable struct OperationRequest end function Base.show(io::IO, o::OperationRequest) - println(io, "OperationRequest(id=$(repr(o.id)), operation=$(repr(o.operation)))") + println(io, "OperationRequest(id=$(repr(o.id)), route=$(repr(o.route)))") end -function OperationRequest(req::HTTP.Request, operation_name::String) +function OperationRequest(req::HTTP.Request, route::String) o = OperationRequest() - @info "[$(o.id)] OperationRequest recieved to route /$operation_name: $(String(copy(req.body)))" + @info "[$(o.id)] OperationRequest recieved to route /$route: $(String(copy(req.body)))" o.obj = JSON3.read(req.body) - o.operation = Symbol(operation_name) + o.route = route for (k,v) in o.obj - if !ENABLE_TDS[] && k in [:model_config_id, :model_config_ids, :dataset] + if !ENABLE_TDS[] && k in [:model_config_id, :model_config_ids, :dataset, :model_configs] @warn "TDS Disabled - ignoring key `$k` from request with id: $(repr(o.id))" continue end @@ -215,6 +209,9 @@ function OperationRequest(req::HTTP.Request, operation_name::String) k == :dataset ? (o.df = get_dataset(v)) : k == :model ? (o.model = v) : + # For ensemble, we get objects with {id, solution_mappings, weight} + k == :model_configs ? (o.models = [get_model(m.id) for m in v]) : + # For testing only: k == :local_model_configuration_file ? (o.model = JSON.read(v).configuration) : k == :local_model_file ? (o.model = JSON3.read(v)) : @@ -224,9 +221,11 @@ function OperationRequest(req::HTTP.Request, operation_name::String) return o end + + function solve(o::OperationRequest) callback = get_callback(o) - T = operations2type[o.operation] + T = route2operation_type[o.route] op = T(o) o.result = solve(op; callback) end @@ -259,23 +258,47 @@ end StructTypes.StructType(::Type{DataServiceModel}) = StructTypes.Mutable() # PIRACY to fix JSON3.read(str, DataServiceModel) +# TODO make upstream issue in JSON3 JSON3.Object(x::AbstractDict) = JSON3.read(JSON3.write(x)) -# Initialize a DataServiceModel -operation_to_dsm_type = Dict( - :simulate => "simulation", - :calibrate => "calibration_simulation", - :ensemble => "ensemble" +# translate route (in OpenAPI spec) to type (in TDS) +# route: https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main/openapi.yaml +# type: https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main/schemas/simulation.json +route2type = Dict( + "simulate" => "simulation", + "calibrate" => "calibration", + "ensemble-simulate" => "ensemble", + "ensemble-calibrate" => "ensemble" ) +# Initialize a DataServiceModel function DataServiceModel(o::OperationRequest) m = DataServiceModel() m.id = o.id - m.type = operation_to_dsm_type[o.operation] + m.type = route2type[o.route] m.execution_payload = o.obj return m end +#-----------------------------------------------------------------------------# debug_data +# For debugging: Try to recreate the OperationRequest from TDS data +# NOTE: TDS does not save which route was used... +function OperationRequest(m::DataServiceModel) + req = HTTP.Request("POST", "", [], JSON3.write(m.execution_payload)) + return OperationRequest(req, "DataServiceModel: $(m.type)") +end + +# Dump all the info we can get about a simulation `id` +function debug_data(id::String) + @assert ENABLE_TDS[] + data_service_model = DataServiceModel(id::String) + request_json = data_service_model.execution_payload + amr = get_model(data_service_model.execution_payload.model_config_id) + operation_request = OperationRequest(data_service_model) + job = get_job(id) + return (; request_json, amr, data_service_model, operation_request, job) +end + #-----------------------------------------------------------------------------# publish_to_rabbitmq # published as JSON3.write(content) function publish_to_rabbitmq(content) @@ -318,6 +341,8 @@ function get_dataset(obj::JSON3.Object) @info "`get_dataset` (dataset id=$(repr(obj.id))) rename! $k => $v" rename!(df, k => v) end + "timestep" in names(df) && rename!(df, "timestep" => "timestep") # hack to get df in our "schema" + @info "get_dataset (id=$(repr(obj.id))) with names: $(names(df))" return df end @@ -385,7 +410,7 @@ function complete(o::OperationRequest) tds_url = "$(TDS_URL[])/simulations/$(o.id)/upload-url?filename=$filename" s3_url = get_json(tds_url).url HTTP.put(s3_url, header; body=body) - update(o; status = "complete", completed_time = timestamp(), result_files = [s3_url]) + update(o; status = "complete", completed_time = timestamp(), result_files = [filename]) end @@ -404,9 +429,9 @@ last_job = Ref{JobSchedulers.Job}() # 5) Job finishes: We get url from TDS where we can store results: GET /simulations/$sim_id/upload-url?filename=result.csv # 6) We upload results to S3: PUT $url # 7) We update simulation in TDS(status="complete", complete_time=): PUT /simulations/$id -function operation(request::HTTP.Request, operation_name::String) - @info "Creating OperationRequest from POST to route $operation_name" - o = OperationRequest(request, operation_name) # 1, 3 +function operation(request::HTTP.Request, route::String) + @info "Creating OperationRequest from POST to route $route" + o = OperationRequest(request, route) # 1, 3 create(o) # 2 job = JobSchedulers.Job( @task begin @@ -423,7 +448,7 @@ function operation(request::HTTP.Request, operation_name::String) job.id = jobhash(o.id) last_operation[] = o last_job[] = job - @info "Submitting job..." + @info "Submitting job $(o.id)" JobSchedulers.submit!(job) body = JSON3.write((; simulation_id = o.id)) @@ -433,6 +458,6 @@ end #-----------------------------------------------------------------------------# operations.jl include("operations.jl") -include("precompile.jl") +get(ENV, "SIMSERVICE_PRECOMPILE", "true") == "true" && include("precompile.jl") end # module diff --git a/src/operations.jl b/src/operations.jl index 897ff82..2447d4c 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -3,6 +3,7 @@ # Things that extract info from AMR JSON # The AMR is the `model` field of an OperationRequest + # Get `ModelingToolkit.ODESystem` from AMR function amr_get(amr::JSON3.Object, ::Type{ODESystem}) @info "amr_get ODESystem" @@ -71,7 +72,21 @@ function amr_get(amr::JSON3.Object, sys::ODESystem, ::Val{:priors}) @info "Invalid distribution type! Distribution type was $(p.distribution.type)" end - dist = EasyModelAnalysis.Distributions.Uniform(p.distribution.parameters.minimum, p.distribution.parameters.maximum) + minval = if p.distribution.parameters.minimum isa Number + p.distribution.parameters.minimum + elseif p.distribution.parameters.minimum isa AbstractString + @info "String in distribution minimum: $(p.distribution.parameters.minimum)" + parse(Float64, p.distribution.parameters.minimum) + end + + maxval = if p.distribution.parameters.maximum isa Number + p.distribution.parameters.maximum + elseif p.distribution.parameters.maximum isa AbstractString + @info "String in distribution maximum: $(p.distribution.parameters.maximum)" + parse(Float64, p.distribution.parameters.maximum) + end + + dist = EasyModelAnalysis.Distributions.Uniform(minval, maxval) paramlist[findfirst(x->x==Symbol(p.id),namelist)] => dist end end @@ -84,7 +99,8 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data}) statelist = states(sys) statenames = string.(statelist) statenames = [replace(nm, "(t)" => "") for nm in statenames] - tvals = df[:,"timestamp"] + + tvals = df[:, "timestamp"] map(statelist, statenames) do s,n s => (tvals,df[:,n]) @@ -111,7 +127,7 @@ function (o::IntermediateResults)(integrator) EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false) end -get_callback(o::OperationRequest) = DiscreteCallback((args...) -> true, IntermediateResults(o.id), +get_callback(o::OperationRequest) = DiscreteCallback((args...) -> true, IntermediateResults(o.id), save_positions = (false,false)) @@ -209,7 +225,7 @@ function solve(o::Calibrate; callback) init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), Statistics.mean.(last.(o.priors))) fit = EasyModelAnalysis.datafit(prob, init_params, o.data) else - + init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), tuple.(minimum.(last.(o.priors)), maximum.(last.(o.priors)))) fit = EasyModelAnalysis.global_datafit(prob, init_params, o.data) end @@ -229,37 +245,104 @@ function solve(o::Calibrate; callback) end #-----------------------------------------------------------------------------# Ensemble -struct Ensemble <: Operation - sys::Vector{ODESystem} - priors::Vector{Pair{Num,Any}} # Any = Distribution - train_datas::Any - ensem_datas::Any - t_forecast::Vector{Float64} - quantiles::Vector{Float64} +# joshday: What is different between simulate and calibrate for ensemble? + +struct Ensemble{T<:Operation} <: Operation + model_ids::Vector{String} + operations::Vector{T} + weights::Vector{Float64} + sol_mappings::Vector{JSON3.Object} +end + +function Ensemble{T}(o::OperationRequest) where {T} + model_ids = map(x -> x.id, o.obj.model_configs) + weights = map(x -> x.weight, o.obj.model_configs) + sol_mappings = map(x -> x.solution_mappings, o.obj.model_configs) + operations = map(o.models) do model + temp = OperationRequest() + temp.df = o.df + temp.timespan = o.timespan + temp.model = model + temp.obj = o.obj + T(temp) + end + Ensemble{T}(model_ids, operations, weights, sol_mappings) end -function Ensemble(o::OperationRequest) - sys = amr_get.(o.models, ODESystem) +function solve(o::Ensemble{Simulate}; callback) + systems = [sim.sys for sim in o.operations] + probs = ODEProblem.(systems, Ref([]), Ref(o.operations[1].timespan)) + enprob = EMA.EnsembleProblem(probs) + sol = solve(enprob; saveat = 1); + weights = [0.2, 0.5, 0.3] + data = [x => vec(sum(stack(o.weights .* sol[:,x]), dims = 2)) for x in error("What goes here?")] end -function solve(o::Ensemble; callback) - probs = [ODEProblem(s, [], o.timespan) for s in sys] - ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3] - datas = [data_train,data_train,data_train] - enprobs = bayesian_ensemble(probs, ps, datas) - ensem_weights = ensemble_weights(sol, data_ensem) - forecast_probs = [remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)] - fit_enprob = EnsembleProblem(forecast_probs) - sol = solve(fit_enprob; saveat = o.t_forecast); +function solve(o::Ensemble{Calibrate}; callback) + EMA = EasyModelAnalysis + probs = [ODEProblem(cal.sys, [], o.timespan) for cal in o.operations] + error("TODO") + - soldata = DataFrame([sol.t;Matrix(sol[names])']) + # probs = [ODEProblem(s, [], o.timespan) for s in sys] + # ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3] + # datas = [data_train,data_train,data_train] + # enprobs = EMA.bayesian_ensemble(probs, ps, datas) + # ensem_weights = EMA.ensemble_weights(sol, data_ensem) + + # forecast_probs = [EMA.remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)] + # fit_enprob = EMA.EnsembleProblem(forecast_probs) + # sol = solve(fit_enprob; saveat = o.t_forecast); + + # soldata = DataFrame([sol.t; Matrix(sol[names])']) # Requires https://github.com/SciML/SciMLBase.jl/pull/467 # weighted_ensem = WeightedEnsembleSolution(sol, ensem_weights; quantiles = o.quantiles) # df = DataFrame(weighted_ensem) # df, soldata end + + + +# struct Ensemble <: Operation +# sys::Vector{ODESystem} +# priors::Vector{Pair{Num,Any}} # Any = Distribution +# train_datas::Any +# ensem_datas::Any +# t_forecast::Vector{Float64} +# quantiles::Vector{Float64} +# end + +# function Ensemble(o::OperationRequest) +# sys = amr_get.(o.models, ODESystem) +# end + +# function solve(o::Ensemble; callback) +# EMA = EasyModelAnalysis +# probs = [ODEProblem(s, [], o.timespan) for s in sys] +# ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3] +# datas = [data_train,data_train,data_train] +# enprobs = EMA.bayesian_ensemble(probs, ps, datas) +# ensem_weights = EMA.ensemble_weights(sol, data_ensem) + +# forecast_probs = [EMA.remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)] +# fit_enprob = EMA.EnsembleProblem(forecast_probs) +# sol = solve(fit_enprob; saveat = o.t_forecast); + +# soldata = DataFrame([sol.t; Matrix(sol[names])']) + +# # Requires https://github.com/SciML/SciMLBase.jl/pull/467 +# # weighted_ensem = WeightedEnsembleSolution(sol, ensem_weights; quantiles = o.quantiles) +# # df = DataFrame(weighted_ensem) +# # df, soldata +# end + #-----------------------------------------------------------------------------# All operations # :simulate => Simulate, etc. -const operations2type = Dict(Symbol(lowercase(string(T.name.name))) => T for T in subtypes(Operation)) +const route2operation_type = Dict( + "simulate" => Simulate, + "calibrate" => Calibrate, + "ensemble-simulate" => Ensemble{Simulate}, + "ensemble-calibrate" => Ensemble{Calibrate} +) diff --git a/test/runtests.jl b/test/runtests.jl index 23b44e6..c7fd4ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -103,7 +103,7 @@ end @test m.engine == "sciml" @test m.id == "" - o = OperationRequest(operation = :simulate) + o = OperationRequest(route = "simulate") m2 = DataServiceModel(o) # OperationRequest constructor with dummy HTTP.Request @@ -159,6 +159,36 @@ end @test names(dfsim) == vcat("timestamp",string.(statenames)) @test names(dfparam) == string.(parameters(sys)) end + @testset "Ensemble" begin + json_url = "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir.json" + amr = SimulationService.get_json(json_url) + + obj = ( + model_configs = map(1:4) do i + (id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S")) + end, + models = [amr for _ in 1:4], + timespan = (start = 0, var"end" = 40), + engine = "sciml", + extra = (; num_samples = 40) + ) + + body = JSON3.write(obj) + + # create ensemble-simulte + o = OperationRequest() + o.route = "ensemble-simulate" + o.obj = JSON3.read(JSON3.write(obj)) + o.models = [amr for _ in 1:4] + o.timespan = (0, 30) + en = Ensemble{Simulate}(o) + + # create ensemble-calibrate + # o = OperationRequest() + # o.route = "ensemble-calibrate" + # json = JSON3.read(here("examples", "sir_calibrate", "sir_calibrate_request"), Dict) + # delete!(json, "modelConfigId") + end @testset "Real Calibrate Payload" begin file = here("examples", "sir_calibrate", "sir.json") @@ -171,10 +201,10 @@ end num_iterations = 100 calibrate_method = "global" ode_method = nothing - + o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method) dfsim, dfparam = SimulationService.solve(o; callback = nothing) - + statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)] @test names(dfsim) == vcat("timestamp",string.(statenames)) @test names(dfparam) == ["beta", "gamma"]