Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Guardrails for Amazon Bedrock #520

Merged
merged 38 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da11a09
wip
fsatsuki Aug 30, 2024
fd7812e
Merge remote-tracking branch 'remotes/github/feature/frontend-kb' int…
fsatsuki Aug 30, 2024
79518a1
Merge branch 'fronend-kb' into feature/guardrails
fsatsuki Aug 30, 2024
7299eed
fontend merge
fsatsuki Aug 30, 2024
4380822
wip
fsatsuki Aug 30, 2024
e634b34
Merge branch 'v1' of https://github.com/aws-samples/bedrock-claude-ch…
fsatsuki Aug 30, 2024
4744425
guardrailの一部機能をデプロイ
fsatsuki Sep 4, 2024
17920cb
bugfix
fsatsuki Sep 4, 2024
5902e3d
mypy, blackの適用
fsatsuki Sep 4, 2024
04733f7
cdkのテストコード修正
fsatsuki Sep 5, 2024
c222809
python formatter
fsatsuki Sep 5, 2024
c27ec0e
fix frontend ci
fsatsuki Sep 5, 2024
243164f
fix migration guide
statefb Sep 17, 2024
82a6912
wip
fsatsuki Sep 30, 2024
6ad2d33
bug fix
fsatsuki Oct 1, 2024
b9b734b
fix migration guide
statefb Oct 3, 2024
b9d400a
add migration guide arch img
statefb Oct 3, 2024
b2cc36b
add unit tests
statefb Oct 3, 2024
41b7427
add: ja helpers
statefb Oct 3, 2024
e486f15
fix: simplify find_public_bot_by_id
statefb Oct 3, 2024
b37015f
nits: simplify websocket.py
statefb Oct 3, 2024
aea3c21
nits: explanation why BedrockRegionResourcesStack needed
statefb Oct 3, 2024
f66255a
fix: textAttachement -> attachment
statefb Oct 3, 2024
c2c7509
refactoring example for compose_args
statefb Oct 3, 2024
bec977f
merge v1
statefb Oct 4, 2024
bf63ea1
fix: bedrock client import error on agent
statefb Oct 4, 2024
dfa63d8
integrate compose_args_for_converse_api
statefb Oct 4, 2024
ad9d7cc
remove unused code
statefb Oct 4, 2024
e088809
sync chat impl
statefb Oct 4, 2024
5a0dd52
chore: mypy
statefb Oct 7, 2024
b1103df
fix: regex for s3 uri
statefb Oct 7, 2024
d4f58ea
fix: not work for multi turn conversaition
statefb Oct 7, 2024
0eb4361
chore: support topK on streaming
statefb Oct 7, 2024
e8a9bfb
fix: unittests
statefb Oct 7, 2024
4efbac8
Merge pull request #2 from fsatsuki/feature/guardrails-chore
fsatsuki Oct 9, 2024
23af30d
lint: black
statefb Oct 9, 2024
a883943
doc: add AWS Backup
statefb Oct 9, 2024
5b2c76d
doc: add v2 notification
statefb Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 148 additions & 12 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_runtime_client
statefb marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

Expand All @@ -24,7 +24,14 @@
else DEFAULT_CLAUDE_GENERATION_CONFIG
)

client = get_bedrock_client()
client = get_bedrock_runtime_client()


class GuardrailConfig(TypedDict):
guardrailIdentifier: str
guardrailVersion: str
trace: str
streamProcessingMode: str


class ConverseApiRequest(TypedDict):
Expand All @@ -34,6 +41,7 @@ class ConverseApiRequest(TypedDict):
messages: list[dict]
stream: bool
system: list[dict]
guardrailConfig: GuardrailConfig | None


