Skip to content

Commit

Permalink
AWS Bedrock Converse API (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
kamillitman authored Jan 22, 2025
1 parent b99567e commit afa5cc7
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@1.73
- uses: dtolnay/rust-toolchain@1.82
with:
components: clippy, rustfmt
- run: cargo clippy -- --deny warnings
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ __pycache__
target/
.idea
examples/secrets
examples/data
examples/data
.env
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "allms"
version = "0.10.0"
version = "0.11.0"
edition = "2021"
authors = [
"Kamil Litman <[email protected]>",
Expand All @@ -16,6 +16,8 @@ categories = ["api-bindings", "development-tools", "parsing", "science", "text-p

[dependencies]
anyhow = "1.0.60"
aws-config = "1.5.4"
aws-sdk-bedrockruntime = "1.40.0"
env_logger = "0.9.0"
jsonschema = "=0.15.2"
log = "0.4.0"
Expand Down
41 changes: 25 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This Rust library is specialized in providing type-safe interactions with APIs o

## Features

- Support for various LLM models including OpenAI (GPT-3.5, GPT-4), Anthropic (Claude, Claude Instant), Mistral, Google GeminiPro, and Perplexity.
- Support for various foundational LLM providers including Anthropic, AWS Bedrock, Azure, Google Gemini, OpenAI, Mistral, and Perplexity.
- Easy-to-use functions for chat/text completions and assistants. Use the same struct and methods regardless of which model you choose.
- Automated response deserialization to custom types.
- Standardized approach to providing context with support of function calling, tools, and file uploads.
Expand All @@ -15,59 +15,68 @@ This Rust library is specialized in providing type-safe interactions with APIs o
- Asynchronous support using Tokio.

### Foundational Models
OpenAI:
- APIs: Chat Completions, Function Calling, Assistants (v1 & v2), Files, Vector Stores, Tools (file_search)
- Models: o1 Preview, o1 Mini (Chat Completions only), GPT-4o, GPT-4, GPT-4 32k, GPT-4 Turbo, GPT-3.5 Turbo, GPT-3.5 Turbo 16k, fine-tuned models (via `Custom` variant)
Anthropic:
- APIs: Messages, Text Completions
- Models: Claude 3.5 Sonnet, Claude 3 Opus, Claude 3 Sonnet, Claude 3 Haiku, Claude 2.0, Claude Instant 1.2

Azure OpenAI:
- APIs: Assistants, Files, Vector Stores, Tools
- API version can be set using `AzureVersion` variant
- Models: as per model deployments in Azure OpenAI Studio
- If using custom model deployment names please use the `Custom` variant of `OpenAIModels`

Anthropic:
- APIs: Messages, Text Completions
- Models: Claude 3.5 Sonnet, Claude 3 Opus, Claude 3 Sonnet, Claude 3 Haiku, Claude 2.0, Claude Instant 1.2
AWS Bedrock:
- APIs: Converse
- Models: Nova Micro, Nova Lite, Nova Pro (additional models to be added)

Google Vertex AI / AI Studio:
- APIs: Chat Completions (including streaming)
- Models: Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 1.0 Pro

Mistral:
- APIs: Chat Completions
- Models: Mistral Large, Mistral Nemo, Mistral 7B, Mixtral 8x7B, Mixtral 8x22B, Mistral Medium, Mistral Small, Mistral Tiny

Google Vertex AI / AI Studio:
- APIs: Chat Completions (including streaming)
- Models: Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 1.0 Pro
OpenAI:
- APIs: Chat Completions, Function Calling, Assistants (v1 & v2), Files, Vector Stores, Tools (file_search)
- Models: o1 Preview, o1 Mini (Chat Completions only), GPT-4o, GPT-4, GPT-4 32k, GPT-4 Turbo, GPT-3.5 Turbo, GPT-3.5 Turbo 16k, fine-tuned models (via `Custom` variant)

Perplexity:
- APIs: Chat Completions
- Models: Llama 3.1 Sonar Small, Llama 3.1 Sonar Large, Llama 3.1 Sonar Huge

### Prerequisites
- OpenAI: API key (passed in model constructor)
- Azure OpenAI: environment variable `OPENAI_API_URL` set to your Azure OpenAI resource endpoint. Endpoint key passed in constructor
- Anthropic: API key (passed in model constructor)
- Mistral: API key (passed in model constructor)
- Azure OpenAI: environment variable `OPENAI_API_URL` set to your Azure OpenAI resource endpoint. Endpoint key passed in constructor
- AWS Bedrock: environment variables `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_REGION` set as per AWS settings.
- Google AI Studio: API key (passed in model constructor)
- Google Vertex AI: GCP service account key (used to obtain access token) + GCP project ID (set as environment variable)
- Mistral: API key (passed in model constructor)
- OpenAI: API key (passed in model constructor)
- Perplexity: API key (passed in model constructor)

### Examples
Explore the `examples` directory to see more use cases and how to use different LLM providers and endpoint types.

Using `Completions` API with different foundational models:
```
let openai_answer = Completions::new(OpenAIModels::Gpt4o, &API_KEY, None, None)
let anthropic_answer = Completions::new(AnthropicModels::Claude2, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
let anthropic_answer = Completions::new(AnthropicModels::Claude2, &API_KEY, None, None)
let aws_bedrock_answer = Completions::new(AwsBedrockModels::NovaLite, "", None, None)
.get_answer::<T>(instructions)
.await?
let google_answer = Completions::new(GoogleModels::GeminiPro, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
let mistral_answer = Completions::new(MistralModels::MistralSmall, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
let google_answer = Completions::new(GoogleModels::GeminiPro, &API_KEY, None, None)
let openai_answer = Completions::new(OpenAIModels::Gpt4o, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
Expand Down
21 changes: 20 additions & 1 deletion examples/use_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use serde::Deserialize;
use serde::Serialize;

use allms::{
llm::{AnthropicModels, GoogleModels, LLMModel, MistralModels, OpenAIModels, PerplexityModels},
llm::{
AnthropicModels, AwsBedrockModels, GoogleModels, LLMModel, MistralModels, OpenAIModels,
PerplexityModels,
},
Completions,
};

Expand All @@ -23,6 +26,22 @@ async fn main() {
let instructions =
"Translate the following English sentence to all the languages in the response type: Rust is best for working with LLMs";

// Get answer using AWS Bedrock Converse
// AWS Bedrock SDK requires `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be defined and matching your AWS account
let model = AwsBedrockModels::try_from_str("amazon.nova-lite-v1:0")
.unwrap_or(AwsBedrockModels::NovaLite); // Choose the model
println!("AWS Bedrock model: {:#?}", model.as_str());

let aws_completion = Completions::new(model, "", None, None);

match aws_completion
.get_answer::<TranslationResponse>(instructions)
.await
{
Ok(response) => println!("AWS Bedrock response: {:#?}", response),
Err(e) => eprintln!("Error: {:?}", e),
}

// Get answer using OpenAI
let openai_api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let model = OpenAIModels::try_from_str("gpt-4o-mini").unwrap_or(OpenAIModels::Gpt4oMini); // Choose the model
Expand Down
15 changes: 15 additions & 0 deletions src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ lazy_static! {
.unwrap_or("https://api.perplexity.ai/chat/completions".to_string());
}

lazy_static! {
/// Docs: https://docs.aws.amazon.com/general/latest/gr/bedrock.html
pub(crate) static ref AWS_REGION: String = std::env::var("AWS_REGION").unwrap_or("us-east-1".to_string());
pub(crate) static ref AWS_BEDROCK_API_URL: String = {
format!("https://bedrock.{}.amazonaws.com", &*AWS_REGION)
};
}

lazy_static! {
pub(crate) static ref AWS_ACCESS_KEY_ID: String =
std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not set");
pub(crate) static ref AWS_SECRET_ACCESS_KEY: String =
std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not set");
}

//Generic OpenAI instructions
pub(crate) const OPENAI_BASE_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks:
Step 1: Review and understand the 'instructions' from the *Instructions* section.
Expand Down
1 change: 0 additions & 1 deletion src/deprecated/openai_completions_deprecated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ pub struct OpenAI {
}

impl OpenAI {
///
pub fn new(
open_ai_key: &str,
model: OpenAIModels,
Expand Down
197 changes: 197 additions & 0 deletions src/llm_models/aws.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
use crate::utils::sanitize_json_response;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::{
types::{ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock},
Client,
};
use log::info;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::constants::{AWS_BEDROCK_API_URL, AWS_REGION};
use crate::domain::RateLimit;
use crate::llm_models::LLMModel;

#[derive(Serialize, Deserialize)]
struct AwsBedrockRequestBody {
instructions: String,
json_schema: Value,
max_tokens: i32,
temperature: f32,
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
// AWS Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
pub enum AwsBedrockModels {
NovaPro,
NovaLite,
NovaMicro,
}

#[async_trait(?Send)]
impl LLMModel for AwsBedrockModels {
fn as_str(&self) -> &str {
match self {
AwsBedrockModels::NovaPro => "amazon.nova-pro-v1:0",
AwsBedrockModels::NovaLite => "amazon.nova-lite-v1:0",
AwsBedrockModels::NovaMicro => "amazon.nova-micro-v1:0",
}
}

fn try_from_str(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"amazon.nova-pro-v1:0" => Some(AwsBedrockModels::NovaPro),
"amazon.nova-lite-v1:0" => Some(AwsBedrockModels::NovaLite),
"amazon.nova-micro-v1:0" => Some(AwsBedrockModels::NovaMicro),
_ => None,
}
}

fn default_max_tokens(&self) -> usize {
match self {
AwsBedrockModels::NovaPro => 5_120,
AwsBedrockModels::NovaLite => 5_120,
AwsBedrockModels::NovaMicro => 5_120,
}
}

fn get_endpoint(&self) -> String {
format!("{}/model/{}/converse", &*AWS_BEDROCK_API_URL, self.as_str())
}

/// AWS Bedrock implementation leverages AWS Bedrock SKD, therefore data is only passed by this method to `call_api` method where the actual logic is implemented
fn get_body(
&self,
instructions: &str,
json_schema: &Value,
_function_call: bool,
max_tokens: &usize,
temperature: &f32,
) -> serde_json::Value {
let body = AwsBedrockRequestBody {
instructions: instructions.to_string(),
json_schema: json_schema.clone(),
max_tokens: *max_tokens as i32,
temperature: *temperature,
};

// Return the body serialized as a JSON value
serde_json::to_value(body).unwrap()
}

/// This function leverages AWS Bedrock SDK to perform any query as per the provided body.
async fn call_api(
&self,
// AWS Bedrock SDK utilizes `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables for request authentication
// Docs: https://docs.aws.amazon.com/sdk-for-rust/latest/dg/credproviders.html
_api_key: &str,
body: &serde_json::Value,
debug: bool,
) -> Result<String> {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region(&**AWS_REGION)
.load()
.await;
let client = Client::new(&sdk_config);

// Get request info from body
let request_body_opt: Option<AwsBedrockRequestBody> =
serde_json::from_value(body.clone()).ok();
let (instructions_opt, json_schema_opt, max_tokens_opt, temperature_opt) = request_body_opt
.map_or_else(
|| (None, None, None, None),
|request_body| {
(
Some(request_body.instructions),
Some(request_body.json_schema),
Some(request_body.max_tokens),
Some(request_body.temperature),
)
},
);

// Get base instructions
let base_instructions = self.get_base_instructions(None);

let converse_builder = client
.converse()
.model_id(self.as_str())
.system(SystemContentBlock::Text(base_instructions));

// Add user instructions including the expected output schema if specifed
let instructions = instructions_opt.unwrap_or_default();
let user_instructions = json_schema_opt
.map(|schema| {
format!(
"Output Json schema:\n
{schema}\n\n
{instructions}"
)
})
.unwrap_or(instructions);
let converse_builder = converse_builder.messages(
Message::builder()
.role(ConversationRole::User)
.content(ContentBlock::Text(user_instructions))
.build()
.map_err(|_| anyhow!("failed to build message"))?,
);

// If specified add inference config
let converse_builder = if max_tokens_opt.is_some() || temperature_opt.is_some() {
let inference_config = InferenceConfiguration::builder()
.set_max_tokens(max_tokens_opt)
.set_temperature(temperature_opt)
.build();
converse_builder.set_inference_config(Some(inference_config))
} else {
converse_builder
};

// Send request
let converse_response = converse_builder.send().await?;

if debug {
info!(
"[debug] AWS Bedrock API response: {:#?}",
&converse_response
);
}

//Parse the response and return the assistant content
let text = converse_response
.output()
.ok_or(anyhow!("no output"))?
.as_message()
.map_err(|_| anyhow!("output not a message"))?
.content()
.first()
.ok_or(anyhow!("no content in message"))?
.as_text()
.map_err(|_| anyhow!("content is not text"))?
.to_string();
Ok(sanitize_json_response(&text))
}

/// AWS Bedrock implementation leverages AWS Bedrock SDK, therefore data extraction is implemented directly in `call_api` method and this method only passes the data on
fn get_data(&self, response_text: &str, _function_call: bool) -> Result<String> {
Ok(response_text.to_string())
}

//This function allows to check the rate limits for different models
fn get_rate_limit(&self) -> RateLimit {
// Docs: https://docs.aws.amazon.com/general/latest/gr/bedrock.html
match self {
AwsBedrockModels::NovaPro => RateLimit {
tpm: 400_000,
rpm: 100,
},
AwsBedrockModels::NovaLite | AwsBedrockModels::NovaMicro => RateLimit {
tpm: 2_000_000,
rpm: 1_000,
},
}
}
}
2 changes: 1 addition & 1 deletion src/llm_models/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl LLMModel for GoogleModels {
.send()
.await?;

//For Vertex we are streaming that data spo we need to deserialize each chunk separately
//For Vertex we are streaming that data so we need to deserialize each chunk separately
// Check if the API uses streaming
if response.status().is_success() {
let mut stream = response.bytes_stream();
Expand Down
Loading

0 comments on commit afa5cc7

Please sign in to comment.