Skip to content

Commit

Permalink
feat: model configurations can be accepted into stratify and model ed…
Browse files Browse the repository at this point in the history
…it (#3426)
  • Loading branch information
shawnyama authored Apr 24, 2024
1 parent cb93027 commit 3c24594
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ export const ModelEditOperation: Operation = {
displayName: 'Edit model',
description: 'Edit a model',
isRunnable: false,
inputs: [{ type: 'modelId', label: 'Model', acceptMultiple: false }],
inputs: [
{ type: 'modelId|modelConfigId', label: 'Model or Model configuration', acceptMultiple: false }
],
outputs: [{ type: 'modelId' }],
action: async () => ({}),
initState: () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const props = defineProps<{
node: WorkflowNode<ModelEditOperationState>;
}>();
const model = ref(null as Model | null);
const model = ref<Model | null>(null);
const updateModel = async () => {
const modelId = operator.getActiveOutput(props.node)?.value?.[0];
if (modelId && modelId !== model?.value?.id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ import TeraModelTemplateEditor from '@/components/model-template/tera-model-temp
import TeraNotebookJupyterInput from '@/components/llm/tera-notebook-jupyter-input.vue';
import { KernelSessionManager } from '@/services/jupyter';
import { getModelIdFromModelConfigurationId } from '@/services/model-configurations';
import { ModelEditOperationState } from './model-edit-operation';
const props = defineProps<{
Expand Down Expand Up @@ -150,7 +151,7 @@ const activeOutput = ref<WorkflowOutput<ModelEditOperationState> | null>(null);
const kernelManager = new KernelSessionManager();
const isKernelReady = ref(false);
const amr = ref<Model | null>(null);
const modelId = props.node.inputs[0].value?.[0];
const newModelName = ref('');
let editor: VAceEditorInstance['_editor'] | null;
const sampleAgentQuestions = [
Expand Down Expand Up @@ -281,6 +282,15 @@ const buildJupyterContext = () => {
};
const inputChangeHandler = async () => {
const input = props.node.inputs[0];
if (!input) return;
let modelId: string | null = null;
if (input.type === 'modelId') {
modelId = input.value?.[0];
} else if (input.type === 'modelConfigId') {
modelId = await getModelIdFromModelConfigurationId(input.value?.[0]);
}
if (!modelId) return;
amr.value = await getModel(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ export const StratifyMiraOperation: Operation = {
name: WorkflowOperationTypes.STRATIFY_MIRA,
displayName: 'Stratify model',
description: 'Stratify a model',
inputs: [{ type: 'modelId', label: 'Model', acceptMultiple: false }],
inputs: [
{ type: 'modelId|modelConfigId', label: 'Model or Model configuration', acceptMultiple: false }
],
outputs: [{ type: 'model' }],
isRunnable: false,
action: () => {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ import '@/ace-config';
import TeraNotebookError from '@/components/drilldown/tera-notebook-error.vue';
import type { Model } from '@/types/Types';
import { AMRSchemaNames } from '@/types/common';
import { getModelIdFromModelConfigurationId } from '@/services/model-configurations';
/* Jupyter imports */
import { KernelSessionManager } from '@/services/jupyter';
Expand Down Expand Up @@ -358,7 +359,15 @@ const getStatesAndParameters = (amrModel: Model) => {
};
const inputChangeHandler = async () => {
const modelId = props.node.inputs[0].value?.[0];
const input = props.node.inputs[0];
if (!input) return;
let modelId: string | null = null;
if (input.type === 'modelId') {
modelId = input.value?.[0];
} else if (input.type === 'modelConfigId') {
modelId = await getModelIdFromModelConfigurationId(input.value?.[0]);
}
if (!modelId) return;
amr.value = await getModel(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ export const getModelConfigurationById = async (id: string) => {
return (response?.data as ModelConfiguration) ?? null;
};

export const getModelIdFromModelConfigurationId = async (id: string) => {
const modelConfiguration = await getModelConfigurationById(id);
return modelConfiguration?.model_id ?? null;
};

export const createModelConfiguration = async (
model_id: string | undefined,
name: string,
Expand Down
34 changes: 22 additions & 12 deletions packages/client/hmi-client/src/services/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,11 @@ export const addEdge = (
) => {
const sourceNode = wf.nodes.find((d) => d.id === sourceId);
const targetNode = wf.nodes.find((d) => d.id === targetId);
if (!sourceNode) return;
if (!targetNode) return;
if (!sourceNode || !targetNode) return;

const sourceOutputPort = sourceNode.outputs.find((d) => d.id === sourcePortId);
const targetInputPort = targetNode.inputs.find((d) => d.id === targetPortId);

if (!sourceOutputPort) return;
if (!targetInputPort) return;
if (!sourceOutputPort || !targetInputPort) return;

// Check if edge already exist
const existingEdge = wf.edges.find(
Expand All @@ -131,16 +128,15 @@ export const addEdge = (
d.target === targetId &&
d.targetPortId === targetPortId
);

if (existingEdge) return;

// Check if type is compatible
if (sourceOutputPort.value === null) return;

const allowedTypes = targetInputPort.type.split('|');
if (!allowedTypes.includes(sourceOutputPort.type)) return;

if (!targetInputPort.acceptMultiple && targetInputPort.status === WorkflowPortStatus.CONNECTED) {
if (
!allowedTypes.includes(sourceOutputPort.type) ||
(!targetInputPort.acceptMultiple && targetInputPort.status === WorkflowPortStatus.CONNECTED)
) {
return;
}

Expand All @@ -153,9 +149,10 @@ export const addEdge = (
targetInputPort.value = sourceOutputPort.value;
}

// Transfer concrete type information where it can accept multiple types
// Note this will lock in the typing, even after unlink
// Transfer concrete type to the input type to match the output type
// Saves the original type in case we want to revert when we unlink the edge
if (allowedTypes.length > 1) {
targetInputPort.originalType = targetInputPort.type;
targetInputPort.type = sourceOutputPort.type;
}

Expand All @@ -181,12 +178,19 @@ export const removeEdge = (wf: Workflow, id: string) => {
// Remove the data reference at the targetPort
const targetNode = wf.nodes.find((d) => d.id === edgeToRemove.target);
if (!targetNode) return;

const targetPort = targetNode.inputs.find((d) => d.id === edgeToRemove.targetPortId);
if (!targetPort) return;

targetPort.value = null;
targetPort.status = WorkflowPortStatus.NOT_CONNECTED;
delete targetPort.label;

// Resets the type to the original type (used when multiple types for a port are allowed)
if (targetPort?.originalType) {
targetPort.type = targetPort.originalType;
}

// Edge re-assignment
wf.edges = wf.edges.filter((edge) => edge.id !== id);

Expand Down Expand Up @@ -229,6 +233,7 @@ const defaultPortLabels = {
modelId: 'Model',
modelConfigId: 'Model configuration',
datasetId: 'Dataset',
simulationId: 'Simulation',
codeAssetId: 'Code asset'
};

Expand All @@ -241,6 +246,11 @@ export function getPortLabel({ label, type, isOptional }: WorkflowPort) {
else if (defaultPortLabels[type]) {
portLabel = defaultPortLabels[type];
}
// Create name if there are multiple types
else if (type.includes('|')) {
const types = type.split('|');
portLabel = types.map((t) => defaultPortLabels[t] ?? t).join(' or ');
}

if (isOptional) portLabel = portLabel.concat(' (optional)');

Expand Down
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/types/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export interface Operation {
export interface WorkflowPort {
id: string;
type: string;
originalType?: string;
status: WorkflowPortStatus;
label?: string;
value?: any[] | null;
Expand Down

0 comments on commit 3c24594

Please sign in to comment.