From b99567ecde38a091dbf8fe009cb016bf5d7a204d Mon Sep 17 00:00:00 2001 From: Kamil Litman <97270937+kamillitman@users.noreply.github.com> Date: Mon, 6 Jan 2025 09:55:48 -0500 Subject: [PATCH] Perplexity AI support (#57) --- Cargo.toml | 2 +- README.md | 13 ++- examples/use_completions.rs | 22 ++++- src/constants.rs | 5 ++ src/domain.rs | 36 ++++++++ src/llm_models/mod.rs | 2 + src/llm_models/perplexity.rs | 161 +++++++++++++++++++++++++++++++++++ src/utils.rs | 61 ++++++++++++- 8 files changed, 296 insertions(+), 6 deletions(-) create mode 100644 src/llm_models/perplexity.rs diff --git a/Cargo.toml b/Cargo.toml index 6ff4854..557c0e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "allms" -version = "0.9.0" +version = "0.10.0" edition = "2021" authors = [ "Kamil Litman ", diff --git a/README.md b/README.md index 995ab49..25b0f7f 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ [![crates.io](https://img.shields.io/crates/v/allms.svg)](https://crates.io/crates/allms) [![docs.rs](https://docs.rs/allms/badge.svg)](https://docs.rs/allms) -This Rust library is specialized in providing type-safe interactions with APIs of the following LLM providers: OpenAI, Anthropic, Mistral, Google Gemini. (More providers to be added in the future.) It's designed to simplify the process of experimenting with different models. It de-risks the process of migrating between providers reducing vendor lock-in issues. It also standardizes serialization of sending requests to LLM APIs and interpreting the responses, ensuring that the JSON data is handled in a type-safe manner. With allms you can focus on creating effective prompts and providing LLM with the right context, instead of worrying about differences in API implementations. +This Rust library is specialized in providing type-safe interactions with APIs of the following LLM providers: OpenAI, Anthropic, Mistral, Google Gemini, Perplexity. (More providers to be added in the future.) It's designed to simplify the process of experimenting with different models. It de-risks the process of migrating between providers reducing vendor lock-in issues. It also standardizes serialization of sending requests to LLM APIs and interpreting the responses, ensuring that the JSON data is handled in a type-safe manner. With allms you can focus on creating effective prompts and providing LLM with the right context, instead of worrying about differences in API implementations. ## Features -- Support for various LLM models including OpenAI (GPT-3.5, GPT-4), Anthropic (Claude, Claude Instant), Mistral, or Google GeminiPro. +- Support for various LLM models including OpenAI (GPT-3.5, GPT-4), Anthropic (Claude, Claude Instant), Mistral, Google GeminiPro, and Perplexity. - Easy-to-use functions for chat/text completions and assistants. Use the same struct and methods regardless of which model you choose. - Automated response deserialization to custom types. - Standardized approach to providing context with support of function calling, tools, and file uploads. @@ -37,6 +37,10 @@ Google Vertex AI / AI Studio: - APIs: Chat Completions (including streaming) - Models: Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 1.0 Pro +Perplexity: +- APIs: Chat Completions +- Models: Llama 3.1 Sonar Small, Llama 3.1 Sonar Large, Llama 3.1 Sonar Huge + ### Prerequisites - OpenAI: API key (passed in model constructor) - Azure OpenAI: environment variable `OPENAI_API_URL` set to your Azure OpenAI resource endpoint. Endpoint key passed in constructor @@ -44,6 +48,7 @@ Google Vertex AI / AI Studio: - Mistral: API key (passed in model constructor) - Google AI Studio: API key (passed in model constructor) - Google Vertex AI: GCP service account key (used to obtain access token) + GCP project ID (set as environment variable) +- Perplexity: API key (passed in model constructor) ### Examples Explore the `examples` directory to see more use cases and how to use different LLM providers and endpoint types. @@ -65,6 +70,10 @@ let mistral_answer = Completions::new(MistralModels::MistralSmall, &API_KEY, Non let google_answer = Completions::new(GoogleModels::GeminiPro, &API_KEY, None, None) .get_answer::(instructions) .await? + +let perplexity_answer = Completions::new(PerplexityModels::Llama3_1SonarSmall, &API_KEY, None, None) + .get_answer::(instructions) + .await? ``` Example: diff --git a/examples/use_completions.rs b/examples/use_completions.rs index fbb4be8..2aa8dc5 100644 --- a/examples/use_completions.rs +++ b/examples/use_completions.rs @@ -3,7 +3,7 @@ use serde::Deserialize; use serde::Serialize; use allms::{ - llm::{AnthropicModels, GoogleModels, LLMModel, MistralModels, OpenAIModels}, + llm::{AnthropicModels, GoogleModels, LLMModel, MistralModels, OpenAIModels, PerplexityModels}, Completions, }; @@ -25,7 +25,7 @@ async fn main() { // Get answer using OpenAI let openai_api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let model = OpenAIModels::try_from_str("o1-preview").unwrap_or(OpenAIModels::O1Preview); // Choose the model + let model = OpenAIModels::try_from_str("gpt-4o-mini").unwrap_or(OpenAIModels::Gpt4oMini); // Choose the model println!("OpenAI model: {:#?}", model.as_str()); let openai_completion = Completions::new(model, &openai_api_key, None, None); @@ -89,4 +89,22 @@ async fn main() { Ok(response) => println!("Gemini response: {:#?}", response), Err(e) => eprintln!("Error: {:?}", e), } + + // Get answer using Perplexity + let model = PerplexityModels::try_from_str("llama-3.1-sonar-small-128k-online") + .unwrap_or(PerplexityModels::Llama3_1SonarSmall); // Choose the model + println!("Perplexity model: {:#?}", model.as_str()); + + let perplexity_token_str: String = + std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"); + + let perplexity_completion = Completions::new(model, &perplexity_token_str, None, None); + + match perplexity_completion + .get_answer::(instructions) + .await + { + Ok(response) => println!("Perplexity response: {:#?}", response), + Err(e) => eprintln!("Error: {:?}", e), + } } diff --git a/src/constants.rs b/src/constants.rs index dbb4087..572230b 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -33,6 +33,11 @@ lazy_static! { ); } +lazy_static! { + pub(crate) static ref PERPLEXITY_API_URL: String = std::env::var("PERPLEXITY_API_URL") + .unwrap_or("https://api.perplexity.ai/chat/completions".to_string()); +} + //Generic OpenAI instructions pub(crate) const OPENAI_BASE_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks: Step 1: Review and understand the 'instructions' from the *Instructions* section. diff --git a/src/domain.rs b/src/domain.rs index ef3ab63..a212289 100644 --- a/src/domain.rs +++ b/src/domain.rs @@ -299,3 +299,39 @@ pub struct AllmsError { pub error_message: String, pub error_detail: String, } + +// Perplexity API response type format for Chat Completions API +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct PerplexityAPICompletionsResponse { + pub id: Option, + pub model: Option, + pub object: Option, + pub created: Option, + pub choices: Vec, + pub citations: Option>, + pub usage: Option, +} + +// Perplexity API response type format for Chat Completions API +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct PerplexityAPICompletionsChoices { + pub index: usize, + pub message: Option, + pub delta: Option, + pub finish_reason: String, +} + +// Perplexity API response type format for Chat Completions API +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct PerplexityAPICompletionsMessage { + pub role: Option, + pub content: Option, +} + +// Perplexity API response type format for Chat Completions API +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct PerplexityAPICompletionsUsage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} diff --git a/src/llm_models/mod.rs b/src/llm_models/mod.rs index fdc178e..575ad96 100644 --- a/src/llm_models/mod.rs +++ b/src/llm_models/mod.rs @@ -3,6 +3,7 @@ pub mod google; pub mod llm_model; pub mod mistral; pub mod openai; +pub mod perplexity; pub use anthropic::AnthropicModels; pub use google::GoogleModels; @@ -10,3 +11,4 @@ pub use llm_model::LLMModel; pub use llm_model::LLMModel as LLM; pub use mistral::MistralModels; pub use openai::OpenAIModels; +pub use perplexity::PerplexityModels; diff --git a/src/llm_models/perplexity.rs b/src/llm_models/perplexity.rs new file mode 100644 index 0000000..59987a7 --- /dev/null +++ b/src/llm_models/perplexity.rs @@ -0,0 +1,161 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use log::info; +use reqwest::{header, Client}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::constants::PERPLEXITY_API_URL; +use crate::domain::{PerplexityAPICompletionsResponse, RateLimit}; +use crate::llm_models::LLMModel; +use crate::utils::{map_to_range_f32, sanitize_json_response}; + +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] +//Mistral docs: https://docs.mistral.ai/platform/endpoints +pub enum PerplexityModels { + Llama3_1SonarSmall, + Llama3_1SonarLarge, + Llama3_1SonarHuge, +} + +#[async_trait(?Send)] +impl LLMModel for PerplexityModels { + fn as_str(&self) -> &str { + match self { + PerplexityModels::Llama3_1SonarSmall => "llama-3.1-sonar-small-128k-online", + PerplexityModels::Llama3_1SonarLarge => "llama-3.1-sonar-large-128k-online", + PerplexityModels::Llama3_1SonarHuge => "llama-3.1-sonar-huge-128k-online", + } + } + + fn try_from_str(name: &str) -> Option { + match name.to_lowercase().as_str() { + "llama-3.1-sonar-small-128k-online" => Some(PerplexityModels::Llama3_1SonarSmall), + "llama-3.1-sonar-large-128k-online" => Some(PerplexityModels::Llama3_1SonarLarge), + "llama-3.1-sonar-huge-128k-online" => Some(PerplexityModels::Llama3_1SonarHuge), + _ => None, + } + } + + // https://docs.perplexity.ai/guides/model-cards + fn default_max_tokens(&self) -> usize { + 127_072 + } + + fn get_endpoint(&self) -> String { + PERPLEXITY_API_URL.to_string() + } + + //This method prepares the body of the API call for different models + fn get_body( + &self, + instructions: &str, + json_schema: &Value, + function_call: bool, + // The total number of tokens requested in max_tokens plus the number of prompt tokens sent in messages must not exceed the context window token limit of model requested. + // If left unspecified, then the model will generate tokens until either it reaches its stop token or the end of its context window. + _max_tokens: &usize, + temperature: &f32, + ) -> serde_json::Value { + //Prepare the 'messages' part of the body + let base_instructions = self.get_base_instructions(Some(function_call)); + let system_message = json!({ + "role": "system", + "content": base_instructions, + }); + let schema_string = serde_json::to_string(json_schema).unwrap_or_default(); + let user_message = json!({ + "role": "user", + "content": format!( + "Output Json schema:\n + {schema_string}\n\n + {instructions}" + ), + }); + json!({ + "model": self.as_str(), + "temperature": temperature, + "messages": vec![ + system_message, + user_message, + ], + }) + } + /// + /// This function leverages Perplexity API to perform any query as per the provided body. + /// + /// It returns a String the Response object that needs to be parsed based on the self.model. + /// + async fn call_api( + &self, + api_key: &str, + body: &serde_json::Value, + debug: bool, + ) -> Result { + //Get the API url + let model_url = self.get_endpoint(); + + //Make the API call + let client = Client::new(); + + //Send request + let response = client + .post(model_url) + .header(header::CONTENT_TYPE, "application/json") + .bearer_auth(api_key) + .json(&body) + .send() + .await?; + + let response_status = response.status(); + let response_text = response.text().await?; + + if debug { + info!( + "[debug] Perplexity API response: [{}] {:#?}", + &response_status, &response_text + ); + } + + Ok(response_text) + } + + //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response + fn get_data(&self, response_text: &str, _function_call: bool) -> Result { + //Convert API response to struct representing expected response format + let completions_response: PerplexityAPICompletionsResponse = + serde_json::from_str(response_text)?; + + //Parse the response and return the assistant content + completions_response + .choices + .iter() + .filter_map(|choice| choice.message.as_ref()) + .find(|&message| message.role == Some("assistant".to_string())) + .and_then(|message| { + message + .content + .as_ref() + .map(|content| sanitize_json_response(content)) + }) + .ok_or_else(|| anyhow!("Assistant role content not found")) + } + + //This function allows to check the rate limits for different models + fn get_rate_limit(&self) -> RateLimit { + //Perplexity documentation: https://docs.perplexity.ai/guides/rate-limits + RateLimit { + tpm: 50 * 127_072, // 50 requests per minute wit max 127,072 context length + rpm: 50, // 50 request per minute + } + } + + // Accepts a [0-100] percentage range and returns the target temperature based on model ranges + fn get_normalized_temperature(&self, relative_temp: u32) -> f32 { + // Temperature range documentation: https://docs.perplexity.ai/api-reference/chat-completions + // "The amount of randomness in the response, valued between 0 *inclusive* and 2 *exclusive*." + let min = 0.0f32; + let max = 1.99999f32; + map_to_range_f32(min, max, relative_temp) + } +} diff --git a/src/utils.rs b/src/utils.rs index e20a556..c28673d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -92,6 +92,17 @@ pub(crate) fn map_to_range(min: u32, max: u32, target: u32) -> f32 { min as f32 + (range * percentage) } +//Used internally to pick a number from range based on its % representation +pub(crate) fn map_to_range_f32(min: f32, max: f32, target: u32) -> f32 { + // Cap the target to the percentage range [0, 100] + let capped_target = target.min(100); + + // Calculate the target value in the range [min, max] + let range = max - min; + let percentage = capped_target as f32 / 100.0; + min + (range * percentage) +} + #[cfg(test)] mod tests { use schemars::schema::{InstanceType, ObjectValidation, RootSchema, Schema, SchemaObject}; @@ -100,7 +111,9 @@ mod tests { use serde_json::Value; use crate::llm_models::OpenAIModels; - use crate::utils::{fix_value_schema, get_tokenizer, get_type_schema, map_to_range}; + use crate::utils::{ + fix_value_schema, get_tokenizer, get_type_schema, map_to_range, map_to_range_f32, + }; #[derive(JsonSchema, Serialize, Deserialize)] struct SimpleStruct { @@ -420,4 +433,50 @@ mod tests { // Not applicable for unsigned inputs but could test edge cases: assert_eq!(map_to_range(0, 100, 0), 0.0); } + + #[test] + fn test_target_at_min_f32() { + assert_eq!(map_to_range_f32(0.0, 100.0, 0), 0.0); + assert_eq!(map_to_range_f32(10.0, 20.0, 0), 10.0); + } + + #[test] + fn test_target_at_max_f32() { + assert_eq!(map_to_range_f32(0.0, 100.0, 100), 100.0); + assert_eq!(map_to_range_f32(10.0, 20.0, 100), 20.0); + } + + #[test] + fn test_target_in_middle_f32() { + assert_eq!(map_to_range_f32(0.0, 100.0, 50), 50.0); + assert_eq!(map_to_range_f32(10.0, 20.0, 50), 15.0); + assert_eq!(map_to_range_f32(0.0, 1.0, 50), 0.5); + } + + #[test] + fn test_target_out_of_bounds_f32() { + assert_eq!(map_to_range_f32(0.0, 100.0, 3000), 100.0); // Cap to 100 + assert_eq!(map_to_range_f32(0.0, 100.0, 200), 100.0); // Cap to 100 + assert_eq!(map_to_range_f32(10.0, 20.0, 200), 20.0); // Cap to 100 + } + + #[test] + fn test_zero_range_f32() { + assert_eq!(map_to_range_f32(10.0, 10.0, 50), 10.0); // Always return min if min == max + assert_eq!(map_to_range_f32(5.0, 5.0, 100), 5.0); // Even at max target + } + + #[test] + fn test_fractional_range_f32() { + assert_eq!(map_to_range_f32(0.0, 0.5, 50), 0.25); + assert_eq!(map_to_range_f32(1.5, 3.0, 25), 1.875); + assert_eq!(map_to_range_f32(-1.0, 1.0, 75), 0.5); + } + + #[test] + fn test_large_range_f32() { + assert_eq!(map_to_range_f32(-1000.0, 1000.0, 50), 0.0); // Midpoint of the range + assert_eq!(map_to_range_f32(-500.0, 500.0, 25), -250.0); // Quarter point + assert_eq!(map_to_range_f32(-2000.0, 0.0, 75), -500.0); // Three-quarters + } }