Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve logic for compute units calculation #407

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 314 additions & 1 deletion atoma-service/src/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use axum::{
use opentelemetry::KeyValue;
use reqwest::Client;
use serde_json::{json, Value};
use tokenizers::Tokenizer;
use tracing::{info, instrument};
use utoipa::OpenApi;

Expand All @@ -43,20 +44,35 @@ use crate::{
middleware::RequestMetadata,
};

use super::{handle_confidential_compute_encryption_response, handle_status_code_error};
use super::{
handle_confidential_compute_encryption_response, handle_status_code_error,
request_model::RequestModel, DEFAULT_MAX_TOKENS,
};

/// The path for confidential chat completions requests
pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions";

/// The key for the content parameter in the request body
pub const CONTENT_KEY: &str = "content";

/// The path for chat completions requests
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";

/// The keep-alive interval in seconds
const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15;

/// The key for the max_completion_tokens parameter in the request body
const MAX_COMPLETION_TOKENS_KEY: &str = "max_completion_tokens";

/// The key for the max_tokens parameter in the request body
const MAX_TOKENS_KEY: &str = "max_tokens";

/// The key for the model parameter in the request body
const MODEL_KEY: &str = "model";

/// The key for the messages parameter in the request body
const MESSAGES_KEY: &str = "messages";

/// The key for the stream parameter in the request body
const STREAM_KEY: &str = "stream";

Expand Down Expand Up @@ -734,6 +750,119 @@ async fn handle_streaming_response(
Ok(stream.into_response())
}

/// Represents a chat completion request model following the OpenAI API format
pub struct RequestModelChatCompletions {
/// Array of message objects that represent the conversation history
/// Each message should contain a "role" (system/user/assistant) and "content"
/// The content can be a string or an array of content parts.
messages: Vec<Value>,

/// The maximum number of tokens to generate in the completion
/// This limits the length of the model's response
max_completion_tokens: u64,
}

impl RequestModel for RequestModelChatCompletions {
fn new(request: &Value) -> Result<Self, AtomaServiceError> {
let messages = request
.get(MESSAGES_KEY)
.and_then(|m| m.as_array())
.ok_or_else(|| AtomaServiceError::InvalidBody {
message: "Missing or invalid 'messages' field".to_string(),
endpoint: CHAT_COMPLETIONS_PATH.to_string(),
})?;

let max_completion_tokens = request
.get(MAX_COMPLETION_TOKENS_KEY)
.or_else(|| request.get(MAX_TOKENS_KEY))
.and_then(serde_json::Value::as_u64)
.unwrap_or(DEFAULT_MAX_TOKENS);

Ok(Self {
messages: messages.clone(),
max_completion_tokens,
})
}

/// Computes the total number of tokens for the chat completion request.
///
/// This is used to estimate the cost of the chat completion request, on the proxy side.
/// We support either string or array of content parts. We further assume that all content messages
/// share the same previous messages. That said, we further assume that content parts formatted into arrays
/// are to be concatenated and treated as a single message, by the model and from the estimate point of view.
fn get_compute_units_estimate(
&self,
tokenizer: Option<&Tokenizer>,
) -> Result<u64, AtomaServiceError> {
// In order to account for the possibility of not taking into account possible additional special tokens,
// which might not be considered by the tokenizer, we add a small overhead to the total number of tokens, per message.
const MESSAGE_OVERHEAD_TOKENS: u64 = 3;
let Some(tokenizer) = tokenizer else {
return Err(AtomaServiceError::InternalError {
message: "Tokenizer is required for current model, but is not currently available"
.to_string(),
endpoint: CHAT_COMPLETIONS_PATH.to_string(),
});
};
// Helper function to count tokens for a text string
let count_text_tokens = |text: &str| -> Result<u64, AtomaServiceError> {
Ok(tokenizer
.encode(text, true)
.map_err(|err| AtomaServiceError::InternalError {
message: format!("Failed to encode message: {err:?}"),
endpoint: CHAT_COMPLETIONS_PATH.to_string(),
})?
.get_ids()
.len() as u64)
};

let mut total_num_tokens = 0;

for message in &self.messages {
let content = message
.get(CONTENT_KEY)
.and_then(|content| MessageContent::deserialize(content).ok())
.ok_or_else(|| AtomaServiceError::InvalidBody {
message: "Missing or invalid message content".to_string(),
endpoint: CHAT_COMPLETIONS_PATH.to_string(),
})?;

match content {
MessageContent::Text(text) => {
let num_tokens = count_text_tokens(&text)?;
total_num_tokens += num_tokens + MESSAGE_OVERHEAD_TOKENS;
}
MessageContent::Array(parts) => {
if parts.is_empty() {
tracing::error!(
"Received empty array of message parts for chat completion request"
);
return Err(AtomaServiceError::InvalidBody {
message: "Missing or invalid message content".to_string(),
endpoint: CHAT_COMPLETIONS_PATH.to_string(),
});
}
for part in parts {
match part {
MessageContentPart::Text { text, .. } => {
let num_tokens = count_text_tokens(&text)?;
total_num_tokens += num_tokens + MESSAGE_OVERHEAD_TOKENS;
}
MessageContentPart::Image { .. } => {
// TODO: Ensure that for image content parts, we have a way to estimate the number of tokens,
// which can depend on the size of the image and the output description.
continue;
}
}
}
}
}
}
// add the max completion tokens, to account for the response
Ok(total_num_tokens + self.max_completion_tokens)
}
}

#[derive(Debug, PartialEq, Serialize, Deserialize, ToSchema)]
#[serde(rename(serialize = "requestBody", deserialize = "RequestBody"))]
pub struct ChatCompletionsRequest {
Expand Down Expand Up @@ -1541,3 +1670,187 @@ pub mod utils {
Ok(response_body)
}
}

#[cfg(test)]
mod tests {
use super::*;

use serde_json::json;
use std::str::FromStr;
use tokenizers::Tokenizer;

async fn load_tokenizer() -> Tokenizer {
let url =
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/raw/main/tokenizer.json";
let tokenizer_json = reqwest::get(url).await.unwrap().text().await.unwrap();

Tokenizer::from_str(&tokenizer_json).unwrap()
}

#[tokio::test]
async fn test_get_compute_units_estimate() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
"content": "Hello from the other side of Mars"
})],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
assert_eq!(result.unwrap(), 21); // 8 tokens + 3 overhead + 10 completion
}

