From 0a19fabc18f3cc094ff8aaa22ed4baeb3b81a945 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 21 Feb 2025 16:58:05 +0800 Subject: [PATCH] Add ITs and blueprints in tutorial Signed-off-by: zane-neo --- .../connector/MLPostProcessFunction.java | 9 +- ...ock_connector_titan_embedding_blueprint.md | 23 ++++ .../cohere_connector_embedding_blueprint.md | 23 ++++ ...dRockV2PostProcessFunctionInferenceIT.java | 108 ++++++++++++++++++ .../ml/rest/RestCohereInferenceIT.java | 76 ++++++++++++ .../templates/BedRockV2ConnectorBodies.json | 30 +++++ .../rest/templates/CohereConnectorBodies.json | 26 +++++ 7 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestBedRockV2PostProcessFunctionInferenceIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/CohereConnectorBodies.json diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 3140de52eb..efda9c4743 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -21,12 +21,14 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; + public static final String COHERE_V2_EMBEDDING_FLOAT32 = "connector.post_process.cohere_v2.embedding.float"; public static final String COHERE_V2_EMBEDDING_INT8 = "connector.post_process.cohere_v2.embedding.int8"; - public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.int8"; + public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.uint8"; public static final String COHERE_V2_EMBEDDING_BINARY = "connector.post_process.cohere_v2.embedding.binary"; public static final String COHERE_V2_EMBEDDING_UBINARY = "connector.post_process.cohere_v2.embedding.ubinary"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; + public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float"; public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary"; public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; @@ -46,24 +48,29 @@ public class MLPostProcessFunction { BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); + JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_INT8, "$.embeddings.int8"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UINT8, "$.embeddings.uint8"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_BINARY, "$.embeddings.binary"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UBINARY, "$.embeddings.ubinary"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); + JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float"); + JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary"); JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$"); JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_INT8, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UINT8, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_BINARY, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UBINARY, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); diff --git a/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md index 9c96f60c92..c7fe394005 100644 --- a/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md +++ b/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md @@ -53,6 +53,29 @@ POST /_plugins/_ml/connectors/_create } ``` +If you're using BedRock V2 API, you should supply `embeddingTypes` in request body: +```json +POST /_plugins/_ml/connectors/_create +{ + ... + "parameters": { + ... + "model": "amazon.titan-embed-text-v2:0" + }, + "actions": [ + { + ... + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\"] }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "onnector.post_process.bedrock_v2.embedding.float" + } + ] +} +``` +For BedRock v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format: +1. v2 float: connector.post_process.bedrock_v2.embedding.float +2. v2 binary: connector.post_process.bedrock_v2.embedding.binary + If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) diff --git a/docs/remote_inference_blueprints/cohere_connector_embedding_blueprint.md b/docs/remote_inference_blueprints/cohere_connector_embedding_blueprint.md index 4386251c00..c60091c7be 100644 --- a/docs/remote_inference_blueprints/cohere_connector_embedding_blueprint.md +++ b/docs/remote_inference_blueprints/cohere_connector_embedding_blueprint.md @@ -59,6 +59,29 @@ POST /_plugins/_ml/connectors/_create ] } ``` +If you're using cohere V2 embedding API, you should pass `embedding_types` in the request body +```json +POST /_plugins/_ml/connectors/_create +{ + ... + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://api.cohere.ai/v2/embed", + "request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"float\"], \"input_type\": \"${parameters.input_type}\"}", + "pre_process_function": "connector.pre_process.cohere.embedding", + "post_process_function": "connector.post_process.cohere_v2.embedding.float" + } + ] +} +``` +For cohere v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format: +1. v2 float: connector.post_process.cohere_v2.embedding.float +2. v2 int8: connector.post_process.cohere_v2.embedding.int8 +3. v2 uint8: connector.post_process.cohere_v2.embedding.uint8 +4. v2 binary: connector.post_process.cohere_v2.embedding.binary +5. v2 ubinary: connector.post_process.cohere_v2.embedding.ubinary This request response will return the `connector_id`, note it down. diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockV2PostProcessFunctionInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockV2PostProcessFunctionInferenceIT.java new file mode 100644 index 0000000000..37c2410c94 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockV2PostProcessFunctionInferenceIT.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +@Log4j2 +public class RestBedRockV2PostProcessFunctionInferenceIT extends MLCommonsRestTestCase { + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + private static final List POST_PROCESS_FUNCTIONS = List.of( + "connector.post_process.bedrock_v2.embedding.float", + "connector.post_process.bedrock_v2.embedding.binary" + ); + private static final Map DATA_TYPE = Map.of( + "connector.post_process.bedrock_v2.embedding.float", "FLOAT32", + "connector.post_process.bedrock_v2.embedding.binary", "BINARY" + ); + + @SneakyThrows + @Before + public void setup() throws IOException, InterruptedException { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + } + + public void test_bedrock_embedding_model() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json") + .toURI() + ) + ); + for (String postProcessFunction : POST_PROCESS_FUNCTIONS) { + String bedrockEmbeddingModelName = "bedrock embedding model: " + postProcessFunction; + String modelId = registerRemoteModel( + String + .format( + templates, + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN, + StringUtils.substringAfterLast(postProcessFunction, "."), + postProcessFunction + ), + bedrockEmbeddingModelName, + true + ); + String errorMsg = String.format("failed to test: %s", postProcessFunction); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, output.get(1) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction)); + validateOutput(errorMsg, (Map) output.get(1), DATA_TYPE.get(postProcessFunction)); + } + } + + private void validateOutput(String errorMsg, Map output, String dataType) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, ((Map) outputList.get(0)).get("data_type"), dataType); + } + + private boolean tokenNotSet() { + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return true; + } + return false; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java new file mode 100644 index 0000000000..2e8e8f4266 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java @@ -0,0 +1,76 @@ +package org.opensearch.ml.rest; + +import org.apache.commons.lang3.StringUtils; +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class RestCohereInferenceIT extends MLCommonsRestTestCase { + private final String COHERE_KEY = Optional.ofNullable(System.getenv("COHERE_KEY")).orElse("UzRF34a6gj0OKkvHOO6FZxLItv8CNpK5dFdCaUDW"); + private final Map DATA_TYPE = Map.of( + "connector.post_process.cohere_v2.embedding.float", "FLOAT32", + "connector.post_process.cohere_v2.embedding.int8", "INT8", + "connector.post_process.cohere_v2.embedding.uint8", "UINT8", + "connector.post_process.cohere_v2.embedding.binary", "BINARY", + "connector.post_process.cohere_v2.embedding.ubinary", "UBINARY" + ); + private final List POST_PROCESS_FUNCTIONS = List.of( + "connector.post_process.cohere_v2.embedding.float", + "connector.post_process.cohere_v2.embedding.int8", + "connector.post_process.cohere_v2.embedding.uint8", + "connector.post_process.cohere_v2.embedding.binary", + "connector.post_process.cohere_v2.embedding.ubinary"); + + @Before + public void setup() throws IOException { + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); + } + + + public void test_cohereInference_withDifferent_postProcessFunction() throws URISyntaxException, IOException, InterruptedException { + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/CohereConnectorBodies.json") + .toURI() + ) + ); + for (String postProcessFunction : POST_PROCESS_FUNCTIONS) { + String connectorRequestBody = String.format(templates, COHERE_KEY, StringUtils.substringAfterLast(postProcessFunction, "."), postProcessFunction); + String testCaseName = postProcessFunction + "_test"; + String modelId = registerRemoteModel(connectorRequestBody, testCaseName, true); + String errorMsg = String.format("failed to run test with test name: %s", testCaseName); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction)); + } + } + + private void validateOutput(String errorMsg, Map output, String dataType) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 2, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data_type").equals(dataType)); + } + +} diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json new file mode 100644 index 0000000000..fc020af2b0 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json @@ -0,0 +1,30 @@ +{ + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"%s\"] }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "%s" + } + ] +} diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/CohereConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/CohereConnectorBodies.json new file mode 100644 index 0000000000..eed074fa85 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/CohereConnectorBodies.json @@ -0,0 +1,26 @@ +{ + "name": "Cohere Connector: embedding", + "description": "The connector to cohere embedding model", + "version": 1, + "protocol": "http", + "parameters": { + "model_name": "embed-english-v3.0" + }, + "credential": { + "cohere_key": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://api.cohere.com/v2/embed", + "headers": { + "content-type": "application/json", + "Authorization": "Bearer ${credential.cohere_key}" + }, + "request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"%s\"], \"input_type\": \"classification\"}", + "pre_process_function": "connector.pre_process.cohere.embedding", + "post_process_function": "%s" + } + ] +}