Skip to content

Commit

Permalink
Perplexity AI support (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
kamillitman authored Jan 6, 2025
1 parent 5b14ca1 commit b99567e
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "allms"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
authors = [
"Kamil Litman <[email protected]>",
Expand Down
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
[![crates.io](https://img.shields.io/crates/v/allms.svg)](https://crates.io/crates/allms)
[![docs.rs](https://docs.rs/allms/badge.svg)](https://docs.rs/allms)

This Rust library is specialized in providing type-safe interactions with APIs of the following LLM providers: OpenAI, Anthropic, Mistral, Google Gemini. (More providers to be added in the future.) It's designed to simplify the process of experimenting with different models. It de-risks the process of migrating between providers reducing vendor lock-in issues. It also standardizes serialization of sending requests to LLM APIs and interpreting the responses, ensuring that the JSON data is handled in a type-safe manner. With allms you can focus on creating effective prompts and providing LLM with the right context, instead of worrying about differences in API implementations.
This Rust library is specialized in providing type-safe interactions with APIs of the following LLM providers: OpenAI, Anthropic, Mistral, Google Gemini, Perplexity. (More providers to be added in the future.) It's designed to simplify the process of experimenting with different models. It de-risks the process of migrating between providers reducing vendor lock-in issues. It also standardizes serialization of sending requests to LLM APIs and interpreting the responses, ensuring that the JSON data is handled in a type-safe manner. With allms you can focus on creating effective prompts and providing LLM with the right context, instead of worrying about differences in API implementations.

## Features

- Support for various LLM models including OpenAI (GPT-3.5, GPT-4), Anthropic (Claude, Claude Instant), Mistral, or Google GeminiPro.
- Support for various LLM models including OpenAI (GPT-3.5, GPT-4), Anthropic (Claude, Claude Instant), Mistral, Google GeminiPro, and Perplexity.
- Easy-to-use functions for chat/text completions and assistants. Use the same struct and methods regardless of which model you choose.
- Automated response deserialization to custom types.
- Standardized approach to providing context with support of function calling, tools, and file uploads.
Expand Down Expand Up @@ -37,13 +37,18 @@ Google Vertex AI / AI Studio:
- APIs: Chat Completions (including streaming)
- Models: Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 1.0 Pro

Perplexity:
- APIs: Chat Completions
- Models: Llama 3.1 Sonar Small, Llama 3.1 Sonar Large, Llama 3.1 Sonar Huge

### Prerequisites
- OpenAI: API key (passed in model constructor)
- Azure OpenAI: environment variable `OPENAI_API_URL` set to your Azure OpenAI resource endpoint. Endpoint key passed in constructor
- Anthropic: API key (passed in model constructor)
- Mistral: API key (passed in model constructor)
- Google AI Studio: API key (passed in model constructor)
- Google Vertex AI: GCP service account key (used to obtain access token) + GCP project ID (set as environment variable)
- Perplexity: API key (passed in model constructor)

### Examples
Explore the `examples` directory to see more use cases and how to use different LLM providers and endpoint types.
Expand All @@ -65,6 +70,10 @@ let mistral_answer = Completions::new(MistralModels::MistralSmall, &API_KEY, Non
let google_answer = Completions::new(GoogleModels::GeminiPro, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
let perplexity_answer = Completions::new(PerplexityModels::Llama3_1SonarSmall, &API_KEY, None, None)
.get_answer::<T>(instructions)
.await?
```

Example:
Expand Down
22 changes: 20 additions & 2 deletions examples/use_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::Deserialize;
use serde::Serialize;

use allms::{
llm::{AnthropicModels, GoogleModels, LLMModel, MistralModels, OpenAIModels},
llm::{AnthropicModels, GoogleModels, LLMModel, MistralModels, OpenAIModels, PerplexityModels},
Completions,
};

Expand All @@ -25,7 +25,7 @@ async fn main() {

// Get answer using OpenAI
let openai_api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let model = OpenAIModels::try_from_str("o1-preview").unwrap_or(OpenAIModels::O1Preview); // Choose the model
let model = OpenAIModels::try_from_str("gpt-4o-mini").unwrap_or(OpenAIModels::Gpt4oMini); // Choose the model
println!("OpenAI model: {:#?}", model.as_str());

let openai_completion = Completions::new(model, &openai_api_key, None, None);
Expand Down Expand Up @@ -89,4 +89,22 @@ async fn main() {
Ok(response) => println!("Gemini response: {:#?}", response),
Err(e) => eprintln!("Error: {:?}", e),
}

// Get answer using Perplexity
let model = PerplexityModels::try_from_str("llama-3.1-sonar-small-128k-online")
.unwrap_or(PerplexityModels::Llama3_1SonarSmall); // Choose the model
println!("Perplexity model: {:#?}", model.as_str());

let perplexity_token_str: String =
std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");

let perplexity_completion = Completions::new(model, &perplexity_token_str, None, None);

match perplexity_completion
.get_answer::<TranslationResponse>(instructions)
.await
{
Ok(response) => println!("Perplexity response: {:#?}", response),
Err(e) => eprintln!("Error: {:?}", e),
}
}
5 changes: 5 additions & 0 deletions src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ lazy_static! {
);
}

lazy_static! {
pub(crate) static ref PERPLEXITY_API_URL: String = std::env::var("PERPLEXITY_API_URL")
.unwrap_or("https://api.perplexity.ai/chat/completions".to_string());
}

//Generic OpenAI instructions
pub(crate) const OPENAI_BASE_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks:
Step 1: Review and understand the 'instructions' from the *Instructions* section.
Expand Down
36 changes: 36 additions & 0 deletions src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,39 @@ pub struct AllmsError {
pub error_message: String,
pub error_detail: String,
}

// Perplexity API response type format for Chat Completions API
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct PerplexityAPICompletionsResponse {
pub id: Option<String>,
pub model: Option<String>,
pub object: Option<String>,
pub created: Option<usize>,
pub choices: Vec<PerplexityAPICompletionsChoices>,
pub citations: Option<Vec<String>>,
pub usage: Option<PerplexityAPICompletionsUsage>,
}

// Perplexity API response type format for Chat Completions API
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct PerplexityAPICompletionsChoices {
pub index: usize,
pub message: Option<PerplexityAPICompletionsMessage>,
pub delta: Option<PerplexityAPICompletionsMessage>,
pub finish_reason: String,
}

// Perplexity API response type format for Chat Completions API
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct PerplexityAPICompletionsMessage {
pub role: Option<String>,
pub content: Option<String>,
}

// Perplexity API response type format for Chat Completions API
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct PerplexityAPICompletionsUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
2 changes: 2 additions & 0 deletions src/llm_models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ pub mod google;
pub mod llm_model;
pub mod mistral;
pub mod openai;
pub mod perplexity;

pub use anthropic::AnthropicModels;
pub use google::GoogleModels;
pub use llm_model::LLMModel;
pub use llm_model::LLMModel as LLM;
pub use mistral::MistralModels;
pub use openai::OpenAIModels;
pub use perplexity::PerplexityModels;
161 changes: 161 additions & 0 deletions src/llm_models/perplexity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use log::info;
use reqwest::{header, Client};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};

use crate::constants::PERPLEXITY_API_URL;
use crate::domain::{PerplexityAPICompletionsResponse, RateLimit};
use crate::llm_models::LLMModel;
use crate::utils::{map_to_range_f32, sanitize_json_response};

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
//Mistral docs: https://docs.mistral.ai/platform/endpoints
pub enum PerplexityModels {
Llama3_1SonarSmall,
Llama3_1SonarLarge,
Llama3_1SonarHuge,
}

#[async_trait(?Send)]
impl LLMModel for PerplexityModels {
fn as_str(&self) -> &str {
match self {
PerplexityModels::Llama3_1SonarSmall => "llama-3.1-sonar-small-128k-online",
PerplexityModels::Llama3_1SonarLarge => "llama-3.1-sonar-large-128k-online",
PerplexityModels::Llama3_1SonarHuge => "llama-3.1-sonar-huge-128k-online",
}
}

fn try_from_str(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"llama-3.1-sonar-small-128k-online" => Some(PerplexityModels::Llama3_1SonarSmall),
"llama-3.1-sonar-large-128k-online" => Some(PerplexityModels::Llama3_1SonarLarge),
"llama-3.1-sonar-huge-128k-online" => Some(PerplexityModels::Llama3_1SonarHuge),
_ => None,
}
}

// https://docs.perplexity.ai/guides/model-cards
fn default_max_tokens(&self) -> usize {
127_072
}

fn get_endpoint(&self) -> String {
PERPLEXITY_API_URL.to_string()
}

//This method prepares the body of the API call for different models
fn get_body(
&self,
instructions: &str,
json_schema: &Value,
function_call: bool,
// The total number of tokens requested in max_tokens plus the number of prompt tokens sent in messages must not exceed the context window token limit of model requested.
// If left unspecified, then the model will generate tokens until either it reaches its stop token or the end of its context window.
_max_tokens: &usize,
temperature: &f32,
) -> serde_json::Value {
//Prepare the 'messages' part of the body
let base_instructions = self.get_base_instructions(Some(function_call));
let system_message = json!({
"role": "system",
"content": base_instructions,
});
let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
let user_message = json!({
"role": "user",
"content": format!(
"Output Json schema:\n
{schema_string}\n\n
{instructions}"
),
});
json!({
"model": self.as_str(),
"temperature": temperature,
"messages": vec![
system_message,
user_message,
],
})
}
///
/// This function leverages Perplexity API to perform any query as per the provided body.
///
/// It returns a String the Response object that needs to be parsed based on the self.model.
///
async fn call_api(
&self,
api_key: &str,
body: &serde_json::Value,
debug: bool,
) -> Result<String> {
//Get the API url
let model_url = self.get_endpoint();

//Make the API call
let client = Client::new();

//Send request
let response = client
.post(model_url)
.header(header::CONTENT_TYPE, "application/json")
.bearer_auth(api_key)
.json(&body)
.send()
.await?;

let response_status = response.status();
let response_text = response.text().await?;

if debug {
info!(
"[debug] Perplexity API response: [{}] {:#?}",
&response_status, &response_text
);
}

Ok(response_text)
}

//This method attempts to convert the provided API response text into the expected struct and extracts the data from the response
fn get_data(&self, response_text: &str, _function_call: bool) -> Result<String> {
//Convert API response to struct representing expected response format
let completions_response: PerplexityAPICompletionsResponse =
serde_json::from_str(response_text)?;

//Parse the response and return the assistant content
completions_response
.choices
.iter()
.filter_map(|choice| choice.message.as_ref())
.find(|&message| message.role == Some("assistant".to_string()))
.and_then(|message| {
message
.content
.as_ref()
.map(|content| sanitize_json_response(content))
})
.ok_or_else(|| anyhow!("Assistant role content not found"))
}

//This function allows to check the rate limits for different models
fn get_rate_limit(&self) -> RateLimit {
//Perplexity documentation: https://docs.perplexity.ai/guides/rate-limits
RateLimit {
tpm: 50 * 127_072, // 50 requests per minute wit max 127,072 context length
rpm: 50, // 50 request per minute
}
}

// Accepts a [0-100] percentage range and returns the target temperature based on model ranges
fn get_normalized_temperature(&self, relative_temp: u32) -> f32 {
// Temperature range documentation: https://docs.perplexity.ai/api-reference/chat-completions
// "The amount of randomness in the response, valued between 0 *inclusive* and 2 *exclusive*."
let min = 0.0f32;
let max = 1.99999f32;
map_to_range_f32(min, max, relative_temp)
}
}
61 changes: 60 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ pub(crate) fn map_to_range(min: u32, max: u32, target: u32) -> f32 {
min as f32 + (range * percentage)
}

//Used internally to pick a number from range based on its % representation
pub(crate) fn map_to_range_f32(min: f32, max: f32, target: u32) -> f32 {
// Cap the target to the percentage range [0, 100]
let capped_target = target.min(100);

// Calculate the target value in the range [min, max]
let range = max - min;
let percentage = capped_target as f32 / 100.0;
min + (range * percentage)
}

#[cfg(test)]
mod tests {
use schemars::schema::{InstanceType, ObjectValidation, RootSchema, Schema, SchemaObject};
Expand All @@ -100,7 +111,9 @@ mod tests {
use serde_json::Value;

use crate::llm_models::OpenAIModels;
use crate::utils::{fix_value_schema, get_tokenizer, get_type_schema, map_to_range};
use crate::utils::{
fix_value_schema, get_tokenizer, get_type_schema, map_to_range, map_to_range_f32,
};

#[derive(JsonSchema, Serialize, Deserialize)]
struct SimpleStruct {
Expand Down Expand Up @@ -420,4 +433,50 @@ mod tests {
// Not applicable for unsigned inputs but could test edge cases:
assert_eq!(map_to_range(0, 100, 0), 0.0);
}

#[test]
fn test_target_at_min_f32() {
assert_eq!(map_to_range_f32(0.0, 100.0, 0), 0.0);
assert_eq!(map_to_range_f32(10.0, 20.0, 0), 10.0);
}

#[test]
fn test_target_at_max_f32() {
assert_eq!(map_to_range_f32(0.0, 100.0, 100), 100.0);
assert_eq!(map_to_range_f32(10.0, 20.0, 100), 20.0);
}

#[test]
fn test_target_in_middle_f32() {
assert_eq!(map_to_range_f32(0.0, 100.0, 50), 50.0);
assert_eq!(map_to_range_f32(10.0, 20.0, 50), 15.0);
assert_eq!(map_to_range_f32(0.0, 1.0, 50), 0.5);
}

#[test]
fn test_target_out_of_bounds_f32() {
assert_eq!(map_to_range_f32(0.0, 100.0, 3000), 100.0); // Cap to 100
assert_eq!(map_to_range_f32(0.0, 100.0, 200), 100.0); // Cap to 100
assert_eq!(map_to_range_f32(10.0, 20.0, 200), 20.0); // Cap to 100
}

#[test]
fn test_zero_range_f32() {
assert_eq!(map_to_range_f32(10.0, 10.0, 50), 10.0); // Always return min if min == max
assert_eq!(map_to_range_f32(5.0, 5.0, 100), 5.0); // Even at max target
}

#[test]
fn test_fractional_range_f32() {
assert_eq!(map_to_range_f32(0.0, 0.5, 50), 0.25);
assert_eq!(map_to_range_f32(1.5, 3.0, 25), 1.875);
assert_eq!(map_to_range_f32(-1.0, 1.0, 75), 0.5);
}

#[test]
fn test_large_range_f32() {
assert_eq!(map_to_range_f32(-1000.0, 1000.0, 50), 0.0); // Midpoint of the range
assert_eq!(map_to_range_f32(-500.0, 500.0, 25), -250.0); // Quarter point
assert_eq!(map_to_range_f32(-2000.0, 0.0, 75), -500.0); // Three-quarters
}
}

0 comments on commit b99567e

Please sign in to comment.