Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply metadata from the pre-MIRA model to the post-MIRA model if the name of variables matches #5181

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,29 +94,29 @@
</template>

<script setup lang="ts">
import { onMounted, onUnmounted, ref, watch } from 'vue';
import { cloneDeep, debounce, isEqual, last } from 'lodash';
import '@/ace-config';
import TeraDrilldownPreview from '@/components/drilldown/tera-drilldown-preview.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import TeraDrilldown from '@/components/drilldown/tera-drilldown.vue';
import TeraNotebookError from '@/components/drilldown/tera-notebook-error.vue';
import TeraNotebookJupyterInput from '@/components/llm/tera-notebook-jupyter-input.vue';
import TeraModel from '@/components/model/tera-model.vue';
import TeraOperatorPlaceholder from '@/components/operator/tera-operator-placeholder.vue';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
import TeraStratificationGroupForm from '@/components/workflow/ops/stratify-mira/tera-stratification-group-form.vue';
import TeraNotebookJupyterInput from '@/components/llm/tera-notebook-jupyter-input.vue';
import { createModel, getModel } from '@/services/model';
import { WorkflowNode, OperatorStatus } from '@/types/workflow';
import { KernelSessionManager } from '@/services/jupyter';
import { createModelFromOld, getModel } from '@/services/model';
import { getModelIdFromModelConfigurationId } from '@/services/model-configurations';
import type { Model } from '@/types/Types';
import { AMRSchemaNames } from '@/types/common';
import { OperatorStatus, WorkflowNode } from '@/types/workflow';
import { logger } from '@/utils/logger';
import { cloneDeep, debounce, isEqual, last } from 'lodash';
import Button from 'primevue/button';
import { v4 as uuidv4 } from 'uuid';
import { onMounted, onUnmounted, ref, watch } from 'vue';
import { VAceEditor } from 'vue3-ace-editor';
import { VAceEditorInstance } from 'vue3-ace-editor/types';
import '@/ace-config';
import TeraNotebookError from '@/components/drilldown/tera-notebook-error.vue';
import type { Model } from '@/types/Types';
import { AMRSchemaNames } from '@/types/common';
import { getModelIdFromModelConfigurationId } from '@/services/model-configurations';
import TeraProgressSpinner from '@/components/widgets/tera-progress-spinner.vue';
import { KernelSessionManager } from '@/services/jupyter';
import TeraModel from '@/components/model/tera-model.vue';
import { blankStratifyGroup, StratifyGroup, StratifyOperationStateMira } from './stratify-mira-operation';

const props = defineProps<{
Expand Down Expand Up @@ -234,6 +234,8 @@ const stratifyModel = () => {
};

const handleModelPreview = async (data: any) => {
if (!amr.value) return;

const amrResponse = data.content['application/json'] as Model;
isStratifyInProgress.value = false;
if (!amrResponse) {
Expand All @@ -250,7 +252,7 @@ const handleModelPreview = async (data: any) => {
amrResponse.header.name = newName;

// Create output
const modelData = await createModel(amrResponse);
const modelData = await createModelFromOld(amr.value, amrResponse);
if (!modelData) return;
outputAmr.value = modelData;

Expand Down
9 changes: 9 additions & 0 deletions packages/client/hmi-client/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ export async function createModel(model: Model): Promise<Model | null> {
return response?.data ?? null;
}

export async function createModelFromOld(oldModel: Model, newModel: Model): Promise<Model | null> {
delete newModel.id;
const response = await API.post(`/models/new-from-old`, {
newModel,
oldModel
});
return response?.data ?? null;
}

export async function createModelAndModelConfig(file: File, progress?: Ref<number>): Promise<Model | null> {
const formData = new FormData();
formData.append('file', file);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.PageRequest;
Expand Down Expand Up @@ -190,7 +191,8 @@ ResponseEntity<Model> getModel(
new TypeReference<>() {}
);

// Append the Document extractions to the Model extractions, just for the front-end.
// Append the Document extractions to the Model extractions, just for the
// front-end.
Comment on lines +194 to +195
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dgauldie Do we still do this? Maybe in another PR when we remove SKEMA-TR

// Those are NOT to be saved back to the data-service.
if (extractions != null) {
model.get().getMetadata().getAttributes().addAll(extractions);
Expand Down Expand Up @@ -473,6 +475,74 @@ ResponseEntity<Model> createModel(
}
}

@Data
public static class CreateModelFromOldRequest {

Model oldModel;
Model newModel;
}

@PostMapping("/new-from-old")
@Secured(Roles.USER)
@Operation(summary = "Create a new model from an old model")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "201",
description = "Model created.",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = Model.class)
)
),
@ApiResponse(responseCode = "500", description = "There was an issue creating the model", content = @Content)
}
)
ResponseEntity<Model> createModelFromOld(
@RequestBody final CreateModelFromOldRequest req,
@RequestParam(name = "project-id", required = false) final UUID projectId
) {
final Schema.Permission permission = projectService.checkPermissionCanWrite(
currentUserService.get().getId(),
projectId
);

try {
req.newModel.retainMetadataFields(req.oldModel);

// Set the model name from the AMR header name.
// TerariumAsset have a name field, but it's not used for the model name outside
// the front-end.
req.newModel.setName(req.newModel.getHeader().getName());
final Model created = modelService.createAsset(req.newModel, projectId, permission);

// create default configuration
final ModelConfiguration modelConfiguration = ModelConfigurationService.modelConfigurationFromAMR(
created,
null,
null
);
modelConfigurationService.createAsset(modelConfiguration, projectId, permission);

// add default model configuration to project
final Optional<Project> project = projectService.getProject(projectId);
if (project.isPresent()) {
projectAssetService.createProjectAsset(
project.get(),
AssetType.MODEL_CONFIGURATION,
modelConfiguration,
permission
);
}

return ResponseEntity.status(HttpStatus.CREATED).body(created);
} catch (final IOException e) {
final String error = "Unable to create model";
log.error(error, e);
throw new ResponseStatusException(org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR, error);
}
}

