diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java index 5f05ee73ff..8b053a5206 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java @@ -8,6 +8,8 @@ import java.util.Map; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import lombok.extern.log4j.Log4j2; @@ -179,6 +181,143 @@ public class ModelInterfaceUtils { + " ]\n" + "}"; + private static final String BEDROCK_AI21_J2_MID_V1_RAW_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"id\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"prompt\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"tokens\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"generatedToken\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"token\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"logprob\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"raw_logprob\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"textRange\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"start\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"end\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"completions\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"data\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"tokens\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"generatedToken\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"token\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"logprob\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"raw_logprob\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"textRange\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"start\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"end\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"finishReason\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"reason\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"length\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + private static final String BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT = "{\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" @@ -297,6 +436,110 @@ public class ModelInterfaceUtils { + " ]\n" + "}"; + private static final String AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"embedding\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " },\n" + + " \"inputTextTokenCount\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\"embedding\", \"inputTextTokenCount\"]\n" + + " }\n" + + " },\n" + + " \"required\": [\"name\", \"dataAsMap\"]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\"output\", \"status_code\"]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\"inference_results\"]\n" + + "}"; + + private static final String COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"id\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"texts\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"embeddings\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"response_type\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\"id\", \"texts\", \"embeddings\", \"response_type\"]\n" + + " }\n" + + " },\n" + + " \"required\": [\"name\", \"dataAsMap\"]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\"output\", \"status_code\"]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\"inference_results\"]\n" + + "}"; + private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT = "{\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" @@ -484,6 +727,9 @@ public class ModelInterfaceUtils { public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE = Map .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE = Map + .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", BEDROCK_AI21_J2_MID_V1_RAW_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE = Map .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); @@ -493,15 +739,27 @@ public class ModelInterfaceUtils { public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE = Map .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE = Map + .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT); + public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE = Map .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE = Map + .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", COHERE_EMBEDDING_V3_RAW_INTERFACE_OUTPUT); + public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE = Map .of("input", TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE = Map + .of("input", TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, "output", AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT); + public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE = Map .of("input", TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); + public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE = Map + .of("input", TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", AMAZON_TITAN_EMBEDDING_V1_RAW_INTERFACE_OUTPUT); + public static final Map AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE = Map .of( "input", @@ -520,17 +778,28 @@ public class ModelInterfaceUtils { private static Map createPresetModelInterfaceByConnector(Connector connector) { if (connector.getParameters() != null) { + ConnectorAction connectorAction = connector.getActions().get(0); switch ((connector.getParameters().get("service_name") != null) ? connector.getParameters().get("service_name") : "null") { case "bedrock": log.debug("Detected Amazon Bedrock model"); switch ((connector.getParameters().get("model") != null) ? connector.getParameters().get("model") : "null") { case "ai21.j2-mid-v1": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; + if (connectorAction.getPostProcessFunction() != null && !connectorAction.getPostProcessFunction().isBlank()) { + log + .debug( + "Creating preset model interface for Amazon Bedrock model with post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; + } else { + + log + .debug( + "Creating preset model interface for Amazon Bedrock model without post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE; + } case "anthropic.claude-3-sonnet-20240229-v1:0": log .debug( @@ -546,33 +815,73 @@ private static Map createPresetModelInterfaceByConnector(Connect ); return BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; case "cohere.embed-english-v3": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; + if (connectorAction.getPostProcessFunction() != null + && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.COHERE_EMBEDDING)) { + log + .debug( + "Creating preset model interface for Amazon Bedrock model with post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; + } else { + log + .debug( + "Creating preset model interface for Amazon Bedrock model without post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE; + } case "cohere.embed-multilingual-v3": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; + if (connectorAction.getPostProcessFunction() != null + && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.COHERE_EMBEDDING)) { + log + .debug( + "Creating preset model interface for Amazon Bedrock model with post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; + } else { + log + .debug( + "Creating preset model interface for Amazon Bedrock model without post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE; + } case "amazon.titan-embed-text-v1": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; + if (connectorAction.getPostProcessFunction() != null + && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.BEDROCK_EMBEDDING)) { + log + .debug( + "Creating preset model interface for Amazon Bedrock model with post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; + } else { + log + .debug( + "Creating preset model interface for Amazon Bedrock model without post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE; + } case "amazon.titan-embed-image-v1": - log - .debug( - "Creating preset model interface for Amazon Bedrock model: {}", - connector.getParameters().get("model") - ); - return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; + if (connectorAction.getPostProcessFunction() != null + && connectorAction.getPostProcessFunction().equalsIgnoreCase(MLPostProcessFunction.BEDROCK_EMBEDDING)) { + log + .debug( + "Creating preset model interface for Amazon Bedrock model with post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; + } else { + log + .debug( + "Creating preset model interface for Amazon Bedrock model without post-process function: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE; + } default: return null; } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java index 96eb4b3fca..2bab51f2f6 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java @@ -7,18 +7,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.Before; @@ -27,7 +34,9 @@ import org.junit.rules.ExpectedException; import org.mockito.Spy; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; public class ModelInterfaceUtilsTest { @@ -40,6 +49,10 @@ public class ModelInterfaceUtilsTest { @Spy public HttpConnector connector; + public ConnectorAction connectorActionWithPostProcessFunction; + + public ConnectorAction connectorActionWithoutPostProcessFunction; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -56,6 +69,23 @@ public void setUp() throws Exception { .modelName("test-model-with-stand-alone-connector") .functionName(FunctionName.REMOTE) .build(); + + connectorActionWithPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction("test-post-process-function") + .build(); + + connectorActionWithoutPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); } @Test @@ -63,18 +93,42 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "ai21.j2-mid-v1"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); - + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); } + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "ai21.j2-mid-v1"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithoutPostProcessFunction)) + .build(); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_RAW_MODEL_INTERFACE); + } + @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "anthropic.claude-3-sonnet-20240229-v1:0"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE); @@ -85,7 +139,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_A Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "anthropic.claude-v2"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE); @@ -96,51 +155,179 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_C Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "cohere.embed-english-v3"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + connectorActionWithPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction(MLPostProcessFunction.COHERE_EMBEDDING) + .build(); + + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE); } + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "cohere.embed-english-v3"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithoutPostProcessFunction)) + .build(); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_RAW_MODEL_INTERFACE); + } + @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "cohere.embed-multilingual-v3"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + connectorActionWithPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction(MLPostProcessFunction.COHERE_EMBEDDING) + .build(); + + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE); } + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "cohere.embed-multilingual-v3"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithoutPostProcessFunction)) + .build(); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_RAW_MODEL_INTERFACE + ); + } + @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "amazon.titan-embed-text-v1"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + connectorActionWithPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) + .build(); + + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE); } + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "amazon.titan-embed-text-v1"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithoutPostProcessFunction)) + .build(); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_RAW_MODEL_INTERFACE); + } + @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "amazon.titan-embed-image-v1"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + connectorActionWithPostProcessFunction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) + .build(); + + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE); } + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "amazon.titan-embed-image-v1"); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithoutPostProcessFunction)) + .build(); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_RAW_MODEL_INTERFACE); + } + @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "comprehend"); parameters.put("api_name", "DetectDominantLanguage"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals( @@ -154,7 +341,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TE Map parameters = new HashMap<>(); parameters.put("service_name", "textract"); parameters.put("api_name", "DetectDocumentText"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE); @@ -163,7 +355,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TE @Test public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorServiceNameNotFound() { Map parameters = new HashMap<>(); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); @@ -173,7 +370,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorServiceNa public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBedrockModelNameNotFound() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); @@ -183,7 +385,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBedrockMo public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAmazonComprehendAPINameNotFound() { Map parameters = new HashMap<>(); parameters.put("service_name", "comprehend"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); @@ -204,7 +411,12 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParam Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "ai21.j2-mid-v1"); - connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + connector = HttpConnector + .builder() + .protocol("http") + .parameters(parameters) + .actions(List.of(connectorActionWithPostProcessFunction)) + .build(); registerModelInputWithInnerConnector.setConnector(connector); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector); assertEquals(registerModelInputWithInnerConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index a797657e00..423cb1ed71 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -265,7 +265,11 @@ public void validateInputSchema(String modelId, MLInput mlInput) { String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString); MLNodeUtils.validateSchema(inputSchemaString, processedInputString); } catch (Exception e) { - throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST); + throw new OpenSearchStatusException( + "Error validating input schema, if you think this is expected, please update your 'input' field in the 'interface' field for this model: " + + e.getMessage(), + RestStatus.BAD_REQUEST + ); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 67948b4957..99d903ba2f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -596,7 +596,11 @@ public void validateOutputSchema(String modelId, ModelTensorOutput output) { output.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString() ); } catch (Exception e) { - throw new OpenSearchStatusException("Error validating output schema: " + e.getMessage(), RestStatus.BAD_REQUEST); + throw new OpenSearchStatusException( + "Error validating output schema, if you think this is expected, please update your 'output' field in the 'interface' field for this model: " + + e.getMessage(), + RestStatus.BAD_REQUEST + ); } } }