-
Notifications
You must be signed in to change notification settings - Fork 272
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
268 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from abc import abstractmethod | ||
from copy import deepcopy | ||
import json | ||
import os | ||
from typing import Any, Dict, List, Mapping, Optional | ||
|
||
from helm.common.cache import CacheConfig | ||
from helm.proxy.clients.client import CachingClient, truncate_and_tokenize_response_text | ||
from helm.common.request import Request, RequestResult, Sequence, wrap_request_time | ||
from helm.proxy.clients.bedrock_utils import get_bedrock_client | ||
from helm.proxy.tokenizers.tokenizer import Tokenizer | ||
|
||
|
||
JSON_CONTENT_TYPE = "application/json" | ||
|
||
|
||
class BedrockClient(CachingClient): | ||
@abstractmethod | ||
def convert_request_to_raw_request(self, request: Request) -> Dict: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[Sequence]: | ||
raise NotImplementedError() | ||
|
||
def __init__( | ||
self, | ||
cache_config: CacheConfig, | ||
tokenizer: Tokenizer, | ||
tokenizer_name: str, | ||
bedrock_model_id: Optional[str] = None, | ||
assumed_role: Optional[str] = None, | ||
region: Optional[str] = None, | ||
): | ||
super().__init__(cache_config=cache_config) | ||
self.tokenizer = tokenizer | ||
self.tokenizer_name = tokenizer_name | ||
self.bedrock_model_id = bedrock_model_id | ||
self.bedrock_client = get_bedrock_client( | ||
assumed_role=assumed_role or os.environ.get("BEDROCK_ASSUME_ROLE", None), | ||
region=region or os.environ.get("AWS_DEFAULT_REGION", None), | ||
) | ||
|
||
def make_request(self, request: Request) -> RequestResult: | ||
# model_id should be something like "amazon.titan-tg1-large" | ||
model_id = self.bedrock_model_id if self.bedrock_model_id else request.model.replace("/", ".") | ||
raw_request = self.convert_request_to_raw_request(request) | ||
|
||
# modelId isn't part of raw_request, so it must be explicitly passed into the input to | ||
raw_request_for_cache: Dict = {"modelId": model_id, **deepcopy(raw_request)} | ||
cache_key: Mapping = CachingClient.make_cache_key(raw_request_for_cache, request) | ||
|
||
def do_it() -> Dict[Any, Any]: | ||
response = self.bedrock_client.invoke_model( | ||
body=json.dumps(raw_request), modelId=model_id, accept=JSON_CONTENT_TYPE, contentType=JSON_CONTENT_TYPE | ||
) | ||
return json.loads(response.get("body").read()) | ||
|
||
try: | ||
response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) | ||
except Exception as error: | ||
return RequestResult( | ||
success=False, | ||
cached=False, | ||
error=str(error), | ||
completions=[], | ||
embedding=[], | ||
) | ||
|
||
completions = self.convert_raw_response_to_completions(response, request) | ||
|
||
return RequestResult( | ||
success=True, | ||
cached=cached, | ||
request_time=response["request_time"], | ||
request_datetime=response["request_datetime"], | ||
completions=completions, | ||
embedding=[], | ||
) | ||
|
||
|
||
class BedrockTitanClient(BedrockClient): | ||
_COMPLETION_REASON_TO_FINISH_REASON = { | ||
"LENGTH": "length", | ||
"FINISH": "endoftext", | ||
} | ||
|
||
def convert_request_to_raw_request(self, request: Request) -> Dict: | ||
# TODO: Support the following: | ||
# - top_k_per_token | ||
# - echo_prompt | ||
# - num_completions | ||
return { | ||
"inputText": request.prompt, | ||
"textGenerationConfig": { | ||
"maxTokenCount": request.max_tokens, | ||
# We ignore stop sequences in the request and always set stop sequences to the empty list. | ||
# This is because: | ||
# | ||
# 1. The only permitted stop sequences are "|" and "User:" | ||
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html | ||
# - https://github.com/boto/boto3/issues/3993 | ||
# - https://github.com/aws/aws-sdk/issues/692 | ||
# | ||
# 2. Titan has the tendency to emit "\n" as the first token in the generated text output, | ||
# which would cause the output to stop immediately if "\n" is in the stop_sequences. | ||
"stopSequences": [], | ||
"temperature": request.temperature, | ||
"topP": request.top_p, | ||
}, | ||
} | ||
|
||
def convert_raw_response_to_completions(self, response: Dict, request: Request) -> List[Sequence]: | ||
# TODO: Support the following: | ||
# - tokens | ||
# - logprob | ||
completions: List[Sequence] = [] | ||
for raw_completion in response["results"]: | ||
output_text = raw_completion["outputText"] | ||
# Call lstrip() Titan has the tendency to emit "\n" as the first token in the generated text output. | ||
finish_reason = BedrockTitanClient._COMPLETION_REASON_TO_FINISH_REASON.get( | ||
raw_completion["completionReason"], raw_completion["completionReason"].lower() | ||
) | ||
completion = truncate_and_tokenize_response_text( | ||
output_text.lstrip(), request, self.tokenizer, self.tokenizer_name, finish_reason | ||
) | ||
completions.append(completion) | ||
return completions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Helper utilities for working with Amazon Bedrock.""" | ||
|
||
import os | ||
from typing import Optional | ||
|
||
from helm.common.hierarchical_logger import hlog | ||
from helm.common.optional_dependencies import handle_module_not_found_error | ||
|
||
try: | ||
import boto3 | ||
from botocore.config import Config | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["aws"]) | ||
|
||
|
||
# From https://github.com/aws-samples/amazon-bedrock-workshop/blob/main/01_Generation/00_generate_w_bedrock.ipynb | ||
# MIT-0 Licensed | ||
def get_bedrock_client( | ||
assumed_role: Optional[str] = None, | ||
region: Optional[str] = None, | ||
runtime: Optional[bool] = True, | ||
): | ||
"""Create a boto3 client for Amazon Bedrock, with optional configuration overrides | ||
Parameters | ||
---------- | ||
assumed_role : | ||
Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not | ||
specified, the current active credentials will be used. | ||
region : | ||
Optional name of the AWS Region in which the service should be called (e.g. "us-east-1"). | ||
If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used. | ||
runtime : | ||
Optional choice of getting different client to perform operations with the Amazon Bedrock service. | ||
""" | ||
if region is None: | ||
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")) | ||
else: | ||
target_region = region | ||
|
||
session_kwargs = {"region_name": target_region} | ||
client_kwargs = {**session_kwargs} | ||
|
||
profile_name = os.environ.get("AWS_PROFILE") | ||
if profile_name: | ||
session_kwargs["profile_name"] = profile_name | ||
|
||
retry_config = Config( | ||
region_name=target_region, | ||
retries={ | ||
"max_attempts": 10, | ||
"mode": "standard", | ||
}, | ||
) | ||
session = boto3.Session(**session_kwargs) | ||
|
||
if assumed_role: | ||
sts = session.client("sts") | ||
response = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="crfm-helm") | ||
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"] | ||
client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"] | ||
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"] | ||
|
||
if runtime: | ||
service_name = "bedrock-runtime" | ||
else: | ||
service_name = "bedrock" | ||
|
||
bedrock_client = session.client(service_name=service_name, config=retry_config, **client_kwargs) | ||
|
||
hlog(f"Amazon Bedrock client successfully created with endpoint {bedrock_client._endpoint}") | ||
return bedrock_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters