diff --git a/Cargo.lock b/Cargo.lock index 9aeedcc..ad50fc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1313,7 +1313,6 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "tracing", "windows-sys 0.52.0", ] @@ -1376,7 +1375,6 @@ name = "tosic-llm" version = "0.1.0" dependencies = [ "async-trait", - "base64", "bytes", "derive_more", "futures-util", @@ -1386,11 +1384,11 @@ dependencies = [ "thiserror 2.0.11", "tokio", "tokio-stream", + "tosic-llm", "tosic-utils", "tracing", "url", "utoipa", - "validator", ] [[package]] @@ -1617,21 +1615,6 @@ dependencies = [ "serde", ] -[[package]] -name = "validator" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0b4a29d8709210980a09379f27ee31549b73292c87ab9899beee1c0d3be6303" -dependencies = [ - "idna", - "once_cell", - "regex", - "serde", - "serde_derive", - "serde_json", - "url", -] - [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index e1185e1..0049428 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,22 +3,25 @@ name = "tosic-llm" version = "0.1.0" edition = "2024" +[package.metadata.docs.rs] +features = ["doc-utils"] +all-features = true +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] + +[lib] +test = true +doctest = true + [workspace.dependencies] thiserror = "2.0.7" tosic-utils = { version = "0.2.3", features = ["env", "dotenv", "tracing"], registry = "gitea" } -# tosic-utils = { version = "0.2.3", features = ["env", "dotenv", "tracing"], path = "../tosic-utils" } -tokio = { version = "1.42", features = ["full", "macros", "rt-multi-thread", "tracing"] } tracing = { version = "0.1.41", features = ["log"] } -serde = { version = "1.0.216", features = ["derive", "alloc", "rc"] } serde_json = "1.0.133" -futures = "0.3.31" utoipa = { version = "5.2.0", features = ["actix_extras", "debug", "rc_schema", "non_strict_integers", "chrono", "uuid", "url"] } -validator ="0.19.0" [dependencies] derive_more = { version = "2.0.1", features = ["full"] } reqwest = { version = "0.12.12", default-features = false, features = ["json", "stream", "rustls-tls", "charset", "http2"] } -tokio = { workspace = true, features = ["full"] } serde = { version = "1.0.217", features = ["derive"] } futures-util = "0.3.31" tokio-stream = "0.1.17" @@ -29,6 +32,12 @@ serde_json.workspace = true thiserror.workspace = true url = { version = "2.5.4", features = ["serde"] } utoipa.workspace = true -validator.workspace = true async-trait = "0.1.86" -base64 = "0.22.1" + +[dev-dependencies] +tosic-llm = { path = ".", features = ["doc-utils"] } +tokio = { version = "1.43.0", features = ["full"] } + +[features] +default = [] +doc-utils = [] diff --git a/examples/gemini.rs b/examples/gemini.rs new file mode 100644 index 0000000..3d3280d --- /dev/null +++ b/examples/gemini.rs @@ -0,0 +1,110 @@ +use std::io::Write; +use thiserror::Error; +use tokio_stream::StreamExt; +use tosic_llm::gemini::{GeminiClient, GeminiContent, GeminiModel}; +use tosic_llm::{ensure, LlmProvider}; +use tosic_llm::error::{LlmError, WithContext}; +use tosic_llm::types::Role; + +#[derive(Debug, Error)] +enum Error { + #[error(transparent)] + Llm(#[from] LlmError), + #[error("{0}")] + Generic(String), +} + +use serde_json::{Result as JsonResult, Value}; + +#[derive(Debug)] +pub struct GeminiResponseParser { + has_started: bool, + has_finished: bool, +} + +impl GeminiResponseParser { + pub fn new() -> Self { + Self { + has_started: false, + has_finished: false, + } + } + + pub fn parse_chunk(&mut self, chunk: &[u8]) -> JsonResult> { + // Convert bytes to string + let chunk_str = String::from_utf8_lossy(chunk); + + // Handle the start and end markers + if chunk_str == "[" { + self.has_started = true; + return Ok(None); + } else if chunk_str == "]" || chunk_str.is_empty() { + self.has_finished = true; + return Ok(None); + } + + // Remove leading comma if present (subsequent chunks start with ,\r\n) + let cleaned_chunk = if chunk_str.starts_with(",") { + chunk_str.trim_start_matches(",").trim_start() + } else { + &chunk_str + }; + + // Parse the JSON object + let v: Value = serde_json::from_str(cleaned_chunk)?; + + // Extract the text from the nested structure + let text = v + .get("candidates") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("content")) + .and_then(|c| c.get("parts")) + .and_then(|p| p.get(0)) + .and_then(|p| p.get("text")) + .and_then(|t| t.as_str()) + .map(String::from); + + Ok(text) + } +} + +async fn ask_ai() -> Result<(), Error> { + let client = GeminiClient::new(GeminiModel::Gemini2Flash).context("Failed to create the Gemini Client")?; + let provider = LlmProvider::new(client); + + let req = GeminiContent::new(Some(Role::User), "Hi my name is Emil and i like to write complex rust libraries".to_string()); + + let res = provider.generate(vec![req], true).await.context("Failed to get response from LLM")?; + + ensure!(res.is_stream(), Error::Generic("Response is not stream".into())); + + let mut stream = res.unwrap_stream(); + + // Stream to STDOUT + let stdout = std::io::stdout(); + let mut stdout = stdout.lock(); + + let mut parser = GeminiResponseParser::new(); + let mut written_len = 0; + + while let Some(chunk) = stream.next().await { + let chunk = chunk.context("Failed to read response from LLM")?; + + if let Ok(Some(text)) = parser.parse_chunk(&chunk) { + let bytes = text.as_bytes(); + written_len += bytes.len(); + stdout.write_all(bytes).context("Failed to write to stdout")?; + stdout.flush().context("Failed to flush stdout")?; + } + } + + + println!("\nWrote {} bytes", written_len); + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + ask_ai().await +} \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index 1b51b37..7621304 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,16 +1,242 @@ // tosic_llm/src/error.rs +use std::fmt::{Debug, Display, Formatter}; +use derive_more::IsVariant; use thiserror::Error; use url::ParseError; -#[derive(Debug, Error)] -pub enum LlmError { +#[derive(Debug, Error, IsVariant)] +pub enum LlmErrorType { #[error(transparent)] Reqwest(#[from] reqwest::Error), #[error(transparent)] Parse(#[from] ParseError), #[error(transparent)] Json(#[from] serde_json::Error), + #[error(transparent)] + Io(#[from] std::io::Error), #[error("An error occurred: {0}")] Generic(#[from] Box), + #[error("{message}")] + Context { + message: String, + #[source] + source: Box, + }, + #[error("Value was empty. {value}, expected one of: {expected:?}")] + Empty { + value: String, + expected: &'static [&'static str], + }, +} + +#[derive(Debug)] +pub struct ErrorContext { + pub message: String, + pub file: &'static str, + pub line: u32, +} + +impl Display for ErrorContext { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} (at {}:{})", self.message, self.file, self.line) + } +} + +pub struct LlmError { + pub error_type: LlmErrorType, +} + +impl Debug for LlmError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.error_type { + LlmErrorType::Context { .. } => write!(f, "{}", self), // Avoid using debug print for context to avoid hard to read message due to recursive type. + _ => write!(f, "{:?}", self.error_type), + } + } } + +impl LlmError { + pub fn new(error_type: LlmErrorType) -> Self { + Self { error_type } + } + + pub fn add_context(self, message: impl Display) -> Self { + let source = Box::new(self.error_type); + + Self::new(LlmErrorType::Context { message: message.to_string(), source }) + } + + pub fn new_context(message: impl Display, source: Box) -> Self { + let source = source; + + Self::new(LlmErrorType::Context { + message: message.to_string(), + source + }) + } +} + +impl> From for LlmError { + fn from(value: E) -> Self { + Self { + error_type: value.into(), + } + } +} + +impl LlmErrorType { + pub(crate) fn fmt_context(&self, writer: &mut Formatter<'_>) -> std::fmt::Result { + match self { + LlmErrorType::Context { message, source } => { + writeln!(writer, "\t- {message}")?; + source.fmt_context(writer) + } + _ => Ok(()) + } + } +} + +impl Display for LlmError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Error: {}", self.error_type)?; + match &self.error_type { + LlmErrorType::Context { message, source } => { + writeln!(f, "\nContext:")?; + writeln!(f, "\t- {}", message)?; + source.fmt_context(f) + } + _ => Ok(()) + } + } +} + +impl std::error::Error for LlmError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.error_type { + LlmErrorType::Context { source, .. } => Some(source), + _ => { + Some(&self.error_type) + } + } + } +} + +pub trait WithContext { + fn context(self, context: C) -> Result; + fn with_context(self, context_fn: F) -> Result + where + F: FnOnce() -> C, + C: Display; +} + +impl WithContext for Result { + fn context(self, context: C) -> Result { + self.map_err(|error| { + error.add_context(context) + }) + } + + fn with_context(self, context_fn: F) -> Result + where + F: FnOnce() -> C, + C: Display, + { + self.map_err(|error| { + error.add_context(context_fn()) + }) + } +} + +impl WithContext for Result +where + E: Into, +{ + fn context(self, context: C) -> Result { + self.map_err(|error| { + let llm_error = error.into(); + + LlmError::new_context(context, Box::new(llm_error)) + }) + } + + fn with_context(self, context_fn: F) -> Result + where + F: FnOnce() -> C, + C: Display, + { + self.map_err(|error| { + let llm_error = error.into(); + + LlmError::new_context(context_fn(), Box::new(llm_error)) + }) + } +} + +impl WithContext for Option { + fn context(self, context: C) -> Result { + match self { + Some(val) => Ok(val), + None => { + let llm_error = LlmErrorType::Empty { + value: format!("{self:?}"), + expected: &[], + }; + + Err(LlmError::new_context(context, Box::new(llm_error))) + } + } + } + + fn with_context(self, context_fn: F) -> Result + where + F: FnOnce() -> C, + C: Display + { + match self { + Some(val) => Ok(val), + None => { + let llm_error = LlmErrorType::Empty { + value: format!("{self:?}"), + expected: &[], + }; + + Err(LlmError::new_context(context_fn(), Box::new(llm_error))) + } + } + } +} + +#[macro_export] +macro_rules! error_context { + ($msg:expr) => { + $crate::error::ErrorContext { + message: $msg.to_string(), + file: file!(), + line: line!(), + } + }; + ($fmt:expr, $($arg:tt)*) => { + $crate::error::ErrorContext { + message: format!($fmt, $($arg)*), + file: file!(), + line: line!(), + } + }; +} + +#[macro_export] +macro_rules! bail { + ($err:expr) => { + return Err($err.into()); + } +} + +#[macro_export] +macro_rules! ensure { + ($cond:expr, $expr:expr) => { + if !$cond { + return Err($expr.into()); + } + } +} \ No newline at end of file diff --git a/src/gemini/mod.rs b/src/gemini/mod.rs index 244cdb3..cd06de5 100644 --- a/src/gemini/mod.rs +++ b/src/gemini/mod.rs @@ -3,7 +3,7 @@ mod impls; mod types; -use crate::error::LlmError; +use crate::error::{WithContext, LlmError}; use crate::traits::LlmClient; use crate::utils::SingleOrMultiple; use bytes::Bytes; @@ -16,6 +16,7 @@ use std::sync::LazyLock; use tosic_utils::env::env_util; pub use types::*; use url::Url; +use crate::error_context; pub const GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta"; pub const GEMINI_STREAM_ENDPOINT: &str = ":streamGenerateContent"; @@ -67,8 +68,7 @@ impl GeminiClient { "{GEMINI_BASE_URL}/{}{}{query}", self.model, endpoint.as_ref() - )) - .map_err(Into::into) + )).context(error_context!("Failed to parse endpoint")) } #[tracing::instrument(skip(request, endpoint))] @@ -83,8 +83,7 @@ impl GeminiClient { .post(url) .json(&request) .send() - .await - .map_err(Into::into) + .await.context(error_context!("Failed to send request")) } async fn stream_generate_content_inner>( diff --git a/src/lib.rs b/src/lib.rs index d26a625..0f35246 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,4 @@ pub mod traits; pub mod types; mod utils; -type Result = core::result::Result; +pub type Result = core::result::Result; diff --git a/src/provider.rs b/src/provider.rs index 8db5036..689f67d 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -49,6 +49,7 @@ pub struct LlmProvider { } impl LlmProvider { + /// Creates a new provider given the inner [LlmClient] #[inline(always)] pub fn new(inner: T) -> Self { Self { inner } @@ -74,49 +75,8 @@ impl LlmProvider { /// # Examples /// /// ``` - /// # use std::error::Error; - /// # use derive_more::{Display, Error}; - /// # use futures_util::Stream; - /// # use tosic_llm::{LlmProvider, traits::LlmClient}; - /// # use serde::{Serialize, Deserialize}; - /// # use tosic_llm::types::LlmMessages; - /// # - /// # // Example minimal LlmClient implementation - /// # struct SimpleClient; - /// # #[derive(Debug, Serialize)] - /// # struct SimpleInput(String); - /// - /// # impl From for SimpleInput { - /// # - /// # fn from(value: LlmMessages) -> Self { - /// # todo!() - /// # } - /// # } - /// # - /// # #[derive(Debug, Deserialize)] - /// # struct SimpleOutput(String); - /// # #[derive(Debug, Error, Display)] - /// # struct SimpleError; - /// # struct SimpleConfig; - /// # - /// # #[async_trait::async_trait] - /// # impl LlmClient for SimpleClient { - /// # type Error = SimpleError; - /// # type Input = SimpleInput; - /// # type Output = SimpleOutput; - /// # type StreamedOutput = String; - /// # type Config = SimpleConfig; - /// # - /// # async fn chat_completion(&self, input: Self::Input) -> Result { - /// # Ok(SimpleOutput("response".to_string())) - /// # } - /// # - /// # async fn stream_chat_completion(&self, input: Self::Input) - /// # -> Result>, Self::Error> { - /// # Ok(futures_util::stream::empty()) - /// # } - /// # } - /// # + /// # use tosic_llm::LlmProvider; + /// # tosic_llm::mocked_llm_client!(); /// # async fn example() -> Result<(), SimpleError> { /// let client = SimpleClient; // Any type implementing LlmClient /// let provider = LlmProvider::new(client); diff --git a/src/utils/doc.rs b/src/utils/doc.rs new file mode 100644 index 0000000..e730a4b --- /dev/null +++ b/src/utils/doc.rs @@ -0,0 +1,42 @@ +#![cfg(any(doc, doctest, test, feature = "doc-utils"))] +#![doc(hidden)] + +#[doc(hidden)] +#[macro_export] +macro_rules! mocked_llm_client { + () => { + struct SimpleClient; + #[derive(Debug, ::serde::Serialize)] + struct SimpleInput(String); + + impl From<$crate::types::LlmMessages> for SimpleInput { + fn from(value: $crate::types::LlmMessages) -> Self { + todo!() + } + } + + #[derive(Debug, ::serde::Deserialize)] + struct SimpleOutput(String); + #[derive(Debug, ::derive_more::Error, ::derive_more::Display)] + struct SimpleError; + struct SimpleConfig; + + #[::async_trait::async_trait] + impl $crate::traits::LlmClient for SimpleClient { + type Error = SimpleError; + type Input = SimpleInput; + type Output = SimpleOutput; + type StreamedOutput = String; + type Config = SimpleConfig; + + async fn chat_completion(&self, input: Self::Input) -> Result { + Ok(SimpleOutput("response".to_string())) + } + + async fn stream_chat_completion(&self, input: Self::Input) + -> Result>, Self::Error> { + Ok(::futures_util::stream::empty()) + } + } + }; +} \ No newline at end of file diff --git a/src/utils.rs b/src/utils/mod.rs similarity index 98% rename from src/utils.rs rename to src/utils/mod.rs index 8d02e30..f27745b 100644 --- a/src/utils.rs +++ b/src/utils/mod.rs @@ -1,5 +1,7 @@ // tosic_llm/src/utils.rs +pub mod doc; + use std::vec::IntoIter; pub enum SingleOrMultiple {