From 31b1a7ab5e47cb6dc9432b5b83fc818845ce7f16 Mon Sep 17 00:00:00 2001 From: noahbald Date: Sat, 28 Oct 2023 23:28:04 +1100 Subject: [PATCH 01/12] feat: allow custom input key and arbitrary data from client --- crates/llm-ls/src/main.rs | 49 +++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 1b8ed40..fbc9d1b 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -84,7 +84,7 @@ fn should_complete(document: &Document, position: Position) -> CompletionType { } } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(untagged)] enum TokenizerConfig { Local { path: PathBuf }, @@ -235,6 +235,8 @@ struct CompletionParams { tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, + request_body: Option>, + inputs_key: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -358,20 +360,33 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, - ide: Ide, - model: &str, - request_params: RequestParams, - api_token: Option<&String>, prompt: String, + params: CompletionParams, ) -> Result> { let t = Instant::now(); + let CompletionParams { + model, + request_body, + inputs_key, + request_params, + api_token, + ide, + .. + } = params; + let mut body: HashMap = HashMap::new(); + body.extend(request_body.unwrap_or_default()); + body.insert( + inputs_key.unwrap_or("input".to_string()), + serde_json::Value::String(prompt), + ); + body.insert( + "parameters".to_string(), + serde_json::to_value(request_params).expect("Failed to serialize request_params"), + ); let res = http_client - .post(build_url(model)) - .json(&APIRequest { - inputs: prompt, - parameters: request_params.into(), - }) - .headers(build_headers(api_token, ide)?) + .post(build_url(&model)) + .json(&body) + .headers(build_headers(api_token.as_ref(), ide)?) .send() .await .map_err(internal_error)?; @@ -584,10 +599,11 @@ impl Backend { return Ok(CompletionResult { request_id, completions: vec![]}); } + let tokenizer_config = params.tokenizer_config.clone(); let tokenizer = get_tokenizer( ¶ms.model, &mut *self.tokenizer_map.write().await, - params.tokenizer_config, + tokenizer_config, &self.http_client, &self.cache_dir, params.api_token.as_ref(), @@ -608,17 +624,15 @@ impl Backend { } else { &self.http_client }; + let tokens_to_clear = params.tokens_to_clear.clone(); let result = request_completion( http_client, - params.ide, - ¶ms.model, - params.request_params, - params.api_token.as_ref(), prompt, + params, ) .await?; - let completions = parse_generations(result, ¶ms.tokens_to_clear, completion_type); + let completions = parse_generations(result, &tokens_to_clear, completion_type); Ok(CompletionResult { request_id, completions }) }.instrument(span).await } @@ -803,3 +817,4 @@ async fn main() { Server::new(stdin, stdout, socket).serve(service).await; } + From eeadc3dac3d1943ab4855086252f9057b170f188 Mon Sep 17 00:00:00 2001 From: noahbald Date: Tue, 21 Nov 2023 21:10:43 +1100 Subject: [PATCH 02/12] feat: Implement adaptors for huggingface and ollama --- crates/llm-ls/src/adaptors.rs | 142 ++++++++++++++++++++++++++++++++++ crates/llm-ls/src/main.rs | 60 ++++++-------- 2 files changed, 165 insertions(+), 37 deletions(-) create mode 100644 crates/llm-ls/src/adaptors.rs diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs new file mode 100644 index 0000000..240379e --- /dev/null +++ b/crates/llm-ls/src/adaptors.rs @@ -0,0 +1,142 @@ +use super::{ + internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, +}; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tower_lsp::jsonrpc; + +struct AdaptHuggingFaceRequest; +impl AdaptHuggingFaceRequest { + fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { + return serde_json::json!({ + "inputs": prompt, + "parameters": params.request_params, + }); + } + fn adapt_headers( + &self, + api_token: Option<&String>, + ide: Ide, + ) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); + headers.insert( + USER_AGENT, + HeaderValue::from_str(&user_agent).map_err(internal_error)?, + ); + + if let Some(api_token) = api_token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, + ); + } + + Ok(headers) + } +} + +struct AdaptHuggingFaceResponse; +impl AdaptHuggingFaceResponse { + fn adapt_blob(&self, text: reqwest::Result) -> Result, jsonrpc::Error> { + let generations = + match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(gens) => gens, + APIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) + } +} + +struct AdaptOllamaRequest; +impl AdaptOllamaRequest { + fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { + let request_body = params.request_body.unwrap_or_default(); + let body = serde_json::json!({ + "prompt": prompt, + "model": request_body.get("model"), + }); + body + } + fn adapt_headers(&self) -> Result { + let headers = HeaderMap::new(); + Ok(headers) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaGeneration { + response: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum OllamaAPIResponse { + Generation(OllamaGeneration), + Error(APIError), +} + +struct AdaptOllamaResponse; +impl AdaptOllamaResponse { + fn adapt_blob( + &self, + text: Result, + ) -> Result, jsonrpc::Error> { + match text { + Ok(text) => { + let mut gen: Vec = Vec::new(); + for row in text.split("\n") { + if row.is_empty() { + continue; + } + let chunk = match serde_json::from_str(row) { + Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response, + Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)), + Err(err) => return Err(internal_error(err)), + }; + gen.push(Generation { + generated_text: chunk, + }) + } + Ok(gen) + } + Err(err) => Err(internal_error(err)), + } + } +} + +const HUGGING_FACE_ADAPTOR: &str = "huggingface"; + +pub struct Adaptors; +impl Adaptors { + pub fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { + let adaptor = params.adaptor.clone(); + match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { + "ollama" => AdaptOllamaRequest.adapt_body(prompt, params), + _ => AdaptHuggingFaceRequest.adapt_body(prompt, params), + } + } + pub fn adapt_headers( + &self, + adaptor: Option, + api_token: Option<&String>, + ide: Ide, + ) -> Result { + match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { + "ollama" => AdaptOllamaRequest.adapt_headers(), + _ => AdaptHuggingFaceRequest.adapt_headers(api_token, ide), + } + } + pub fn adapt_blob( + &self, + adaptor: Option, + text: Result, + ) -> Result, jsonrpc::Error> { + match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { + "ollama" => AdaptOllamaResponse.adapt_blob(text), + _ => AdaptHuggingFaceResponse.adapt_blob(text), + } + } +} diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index a23d6c6..a37e6d2 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,3 +1,4 @@ +use adaptors::Adaptors; use document::Document; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; @@ -18,12 +19,13 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; +mod adaptors; mod document; mod language_id; const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); -const NAME: &str = "llm-ls"; -const VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const NAME: &str = "llm-ls"; +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { Ok(rope.try_line_to_char(row).map_err(internal_error)? @@ -178,12 +180,12 @@ struct APIRequest { } #[derive(Debug, Serialize, Deserialize)] -struct Generation { +pub struct Generation { generated_text: String, } #[derive(Debug, Deserialize)] -struct APIError { +pub struct APIError { error: String, } @@ -195,7 +197,7 @@ impl Display for APIError { #[derive(Debug, Deserialize)] #[serde(untagged)] -enum APIResponse { +pub enum APIResponse { Generation(Generation), Generations(Vec), Error(APIError), @@ -219,7 +221,7 @@ struct Completion { #[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] -enum Ide { +pub enum Ide { Neovim, VSCode, JetBrains, @@ -261,7 +263,7 @@ struct RejectedCompletion { #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -struct CompletionParams { +pub struct CompletionParams { #[serde(flatten)] text_document_position: TextDocumentPositionParams, request_params: RequestParams, @@ -271,12 +273,12 @@ struct CompletionParams { fim: FimParams, api_token: Option, model: String, + adaptor: Option, tokens_to_clear: Vec, tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, - request_body: Option>, - inputs_key: Option, + request_body: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -285,7 +287,7 @@ struct CompletionResult { completions: Vec, } -fn internal_error(err: E) -> Error { +pub fn internal_error(err: E) -> Error { let err_msg = err.to_string(); error!(err_msg); Error { @@ -404,38 +406,22 @@ async fn request_completion( params: CompletionParams, ) -> Result> { let t = Instant::now(); - let CompletionParams { - model, - request_body, - inputs_key, - request_params, - api_token, - ide, - .. - } = params; - let mut body: HashMap = HashMap::new(); - body.extend(request_body.unwrap_or_default()); - body.insert( - inputs_key.unwrap_or("input".to_string()), - serde_json::Value::String(prompt), - ); - body.insert( - "parameters".to_string(), - serde_json::to_value(request_params).expect("Failed to serialize request_params"), - ); + let model = params.model.clone(); + let adaptor = params.adaptor.clone(); + let api_token = params.api_token.clone(); + let ide = params.ide.clone(); + + let json = Adaptors.adapt_body(prompt, params); + let headers = Adaptors.adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?; let res = http_client .post(build_url(&model)) - .json(&body) - .headers(build_headers(api_token.as_ref(), ide)?) + .json(&json) + .headers(headers) .send() .await .map_err(internal_error)?; - let generations = match res.json().await.map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(gens) => gens, - APIResponse::Error(err) => return Err(internal_error(err)), - }; + let generations = Adaptors.adapt_blob(adaptor, res.text().await); let time = t.elapsed().as_millis(); info!( model, @@ -443,7 +429,7 @@ async fn request_completion( generations = serde_json::to_string(&generations).map_err(internal_error)?, "{model} computed generations in {time} ms" ); - Ok(generations) + generations } fn parse_generations( From ef5bac3d129fb94b40e6d898a18758782a84344c Mon Sep 17 00:00:00 2001 From: noahbald Date: Fri, 24 Nov 2023 22:08:31 +1100 Subject: [PATCH 03/12] Clean up adaptors --- crates/llm-ls/src/adaptors.rs | 227 +++++++++++++++++++--------------- crates/llm-ls/src/main.rs | 10 +- 2 files changed, 129 insertions(+), 108 deletions(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 240379e..b6c1d29 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -6,137 +6,158 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use tower_lsp::jsonrpc; -struct AdaptHuggingFaceRequest; -impl AdaptHuggingFaceRequest { - fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { - return serde_json::json!({ - "inputs": prompt, - "parameters": params.request_params, - }); - } - fn adapt_headers( - &self, - api_token: Option<&String>, - ide: Ide, - ) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); +fn build_tgi_body(prompt: String, params: CompletionParams) -> Value { + serde_json::json!({ + "inputs": prompt, + "parameters": params.request_params, + }) +} + +fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); + headers.insert( + USER_AGENT, + HeaderValue::from_str(&user_agent).map_err(internal_error)?, + ); + + if let Some(api_token) = api_token { headers.insert( - USER_AGENT, - HeaderValue::from_str(&user_agent).map_err(internal_error)?, + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, ); + } - if let Some(api_token) = api_token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, - ); - } + Ok(headers) +} - Ok(headers) - } +fn parse_tgi_text(text: reqwest::Result) -> Result, jsonrpc::Error> { + let generations = + match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(_) => { + return Err(internal_error( + "TGI parser unexpectedly encountered api-inference", + )) + } + APIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) } -struct AdaptHuggingFaceResponse; -impl AdaptHuggingFaceResponse { - fn adapt_blob(&self, text: reqwest::Result) -> Result, jsonrpc::Error> { - let generations = - match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(gens) => gens, - APIResponse::Error(err) => return Err(internal_error(err)), - }; - Ok(generations) - } +fn build_api_body(prompt: String, params: CompletionParams) -> Value { + build_tgi_body(prompt, params) } -struct AdaptOllamaRequest; -impl AdaptOllamaRequest { - fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { - let request_body = params.request_body.unwrap_or_default(); - let body = serde_json::json!({ - "prompt": prompt, - "model": request_body.get("model"), - }); - body - } - fn adapt_headers(&self) -> Result { - let headers = HeaderMap::new(); - Ok(headers) - } +fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { + build_tgi_headers(api_token, ide) +} + +fn parse_api_text(text: reqwest::Result) -> Result, jsonrpc::Error> { + let generations = + match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(gens) => gens, + APIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) +} + +fn build_ollama_body(prompt: String, params: CompletionParams) -> Value { + let request_body = params.request_body.unwrap_or_default(); + let body = serde_json::json!({ + "prompt": prompt, + "model": request_body.get("model"), + }); + body +} +fn build_ollama_headers() -> Result { + let headers = HeaderMap::new(); + Ok(headers) } #[derive(Debug, Serialize, Deserialize)] -pub struct OllamaGeneration { +struct OllamaGeneration { response: String, } #[derive(Debug, Deserialize)] #[serde(untagged)] -pub enum OllamaAPIResponse { +enum OllamaAPIResponse { Generation(OllamaGeneration), Error(APIError), } -struct AdaptOllamaResponse; -impl AdaptOllamaResponse { - fn adapt_blob( - &self, - text: Result, - ) -> Result, jsonrpc::Error> { - match text { - Ok(text) => { - let mut gen: Vec = Vec::new(); - for row in text.split("\n") { - if row.is_empty() { - continue; - } - let chunk = match serde_json::from_str(row) { - Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response, - Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)), - Err(err) => return Err(internal_error(err)), - }; - gen.push(Generation { - generated_text: chunk, - }) +fn parse_ollama_text( + text: Result, +) -> Result, jsonrpc::Error> { + match text { + Ok(text) => { + let mut gen: Vec = Vec::new(); + for row in text.split('\n') { + if row.is_empty() { + continue; } - Ok(gen) + let chunk = match serde_json::from_str(row) { + Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response, + Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)), + Err(err) => return Err(internal_error(err)), + }; + gen.push(Generation { + generated_text: chunk, + }) } - Err(err) => Err(internal_error(err)), + Ok(gen) } + Err(err) => Err(internal_error(err)), } } -const HUGGING_FACE_ADAPTOR: &str = "huggingface"; +const TGI: &str = "tgi"; +const HUGGING_FACE: &str = "huggingface"; +const OLLAMA: &str = "ollama"; +const DEFAULT_ADAPTOR: &str = HUGGING_FACE; -pub struct Adaptors; -impl Adaptors { - pub fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value { - let adaptor = params.adaptor.clone(); - match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { - "ollama" => AdaptOllamaRequest.adapt_body(prompt, params), - _ => AdaptHuggingFaceRequest.adapt_body(prompt, params), - } +fn unknown_adaptor_error(adaptor: String) -> jsonrpc::Error { + internal_error(format!("Unknown adaptor {}", adaptor)) +} + +pub fn adapt_body(prompt: String, params: CompletionParams) -> Result { + let adaptor = params + .adaptor + .clone() + .unwrap_or(DEFAULT_ADAPTOR.to_string()); + match adaptor.as_str() { + TGI => Ok(build_tgi_body(prompt, params)), + HUGGING_FACE => Ok(build_api_body(prompt, params)), + OLLAMA => Ok(build_ollama_body(prompt, params)), + _ => Err(unknown_adaptor_error(adaptor)), } - pub fn adapt_headers( - &self, - adaptor: Option, - api_token: Option<&String>, - ide: Ide, - ) -> Result { - match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { - "ollama" => AdaptOllamaRequest.adapt_headers(), - _ => AdaptHuggingFaceRequest.adapt_headers(api_token, ide), - } +} + +pub fn adapt_headers( + adaptor: Option, + api_token: Option<&String>, + ide: Ide, +) -> Result { + let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string()); + match adaptor.as_str() { + TGI => build_tgi_headers(api_token, ide), + HUGGING_FACE => build_api_headers(api_token, ide), + OLLAMA => build_ollama_headers(), + _ => Err(internal_error(adaptor)), } - pub fn adapt_blob( - &self, - adaptor: Option, - text: Result, - ) -> Result, jsonrpc::Error> { - match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() { - "ollama" => AdaptOllamaResponse.adapt_blob(text), - _ => AdaptHuggingFaceResponse.adapt_blob(text), - } +} + +pub fn adapt_text( + adaptor: Option, + text: Result, +) -> jsonrpc::Result> { + let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string()); + match adaptor.as_str() { + TGI => parse_tgi_text(text), + HUGGING_FACE => parse_api_text(text), + OLLAMA => parse_ollama_text(text), + _ => Err(unknown_adaptor_error(adaptor)), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index a37e6d2..2a20a59 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,4 +1,4 @@ -use adaptors::Adaptors; +use adaptors::{adapt_body, adapt_headers, adapt_text}; use document::Document; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; @@ -409,10 +409,10 @@ async fn request_completion( let model = params.model.clone(); let adaptor = params.adaptor.clone(); let api_token = params.api_token.clone(); - let ide = params.ide.clone(); + let ide = params.ide; - let json = Adaptors.adapt_body(prompt, params); - let headers = Adaptors.adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?; + let json = adapt_body(prompt, params); + let headers = adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?; let res = http_client .post(build_url(&model)) .json(&json) @@ -421,7 +421,7 @@ async fn request_completion( .await .map_err(internal_error)?; - let generations = Adaptors.adapt_blob(adaptor, res.text().await); + let generations = adapt_text(adaptor, res.text().await); let time = t.elapsed().as_millis(); info!( model, From e3db990996c5f68e581194ae8c688e739eb87dd7 Mon Sep 17 00:00:00 2001 From: noahbald Date: Thu, 30 Nov 2023 22:40:02 +1100 Subject: [PATCH 04/12] feat: Port rhemlot's adaptor, improve structure of adaptors --- crates/llm-ls/src/adaptors.rs | 201 ++++++++++++++++++++++++---------- crates/llm-ls/src/main.rs | 36 +++--- 2 files changed, 160 insertions(+), 77 deletions(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index b6c1d29..0f5a8e9 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -1,15 +1,18 @@ +use crate::RequestParams; + use super::{ internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, }; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::fmt::Display; use tower_lsp::jsonrpc; -fn build_tgi_body(prompt: String, params: CompletionParams) -> Value { +fn build_tgi_body(prompt: String, params: &RequestParams) -> Value { serde_json::json!({ "inputs": prompt, - "parameters": params.request_params, + "parameters": params, }) } @@ -31,13 +34,13 @@ fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result) -> Result, jsonrpc::Error> { +fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { let generations = - match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { + match serde_json::from_str(text).map_err(internal_error)? { APIResponse::Generation(gen) => vec![gen], APIResponse::Generations(_) => { return Err(internal_error( - "TGI parser unexpectedly encountered api-inference", + "You are attempting to parse a result in the API inference format when using the `tgi` adaptor", )) } APIResponse::Error(err) => return Err(internal_error(err)), @@ -45,7 +48,7 @@ fn parse_tgi_text(text: reqwest::Result) -> Result, json Ok(generations) } -fn build_api_body(prompt: String, params: CompletionParams) -> Value { +fn build_api_body(prompt: String, params: &RequestParams) -> Value { build_tgi_body(prompt, params) } @@ -53,27 +56,24 @@ fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result) -> Result, jsonrpc::Error> { - let generations = - match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(gens) => gens, - APIResponse::Error(err) => return Err(internal_error(err)), - }; +fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { + let generations = match serde_json::from_str(text).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(gens) => gens, + APIResponse::Error(err) => return Err(internal_error(err)), + }; Ok(generations) } -fn build_ollama_body(prompt: String, params: CompletionParams) -> Value { - let request_body = params.request_body.unwrap_or_default(); - let body = serde_json::json!({ +fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { + serde_json::json!({ "prompt": prompt, - "model": request_body.get("model"), - }); - body + "model": params.request_body.as_ref().unwrap().get("model"), + "stream": false, + }) } fn build_ollama_headers() -> Result { - let headers = HeaderMap::new(); - Ok(headers) + Ok(HeaderMap::new()) } #[derive(Debug, Serialize, Deserialize)] @@ -81,6 +81,14 @@ struct OllamaGeneration { response: String, } +impl From for Generation { + fn from(value: OllamaGeneration) -> Self { + Generation { + generated_text: value.response, + } + } +} + #[derive(Debug, Deserialize)] #[serde(untagged)] enum OllamaAPIResponse { @@ -88,27 +96,100 @@ enum OllamaAPIResponse { Error(APIError), } -fn parse_ollama_text( - text: Result, -) -> Result, jsonrpc::Error> { - match text { - Ok(text) => { - let mut gen: Vec = Vec::new(); - for row in text.split('\n') { - if row.is_empty() { - continue; - } - let chunk = match serde_json::from_str(row) { - Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response, - Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)), - Err(err) => return Err(internal_error(err)), - }; - gen.push(Generation { - generated_text: chunk, - }) +fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { + let generations = match serde_json::from_str(text).map_err(internal_error)? { + OllamaAPIResponse::Generation(gen) => vec![gen.into()], + OllamaAPIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) +} + +fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { + serde_json::json!({ + "prompt": prompt, + "model": params.model, + "max_tokens": params.request_params.max_new_tokens, + "temperature": params.request_params.temperature, + "top_p": params.request_params.top_p, + "stop": params.request_params.stop_tokens.clone(), + }) +} + +fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { + build_api_headers(api_token, ide) +} + +#[derive(Debug, Deserialize)] +struct OpenAIGenerationChoice { + text: String, +} + +impl From for Generation { + fn from(value: OpenAIGenerationChoice) -> Self { + Generation { + generated_text: value.text, + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAIGeneration { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum OpenAIErrorLoc { + String(String), + Int(u32), +} + +impl Display for OpenAIErrorLoc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OpenAIErrorLoc::String(s) => s.fmt(f), + OpenAIErrorLoc::Int(i) => i.fmt(f), + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAIErrorDetail { + loc: OpenAIErrorLoc, + msg: String, + r#type: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAIError { + detail: Vec, +} + +impl Display for OpenAIError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, item) in self.detail.iter().enumerate() { + if i != 0 { + writeln!(f)?; } - Ok(gen) + write!(f, "{}: {} ({})", item.loc, item.msg, item.r#type)?; + } + Ok(()) + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum OpenAIAPIResponse { + Generation(OpenAIGeneration), + Error(OpenAIError), +} + +fn parse_openai_text(text: &str) -> Result, jsonrpc::Error> { + match serde_json::from_str(text).map_err(internal_error) { + Ok(OpenAIAPIResponse::Generation(completion)) => { + Ok(completion.choices.into_iter().map(|x| x.into()).collect()) } + Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)), Err(err) => Err(internal_error(err)), } } @@ -116,48 +197,48 @@ fn parse_ollama_text( const TGI: &str = "tgi"; const HUGGING_FACE: &str = "huggingface"; const OLLAMA: &str = "ollama"; +const OPENAI: &str = "openai"; const DEFAULT_ADAPTOR: &str = HUGGING_FACE; -fn unknown_adaptor_error(adaptor: String) -> jsonrpc::Error { - internal_error(format!("Unknown adaptor {}", adaptor)) +fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error { + internal_error(format!("Unknown adaptor {:?}", adaptor)) } -pub fn adapt_body(prompt: String, params: CompletionParams) -> Result { - let adaptor = params +pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result { + match params .adaptor - .clone() - .unwrap_or(DEFAULT_ADAPTOR.to_string()); - match adaptor.as_str() { - TGI => Ok(build_tgi_body(prompt, params)), - HUGGING_FACE => Ok(build_api_body(prompt, params)), + .as_ref() + .unwrap_or(&DEFAULT_ADAPTOR.to_string()) + .as_str() + { + TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)), + HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)), OLLAMA => Ok(build_ollama_body(prompt, params)), - _ => Err(unknown_adaptor_error(adaptor)), + OPENAI => Ok(build_openai_body(prompt, params)), + _ => Err(unknown_adaptor_error(params.adaptor.as_ref())), } } pub fn adapt_headers( - adaptor: Option, + adaptor: Option<&String>, api_token: Option<&String>, ide: Ide, ) -> Result { - let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string()); - match adaptor.as_str() { + match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { TGI => build_tgi_headers(api_token, ide), HUGGING_FACE => build_api_headers(api_token, ide), OLLAMA => build_ollama_headers(), - _ => Err(internal_error(adaptor)), + OPENAI => build_openai_headers(api_token, ide), + _ => Err(unknown_adaptor_error(adaptor)), } } -pub fn adapt_text( - adaptor: Option, - text: Result, -) -> jsonrpc::Result> { - let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string()); - match adaptor.as_str() { +pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result> { + match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { TGI => parse_tgi_text(text), HUGGING_FACE => parse_api_text(text), OLLAMA => parse_ollama_text(text), + OPENAI => parse_openai_text(text), _ => Err(unknown_adaptor_error(adaptor)), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 2a20a59..ede591a 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,4 +1,4 @@ -use adaptors::{adapt_body, adapt_headers, adapt_text}; +use adaptors::{adapt_body, adapt_headers, parse_generations}; use document::Document; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; @@ -403,25 +403,29 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, prompt: String, - params: CompletionParams, + params: &CompletionParams, ) -> Result> { let t = Instant::now(); - let model = params.model.clone(); - let adaptor = params.adaptor.clone(); - let api_token = params.api_token.clone(); - let ide = params.ide; let json = adapt_body(prompt, params); - let headers = adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?; + let headers = adapt_headers( + params.adaptor.as_ref(), + params.api_token.as_ref(), + params.ide, + )?; let res = http_client - .post(build_url(&model)) + .post(build_url(¶ms.model)) .json(&json) .headers(headers) .send() .await .map_err(internal_error)?; - let generations = adapt_text(adaptor, res.text().await); + let model = ¶ms.model; + let generations = parse_generations( + params.adaptor.as_ref(), + res.text().await.unwrap_or(String::new()).as_str(), + ); let time = t.elapsed().as_millis(); info!( model, @@ -432,7 +436,7 @@ async fn request_completion( generations } -fn parse_generations( +fn format_generations( generations: Vec, tokens_to_clear: &[String], completion_type: CompletionType, @@ -525,7 +529,7 @@ async fn download_tokenizer_file( async fn get_tokenizer( model: &str, tokenizer_map: &mut HashMap>, - tokenizer_config: Option, + tokenizer_config: Option<&TokenizerConfig>, http_client: &reqwest::Client, cache_dir: impl AsRef, api_token: Option<&String>, @@ -557,7 +561,7 @@ async fn get_tokenizer( } } TokenizerConfig::Download { url, to } => { - download_tokenizer_file(http_client, &url, api_token, &to, ide).await?; + download_tokenizer_file(http_client, url, api_token, &to, ide).await?; match Tokenizer::from_file(to) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -625,11 +629,10 @@ impl Backend { return Ok(CompletionResult { request_id, completions: vec![]}); } - let tokenizer_config = params.tokenizer_config.clone(); let tokenizer = get_tokenizer( ¶ms.model, &mut *self.tokenizer_map.write().await, - tokenizer_config, + params.tokenizer_config.as_ref(), &self.http_client, &self.cache_dir, params.api_token.as_ref(), @@ -650,15 +653,14 @@ impl Backend { } else { &self.http_client }; - let tokens_to_clear = params.tokens_to_clear.clone(); let result = request_completion( http_client, prompt, - params, + ¶ms, ) .await?; - let completions = parse_generations(result, &tokens_to_clear, completion_type); + let completions = format_generations(result, ¶ms.tokens_to_clear, completion_type); Ok(CompletionResult { request_id, completions }) }.instrument(span).await } From ef1a8cc88693ae1c812a2c9ef9d8acfcc28ee0a5 Mon Sep 17 00:00:00 2001 From: noahbald Date: Thu, 30 Nov 2023 22:59:41 +1100 Subject: [PATCH 05/12] fix: Update incorrect crate usage, revert derived Clone --- crates/llm-ls/src/adaptors.rs | 5 ++--- crates/llm-ls/src/main.rs | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 0f5a8e9..99e37b4 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -1,7 +1,6 @@ -use crate::RequestParams; - use super::{ - internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, + internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME, + VERSION, }; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index ede591a..f8bc6db 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -122,7 +122,7 @@ fn should_complete(document: &Document, position: Position) -> Result Date: Sun, 3 Dec 2023 18:32:17 +1100 Subject: [PATCH 06/12] fix: Throw error on failed server response --- crates/llm-ls/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index f8bc6db..125c9fb 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -424,7 +424,7 @@ async fn request_completion( let model = ¶ms.model; let generations = parse_generations( params.adaptor.as_ref(), - res.text().await.unwrap_or(String::new()).as_str(), + res.text().await.map_err(internal_error)?.as_str(), ); let time = t.elapsed().as_millis(); info!( From 510a7c2ae384dc2e316edfbc748d8643fe62e243 Mon Sep 17 00:00:00 2001 From: noahbald Date: Sun, 3 Dec 2023 18:43:24 +1100 Subject: [PATCH 07/12] feat: Add options to ollama adaptor --- crates/llm-ls/src/adaptors.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 99e37b4..47fabaf 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -69,6 +69,13 @@ fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { "prompt": prompt, "model": params.request_body.as_ref().unwrap().get("model"), "stream": false, + // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) + "options": { + "num_predict": params.request_params.max_new_tokens, + "temperature": params.request_params.temperature, + "top_p": params.request_params.top_p, + "stop": params.request_params.stop_tokens.clone(), + } }) } fn build_ollama_headers() -> Result { From 4899f86ed82fc424e166301028a062c92faa4433 Mon Sep 17 00:00:00 2001 From: noahbald Date: Tue, 5 Dec 2023 18:53:01 +1100 Subject: [PATCH 08/12] fix: Remove wrapped "Ok" from json --- crates/llm-ls/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 125c9fb..cbaf52c 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -407,7 +407,7 @@ async fn request_completion( ) -> Result> { let t = Instant::now(); - let json = adapt_body(prompt, params); + let json = adapt_body(prompt, params).map_err(internal_error)?; let headers = adapt_headers( params.adaptor.as_ref(), params.api_token.as_ref(), From 25971bb8779972803bdc7539cca59a0d11f88931 Mon Sep 17 00:00:00 2001 From: noahbald Date: Fri, 8 Dec 2023 20:13:30 +1100 Subject: [PATCH 09/12] fix: Use correct value for openai model --- crates/llm-ls/src/adaptors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 47fabaf..95e67b2 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -113,7 +113,7 @@ fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { serde_json::json!({ "prompt": prompt, - "model": params.model, + "model": params.request_body.as_ref().unwrap().get("model"), "max_tokens": params.request_params.max_new_tokens, "temperature": params.request_params.temperature, "top_p": params.request_params.top_p, From d2fd7a2d691b7449add0646b33f9b286b7b89ed8 Mon Sep 17 00:00:00 2001 From: noahbald Date: Thu, 14 Dec 2023 19:36:03 +1100 Subject: [PATCH 10/12] fix: Provide useful error message when required request_body is missing --- crates/llm-ls/src/adaptors.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 95e67b2..92f3b4c 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -67,7 +67,7 @@ fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { serde_json::json!({ "prompt": prompt, - "model": params.request_body.as_ref().unwrap().get("model"), + "model": params.request_body.as_ref().ok_or_else(|| "missing request_body").get("model"), "stream": false, // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) "options": { @@ -113,7 +113,7 @@ fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { serde_json::json!({ "prompt": prompt, - "model": params.request_body.as_ref().unwrap().get("model"), + "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).get("model"), "max_tokens": params.request_params.max_new_tokens, "temperature": params.request_params.temperature, "top_p": params.request_params.top_p, From 4ee8c28390ed68abf747f026f8655ebf5698ce11 Mon Sep 17 00:00:00 2001 From: noahbald Date: Tue, 19 Dec 2023 21:15:41 +1100 Subject: [PATCH 11/12] fix: Use internal_error on missing params --- crates/llm-ls/src/adaptors.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs index 92f3b4c..b0861ba 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/adaptors.rs @@ -67,7 +67,7 @@ fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { serde_json::json!({ "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| "missing request_body").get("model"), + "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"), "stream": false, // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) "options": { @@ -113,7 +113,7 @@ fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { serde_json::json!({ "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).get("model"), + "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"), "max_tokens": params.request_params.max_new_tokens, "temperature": params.request_params.temperature, "top_p": params.request_params.top_p, From 861e608e1fa99bc63059bedd2afa92aa25702aca Mon Sep 17 00:00:00 2001 From: noahbald Date: Sun, 31 Dec 2023 10:57:39 +1100 Subject: [PATCH 12/12] fix: prevent crashing when using url model but repo tokenizer, prevent underflow crash Windows --- crates/llm-ls/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index cbaf52c..054aada 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -30,10 +30,10 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION"); fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { Ok(rope.try_line_to_char(row).map_err(internal_error)? + col.min( - rope.get_line(row.min(rope.len_lines() - 1)) + rope.get_line(row.min(rope.len_lines().saturating_sub(1))) .ok_or_else(|| internal_error(format!("failed to find line at {row}")))? .len_chars() - - 1, + .saturating_sub(1), )) } @@ -548,7 +548,7 @@ async fn get_tokenizer( } }, TokenizerConfig::HuggingFace { repository } => { - let path = cache_dir.as_ref().join(model).join("tokenizer.json"); + let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let url = format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;