diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 4940e744..e9da3639 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -24,6 +24,7 @@ use axum::{ use opentelemetry::KeyValue; use reqwest::Client; use serde_json::{json, Value}; +use tokenizers::Tokenizer; use tracing::{info, instrument}; use utoipa::OpenApi; @@ -43,20 +44,35 @@ use crate::{ middleware::RequestMetadata, }; -use super::{handle_confidential_compute_encryption_response, handle_status_code_error}; +use super::{ + handle_confidential_compute_encryption_response, handle_status_code_error, + request_model::RequestModel, DEFAULT_MAX_TOKENS, +}; /// The path for confidential chat completions requests pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions"; +/// The key for the content parameter in the request body +pub const CONTENT_KEY: &str = "content"; + /// The path for chat completions requests pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; /// The keep-alive interval in seconds const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15; +/// The key for the max_completion_tokens parameter in the request body +const MAX_COMPLETION_TOKENS_KEY: &str = "max_completion_tokens"; + +/// The key for the max_tokens parameter in the request body +const MAX_TOKENS_KEY: &str = "max_tokens"; + /// The key for the model parameter in the request body const MODEL_KEY: &str = "model"; +/// The key for the messages parameter in the request body +const MESSAGES_KEY: &str = "messages"; + /// The key for the stream parameter in the request body const STREAM_KEY: &str = "stream"; @@ -734,6 +750,119 @@ async fn handle_streaming_response( Ok(stream.into_response()) } +/// Represents a chat completion request model following the OpenAI API format +pub struct RequestModelChatCompletions { + /// Array of message objects that represent the conversation history + /// Each message should contain a "role" (system/user/assistant) and "content" + /// The content can be a string or an array of content parts. + messages: Vec, + + /// The maximum number of tokens to generate in the completion + /// This limits the length of the model's response + max_completion_tokens: u64, +} + +impl RequestModel for RequestModelChatCompletions { + fn new(request: &Value) -> Result { + let messages = request + .get(MESSAGES_KEY) + .and_then(|m| m.as_array()) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Missing or invalid 'messages' field".to_string(), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + })?; + + let max_completion_tokens = request + .get(MAX_COMPLETION_TOKENS_KEY) + .or_else(|| request.get(MAX_TOKENS_KEY)) + .and_then(serde_json::Value::as_u64) + .unwrap_or(DEFAULT_MAX_TOKENS); + + Ok(Self { + messages: messages.clone(), + max_completion_tokens, + }) + } + + /// Computes the total number of tokens for the chat completion request. + /// + /// This is used to estimate the cost of the chat completion request, on the proxy side. + /// We support either string or array of content parts. We further assume that all content messages + /// share the same previous messages. That said, we further assume that content parts formatted into arrays + /// are to be concatenated and treated as a single message, by the model and from the estimate point of view. + fn get_compute_units_estimate( + &self, + tokenizer: Option<&Tokenizer>, + ) -> Result { + // In order to account for the possibility of not taking into account possible additional special tokens, + // which might not be considered by the tokenizer, we add a small overhead to the total number of tokens, per message. + const MESSAGE_OVERHEAD_TOKENS: u64 = 3; + let Some(tokenizer) = tokenizer else { + return Err(AtomaServiceError::InternalError { + message: "Tokenizer is required for current model, but is not currently available" + .to_string(), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + }); + }; + // Helper function to count tokens for a text string + let count_text_tokens = |text: &str| -> Result { + Ok(tokenizer + .encode(text, true) + .map_err(|err| AtomaServiceError::InternalError { + message: format!("Failed to encode message: {err:?}"), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + })? + .get_ids() + .len() as u64) + }; + + let mut total_num_tokens = 0; + + for message in &self.messages { + let content = message + .get(CONTENT_KEY) + .and_then(|content| MessageContent::deserialize(content).ok()) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Missing or invalid message content".to_string(), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + })?; + + match content { + MessageContent::Text(text) => { + let num_tokens = count_text_tokens(&text)?; + total_num_tokens += num_tokens + MESSAGE_OVERHEAD_TOKENS; + } + MessageContent::Array(parts) => { + if parts.is_empty() { + tracing::error!( + "Received empty array of message parts for chat completion request" + ); + return Err(AtomaServiceError::InvalidBody { + message: "Missing or invalid message content".to_string(), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + }); + } + for part in parts { + match part { + MessageContentPart::Text { text, .. } => { + let num_tokens = count_text_tokens(&text)?; + total_num_tokens += num_tokens + MESSAGE_OVERHEAD_TOKENS; + } + MessageContentPart::Image { .. } => { + // TODO: Ensure that for image content parts, we have a way to estimate the number of tokens, + // which can depend on the size of the image and the output description. + continue; + } + } + } + } + } + } + // add the max completion tokens, to account for the response + Ok(total_num_tokens + self.max_completion_tokens) + } +} + #[derive(Debug, PartialEq, Serialize, Deserialize, ToSchema)] #[serde(rename(serialize = "requestBody", deserialize = "RequestBody"))] pub struct ChatCompletionsRequest { @@ -1541,3 +1670,187 @@ pub mod utils { Ok(response_body) } } + +#[cfg(test)] +mod tests { + use super::*; + + use serde_json::json; + use std::str::FromStr; + use tokenizers::Tokenizer; + + async fn load_tokenizer() -> Tokenizer { + let url = + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/raw/main/tokenizer.json"; + let tokenizer_json = reqwest::get(url).await.unwrap().text().await.unwrap(); + + Tokenizer::from_str(&tokenizer_json).unwrap() + } + + #[tokio::test] + async fn test_get_compute_units_estimate() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + "content": "Hello from the other side of Mars" + })], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 21); // 8 tokens + 3 overhead + 10 completion + } + + #[tokio::test] + async fn test_get_compute_units_estimate_multiple_messages() { + let request = RequestModelChatCompletions { + messages: vec![ + json!({ + "role": "user", + "content": "Hello from the other side of Mars" + }), + json!({ + "role": "assistant", + "content": "Hello from the other side of Mars" + }), + ], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 32); // (8+8) tokens + (3+3) overhead + 10 completion + } + + #[tokio::test] + async fn test_get_compute_units_estimate_array_content() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello from the other side of Mars" + }, + { + "type": "text", + "text": "Hello from the other side of Mars" + } + ] + })], + max_completion_tokens: 10, + }; + + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 32); // (8+8) tokens (3 + 3) overhead + 10 completion + } + + #[tokio::test] + async fn test_get_compute_units_estimate_empty_message() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + "content": "" + })], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 14); // 1 tokens (special token) + 3 overhead + 10 completion + } + + #[tokio::test] + async fn test_get_compute_units_estimate_mixed_content() { + let request = RequestModelChatCompletions { + messages: vec![ + json!({ + "role": "system", + "content": "Hello from the other side of Mars" + }), + json!({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello from the other side of Mars" + }, + { + "type": "image", + "image_url": { + "url": "http://example.com/image.jpg" + } + }, + { + "type": "text", + "text": "Hello from the other side of Mars" + } + ] + }), + ], + max_completion_tokens: 15, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + // System message: tokens + 15 completion + // User message array: (2 text parts tokens) + (15 * 2 for text completion for parts) + let tokens = result.unwrap(); + assert_eq!(tokens, 48); // 3 * 8 + 3 * 3 overhead + 15 + } + + #[tokio::test] + async fn test_get_compute_units_estimate_invalid_content() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + // Missing "content" field + })], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AtomaServiceError::InvalidBody { .. } + )); + } + + #[tokio::test] + async fn test_get_compute_units_estimate_empty_array_content() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + "content": [] + })], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AtomaServiceError::InvalidBody { .. } + )); + } + + #[tokio::test] + async fn test_get_compute_units_estimate_special_characters() { + let request = RequestModelChatCompletions { + messages: vec![json!({ + "role": "user", + "content": "Hello! 👋 🌍 \n\t Special chars: &*#@" + })], + max_completion_tokens: 10, + }; + let tokenizer = load_tokenizer().await; + let result = request.get_compute_units_estimate(Some(&tokenizer)); + assert!(result.is_ok()); + let tokens = result.unwrap(); + assert!(tokens > 13); // Should be more than minimum (3 overhead + 10 completion) + } +} diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index 96055674..fb6b4468 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -20,10 +20,11 @@ use axum::{extract::State, Extension, Json}; use opentelemetry::KeyValue; use reqwest::Client; use serde_json::Value; +use tokenizers::Tokenizer; use tracing::{info, instrument}; use utoipa::OpenApi; -use super::handle_status_code_error; +use super::{handle_status_code_error, request_model::RequestModel}; /// The path for confidential embeddings requests pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings"; @@ -34,6 +35,9 @@ pub const EMBEDDINGS_PATH: &str = "/v1/embeddings"; /// The key for the model parameter in the request body pub const MODEL_KEY: &str = "model"; +/// The key for the input parameter in the request body +pub const INPUT_KEY: &str = "input"; + /// OpenAPI documentation structure for the embeddings endpoint. /// /// This struct defines the OpenAPI (Swagger) documentation for the embeddings API, @@ -386,3 +390,71 @@ async fn handle_embeddings_response( }), } } + +// A model representing an embeddings request payload. +/// +/// This struct encapsulates the necessary fields for processing an embeddings request +/// following the OpenAI API format. +pub struct RequestModelEmbeddings { + /// The input text to generate embeddings for + input: Value, +} + +impl RequestModel for RequestModelEmbeddings { + fn new(request: &Value) -> Result { + let input = request + .get(INPUT_KEY) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Input field is required".to_string(), + endpoint: EMBEDDINGS_PATH.to_string(), + })?; + + Ok(Self { + input: input.clone(), + }) + } + + fn get_compute_units_estimate( + &self, + tokenizer: Option<&Tokenizer>, + ) -> Result { + let Some(tokenizer) = tokenizer else { + return Err(AtomaServiceError::InternalError { + message: "Tokenizer is required for current model, but is not currently available" + .to_string(), + endpoint: EMBEDDINGS_PATH.to_string(), + }); + }; + + // input can be a string or an array of strings + let total_units = match &self.input { + Value::String(text) => tokenizer + .encode(text.as_str(), true) + .map_err(|_| AtomaServiceError::InvalidBody { + message: "Failed to encode input text".to_string(), + endpoint: EMBEDDINGS_PATH.to_string(), + })? + .get_ids() + .len() as u64, + Value::Array(texts) => texts + .iter() + .map(|v| { + v.as_str().map_or(0, |s| { + tokenizer + .encode(s, true) + .map(|tokens| tokens.get_ids().len() as u64) + .unwrap_or(0) + }) + }) + .sum(), + _ => { + return Err(AtomaServiceError::InvalidBody { + message: "Invalid input format".to_string(), + endpoint: EMBEDDINGS_PATH.to_string(), + }); + } + }; + + Ok(total_units) + } +} diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index 7b994715..a55def96 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -18,12 +18,13 @@ use axum::{extract::State, Extension, Json}; use opentelemetry::KeyValue; use reqwest::Client; use serde_json::Value; +use tokenizers::Tokenizer; use tracing::{info, instrument}; use utoipa::OpenApi; use super::{ handle_confidential_compute_encryption_response, handle_status_code_error, - sign_response_and_update_stack_hash, + request_model::RequestModel, sign_response_and_update_stack_hash, }; /// The path for confidential image generations requests @@ -35,6 +36,12 @@ pub const IMAGE_GENERATIONS_PATH: &str = "/v1/images/generations"; /// The key for the model parameter in the request body pub const MODEL_KEY: &str = "model"; +/// The key for the n parameter in the request body +pub const N_KEY: &str = "n"; + +/// The key for the size parameter in the request body +pub const SIZE_KEY: &str = "size"; + /// OpenAPI documentation structure for the image generations endpoint. /// /// This struct defines the OpenAPI (Swagger) documentation for the image generations API, @@ -384,3 +391,67 @@ async fn handle_image_generations_response( } } } + +/// A model representing the parameters for an image generation request. +/// +/// This struct encapsulates the required parameters for generating images through +/// the API endpoint. +pub struct RequestModelImageGenerations { + /// The number of sampling generation to be performed for this request + n: u64, + /// The desired dimensions of the generated images in the format "WIDTHxHEIGHT" + /// (e.g., "1024x1024") + size: String, +} + +impl RequestModel for RequestModelImageGenerations { + fn new(request: &Value) -> Result { + let n = request + .get(N_KEY) + .and_then(serde_json::Value::as_u64) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "N field is required".to_string(), + endpoint: IMAGE_GENERATIONS_PATH.to_string(), + })?; + let size = request + .get(SIZE_KEY) + .and_then(|s| s.as_str()) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Size field is required".to_string(), + endpoint: IMAGE_GENERATIONS_PATH.to_string(), + })?; + + Ok(Self { + n, + size: size.to_string(), + }) + } + + fn get_compute_units_estimate( + &self, + _tokenizer: Option<&Tokenizer>, + ) -> Result { + // Parse dimensions from size string (e.g., "1024x1024") + let dimensions: Vec = self + .size + .split('x') + .filter_map(|s| s.parse().ok()) + .collect(); + + if dimensions.len() != 2 { + return Err(AtomaServiceError::InvalidBody { + message: format!( + "Invalid size format, expected two dimensional image, but got: {}", + self.size + ), + endpoint: IMAGE_GENERATIONS_PATH.to_string(), + }); + } + + let width = dimensions[0]; + let height = dimensions[1]; + + // Calculate compute units based on number of images and pixel count + Ok(self.n * width * height) + } +} diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index d5d1a0dc..1d6f51c9 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -2,6 +2,7 @@ pub mod chat_completions; pub mod embeddings; pub mod image_generations; pub mod metrics; +pub mod request_model; use atoma_confidential::types::{ ConfidentialComputeEncryptionRequest, ConfidentialComputeEncryptionResponse, @@ -24,6 +25,9 @@ use atoma_state::types::AtomaAtomaStateManagerEvent; /// Key for the ciphertext in the response body const CIPHERTEXT_KEY: &str = "ciphertext"; +/// The default max tokens for a chat completion request +const DEFAULT_MAX_TOKENS: u64 = 8_192; + /// Key for the nonce in the response body const NONCE_KEY: &str = "nonce"; diff --git a/atoma-service/src/handlers/request_model.rs b/atoma-service/src/handlers/request_model.rs new file mode 100644 index 00000000..f828e134 --- /dev/null +++ b/atoma-service/src/handlers/request_model.rs @@ -0,0 +1,38 @@ +use serde_json::Value; +use tokenizers::Tokenizer; + +use crate::error::AtomaServiceError; + +/// A trait for parsing and handling AI model requests across different endpoints (chat, embeddings, images). +/// This trait provides a common interface for processing various types of AI model requests +/// and estimating their computational costs. +pub trait RequestModel { + /// Constructs a new request model instance by parsing the provided JSON request. + /// + /// # Arguments + /// * `request` - The JSON payload containing the request parameters + /// + /// # Returns + /// * `Ok(Self)` - Successfully parsed request model + /// * `Err(AtomaProxyError)` - If the request is invalid or malformed + fn new(request: &Value) -> Result + where + Self: Sized; + + /// Calculates the estimated computational resources required for this request. + /// + /// # Arguments + /// * `tokenizer` - The tokenizer to use for the request + /// + /// # Returns + /// * `Ok(u64)` - The estimated compute units needed + /// * `Err(AtomaProxyError)` - If the estimation fails or parameters are invalid + /// + /// # Warning + /// This method assumes that the tokenizer has been correctly retrieved from the `ProxyState` for + /// the associated model, as obtained by calling `get_model` on `Self`. + fn get_compute_units_estimate( + &self, + tokenizer: Option<&Tokenizer>, + ) -> Result; +} diff --git a/atoma-service/src/middleware.rs b/atoma-service/src/middleware.rs index e31da51d..58ff4e73 100644 --- a/atoma-service/src/middleware.rs +++ b/atoma-service/src/middleware.rs @@ -39,27 +39,6 @@ const MAX_BODY_SIZE: usize = 1024 * 1024; // 1MB /// The key for the model in the request body const MODEL: &str = "model"; -/// The key for the max tokens in the request body (currently deprecated, as per OpenAI API spec) -const MAX_TOKENS: &str = "max_tokens"; - -/// The key for max completion tokens in the request body -const MAX_COMPLETION_TOKENS: &str = "max_completion_tokens"; - -/// The default value for the max tokens for chat completions -const DEFAULT_MAX_TOKENS_CHAT_COMPLETIONS: i64 = 8192; - -/// The key for the messages in the request body -const MESSAGES: &str = "messages"; - -/// The key for the input tokens in the request body -const INPUT: &str = "input"; - -/// The key for the image size in the request body -const IMAGE_SIZE: &str = "size"; - -/// The key for the number of images in the request body -const IMAGE_N: &str = "n"; - /// Metadata for confidential compute decryption requests pub struct DecryptionMetadata { /// The plaintext body @@ -438,7 +417,7 @@ pub async fn verify_stack_permissions( } let total_num_compute_units = - utils::calculate_compute_units(&body_json, request_type, &state, model, endpoint.clone())?; + utils::calculate_compute_units(&body_json, request_type, &state, model, &endpoint)?; let (result_sender, result_receiver) = oneshot::channel(); state @@ -647,17 +626,33 @@ pub async fn confidential_compute_middleware( } } -pub(crate) mod utils { +pub mod utils { use hyper::HeaderMap; + use crate::handlers::{ + chat_completions::RequestModelChatCompletions, embeddings::RequestModelEmbeddings, + image_generations::RequestModelImageGenerations, request_model::RequestModel, + }; + use super::{ blake2b_hash, instrument, oneshot, verify_signature, AppState, AtomaServiceError, ConfidentialComputeDecryptionRequest, ConfidentialComputeRequest, DecryptionMetadata, - Engine, RequestType, TransactionDigest, Value, DEFAULT_MAX_TOKENS_CHAT_COMPLETIONS, - DH_PUBLIC_KEY_SIZE, IMAGE_N, IMAGE_SIZE, INPUT, MAX_COMPLETION_TOKENS, MAX_TOKENS, - MESSAGES, NONCE_SIZE, PAYLOAD_HASH_SIZE, SALT_SIZE, STANDARD, + Engine, RequestType, TransactionDigest, Value, DH_PUBLIC_KEY_SIZE, NONCE_SIZE, + PAYLOAD_HASH_SIZE, SALT_SIZE, STANDARD, }; + /// Default max completion tokens for chat completions + const DEFAULT_MAX_TOKENS_CHAT_COMPLETIONS: i64 = 8192; + + /// The key for the max tokens in the request body (currently deprecated, as per OpenAI API spec) + const MAX_TOKENS: &str = "max_tokens"; + + /// The key for max completion tokens in the request body + const MAX_COMPLETION_TOKENS: &str = "max_completion_tokens"; + + /// The key for the messages in the request body + const MESSAGES: &str = "messages"; + /// Requests and verifies stack information from the blockchain for a given transaction. /// /// This function communicates with a blockchain service to verify the existence and validity @@ -766,6 +761,11 @@ pub(crate) mod utils { /// * `Ok(i64)` - The total number of compute units required /// * `Err(AtomaServiceError)` - If there's an error calculating the units, returns an appropriate HTTP status code /// + /// # Errors + /// - `InvalidBody` - If the request body is invalid + /// - `InvalidHeader` - If the request headers are invalid + /// - `InternalError` - If there's an error calculating the units + /// /// # Compute Unit Calculation /// The calculation varies by request type: /// - ChatCompletions: Based on input tokens + max output tokens @@ -774,25 +774,52 @@ pub(crate) mod utils { /// - NonInference: Returns 0 (no compute units required) /// /// This function delegates to specific calculators based on the request type: - /// - `calculate_chat_completion_compute_units` - /// - `calculate_embedding_compute_units` - /// - `calculate_image_generation_compute_units` + /// - `RequestModelChatCompletions` + /// - `RequestModelEmbeddings` + /// - `RequestModelImageGenerations` pub fn calculate_compute_units( body_json: &Value, request_type: RequestType, state: &AppState, model: &str, - endpoint: String, + endpoint: &str, ) -> Result { match request_type { RequestType::ChatCompletions => { - calculate_chat_completion_compute_units(body_json, state, model, endpoint) + let request_model = RequestModelChatCompletions::new(body_json)?; + let tokenizer_index = + state + .models + .iter() + .position(|m| m == model) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Model not supported".to_string(), + endpoint: endpoint.to_string(), + })?; + request_model + .get_compute_units_estimate(Some(&state.tokenizers[tokenizer_index])) + .map(|i| i as i64) } RequestType::Embeddings => { - calculate_embedding_compute_units(body_json, state, model, endpoint) + let request_model = RequestModelEmbeddings::new(body_json)?; + let tokenizer_index = + state + .models + .iter() + .position(|m| m == model) + .ok_or_else(|| AtomaServiceError::InvalidBody { + message: "Model not supported".to_string(), + endpoint: endpoint.to_string(), + })?; + request_model + .get_compute_units_estimate(Some(&state.tokenizers[tokenizer_index])) + .map(|i| i as i64) } RequestType::ImageGenerations => { - calculate_image_generation_compute_units(body_json, endpoint) + let request_model = RequestModelImageGenerations::new(body_json)?; + request_model + .get_compute_units_estimate(None) + .map(|i| i as i64) } RequestType::NonInference => Ok(0), } @@ -903,167 +930,6 @@ pub(crate) mod utils { Ok(total_num_compute_units) } - /// Calculates the total number of compute units required for an embedding request. - /// - /// This function analyzes the request body to determine the computational cost by counting - /// the number of tokens in the input text(s) that will be embedded. - /// - /// # Arguments - /// * `body_json` - The parsed JSON body of the request containing: - /// - `input`: Either a single string or an array of strings to be embedded - /// * `state` - Application state containing model configurations and tokenizers - /// * `model` - The name of the AI model being used - /// - /// # Returns - /// * `Ok(i64)` - The total number of compute units required - /// * `Err(AtomaServiceError)` - AtomaServiceError::InvalidBody if: - /// - The model is not supported - /// - The input field is missing - /// - The input format is invalid - /// - Tokenization fails - /// - /// # Input Formats - /// The function supports two input formats: - /// 1. Single string: - /// ```json - /// { - /// "input": "text to embed" - /// } - /// ``` - /// - /// 2. Array of strings: - /// ```json - /// { - /// "input": ["text one", "text two"] - /// } - /// ``` - /// - /// # Computation - /// The total compute units is calculated as the sum of tokens across all input texts. - /// For array inputs, each string is tokenized separately and the results are summed. - #[instrument(level = "trace", skip_all)] - fn calculate_embedding_compute_units( - body_json: &Value, - state: &AppState, - model: &str, - endpoint: String, - ) -> Result { - let tokenizer_index = state - .models - .iter() - .position(|m| m == model) - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Model not supported".to_string(), - endpoint: endpoint.clone(), - })?; - - let input = body_json - .get(INPUT) - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Input not found in body".to_string(), - endpoint: endpoint.clone(), - })?; - - // input can be a string or an array of strings - let total_units = match input { - Value::String(text) => state.tokenizers[tokenizer_index] - .encode(text.as_str(), true) - .map_err(|_| AtomaServiceError::InvalidBody { - message: "Failed to encode input text".to_string(), - endpoint: endpoint.clone(), - })? - .get_ids() - .len() as i64, - Value::Array(texts) => texts - .iter() - .map(|v| { - v.as_str().map_or(0, |s| { - state.tokenizers[tokenizer_index] - .encode(s, true) - .map(|tokens| tokens.get_ids().len() as i64) - .unwrap_or(0) - }) - }) - .sum(), - _ => { - return Err(AtomaServiceError::InvalidBody { - message: "Invalid input format".to_string(), - endpoint: endpoint.clone(), - }); - } - }; - - Ok(total_units) - } - - /// Calculates the total number of compute units required for an image generation request. - /// - /// This function analyzes the request body to determine the computational cost based on: - /// - The dimensions of the requested image(s) - /// - The number of images to generate - /// - /// # Arguments - /// * `body_json` - The parsed JSON body of the request containing: - /// - `size`: String in format "WxH" (e.g., "1024x1024") - /// - `n`: Number of images to generate - /// - /// # Returns - /// * `Ok(i64)` - The total number of compute units required (width * height * n) - /// * `Err(AtomaServiceError)` - AtomaServiceError::InvalidBody if: - /// - The size field is missing or invalid - /// - The dimensions cannot be parsed - /// - The number of images is missing or invalid - /// - /// # Example JSON Structure - /// ```json - /// { - /// "size": "1024x1024", - /// "n": 1 - /// } - /// ``` - #[instrument(level = "trace", skip_all)] - fn calculate_image_generation_compute_units( - body_json: &Value, - endpoint: String, - ) -> Result { - let size = body_json - .get(IMAGE_SIZE) - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Image size not found in body".to_string(), - endpoint: endpoint.clone(), - })? - .as_str() - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Image size is not a string".to_string(), - endpoint: endpoint.clone(), - })?; - - // width and height are the dimensions of the image to generate - let (width, height) = size - .split_once('x') - .and_then(|(w, h)| { - let width = w.parse::().ok()?; - let height = h.parse::().ok()?; - Some((width, height)) - }) - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Invalid image size format".to_string(), - endpoint: endpoint.clone(), - })?; - - // n is the number of images to generate - let n = body_json - .get(IMAGE_N) - .and_then(serde_json::Value::as_u64) - .ok_or_else(|| AtomaServiceError::InvalidBody { - message: "Invalid or missing image count (n)".to_string(), - endpoint: endpoint.clone(), - })? as i64; - - // Calculate total pixels - Ok(width * height * n) - } - /// Verifies a plaintext body hash against a provided signature. /// /// This function performs signature verification for confidential compute requests by: