Skip to content

Commit

Permalink
Add ITs and blueprints in tutorial
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 21, 2025
1 parent c3d50eb commit 0a19fab
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String COHERE_V2_EMBEDDING_FLOAT32 = "connector.post_process.cohere_v2.embedding.float";
public static final String COHERE_V2_EMBEDDING_INT8 = "connector.post_process.cohere_v2.embedding.int8";
public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.int8";
public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.uint8";
public static final String COHERE_V2_EMBEDDING_BINARY = "connector.post_process.cohere_v2.embedding.binary";
public static final String COHERE_V2_EMBEDDING_UBINARY = "connector.post_process.cohere_v2.embedding.ubinary";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float";
public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
Expand All @@ -46,24 +48,29 @@ public class MLPostProcessFunction {
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_INT8, "$.embeddings.int8");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UINT8, "$.embeddings.uint8");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_BINARY, "$.embeddings.binary");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UBINARY, "$.embeddings.ubinary");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_INT8, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UINT8, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_BINARY, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UBINARY, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ POST /_plugins/_ml/connectors/_create
}
```

If you're using BedRock V2 API, you should supply `embeddingTypes` in request body:
```json
POST /_plugins/_ml/connectors/_create
{
...
"parameters": {
...
"model": "amazon.titan-embed-text-v2:0"
},
"actions": [
{
...
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\"] }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "onnector.post_process.bedrock_v2.embedding.float"
}
]
}
```
For BedRock v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format:
1. v2 float: connector.post_process.bedrock_v2.embedding.float
2. v2 binary: connector.post_process.bedrock_v2.embedding.binary

If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service.
Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ POST /_plugins/_ml/connectors/_create
]
}
```
If you're using cohere V2 embedding API, you should pass `embedding_types` in the request body
```json
POST /_plugins/_ml/connectors/_create
{
...
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://api.cohere.ai/v2/embed",
"request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"float\"], \"input_type\": \"${parameters.input_type}\"}",
"pre_process_function": "connector.pre_process.cohere.embedding",
"post_process_function": "connector.post_process.cohere_v2.embedding.float"
}
]
}
```
For cohere v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format:
1. v2 float: connector.post_process.cohere_v2.embedding.float
2. v2 int8: connector.post_process.cohere_v2.embedding.int8
3. v2 uint8: connector.post_process.cohere_v2.embedding.uint8
4. v2 binary: connector.post_process.cohere_v2.embedding.binary
5. v2 ubinary: connector.post_process.cohere_v2.embedding.ubinary

This request response will return the `connector_id`, note it down.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.StringUtils;
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

@Log4j2
public class RestBedRockV2PostProcessFunctionInferenceIT extends MLCommonsRestTestCase {
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
private static final List<String> POST_PROCESS_FUNCTIONS = List.of(
"connector.post_process.bedrock_v2.embedding.float",
"connector.post_process.bedrock_v2.embedding.binary"
);
private static final Map<String, String> DATA_TYPE = Map.of(
"connector.post_process.bedrock_v2.embedding.float", "FLOAT32",
"connector.post_process.bedrock_v2.embedding.binary", "BINARY"
);

@SneakyThrows
@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);
}

public void test_bedrock_embedding_model() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json")
.toURI()
)
);
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) {
String bedrockEmbeddingModelName = "bedrock embedding model: " + postProcessFunction;
String modelId = registerRemoteModel(
String
.format(
templates,
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN,
StringUtils.substringAfterLast(postProcessFunction, "."),
postProcessFunction
),
bedrockEmbeddingModelName,
true
);
String errorMsg = String.format("failed to test: %s", postProcessFunction);
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 2, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, output.get(1) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction));
validateOutput(errorMsg, (Map) output.get(1), DATA_TYPE.get(postProcessFunction));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type"), dataType);
}

private boolean tokenNotSet() {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
log.info("#### The AWS credentials are not set. Skipping test. ####");
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.opensearch.ml.rest;

import org.apache.commons.lang3.StringUtils;
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class RestCohereInferenceIT extends MLCommonsRestTestCase {
private final String COHERE_KEY = Optional.ofNullable(System.getenv("COHERE_KEY")).orElse("UzRF34a6gj0OKkvHOO6FZxLItv8CNpK5dFdCaUDW");
private final Map<String, String> DATA_TYPE = Map.of(
"connector.post_process.cohere_v2.embedding.float", "FLOAT32",
"connector.post_process.cohere_v2.embedding.int8", "INT8",
"connector.post_process.cohere_v2.embedding.uint8", "UINT8",
"connector.post_process.cohere_v2.embedding.binary", "BINARY",
"connector.post_process.cohere_v2.embedding.ubinary", "UBINARY"
);
private final List<String> POST_PROCESS_FUNCTIONS = List.of(
"connector.post_process.cohere_v2.embedding.float",
"connector.post_process.cohere_v2.embedding.int8",
"connector.post_process.cohere_v2.embedding.uint8",
"connector.post_process.cohere_v2.embedding.binary",
"connector.post_process.cohere_v2.embedding.ubinary");

@Before
public void setup() throws IOException {
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$"));
}


public void test_cohereInference_withDifferent_postProcessFunction() throws URISyntaxException, IOException, InterruptedException {
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/CohereConnectorBodies.json")
.toURI()
)
);
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) {
String connectorRequestBody = String.format(templates, COHERE_KEY, StringUtils.substringAfterLast(postProcessFunction, "."), postProcessFunction);
String testCaseName = postProcessFunction + "_test";
String modelId = registerRemoteModel(connectorRequestBody, testCaseName, true);
String errorMsg = String.format("failed to run test with test name: %s", testCaseName);
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 2, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type").equals(dataType));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v2:0"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"%s\"] }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "%s"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"name": "Cohere Connector: embedding",
"description": "The connector to cohere embedding model",
"version": 1,
"protocol": "http",
"parameters": {
"model_name": "embed-english-v3.0"
},
"credential": {
"cohere_key": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://api.cohere.com/v2/embed",
"headers": {
"content-type": "application/json",
"Authorization": "Bearer ${credential.cohere_key}"
},
"request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"%s\"], \"input_type\": \"classification\"}",
"pre_process_function": "connector.pre_process.cohere.embedding",
"post_process_function": "%s"
}
]
}

0 comments on commit 0a19fab

Please sign in to comment.