diff --git a/app/bedrock.py b/app/bedrock.py index 280632931..5933864b9 100644 --- a/app/bedrock.py +++ b/app/bedrock.py @@ -36,10 +36,20 @@ def model_dump(self, *args, **kwargs): # Main client class for interacting with Amazon Bedrock class BedrockClient: - def __init__(self): + def __init__( + self, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_region_name: str | None = None, + ): # Initialize Bedrock client, you need to configure AWS env first try: - self.client = boto3.client("bedrock-runtime") + self.client = boto3.client( + "bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) self.chat = Chat(self.client) except Exception as e: print(f"Error initializing Bedrock client: {e}") @@ -173,9 +183,9 @@ def _convert_bedrock_response_to_openai_format(self, bedrock_response): "role": bedrock_response.get("output", {}) .get("message", {}) .get("role", "assistant"), - "tool_calls": openai_tool_calls - if openai_tool_calls != [] - else None, + "tool_calls": ( + openai_tool_calls if openai_tool_calls != [] else None + ), "function_call": None, }, } diff --git a/app/config.py b/app/config.py index eb91cf432..186089f95 100644 --- a/app/config.py +++ b/app/config.py @@ -25,8 +25,13 @@ class LLMSettings(BaseModel): description="Maximum input tokens to use across all requests (None for unlimited)", ) temperature: float = Field(1.0, description="Sampling temperature") - api_type: str = Field(..., description="Azure, Openai, or Ollama") + api_type: str = Field(..., description="Azure, Openai, Ollama or Bedrock") api_version: str = Field(..., description="Azure Openai version if AzureOpenai") + aws_access_key_id: str = Field(..., description="Aws access key id if Bedrock") + aws_secret_access_key: str = Field( + ..., description="Aws secret access key if Bedrock" + ) + aws_region_name: str = Field(..., description="Aws region name if Bedrock") class ProxySettings(BaseModel): @@ -175,6 +180,9 @@ def _load_initial_config(self): "temperature": base_llm.get("temperature", 1.0), "api_type": base_llm.get("api_type", ""), "api_version": base_llm.get("api_version", ""), + "aws_access_key_id": base_llm.get("aws_access_key_id", ""), + "aws_secret_access_key": base_llm.get("aws_secret_access_key", ""), + "aws_region_name": base_llm.get("aws_region_name", ""), } # handle browser config. diff --git a/app/llm.py b/app/llm.py index 37d493b76..ebf1ba1d8 100644 --- a/app/llm.py +++ b/app/llm.py @@ -203,6 +203,9 @@ def __init__( self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url + self.aws_access_key_id = llm_config.aws_access_key_id + self.aws_secret_access_key = llm_config.aws_secret_access_key + self.aws_region_name = llm_config.aws_region_name # Add token counting related attributes self.total_input_tokens = 0 @@ -227,7 +230,11 @@ def __init__( api_version=self.api_version, ) elif self.api_type == "aws": - self.client = BedrockClient() + self.client = BedrockClient( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_region_name=self.aws_region_name, + ) else: self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) diff --git a/config/config.example.toml b/config/config.example.toml index e4560513a..2f6cdff11 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -13,6 +13,9 @@ temperature = 0.0 # Controls randomness # max_tokens = 8192 # temperature = 1.0 # api_key = "bear" # Required but not used for Bedrock +# aws_access_key_id = "" # Required +# aws_secret_access_key = "" # Required +# aws_region_name = "us-west-2" # Required # [llm] #AZURE OPENAI: # api_type= 'azure'