From 7834a19a1f26aedcbdfb496e36b51b9c0124776a Mon Sep 17 00:00:00 2001 From: Varik Matevosyan Date: Fri, 1 Nov 2024 17:20:27 +0400 Subject: [PATCH] Refactor lantern_extras SQL functions: - Remove `cohere_embedding`, `clip_text`, `clip_image` functions - Add `llm_embedding` function - Refactor arguments for `llm_completion` function - Refactor arguments for `add_embedding_job` function - Refactor arguments for `add_completion_job` function - Remove GUC `lantern_extras.openai_azure_api_token`, `lantern_extras.cohere_token` and use `lantern_extras.llm_token` instead --- .../src/embeddings/core/openai_runtime.rs | 18 +- .../tests/daemon_completion_test_with_db.rs | 2 +- lantern_cli/tests/embedding_test_with_db.rs | 6 +- lantern_cli/tests/query_completion_test.rs | 6 +- lantern_extras/README.md | 84 +++++-- lantern_extras/src/daemon.rs | 170 ++++++------- lantern_extras/src/embeddings.rs | 234 ++++++++---------- lantern_extras/src/lib.rs | 34 +-- 8 files changed, 272 insertions(+), 282 deletions(-) diff --git a/lantern_cli/src/embeddings/core/openai_runtime.rs b/lantern_cli/src/embeddings/core/openai_runtime.rs index f5c5c91d..8d421cd9 100644 --- a/lantern_cli/src/embeddings/core/openai_runtime.rs +++ b/lantern_cli/src/embeddings/core/openai_runtime.rs @@ -188,7 +188,7 @@ pub struct OpenAiRuntime<'a> { request_timeout: u64, base_url: String, headers: Vec<(String, String)>, - context: serde_json::Value, + system_prompt: serde_json::Value, dimensions: Option, deployment_type: OpenAiDeployment, #[allow(dead_code)] @@ -199,9 +199,8 @@ pub struct OpenAiRuntime<'a> { pub struct OpenAiRuntimeParams { pub base_url: Option, pub api_token: Option, - pub azure_api_token: Option, pub azure_entra_token: Option, - pub context: Option, + pub system_prompt: Option, pub dimensions: Option, } @@ -223,15 +222,14 @@ impl<'a> OpenAiRuntime<'a> { } OpenAiDeployment::Azure => { // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference - if runtime_params.azure_api_token.is_none() - && runtime_params.azure_entra_token.is_none() + if runtime_params.api_token.is_none() && runtime_params.azure_entra_token.is_none() { anyhow::bail!( - "'azure_api_key' or 'azure_entra_id' is required for Azure OpenAi runtime" + "'api_token' or 'azure_entra_id' is required for Azure OpenAi runtime" ); } - if let Some(key) = runtime_params.azure_api_token { + if let Some(key) = runtime_params.api_token { ("api-key".to_owned(), format!("{}", key)) } else { ( @@ -242,7 +240,7 @@ impl<'a> OpenAiRuntime<'a> { } }; - let context = match &runtime_params.context { + let system_prompt = match &runtime_params.system_prompt { Some(system_prompt) => json!({ "role": "system", "content": system_prompt.clone()}), None => json!({ "role": "system", "content": "" }), }; @@ -257,7 +255,7 @@ impl<'a> OpenAiRuntime<'a> { auth_header, ], dimensions: runtime_params.dimensions, - context, + system_prompt, }) } @@ -388,7 +386,7 @@ impl<'a> OpenAiRuntime<'a> { serde_json::to_string(&json!({ "model": model_name, "messages": [ - self.context, + self.system_prompt, { "role": "user", "content": query } ] }))?, diff --git a/lantern_cli/tests/daemon_completion_test_with_db.rs b/lantern_cli/tests/daemon_completion_test_with_db.rs index 4cf98e20..068071ed 100644 --- a/lantern_cli/tests/daemon_completion_test_with_db.rs +++ b/lantern_cli/tests/daemon_completion_test_with_db.rs @@ -28,7 +28,7 @@ async fn test_daemon_completion_init_job() { ('Test5'); INSERT INTO _lantern_extras_internal.embedding_generation_jobs ("id", "table", src_column, dst_column, embedding_model, runtime, runtime_params, job_type, column_type) - VALUES (1, '{CLIENT_TABLE_NAME}', 'title', 'num', 'openai/gpt-4o', 'openai', '{{"api_token": "{api_token}", "context": "Given text testN, return the N as number without any quotes, so for Test1 you should return 1, Test105 you should return 105" }}', 'completion', 'INT'); + VALUES (1, '{CLIENT_TABLE_NAME}', 'title', 'num', 'openai/gpt-4o', 'openai', '{{"api_token": "{api_token}", "system_prompt": "Given text testN, return the N as number without any quotes, so for Test1 you should return 1, Test105 you should return 105" }}', 'completion', 'INT'); "# )) .await diff --git a/lantern_cli/tests/embedding_test_with_db.rs b/lantern_cli/tests/embedding_test_with_db.rs index 5f9fc4f7..4e50bc55 100644 --- a/lantern_cli/tests/embedding_test_with_db.rs +++ b/lantern_cli/tests/embedding_test_with_db.rs @@ -163,7 +163,7 @@ async fn test_openai_completion_from_db() { limit: Some(10), filter: None, runtime: Runtime::OpenAi, - runtime_params: format!(r#"{{"api_token": "{api_token}", "context": "you will be given text, return postgres array of TEXT[] by splitting the text by characters skipping spaces. Example 'te st' -> {{t,e,s,t}} . Do not put tailing commas, do not put double or single quotes around characters" }}"#), + runtime_params: format!(r#"{{"api_token": "{api_token}", "system_prompt": "you will be given text, return postgres array of TEXT[] by splitting the text by characters skipping spaces. Example 'te st' -> {{t,e,s,t}} . Do not put tailing commas, do not put double or single quotes around characters" }}"#), create_column: true, stream: true, job_type: Some(EmbeddingJobType::Completion), @@ -241,7 +241,7 @@ async fn test_openai_completion_special_chars_from_db() { limit: Some(2), filter: None, runtime: Runtime::OpenAi, - runtime_params: format!(r#"{{"api_token": "{api_token}", "context": "for any input return multi line text which will contain escape characters which can potentially break postgres COPY" }}"#), + runtime_params: format!(r#"{{"api_token": "{api_token}", "system_prompt": "for any input return multi line text which will contain escape characters which can potentially break postgres COPY" }}"#), create_column: true, stream: true, job_type: Some(EmbeddingJobType::Completion), @@ -319,7 +319,7 @@ async fn test_openai_completion_failed_rows_from_db() { limit: Some(10), filter: None, runtime: Runtime::OpenAi, - runtime_params: format!(r#"{{"api_token": "{api_token}", "context": "you will be given text, return array by splitting the text by characters skipping spaces. Example 'te st' -> [t,e,s,t]" }}"#), + runtime_params: format!(r#"{{"api_token": "{api_token}", "system_prompt": "you will be given text, return array by splitting the text by characters skipping spaces. Example 'te st' -> [t,e,s,t]" }}"#), create_column: true, stream: true, job_type: Some(EmbeddingJobType::Completion), diff --git a/lantern_cli/tests/query_completion_test.rs b/lantern_cli/tests/query_completion_test.rs index f1f323ee..b0246f37 100644 --- a/lantern_cli/tests/query_completion_test.rs +++ b/lantern_cli/tests/query_completion_test.rs @@ -1,7 +1,7 @@ use lantern_cli::embeddings::core::{EmbeddingRuntime, Runtime}; use std::env; -static LLM_CONTEXT: &'static str = "You will be provided JSON with the following schema: {x: string}, answer to the message returning the x propery from the provided JSON object"; +static LLM_SYSTEM_PROMPT: &'static str = "You will be provided JSON with the following schema: {x: string}, answer to the message returning the x propery from the provided JSON object"; macro_rules! query_completion_test { ($($name:ident: $value:expr,)*) => { @@ -19,7 +19,7 @@ macro_rules! query_completion_test { return; } - let params = format!(r#"{{"api_token": "{api_token}", "context": "{LLM_CONTEXT}"}}"#); + let params = format!(r#"{{"api_token": "{api_token}", "system_prompt": "{LLM_SYSTEM_PROMPT}"}}"#); let runtime = EmbeddingRuntime::new(&runtime_name, None, ¶ms).unwrap(); let output = runtime.completion( @@ -62,7 +62,7 @@ macro_rules! query_completion_test_multiple { expected_output.push(output); } - let params = format!(r#"{{"api_token": "{api_token}", "context": "{LLM_CONTEXT}"}}"#); + let params = format!(r#"{{"api_token": "{api_token}", "system_prompt": "{LLM_SYSTEM_PROMPT}"}}"#); let runtime = EmbeddingRuntime::new(&runtime_name, None, ¶ms).unwrap(); let output = runtime.batch_completion( diff --git a/lantern_extras/README.md b/lantern_extras/README.md index 74d2f0cf..8e2dda5e 100644 --- a/lantern_extras/README.md +++ b/lantern_extras/README.md @@ -42,14 +42,34 @@ FROM papers; -- generate embeddings from other models which can be extended ```sql +SELECT llm_embedding( + input => 'User input', -- User prompt to LLM model + model => 'gpt-4o', -- Model for runtime to use (default: 'gpt-4o') + base_url => 'https://api.openai.com', -- If you have custom LLM deployment provide the server url. (default: OpenAi API URL) + api_token => '', -- API token for LLM server. (default: inferred from lantern_extras.llm_token GUC) + azure_entra_token => '', -- If this is Azure deployment it supports Auth with entra token too + dimensions => 1536, -- For new generation OpenAi models you can provide dimensions for returned embeddings. (default: 1536) + input_type => 'search_query', -- Needed only for cohere runtime to indicate if this input is for search or storing. (default: 'search_query'). Can also be 'search_document' + runtime => 'openai' -- Runtime to use. (default: 'openai'). Use `SELECT get_available_runtimes()` for list +); + -- generate text embedding -SELECT text_embedding('BAAI/bge-base-en', 'My text input'); +SELECT llm_embedding(model => 'BAAI/bge-base-en', input => 'My text input', runtime => 'ort'); -- generate image embedding with image url -SELECT image_embedding('clip/ViT-B-32-visual', 'https://link-to-your-image'); +SELECT llm_embedding(model => 'clip/ViT-B-32-visual', input => 'https://link-to-your-image', runtime => 'ort'); -- generate image embedding with image path (this path should be accessible from postgres server) -SELECT image_embedding('clip/ViT-B-32-visual', '/path/to/image/in-postgres-server'); +SELECT llm_embedding(model => 'clip/ViT-B-32-visual', input => '/path/to/image/in-postgres-server', runtime => 'ort'); -- get available list of models SELECT get_available_models(); +-- generate openai embeddings +SELECT llm_embedding(model => 'text-embedding-3-small', api_token => '', input => 'My text input', runtime => 'openai'); +-- generate embeddings from custom openai compatible servers +SELECT llm_embedding(model => 'intfloat/e5-mistral-7b-instruct', api_token => '', input => 'My text input', runtime => 'openai', base_url => 'https://my-llm-url'); +-- generate cohere embeddings +SELECT llm_embedding(model => 'embed-multilingual-light-v3.0', api_token => '', input => 'My text input', runtime => 'cohere'); +-- api_token can be set via GUC +SET lantern_extras.llm_token = ''; +SELECT llm_embedding(model => 'text-embedding-3-small', input => 'My text input', runtime => 'openai'); ``` ## Getting started @@ -135,7 +155,7 @@ To add new textual or visual models for generating vector embeddings you can fol After this your model should be callable from SQL like ```sql -SELECT text_embedding('your/model_name', 'Your text'); +SELECT llm_embedding(model => 'your/model_name', input => 'Your text', runtime => 'ort'); ``` ## Lantern Daemon in SQL @@ -158,14 +178,18 @@ To add a new embedding job, use the `add_embedding_job` function: ```sql SELECT add_embedding_job( - 'table_name', -- Name of the table - 'src_column', -- Source column for embeddings - 'dst_column', -- Destination column for embeddings - 'embedding_model', -- Embedding model to use - 'runtime', -- Runtime environment (default: 'ort') - 'runtime_params', -- Runtime parameters (default: '{}') - 'pk', -- Primary key column (default: 'id') - 'schema' -- Schema name (default: 'public') + table => 'articles', -- Name of the table + src_column => 'content', -- Source column for embeddings + dst_column => 'content_embedding', -- Destination column for embeddings (will be created automatically) + model => 'text-embedding-3-small', -- Model for runtime to use (default: 'text-embedding-3-small') + pk => 'id', -- Primary key of the table. It is required for table to have primary key (default: id) + schema => 'public', -- Schema on which the table is located (default: 'public') + base_url => 'https://api.openai.com', -- If you have custom LLM deployment provide the server url. (default: OpenAi API URL) + batch_size => 500, -- Batch size for the inputs to use when requesting LLM server. This is based on your API tier. (default: determined based on model and runtime) + dimensions => 1536, -- For new generation OpenAi models you can provide dimensions for returned embeddings. (default: 1536) + api_token => '', -- API token for LLM server. (default: inferred from lantern_extras.llm_token GUC) + azure_entra_token => '', -- If this is Azure deployment it supports Auth with entra token too + runtime => 'openai' -- Runtime to use. (default: 'openai'). Use `SELECT get_available_runtimes()` for list ); ``` @@ -200,17 +224,19 @@ To add a new completion job, use the `add_completion_job` function: ```sql SELECT add_completion_job( - 'table_name', -- Name of the table - 'src_column', -- Source column for embeddings - 'dst_column', -- Destination column for embeddings - 'context', -- System prompt to be used for LLM (default: lantern_extras.completion_context GUC) - 'column_type', -- Target column type to be used for destination (default: TEXT) - 'model', -- LLM model to use (default: 'gpt-4o') - 'batch_size', -- Batch size to use when sending batch requests (default: 2) - 'runtime', -- Runtime environment (default: 'openai') - 'runtime_params', -- Runtime parameters (default: '{}' inferred from GUC variables) - 'pk', -- Primary key column (default: 'id') - 'schema' -- Schema name (default: 'public') + table => 'articles', -- Name of the table + src_column => 'content', -- Source column for embeddings + dst_column => 'content_summary', -- Destination column for llm response (will be created automatically) + system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '') + column_type => 'TEXT', -- Destination column type + model => 'gpt-4o', -- Model for runtime to use (default: 'gpt-4o') + pk => 'id', -- Primary key of the table. It is required for table to have primary key (default: id) + schema => 'public', -- Schema on which the table is located (default: 'public') + base_url => 'https://api.openai.com', -- If you have custom LLM deployment provide the server url. (default: OpenAi API URL) + batch_size => 10, -- Batch size for the inputs to use when requesting LLM server. This is based on your API tier. (default: determined based on model and runtime) + api_token => '', -- API token for LLM server. (default: inferred from lantern_extras.llm_token GUC) + azure_entra_token => '', -- If this is Azure deployment it supports Auth with entra token too + runtime => 'openai' -- Runtime to use. (default: 'openai'). Use `SELECT get_available_runtimes()` for list ); ``` @@ -258,6 +284,14 @@ This will return a table with the following columns: ***Calling LLM Completion API*** ```sql -SET lantern_extras.llm_token='xxxx'; -SELECT llm_completion(query, [model, context, base_url, runtime]); +SET lantern_extras.llm_token='xxxx'; -- this will be used as api_token if it is not passed via arguments +SELECT llm_completion( + user_prompt => 'User input', -- User prompt to LLM model + model => 'gpt-4o', -- Model for runtime to use (default: 'gpt-4o') + system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '') + base_url => 'https://api.openai.com', -- If you have custom LLM deployment provide the server url. (default: OpenAi API URL) + api_token => '', -- API token for LLM server. (default: inferred from lantern_extras.llm_token GUC) + azure_entra_token => '', -- If this is Azure deployment it supports Auth with entra token too + runtime => 'openai' -- Runtime to use. (default: 'openai'). Use `SELECT get_available_runtimes()` for list +); ``` diff --git a/lantern_extras/src/daemon.rs b/lantern_extras/src/daemon.rs index 4da41d43..6ca5aa86 100644 --- a/lantern_extras/src/daemon.rs +++ b/lantern_extras/src/daemon.rs @@ -15,11 +15,7 @@ use crate::{ DAEMON_DATABASES, ENABLE_DAEMON, }; -pub fn start_daemon( - embeddings: bool, - indexing: bool, - autotune: bool, -) -> Result<(), anyhow::Error> { +pub fn start_daemon(embeddings: bool, indexing: bool, autotune: bool) -> Result<(), anyhow::Error> { let (db, user, socket_path, port) = BackgroundWorker::transaction(|| { Spi::connect(|client| { let row = client @@ -75,7 +71,10 @@ pub fn start_daemon( let logger = Logger::new("Lantern Daemon", LogLevel::Debug); let cancellation_token = CancellationToken::new(); - let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();// Runtime::new().unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); // Runtime::new().unwrap(); rt.block_on(async { start( DaemonArgs { @@ -90,11 +89,12 @@ pub fn start_daemon( schema: String::from("_lantern_extras_internal"), target_db: Some(target_dbs.clone()), data_path: Some(DATA_PATH.to_owned()), - inside_postgres: true + inside_postgres: true, }, Some(logger.clone()), cancellation_token.clone(), - ).await?; + ) + .await?; tokio::select! { _ = cancellation_token.cancelled() => { @@ -115,36 +115,37 @@ pub fn start_daemon( Ok::<(), anyhow::Error>(()) })?; - Ok(()) } #[pg_extern(immutable, parallel_unsafe, security_definer)] fn add_embedding_job<'a>( - table: &'a str, + table_name: &'a str, src_column: &'a str, dst_column: &'a str, - embedding_model: &'a str, - batch_size: default!(i32, -1), - runtime: default!(&'a str, "'ort'"), - runtime_params: default!(&'a str, "'{}'"), + model: default!(&'a str, "'text-embedding-3-small'"), pk: default!(&'a str, "'id'"), schema: default!(&'a str, "'public'"), + base_url: default!(&'a str, "''"), + batch_size: default!(i32, -1), + dimensions: default!(i32, 1536), + api_token: default!(&'a str, "''"), + azure_entra_token: default!(&'a str, "''"), + runtime: default!(&'a str, "'openai'"), ) -> Result { - let mut params = runtime_params.to_owned(); - if params == "{}" { - match runtime { - "openai" => { - params = get_openai_runtime_params("", "", 1536)?; - } - "cohere" => { - params = get_cohere_runtime_params("search_document")?; - } - _ => {} + let params = match runtime { + "openai" => { + get_openai_runtime_params(api_token, azure_entra_token, base_url, "", dimensions)? } - } - - let batch_size = if batch_size == -1 { "NULL".to_string() } else { batch_size.to_string() }; + "cohere" => get_cohere_runtime_params(api_token, "search_document")?, + _ => "{}".to_owned(), + }; + + let batch_size = if batch_size == -1 { + "NULL".to_string() + } else { + batch_size.to_string() + }; let id: Option = Spi::get_one_with_args( &format!( r#" @@ -152,16 +153,16 @@ fn add_embedding_job<'a>( INSERT INTO _lantern_extras_internal.embedding_generation_jobs ("table", "schema", pk, src_column, dst_column, embedding_model, runtime, runtime_params, batch_size) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, {batch_size}) RETURNING id; "#, - table = get_full_table_name(schema, table), + table = get_full_table_name(schema, table_name), dst_column = quote_ident(dst_column) ), vec![ - (PgBuiltInOids::TEXTOID.oid(), table.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), (PgBuiltInOids::TEXTOID.oid(), schema.into_datum()), (PgBuiltInOids::TEXTOID.oid(), pk.into_datum()), (PgBuiltInOids::TEXTOID.oid(), src_column.into_datum()), (PgBuiltInOids::TEXTOID.oid(), dst_column.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), embedding_model.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), model.into_datum()), (PgBuiltInOids::TEXTOID.oid(), runtime.into_datum()), (PgBuiltInOids::TEXTOID.oid(), params.into_datum()), ], @@ -172,29 +173,33 @@ fn add_embedding_job<'a>( #[pg_extern(immutable, parallel_unsafe, security_definer)] fn add_completion_job<'a>( - table: &'a str, + table_name: &'a str, src_column: &'a str, dst_column: &'a str, - context: default!(&'a str, "''"), + system_prompt: default!(&'a str, "''"), column_type: default!(&'a str, "'TEXT'"), - embedding_model: default!(&'a str, "'gpt-4o'"), - batch_size: default!(i32, -1), - runtime: default!(&'a str, "'openai'"), - runtime_params: default!(&'a str, "'{}'"), + model: default!(&'a str, "'gpt-4o'"), pk: default!(&'a str, "'id'"), schema: default!(&'a str, "'public'"), + base_url: default!(&'a str, "''"), + batch_size: default!(i32, -1), + api_token: default!(&'a str, "''"), + azure_entra_token: default!(&'a str, "''"), + runtime: default!(&'a str, "'openai'"), ) -> Result { - let mut params = runtime_params.to_owned(); - if params == "{}" { - match runtime { - "openai" => { - params = get_openai_runtime_params("", context, 0)?; - } - _ => anyhow::bail!("Runtime {runtime} does not support completion jobs"), + let params = match runtime { + "openai" => { + get_openai_runtime_params(api_token, azure_entra_token, base_url, system_prompt, 0)? } - } + _ => anyhow::bail!("Runtime {runtime} does not support completion jobs"), + }; + + let batch_size = if batch_size == -1 { + "NULL".to_string() + } else { + batch_size.to_string() + }; - let batch_size = if batch_size == -1 { "NULL".to_string() } else { batch_size.to_string() }; let id: Option = Spi::get_one_with_args( &format!( r#" @@ -202,16 +207,16 @@ fn add_completion_job<'a>( INSERT INTO _lantern_extras_internal.embedding_generation_jobs ("table", "schema", pk, src_column, dst_column, embedding_model, runtime, runtime_params, column_type, batch_size, job_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, {batch_size}, 'completion') RETURNING id; "#, - table = get_full_table_name(schema, table), + table = get_full_table_name(schema, table_name), dst_column = quote_ident(dst_column) ), vec![ - (PgBuiltInOids::TEXTOID.oid(), table.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), (PgBuiltInOids::TEXTOID.oid(), schema.into_datum()), (PgBuiltInOids::TEXTOID.oid(), pk.into_datum()), (PgBuiltInOids::TEXTOID.oid(), src_column.into_datum()), (PgBuiltInOids::TEXTOID.oid(), dst_column.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), embedding_model.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), model.into_datum()), (PgBuiltInOids::TEXTOID.oid(), runtime.into_datum()), (PgBuiltInOids::TEXTOID.oid(), params.into_datum()), (PgBuiltInOids::TEXTOID.oid(), column_type.into_datum()), @@ -360,7 +365,7 @@ pub mod tests { None, None, )?; - let id = client.select("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; + let id = client.select("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; let id: Option = id.first().get(1)?; @@ -385,7 +390,7 @@ pub mod tests { )?; let id = client.select( " - SELECT add_completion_job('t1', 'title', 'title_embedding', 'my test context','TEXT[]'); + SELECT add_completion_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', system_prompt => 'my test prompt', column_type => 'TEXT[]'); ", None, None, @@ -393,27 +398,27 @@ pub mod tests { let id: Option = id.first().get(1)?; assert_eq!(id.is_none(), false); - + let row = client.select( - "SELECT column_type, job_type, runtime, embedding_model, (runtime_params->'context')::text as context, batch_size FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", + "SELECT column_type, job_type, runtime, embedding_model, (runtime_params->'system_prompt')::text as system_prompt, batch_size FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]) )?; - + let row = row.first(); assert_eq!(row.get::<&str>(1)?.unwrap(), "TEXT[]"); assert_eq!(row.get::<&str>(2)?.unwrap(), "completion"); assert_eq!(row.get::<&str>(3)?.unwrap(), "openai"); assert_eq!(row.get::<&str>(4)?.unwrap(), "gpt-4o"); - assert_eq!(row.get::<&str>(5)?.unwrap(), "\"my test context\""); + assert_eq!(row.get::<&str>(5)?.unwrap(), "\"my test prompt\""); assert_eq!(row.get::(6)?.is_none(), true); Ok::<(), anyhow::Error>(()) }) .unwrap(); } - + #[pg_test] fn test_add_daemon_completion_job_batch_size() { Spi::connect(|mut client| { @@ -422,14 +427,13 @@ pub mod tests { client.update( " CREATE TABLE t1 (id serial primary key, title text); - SET lantern_extras.openai_token='test'; ", None, None, )?; let id = client.select( " - SELECT add_completion_job('t1', 'title', 'title_embedding', 'my test context','TEXT[]', 'gpt-4o', 15); + SELECT add_completion_job(api_token => 'test', table_name => 't1', src_column => 'title', dst_column => 'title_embedding', system_prompt => 'my test prompt', column_type => 'TEXT[]', batch_size => 15, model => 'gpt-4o'); ", None, None, @@ -437,21 +441,22 @@ pub mod tests { let id: Option = id.first().get(1)?; assert_eq!(id.is_none(), false); - + let row = client.select( - "SELECT column_type, job_type, runtime, embedding_model, (runtime_params->'context')::text as context, batch_size FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", + "SELECT column_type, job_type, runtime, embedding_model, (runtime_params->'system_prompt')::text as system_prompt, batch_size, (runtime_params->'api_token')::text as api_token FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]) )?; - + let row = row.first(); assert_eq!(row.get::<&str>(1)?.unwrap(), "TEXT[]"); assert_eq!(row.get::<&str>(2)?.unwrap(), "completion"); assert_eq!(row.get::<&str>(3)?.unwrap(), "openai"); assert_eq!(row.get::<&str>(4)?.unwrap(), "gpt-4o"); - assert_eq!(row.get::<&str>(5)?.unwrap(), "\"my test context\""); + assert_eq!(row.get::<&str>(5)?.unwrap(), "\"my test prompt\""); assert_eq!(row.get::(6)?.unwrap(), 15); + assert_eq!(row.get::<&str>(7)?.unwrap(), "\"test\""); Ok::<(), anyhow::Error>(()) }) @@ -473,7 +478,7 @@ pub mod tests { )?; let id = client.select( " - SELECT add_completion_job('t1', 'title', 'title_embedding', 'my test context','TEXT[]'); + SELECT add_completion_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', system_prompt => 'my test prompt', column_type => 'TEXT[]'); ", None, None, @@ -481,13 +486,13 @@ pub mod tests { let id: Option = id.first().get(1)?; assert_eq!(id.is_none(), false); - + let row = client.select( "SELECT id, status, progress, error FROM get_completion_jobs() WHERE id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]) )?; - + let row = row.first(); assert_eq!(row.get::(1)?.unwrap(), id.unwrap()); @@ -496,7 +501,7 @@ pub mod tests { }) .unwrap(); } - + #[pg_test] fn test_get_completion_job_failures() { Spi::connect(|mut client| { @@ -504,7 +509,7 @@ pub mod tests { std::thread::sleep(Duration::from_secs(5)); client.update( " - INSERT INTO _lantern_extras_internal.embedding_failure_info (job_id, row_id, value) VALUES + INSERT INTO _lantern_extras_internal.embedding_failure_info (job_id, row_id, value) VALUES (1, 1, '1test1'), (1, 2, '1test2'), (2, 1, '2test1'); @@ -512,33 +517,33 @@ pub mod tests { None, None, )?; - + let mut rows = client.select( "SELECT row_id, value FROM get_completion_job_failures($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), 1.into_datum())]) )?; - + assert_eq!(rows.len(), 2); let row = rows.next().unwrap(); assert_eq!(row.get::(1)?.unwrap(), 1); assert_eq!(row.get::<&str>(2)?.unwrap(), "1test1"); - + let row = rows.next().unwrap(); assert_eq!(row.get::(1)?.unwrap(), 2); assert_eq!(row.get::<&str>(2)?.unwrap(), "1test2"); - + let mut rows = client.select( "SELECT row_id, value FROM get_completion_job_failures($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), 2.into_datum())]) )?; - + assert_eq!(rows.len(), 1); - + let row = rows.next().unwrap(); assert_eq!(row.get::(1)?.unwrap(), 1); @@ -557,13 +562,12 @@ pub mod tests { client.update( " CREATE TABLE t1 (id serial primary key, title text); - SET lantern_extras.openai_token='test_openai'; - SET lantern_extras.cohere_token='test_cohere'; + SET lantern_extras.llm_token='test_llm_token'; ", None, None, )?; - let id = client.select("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'openai', '{}', 'id', 'public')", None, None)?; + let id = client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', runtime => 'openai')", None, None)?; let id: Option = id.first().get(1)?; @@ -572,9 +576,9 @@ pub mod tests { let rows = client.select("SELECT (runtime_params->'api_token')::text as token FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; let api_token: Option = rows.first().get(1)?; - assert_eq!(api_token.unwrap(), "\"test_openai\"".to_owned()); + assert_eq!(api_token.unwrap(), "\"test_llm_token\"".to_owned()); - let id = client.select("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'cohere', '{}', 'id', 'public')", None, None)?; + let id = client.select("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', runtime => 'cohere')", None, None)?; let id: Option = id.first().get(1)?; @@ -583,7 +587,7 @@ pub mod tests { let rows = client.select("SELECT (runtime_params->'api_token')::text as token FROM _lantern_extras_internal.embedding_generation_jobs WHERE id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; let api_token: Option = rows.first().get(1)?; - assert_eq!(api_token.unwrap(), "\"test_cohere\"".to_owned()); + assert_eq!(api_token.unwrap(), "\"test_llm_token\"".to_owned()); Ok::<(), anyhow::Error>(()) }) .unwrap(); @@ -602,7 +606,7 @@ pub mod tests { None, )?; - let id = client.update("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; + let id = client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; let id: i32 = id.first().get(1)?.unwrap(); // queued @@ -688,8 +692,8 @@ pub mod tests { None, )?; - client.update("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; - client.update("SELECT add_embedding_job('t1', 'title', 'title_embedding2', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; + client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; + client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding2', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; // queued let rows = client.select("SELECT status, progress, error FROM get_embedding_jobs()", None, None)?; @@ -720,7 +724,7 @@ pub mod tests { None, None, )?; - let id = client.update("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; + let id = client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; let id: i32 = id.first().get(1)?.unwrap(); client.update("SELECT cancel_embedding_job($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; let rows = client.select("SELECT status, progress, error FROM get_embedding_job_status($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; @@ -750,7 +754,7 @@ pub mod tests { None, None, )?; - let id = client.update("SELECT add_embedding_job('t1', 'title', 'title_embedding', 'BAAI/bge-small-en', -1, 'ort', '{}', 'id', 'public')", None, None)?; + let id = client.update("SELECT add_embedding_job(table_name => 't1', src_column => 'title', dst_column => 'title_embedding', model => 'BAAI/bge-small-en', runtime => 'ort')", None, None)?; let id: i32 = id.first().get(1)?.unwrap(); client.update("SELECT cancel_embedding_job($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; client.update("SELECT resume_embedding_job($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?; diff --git a/lantern_extras/src/embeddings.rs b/lantern_extras/src/embeddings.rs index 0bea522e..102fa668 100644 --- a/lantern_extras/src/embeddings.rs +++ b/lantern_extras/src/embeddings.rs @@ -7,10 +7,7 @@ use lantern_cli::embeddings::{ }; use pgrx::prelude::*; -use crate::{ - COHERE_TOKEN, COMPLETION_CONTEXT, LLM_DEPLOYMENT_URL, LLM_TOKEN, OPENAI_AZURE_API_TOKEN, - OPENAI_AZURE_ENTRA_TOKEN, OPENAI_TOKEN, -}; +use crate::{LLM_DEPLOYMENT_URL, LLM_TOKEN, OPENAI_AZURE_ENTRA_TOKEN, OPENAI_TOKEN}; pub static ORT_RUNTIME_PARAMS: &'static str = r#"{ "cache": true }"#; @@ -26,16 +23,18 @@ fn get_dummy_runtime_params(runtime: &Runtime) -> String { } pub fn get_openai_runtime_params( + api_token: &str, + azure_entra_token: &str, base_url: &str, - context: &str, + system_prompt: &str, dimensions: i32, ) -> Result { - if OPENAI_TOKEN.get().is_none() + if api_token == "" + && OPENAI_TOKEN.get().is_none() && LLM_TOKEN.get().is_none() - && OPENAI_AZURE_API_TOKEN.get().is_none() && OPENAI_AZURE_ENTRA_TOKEN.get().is_none() { - error!("'lantern_extras.openai_token/lantern_extras.llm_token', 'lantern_extras.openai_azure_api_token' or 'lantern_extras.openai_azure_entra_token' is required for 'openai' runtime"); + error!("'lantern_extras.llm_token' or 'lantern_extras.openai_azure_entra_token' is required for 'openai' runtime"); } let dimensions = if dimensions > 0 { @@ -54,63 +53,69 @@ pub fn get_openai_runtime_params( Some(base_url.to_owned()) }; - let mut api_token = if let Some(api_token) = OPENAI_TOKEN.get() { - Some(api_token.to_str().unwrap().to_owned()) + let mut api_token = if api_token != "" { + Some(api_token.to_owned()) } else { None }; - if api_token.is_none() && LLM_TOKEN.get().is_some() { - api_token = Some(LLM_TOKEN.get().unwrap().to_str().unwrap().to_owned()); - } + if api_token.is_none() { + api_token = if let Some(api_token) = OPENAI_TOKEN.get() { + Some(api_token.to_str().unwrap().to_owned()) + } else { + None + }; - let azure_api_token = if let Some(api_token) = OPENAI_AZURE_API_TOKEN.get() { - Some(api_token.to_str().unwrap().to_owned()) - } else { - None - }; + if api_token.is_none() && LLM_TOKEN.get().is_some() { + api_token = Some(LLM_TOKEN.get().unwrap().to_str().unwrap().to_owned()); + } + } - let azure_entra_token = if let Some(api_token) = OPENAI_AZURE_ENTRA_TOKEN.get() { - Some(api_token.to_str().unwrap().to_owned()) + let mut azure_entra_token = if azure_entra_token != "" { + Some(azure_entra_token.to_owned()) } else { None }; - let context = if context == "" { - if let Some(guc_context) = COMPLETION_CONTEXT.get() { - Some(guc_context.to_str().unwrap().to_owned()) + if azure_entra_token.is_none() { + azure_entra_token = if let Some(api_token) = OPENAI_AZURE_ENTRA_TOKEN.get() { + Some(api_token.to_str().unwrap().to_owned()) } else { None - } - } else { - Some(context.to_owned()) - }; + }; + } let params = serde_json::to_string(&OpenAiRuntimeParams { dimensions, base_url, api_token, - azure_api_token, azure_entra_token, - context, + system_prompt: Some(system_prompt.to_owned()), })?; Ok(params) } -pub fn get_cohere_runtime_params(input_type: &str) -> Result { - if COHERE_TOKEN.get().is_none() && LLM_TOKEN.get().is_none() { - error!("'lantern_extras.cohere_token/lantern_extras.llm_token' is required for 'cohere' runtime"); +pub fn get_cohere_runtime_params( + api_token: &str, + input_type: &str, +) -> Result { + if api_token == "" && LLM_TOKEN.get().is_none() { + error!("'lantern_extras.llm_token' is required for 'cohere' runtime"); } - let mut api_token = if let Some(api_token) = COHERE_TOKEN.get() { - Some(api_token.to_str().unwrap().to_owned()) + let mut api_token = if api_token != "" { + Some(api_token.to_owned()) } else { None }; - if api_token.is_none() && LLM_TOKEN.get().is_some() { - api_token = Some(LLM_TOKEN.get().unwrap().to_str().unwrap().to_owned()); + if api_token.is_none() { + api_token = if let Some(api_token) = LLM_TOKEN.get() { + Some(api_token.to_str().unwrap().to_owned()) + } else { + None + }; } let runtime_params = serde_json::to_string(&CohereRuntimeParams { @@ -121,83 +126,76 @@ pub fn get_cohere_runtime_params(input_type: &str) -> Result(model_name: &'a str, text: &'a str) -> Result, anyhow::Error> { - let runtime = EmbeddingRuntime::new( - &Runtime::Ort, - Some(&(notice_fn as LoggerFn)), - &ORT_RUNTIME_PARAMS, - )?; - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - let mut res = rt.block_on(runtime.process(model_name, &vec![text]))?; - Ok(res.embeddings.pop().unwrap()) -} - #[pg_extern(immutable, parallel_safe, create_or_replace)] -fn openai_embedding<'a>( - model_name: &'a str, - text: &'a str, +fn llm_completion<'a>( + user_prompt: &'a str, + model: default!(&'a str, "'gpt-4o'"), + system_prompt: default!(&'a str, "''"), base_url: default!(&'a str, "''"), - dimensions: default!(i32, 1536), -) -> Result, anyhow::Error> { - let runtime_params = get_openai_runtime_params(base_url, "", dimensions)?; - let runtime = EmbeddingRuntime::new( - &Runtime::OpenAi, - Some(&(notice_fn as LoggerFn)), - &runtime_params, - )?; + api_token: default!(&'a str, "''"), + azure_entra_token: default!(&'a str, "''"), + runtime: default!(&'a str, "'openai'"), +) -> Result { + let runtime_params = + get_openai_runtime_params(api_token, azure_entra_token, base_url, system_prompt, 0)?; + + let runtime = Runtime::try_from(runtime)?; + let embedding_runtime = + EmbeddingRuntime::new(&runtime, Some(&(notice_fn as LoggerFn)), &runtime_params)?; let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let mut res = rt.block_on(runtime.process( - &get_clean_model_name(model_name, Runtime::OpenAi), - &vec![text], - ))?; - Ok(res.embeddings.pop().unwrap()) + let res = rt.block_on(embedding_runtime.completion(model, user_prompt))?; + Ok(res.message) } #[pg_extern(immutable, parallel_safe, create_or_replace)] -fn cohere_embedding<'a>( - model_name: &'a str, - text: &'a str, +fn llm_embedding<'a>( + input: &'a str, + model: default!(&'a str, "'text-embedding-3-small'"), + base_url: default!(&'a str, "''"), + api_token: default!(&'a str, "''"), + azure_entra_token: default!(&'a str, "''"), + dimensions: default!(i32, 1536), input_type: default!(&'a str, "'search_query'"), + runtime: default!(&'a str, "'openai'"), ) -> Result, anyhow::Error> { - let runtime_params = get_cohere_runtime_params(input_type)?; - let runtime = EmbeddingRuntime::new( - &Runtime::Cohere, - Some(&(notice_fn as LoggerFn)), - &runtime_params, - )?; + let runtime = Runtime::try_from(runtime)?; + let runtime_params = match runtime { + Runtime::Ort => ORT_RUNTIME_PARAMS.to_owned(), + Runtime::OpenAi => { + get_openai_runtime_params(api_token, azure_entra_token, base_url, "", dimensions)? + } + Runtime::Cohere => get_cohere_runtime_params(api_token, input_type)?, + }; + + let embedding_runtime = + EmbeddingRuntime::new(&runtime, Some(&(notice_fn as LoggerFn)), &runtime_params)?; + let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let mut res = rt.block_on(runtime.process( - &get_clean_model_name(model_name, Runtime::Cohere), - &vec![text], - ))?; + + let mut res = rt + .block_on(embedding_runtime.process(&get_clean_model_name(model, runtime), &vec![input]))?; Ok(res.embeddings.pop().unwrap()) } #[pg_extern(immutable, parallel_safe)] -fn clip_text<'a>(text: &'a str) -> Result, anyhow::Error> { - text_embedding("clip/ViT-B-32-textual", text) +fn text_embedding<'a>(model_name: &'a str, text: &'a str) -> Result, anyhow::Error> { + return llm_embedding(text, model_name, "", "", "", 0, "", "ort"); } -#[pg_extern(immutable, parallel_safe)] -fn image_embedding<'a>( +#[pg_extern(immutable, parallel_safe, create_or_replace)] +fn openai_embedding<'a>( model_name: &'a str, - path_or_url: &'a str, + text: &'a str, + base_url: default!(&'a str, "''"), + dimensions: default!(i32, 1536), ) -> Result, anyhow::Error> { - text_embedding(model_name, path_or_url) -} - -#[pg_extern(immutable, parallel_safe)] -fn clip_image<'a>(path_or_url: &'a str) -> Result, anyhow::Error> { - image_embedding("clip/ViT-B-32-visual", path_or_url) + return llm_embedding(text, model_name, base_url, "", "", dimensions, "", "openai"); } #[pg_extern(immutable, parallel_safe, create_or_replace)] @@ -226,28 +224,6 @@ fn get_available_runtimes() -> Result { return Ok(runtimes_str); } -#[pg_extern(immutable, parallel_safe, create_or_replace)] -fn llm_completion<'a>( - text: &'a str, - model_name: default!(&'a str, "'gpt-4o'"), - context: default!(&'a str, "''"), - base_url: default!(&'a str, "''"), - runtime: default!(&'a str, "'openai'"), -) -> Result { - let runtime_params = get_openai_runtime_params(base_url, context, 0)?; - - let runtime = Runtime::try_from(runtime)?; - let embedding_runtime = - EmbeddingRuntime::new(&runtime, Some(&(notice_fn as LoggerFn)), &runtime_params)?; - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - let res = rt.block_on(embedding_runtime.completion(model_name, text))?; - Ok(res.message) -} - #[cfg(any(test, feature = "pg_test"))] #[pg_schema] pub mod tests { @@ -280,12 +256,22 @@ pub mod tests { } #[pg_test] - fn test_clip_text() { - let embedding = - Spi::get_one::>(&format!("SELECT clip_text('{HELLO_WORLD_TEXT}');")).unwrap(); - let distance = 1.0 - cosine_similarity(&embedding.unwrap(), HELLO_WORLD_CLIP_EMB); - assert!(distance < 0.01); + fn test_ort_text_embedding() { + let embedding_old_syntax = Spi::get_one::>(&format!( + "SELECT text_embedding('clip/ViT-B-32-textual', '{HELLO_WORLD_TEXT}');" + )) + .unwrap(); + let embedding_new_syntax = + Spi::get_one::>(&format!("SELECT llm_embedding(model => 'clip/ViT-B-32-textual', input => '{HELLO_WORLD_TEXT}', runtime => 'ort');")) + .unwrap(); + let distance1 = + 1.0 - cosine_similarity(&embedding_old_syntax.unwrap(), HELLO_WORLD_CLIP_EMB); + let distance2 = + 1.0 - cosine_similarity(&embedding_new_syntax.unwrap(), HELLO_WORLD_CLIP_EMB); + assert!(distance1 < 0.01); + assert!(distance2 < 0.01); } + #[pg_test] fn test_cohere_embeddings() { static HELLO_WORLD_TEXT: &'static str = "Hello world!"; @@ -307,7 +293,7 @@ pub mod tests { Spi::connect(|mut client| { client.update( - &format!("SET lantern_extras.cohere_token='{cohere_token}'"), + &format!("SET lantern_extras.llm_token='{cohere_token}'"), None, None, )?; @@ -315,7 +301,7 @@ pub mod tests { .select( &format!( " - SELECT cohere_embedding('cohere/embed-multilingual-light-v3.0', '{HELLO_WORLD_TEXT}') as embedding + SELECT llm_embedding(model => 'cohere/embed-multilingual-light-v3.0', input => '{HELLO_WORLD_TEXT}', runtime => 'cohere') as embedding " ), None, @@ -331,6 +317,7 @@ pub mod tests { }) .unwrap(); } + #[pg_test(volatile, create_or_replace)] fn test_openai_embeddings() { static HELLO_WORLD_TEXT: &'static str = "Hello world!"; @@ -389,7 +376,7 @@ pub mod tests { .select( &format!( " - SELECT openai_embedding('openai/text-embedding-3-large','{HELLO_WORLD_TEXT}', '', 768) as embedding + SELECT llm_embedding(model => 'openai/text-embedding-3-large',input => '{HELLO_WORLD_TEXT}', dimensions => 768, api_token => '{openai_token}') as embedding " ), None, @@ -424,16 +411,11 @@ pub mod tests { None, None, )?; - client.update( - &format!("SET lantern_extras.completion_context='return x property from provided json object without any additional text and without quotes'"), - None, - None, - )?; let row = client .select( &format!( " - SELECT llm_completion('{{\"x\": \"test1\"}}') as response + SELECT llm_completion(user_prompt => '{{\"x\": \"test1\"}}', system_prompt => 'return x property from provided json object without any additional text and without quotes') as response " ), None, diff --git a/lantern_extras/src/lib.rs b/lantern_extras/src/lib.rs index e613f46a..44815da2 100644 --- a/lantern_extras/src/lib.rs +++ b/lantern_extras/src/lib.rs @@ -8,21 +8,17 @@ pub mod dotvecs; pub mod embeddings; pub mod external_index; +// this will be deprecated and removed on upcoming releases pub static OPENAI_TOKEN: GucSetting> = GucSetting::>::new(None); + pub static LLM_TOKEN: GucSetting> = GucSetting::>::new(None); pub static LLM_DEPLOYMENT_URL: GucSetting> = GucSetting::>::new(None); -pub static OPENAI_AZURE_API_TOKEN: GucSetting> = - GucSetting::>::new(None); pub static OPENAI_AZURE_ENTRA_TOKEN: GucSetting> = GucSetting::>::new(None); -pub static COHERE_TOKEN: GucSetting> = - GucSetting::>::new(None); pub static ENABLE_DAEMON: GucSetting = GucSetting::::new(false); -pub static COMPLETION_CONTEXT: GucSetting> = - GucSetting::>::new(None); pub static DAEMON_DATABASES: GucSetting> = GucSetting::>::new(None); @@ -61,27 +57,11 @@ pub unsafe extern "C" fn _PG_init() { GucContext::Userset, GucFlags::NO_SHOW_ALL, ); - GucRegistry::define_string_guc( - "lantern_extras.openai_azure_api_token", - "Azure API token.", - "Used when generating embeddings with Azure OpenAI models", - &OPENAI_AZURE_API_TOKEN, - GucContext::Userset, - GucFlags::NO_SHOW_ALL, - ); GucRegistry::define_string_guc( "lantern_extras.openai_azure_entra_token", "Azure Entra token.", "Used when generating embeddings with Azure OpenAI models", - &OPENAI_AZURE_API_TOKEN, - GucContext::Userset, - GucFlags::NO_SHOW_ALL, - ); - GucRegistry::define_string_guc( - "lantern_extras.cohere_token", - "Cohere API token.", - "Used when generating embeddings with Cohere models", - &COHERE_TOKEN, + &OPENAI_AZURE_ENTRA_TOKEN, GucContext::Userset, GucFlags::NO_SHOW_ALL, ); @@ -101,14 +81,6 @@ pub unsafe extern "C" fn _PG_init() { GucContext::Sighup, GucFlags::NO_SHOW_ALL, ); - GucRegistry::define_string_guc( - "lantern_extras.completion_context", - "Context to pass on LLM completion calls.", - "Used when calling completion method on LLMs", - &COMPLETION_CONTEXT, - GucContext::Userset, - GucFlags::NO_SHOW_ALL, - ); } #[pg_guard]