class ConverseApiResponseMessageContent(TypedDict):
Expand Down Expand Up @@ -188,23 +196,151 @@ def compose_args_for_converse_api(
return args


def compose_args_for_converse_api_with_guardrail(
statefb marked this conversation as resolved.
Show resolved Hide resolved
messages: list[MessageModel],
model: type_model_name,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
grounding_source: dict | None = None,
guardrail: dict | None = None,
) -> ConverseApiRequest:
"""Compose arguments for Converse API.
Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html
"""
arg_messages = []
for message in messages:
if message.role not in ["system", "instruction"]:
content_blocks = []
for c in message.content:
if c.content_type == "text":
if message.role == "user":
if guardrail and guardrail["grounding_threshold"] > 0:
content_blocks.append({"guardContent": grounding_source})
content_blocks.append(
{
"guardContent": {
"text": {"text": c.body, "qualifiers": ["query"]},
}
}
)
elif message.role == "assistant":
content_blocks.append(
{
"text": (
{"content": c.body}
if isinstance(c.body, str)
else None
)
}
)

elif c.content_type == "image":
# e.g. "image/png" -> "png"
format = (
c.media_type.split("/")[1]
if c.media_type is not None
else "unknown"
)

content_blocks.append(
{
"image": {
"format": format,
# decode base64 encoded image
"source": {"bytes": base64.b64decode(c.body)},
}
}
)

elif c.content_type == "textAttachment":
content_blocks.append(
{
"document": {
"format": _get_converse_supported_format(
Path(c.file_name).suffix[
1:
], # e.g. "document.txt" -> "txt"
),
"name": Path(
c.file_name
).stem, # e.g. "document.txt" -> "document"
# encode text attachment body
"source": {"bytes": c.body.encode("utf-8")},
}
}
)
else:
raise NotImplementedError()
arg_messages.append({"role": message.role, "content": content_blocks})

inference_config = {
**DEFAULT_GENERATION_CONFIG,
**(
{
"maxTokens": generation_params.max_tokens,
"temperature": generation_params.temperature,
"topP": generation_params.top_p,
"stopSequences": generation_params.stop_sequences,
}
if generation_params
else {}
),
}

# `top_k` is configured in `additional_model_request_fields` instead of `inference_config`
additional_model_request_fields = {"top_k": inference_config["top_k"]}
del inference_config["top_k"]

args: ConverseApiRequest = {
"inference_config": convert_dict_keys_to_camel_case(inference_config),
"additional_model_request_fields": additional_model_request_fields,
"model_id": get_model_id(model),
"messages": arg_messages,
"stream": stream,
"system": [],
"guardrailConfig": None, # Initialize with None
}

if instruction:
args["system"].append({"text": instruction})

if guardrail and "guardrail_arn" in guardrail and "guardrail_version" in guardrail:
args["guardrailConfig"] = { # Update the value
"guardrailIdentifier": guardrail["guardrail_arn"],
"guardrailVersion": guardrail["guardrail_version"],
"trace": "enabled",
"streamProcessingMode": "async",
}

return args


def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse:
client = get_bedrock_client()
client = get_bedrock_runtime_client()
messages = args["messages"]
inference_config = args["inference_config"]
additional_model_request_fields = args["additional_model_request_fields"]
model_id = args["model_id"]
system = args["system"]

response = client.converse(
modelId=model_id,
messages=messages,
inferenceConfig=inference_config,
system=system,
additionalModelRequestFields=additional_model_request_fields,
)

return response
if args and "guardrailConfig" in args:
statefb marked this conversation as resolved.
Show resolved Hide resolved
return client.converse(
modelId=model_id,
messages=messages,
inferenceConfig=inference_config,
system=system,
additionalModelRequestFields=additional_model_request_fields,
guardrailConfig=args["guardrailConfig"],
)
else:
return client.converse(
modelId=model_id,
messages=messages,
inferenceConfig=inference_config,
system=system,
additionalModelRequestFields=additional_model_request_fields,
)


def calculate_price(
Expand Down
3 changes: 2 additions & 1 deletion backend/app/bot_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
DOCUMENT_BUCKET = os.environ.get("DOCUMENT_BUCKET", "documents")
BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")

s3_client = boto3.client("s3")
s3_client = boto3.client("s3", BEDROCK_REGION)


def delete_from_postgres(bot_id: str):
Expand Down
41 changes: 41 additions & 0 deletions backend/app/guardrails.py
statefb marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from botocore.exceptions import ClientError
from app.repositories.common import _get_table_client, compose_bot_id
from app.utils import get_bedrock_client


class BotNotFoundException(Exception):
pass


class GuardrailArnRetrievalError(Exception):
pass


def get_guardrail_arn(user_id: str, bot_id: str) -> str:
table = _get_table_client(user_id)
try:
response = table.get_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
ConsistentRead=True,
)
guardrail_arn = (
response["Item"]["GuardrailsParams"]["guardrail_arn"]
if "Item" in response
and "GuardrailsParams" in response["Item"]
and "guardrail_arn" in response["Item"]["GuardrailsParams"]
else ""
)
print(guardrail_arn)
return guardrail_arn
except ClientError as e:
if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
raise BotNotFoundException(f"Bot with id {bot_id} not found")
else:
raise GuardrailArnRetrievalError(
f"Error getting guardrail_arn for bot: {bot_id}: {e}"
)


def delete_guardrail(guardrail_id):
client = get_bedrock_client()
client.delete_guardrail(guardrailIdentifier=guardrail_id)
4 changes: 3 additions & 1 deletion backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
s3_client = boto3.client("s3")

THRESHOLD_LARGE_MESSAGE = 300 * 1024 # 300KB
LARGE_MESSAGE_BUCKET = os.environ.get("LARGE_MESSAGE_BUCKET")

BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")
s3_client = boto3.client("s3", BEDROCK_REGION)


def store_conversation(
user_id: str, conversation: ConversationModel, threshold=THRESHOLD_LARGE_MESSAGE
Expand Down
Loading
Loading