diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 61c7c2ebea..9402db1d71 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -740,120 +740,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti assertFalse(((String) responseMap.get("text")).isEmpty()); } - public void testCohereClassifyModel() throws IOException, InterruptedException { - // Skip test if key is null - if (COHERE_KEY == null) { - return; - } - String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"client_config\": {\n" - + " \"max_connection\": 20,\n" - + " \"connection_timeout\": 50000,\n" - + " \"read_timeout\": 50000\n" - + " },\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + COHERE_KEY - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; - Response response = createConnector(entity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("cohere classify model", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"inputs\": [\n" - + " \"Confirm your email address\",\n" - + " \"hey i need u to send some $\"\n" - + " ],\n" - + " \"examples\": [\n" - + " {\n" - + " \"text\": \"Dermatologists don't like her!\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Hello, open to this?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"I need help please wire me $1000 right now\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Nice to know you ;)\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Please help me?\",\n" - + " \"label\": \"Spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Your parcel will be delivered today\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Review changes to our Terms and Conditions\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Weekly sync notes\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Re: Follow up from todays meeting\",\n" - + " \"label\": \"Not spam\"\n" - + " },\n" - + " {\n" - + " \"text\": \"Pre-read for tomorrow\",\n" - + " \"label\": \"Not spam\"\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("classifications"); - assertFalse(responseList.isEmpty()); - } - public static Response createConnector(String input) throws IOException { try { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);