Skip to content

Commit

Permalink
feat: save funman model configuration (#5050)
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnyama authored Oct 7, 2024
1 parent d78faf6 commit 019c2fc
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ import InputSwitch from 'primevue/inputswitch';
import type { FunmanInterval, FunmanPostQueriesRequest, Model, ModelParameter, TimeSpan } from '@/types/Types';
import { makeQueries } from '@/services/models/funman-service';
import { WorkflowNode, WorkflowOutput } from '@/types/workflow';
import { getAsConfiguredModel, getModelConfigurationById } from '@/services/model-configurations';
import {
getAsConfiguredModel,
getModelConfigurationById,
getModelIdFromModelConfigurationId
} from '@/services/model-configurations';
import { useToastService } from '@/services/toast';
import { pythonInstance } from '@/python/PyodideController';
import TeraFunmanOutput from '@/components/workflow/ops/funman/tera-funman-output.vue';
Expand Down Expand Up @@ -213,6 +217,7 @@ const MAX = 99999999999;
const toast = useToastService();
const validateParametersToolTip =
'Validate the configuration of the model using functional model analysis (FUNMAN). \n \n The parameter space regions defined by the model configuration are evaluated to satisfactory or unsatisfactory depending on whether they generate model outputs that are within a given set of time-dependent constraints';
let originalModelId = '';
const showSpinner = ref(false);
const isSliderOpen = ref(true);
Expand All @@ -223,7 +228,7 @@ const requestStepList = computed(() => getStepList());
const requestStepListString = computed(() => requestStepList.value.join()); // Just used to display. dont like this but need to be quick
const requestParameters = ref<any[]>([]);
const model = ref<Model | null>();
const configuredModel = ref<Model | null>();
const stateIds = ref<string[]>([]);
const parameterIds = ref<string[]>([]);
Expand Down Expand Up @@ -261,7 +266,7 @@ const onToggleVariableOfInterest = (vals: string[]) => {
};
const runMakeQuery = async () => {
if (!model.value) {
if (!configuredModel.value) {
toast.error('', 'No Model provided for request');
return;
}
Expand Down Expand Up @@ -323,7 +328,7 @@ const runMakeQuery = async () => {
.filter(Boolean); // Removes falsey values
const request: FunmanPostQueriesRequest = {
model: model.value,
model: configuredModel.value,
request: {
constraints,
parameters: requestParameters.value,
Expand All @@ -336,14 +341,14 @@ const runMakeQuery = async () => {
config: {
use_compartmental_constraints: knobs.value.compartmentalConstraint.isActive,
normalization_constant:
knobs.value.compartmentalConstraint.isActive && model.value.semantics ? parseFloat(mass.value) : 1,
knobs.value.compartmentalConstraint.isActive && configuredModel.value.semantics ? parseFloat(mass.value) : 1,
normalize: false,
tolerance: knobs.value.tolerance
}
}
};
const response = await makeQueries(request);
const response = await makeQueries(request, originalModelId);
// Setup the in-progress id
const state = _.cloneDeep(props.node.state);
Expand Down Expand Up @@ -414,11 +419,12 @@ const initialize = async () => {
const modelConfigurationId = props.node.inputs[0].value?.[0];
if (!modelConfigurationId) return;
const modelConfiguration = await getModelConfigurationById(modelConfigurationId);
model.value = await getAsConfiguredModel(modelConfiguration);
configuredModel.value = await getAsConfiguredModel(modelConfiguration);
originalModelId = await getModelIdFromModelConfigurationId(modelConfigurationId);
};
const setModelOptions = async () => {
if (!model.value) return;
if (!configuredModel.value) return;
const renameReserved = (v: string) => {
const reserved = ['lambda'];
Expand All @@ -427,7 +433,7 @@ const setModelOptions = async () => {
};
// Calculate mass
const semantics = model.value.semantics;
const semantics = configuredModel.value.semantics;
const modelInitials = semantics?.ode.initials;
const modelMassExpression = modelInitials?.map((d) => renameReserved(d.expression)).join(' + ');
Expand All @@ -439,7 +445,7 @@ const setModelOptions = async () => {
mass.value = await pythonInstance.evaluateExpression(modelMassExpression as string, parametersMap);
const ode = model.value.semantics?.ode;
const ode = configuredModel.value.semantics?.ode;
if (ode) {
if (ode.initials) stateIds.value = ode.initials.map((s) => s.target);
if (ode.parameters) parameterIds.value = ode.parameters.map((d) => d.id);
Expand All @@ -452,8 +458,8 @@ const setModelOptions = async () => {
knobs.value.tolerance = state.tolerance;
knobs.value.compartmentalConstraint = state.compartmentalConstraint;
if (model.value.semantics?.ode.parameters) {
setRequestParameters(model.value.semantics?.ode.parameters);
if (configuredModel.value.semantics?.ode.parameters) {
setRequestParameters(configuredModel.value.semantics?.ode.parameters);
variablesOfInterest.value = requestParameters.value.filter((d: any) => d.label === 'all').map((d: any) => d.name);
} else {
toast.error('', 'Provided model has no parameters');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ export interface RenderOptions {
click?: Function;
}

export async function makeQueries(body: FunmanPostQueriesRequest) {
export async function makeQueries(body: FunmanPostQueriesRequest, modelId: string) {
try {
const resp = await API.post('/funman/queries', body);
const resp = await API.post('/funman/queries', body, { params: { 'model-id': modelId } });
const output = resp.data;
return output;
} catch (err) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void init() {
)
public ResponseEntity<Simulation> createValidationRequest(
@RequestBody final JsonNode input,
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "project-id", required = false) final UUID projectId
) {
final Schema.Permission permission = projectService.checkPermissionCanWrite(
Expand Down Expand Up @@ -113,6 +114,7 @@ public ResponseEntity<Simulation> createValidationRequest(

final ValidateModelConfigHandler.Properties props = new ValidateModelConfigHandler.Properties();
props.setProjectId(projectId);
props.setModelId(modelId);
props.setSimulationId(newSimulation.getId());
taskRequest.setAdditionalProperties(props);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package software.uncharted.terarium.hmiserver.models.dataservice.modelparts;

import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.io.Serial;
import java.io.Serializable;
Expand All @@ -22,6 +23,7 @@ public class ModelHeader extends SupportAdditionalProperties implements Serializ
private String name;

@JsonProperty("schema")
@JsonAlias("schema_") // Funman returns schema_ instead of schema
private String modelSchema;

@TSOptional
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.ArrayList;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -12,9 +14,13 @@
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.springframework.stereotype.Component;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
import software.uncharted.terarium.hmiserver.models.dataservice.model.configurations.ModelConfiguration;
import software.uncharted.terarium.hmiserver.models.dataservice.simulation.ProgressState;
import software.uncharted.terarium.hmiserver.models.dataservice.simulation.Simulation;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.service.data.ModelConfigurationService;
import software.uncharted.terarium.hmiserver.service.data.ModelService;
import software.uncharted.terarium.hmiserver.service.data.SimulationService;

@Component
Expand All @@ -26,6 +32,8 @@ public class ValidateModelConfigHandler extends TaskResponseHandler {

private final ObjectMapper objectMapper;
private final SimulationService simulationService;
private final ModelService modelService;
private final ModelConfigurationService modelConfigurationService;

@Override
public String getName() {
Expand All @@ -36,6 +44,7 @@ public String getName() {
public static class Properties {

UUID projectId;
UUID modelId;
UUID simulationId;
}

Expand Down Expand Up @@ -86,6 +95,34 @@ public TaskResponse onSuccess(final TaskResponse resp) {
// Retrive final result json
final JsonNode result = objectMapper.readValue(resp.getOutput(), JsonNode.class);

// Save contracted model/model configuration
// The response is stringified JSON so convert it to an object to access the contracted_model and clean it up
final String responseString = result.get("response").asText();
ObjectNode contractedModelObject = (ObjectNode) objectMapper.readTree(responseString).get("contracted_model");
// Only use contracted model to create model configuration, no need to save it
final Model contractedModel = objectMapper.convertValue(contractedModelObject, Model.class);

final ModelConfiguration contractedModelConfiguration = ModelConfigurationService.modelConfigurationFromAMR(
contractedModel,
"Validated " + contractedModel.getName(),
contractedModel.getDescription()
);
contractedModelConfiguration.setModelId(props.modelId); // Config should be linked to the original model

// Save validated model configuration
final ModelConfiguration createdModelConfiguration = modelConfigurationService.createAsset(
contractedModelConfiguration,
props.projectId,
ASSUME_WRITE_PERMISSION_ON_BEHALF_OF_USER
);

// Add model configuration to the response
JsonNode responseNode = objectMapper.readTree(responseString);
JsonNode modelConfigNode = objectMapper.valueToTree(createdModelConfiguration);
((ObjectNode) responseNode).set("modelConfiguration", modelConfigNode);
String updatedResponseString = objectMapper.writeValueAsString(responseNode);
((ObjectNode) result).put("response", updatedResponseString);

// Upload final result into S3
final byte[] bytes = objectMapper.writeValueAsBytes(result.get("response"));
final HttpEntity fileEntity = new ByteArrayEntity(bytes, ContentType.TEXT_PLAIN);
Expand Down

0 comments on commit 019c2fc

Please sign in to comment.