@GetMapping("/{id}/model-configurations")
@Secured(Roles.USER)
@Operation(summary = "Gets all model configurations for a model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import io.hypersistence.utils.hibernate.type.json.JsonType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
Expand All @@ -26,6 +27,7 @@
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.ModelSemantics;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.semantics.Initial;
import software.uncharted.terarium.hmiserver.models.dataservice.modelparts.semantics.Observable;
import software.uncharted.terarium.hmiserver.utils.JsonUtil;

@EqualsAndHashCode(callSuper = true)
@Data
Expand Down Expand Up @@ -64,6 +66,59 @@ public class Model extends TerariumAssetThatSupportsAdditionalProperties {
@Column(columnDefinition = "json")
private ModelMetadata metadata;

public void retainMetadataFields(final Model other) {
final Map<String, JsonNode> props = getAdditionalProperties();
final Map<String, JsonNode> otherProps = other.getAdditionalProperties();

if (metadata == null) {
metadata = other.getMetadata();
} else {
metadata.retainMetadataFields(other.getMetadata());
}

if (getDescription() == null) {
setDescription(other.getDescription());
}

final List<String> propertiesToPreserve = List.of(
"states",
"metadata",
"units",
"vertices",
"edges",
"parameters",
"initials"
);

for (final String property : propertiesToPreserve) {
if (!otherProps.containsKey(property) || !otherProps.get(property).isArray()) {
continue;
}

if (!props.containsKey(property) || !props.get(property).isArray()) {
continue;
}

final ArrayNode otherProperty = (ArrayNode) otherProps.get(property);
final ArrayNode thisProperty = (ArrayNode) props.get(property);

for (final JsonNode element : otherProperty) {
final JsonNode matching = JsonUtil.getFirstByPredicate(thisProperty, (final JsonNode node) -> {
// Check if the 'state' object has a 'name' field
return node.has("id");
});

if (matching == null) {
// does not exist in current model, we can add the old one, otherwise keep the
// new one
thisProperty.add(element);
}
}

props.put(property, thisProperty);
}
}

public ModelMetadata getMetadata() {
if (metadata == null) {
return new ModelMetadata();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,51 @@ public class ModelMetadata extends SupportAdditionalProperties implements Serial
@JdbcTypeCode(Types.BINARY)
private byte[] description;

public void retainMetadataFields(final ModelMetadata other) {
if (description == null) {
description = other.description;
}
if (gollmCard == null) {
gollmCard = other.gollmCard;
}
if (gollmExtractions == null) {
gollmExtractions = other.gollmExtractions;
}
if (provenance == null) {
provenance = other.provenance;
}
if (templateCard == null) {
templateCard = other.templateCard;
}
if (codeId == null) {
codeId = other.codeId;
}
if (source == null) {
source = other.source;
}
if (processedBy == null) {
processedBy = other.processedBy;
}
if (variableStatements == null) {
variableStatements = other.variableStatements;
}
if (annotations == null) {
annotations = other.annotations;
}
if (attributes == null) {
attributes = other.attributes;
}
if (initials == null) {
initials = other.initials;
}
if (parameters == null) {
parameters = other.parameters;
}
if (card == null) {
card = other.card;
}
}

@Override
public ModelMetadata clone() {
final ModelMetadata clone = (ModelMetadata) super.clone();
Expand All @@ -92,7 +137,7 @@ public ModelMetadata clone() {

if (this.variableStatements != null) {
clone.variableStatements = new ArrayList<>();
for (VariableStatement variableStatement : this.variableStatements) {
for (final VariableStatement variableStatement : this.variableStatements) {
clone.variableStatements.add(variableStatement.clone());
}
}
Expand All @@ -103,7 +148,7 @@ public ModelMetadata clone() {

if (this.attributes != null) {
clone.attributes = new ArrayList<>();
for (JsonNode attribute : this.attributes) {
for (final JsonNode attribute : this.attributes) {
clone.attributes.add(attribute.deepCopy());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package software.uncharted.terarium.hmiserver.utils;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Predicate;

public class JsonUtil {

Expand All @@ -17,6 +19,15 @@ public static void setAll(final ObjectNode dest, final JsonNode src) {
});
}

public static JsonNode getFirstByPredicate(final ArrayNode node, final Predicate<JsonNode> predicate) {
for (final JsonNode element : node) {
if (predicate.test(element)) {
return element;
}
}
return null; // Return null if no element matches the predicate
}

public static void recursiveSetAll(final ObjectNode dest, final JsonNode src) {
if (src.isObject()) {
final Iterator<Map.Entry<String, JsonNode>> fields = src.fields();
Expand Down
Loading