#[tokio::test]
async fn test_get_compute_units_estimate_multiple_messages() {
let request = RequestModelChatCompletions {
messages: vec![
json!({
"role": "user",
"content": "Hello from the other side of Mars"
}),
json!({
"role": "assistant",
"content": "Hello from the other side of Mars"
}),
],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
assert_eq!(result.unwrap(), 32); // (8+8) tokens + (3+3) overhead + 10 completion
}

#[tokio::test]
async fn test_get_compute_units_estimate_array_content() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
"content": [
{
"type": "text",
"text": "Hello from the other side of Mars"
},
{
"type": "text",
"text": "Hello from the other side of Mars"
}
]
})],
max_completion_tokens: 10,
};

let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
assert_eq!(result.unwrap(), 32); // (8+8) tokens (3 + 3) overhead + 10 completion
}

#[tokio::test]
async fn test_get_compute_units_estimate_empty_message() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
"content": ""
})],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
assert_eq!(result.unwrap(), 14); // 1 tokens (special token) + 3 overhead + 10 completion
}

#[tokio::test]
async fn test_get_compute_units_estimate_mixed_content() {
let request = RequestModelChatCompletions {
messages: vec![
json!({
"role": "system",
"content": "Hello from the other side of Mars"
}),
json!({
"role": "user",
"content": [
{
"type": "text",
"text": "Hello from the other side of Mars"
},
{
"type": "image",
"image_url": {
"url": "http://example.com/image.jpg"
}
},
{
"type": "text",
"text": "Hello from the other side of Mars"
}
]
}),
],
max_completion_tokens: 15,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
// System message: tokens + 15 completion
// User message array: (2 text parts tokens) + (15 * 2 for text completion for parts)
let tokens = result.unwrap();
assert_eq!(tokens, 48); // 3 * 8 + 3 * 3 overhead + 15
}

#[tokio::test]
async fn test_get_compute_units_estimate_invalid_content() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
// Missing "content" field
})],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AtomaServiceError::InvalidBody { .. }
));
}

#[tokio::test]
async fn test_get_compute_units_estimate_empty_array_content() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
"content": []
})],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AtomaServiceError::InvalidBody { .. }
));
}

#[tokio::test]
async fn test_get_compute_units_estimate_special_characters() {
let request = RequestModelChatCompletions {
messages: vec![json!({
"role": "user",
"content": "Hello! 👋 🌍 \n\t Special chars: &*#@"
})],
max_completion_tokens: 10,
};
let tokenizer = load_tokenizer().await;
let result = request.get_compute_units_estimate(Some(&tokenizer));
assert!(result.is_ok());
let tokens = result.unwrap();
assert!(tokens > 13); // Should be more than minimum (3 overhead + 10 completion)
}
}
Loading