Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test 1 #2

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ SELECT * FROM products limit 2;
```sql
SELECT vectorize.table(
job_name => 'product_search_hf',
"table" => 'products',
"table_name" => 'products',
primary_key => 'product_id',
columns => ARRAY['product_name', 'description'],
transformer => 'sentence-transformers/multi-qa-MiniLM-L6-dot-v1'
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Create a job to vectorize the products table. We'll specify the tables primary k
```sql
SELECT vectorize.table(
job_name => 'product_search_hf',
"table" => 'products',
"table_name" => 'products',
primary_key => 'product_id',
columns => ARRAY['product_name', 'description'],
transformer => 'sentence-transformers/all-MiniLM-L6-v2',
Expand Down
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ lazy_static = "1.4.0"
log = "0.4.21"
ollama-rs = "=0.2.1"
pgmq = "0.29"
pgrx = "=0.12.5"
regex = "1.9.2"
reqwest = {version = "0.11.18", features = ["json"] }
serde = { version = "1.0.173", features = ["derive"] }
Expand Down
4 changes: 2 additions & 2 deletions core/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use chrono::serde::ts_seconds_option::deserialize as from_tsopt;

use pgrx::pg_sys::Oid;
use serde::{Deserialize, Serialize};
use sqlx::types::chrono::Utc;
use sqlx::FromRow;
Expand Down Expand Up @@ -103,8 +104,7 @@ pub enum TableMethod {

#[derive(Clone, Debug, Default, Serialize, Deserialize, FromRow)]
pub struct JobParams {
pub schema: String,
pub table: String,
pub table: PgOid,
pub columns: Vec<String>,
pub update_time_col: Option<String>,
pub table_method: TableMethod,
Expand Down
6 changes: 2 additions & 4 deletions docs/api/search.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ Initialize a table for vector search. Generates embeddings and index. Creates tr

```sql
vectorize."table"(
"table" TEXT,
"table_name" REGCLASS,
"columns" TEXT[],
"job_name" TEXT,
"primary_key" TEXT,
"schema" TEXT DEFAULT 'public',
"update_col" TEXT DEFAULT 'last_updated_at',
"transformer" TEXT DEFAULT 'sentence-transformers/all-MiniLM-L6-v2',
"index_dist_type" vectorize.IndexDist DEFAULT 'pgv_hnsw_cosine',
Expand All @@ -23,12 +22,11 @@ vectorize."table"(

| Parameter | Type | Description |
| :--- | :---- | :--- |
| table | text | The name of the table to be initialized. |
| table_name | regclass | The name of the table to be initialized. Automatically includes schema information. |
| columns | text | The name of the columns that contains the content that is used for context for RAG. Multiple columns are concatenated. |
| job_name | text | A unique name for the project. |
| primary_key | text | The name of the column that contains the unique record id. |
| args | json | Additional arguments for the transformer. Defaults to '{}'. |
| schema | text | The name of the schema where the table is located. Defaults to 'public'. |
| update_col | text | Column specifying the last time the record was updated. Required for cron-like schedule. Defaults to `last_updated_at` |
| transformer | text | The name of the transformer to use for the embeddings. Defaults to 'text-embedding-ada-002'. |
| index_dist_type | IndexDist | The name of index type to build. Defaults to 'pgv_hnsw_cosine'. |
Expand Down
1 change: 1 addition & 0 deletions extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ sqlx = { version = "0.8", features = [
"chrono",
"json"
] }
text-splitter = "0.22.0"
thiserror = "1.0.44"
tiktoken-rs = "0.5.7"
tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
Expand Down
27 changes: 27 additions & 0 deletions extension/sql/meta.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ GRANT SELECT ON ALL SEQUENCES IN SCHEMA vectorize TO pg_monitor;
ALTER DEFAULT PRIVILEGES IN SCHEMA vectorize GRANT SELECT ON TABLES TO pg_monitor;
ALTER DEFAULT PRIVILEGES IN SCHEMA vectorize GRANT SELECT ON SEQUENCES TO pg_monitor;

CREATE OR REPLACE FUNCTION handle_table_drop()
RETURNS event_trigger AS $$
DECLARE
obj RECORD;
schema_name TEXT;
table_name TEXT;
BEGIN
FOR obj IN SELECT * FROM pg_event_trigger_dropped_objects() LOOP
IF obj.object_type = 'table' THEN
schema_name := split_part(obj.object_identity, '.', 1);
table_name := split_part(obj.object_identity, '.', 2);

-- Perform cleanup: delete the associated job from the vectorize.job table
DELETE FROM vectorize.job
WHERE params ->> 'table' = table_name
AND params ->> 'schema' = schema_name;
END IF;
END LOOP;
END;
$$ LANGUAGE plpgsql;

DROP EVENT TRIGGER IF EXISTS vectorize_job_drop_trigger;

CREATE EVENT TRIGGER vectorize_job_drop_trigger
ON sql_drop
WHEN TAG IN ('DROP TABLE')
EXECUTE FUNCTION handle_table_drop();

INSERT INTO vectorize.prompts (prompt_type, sys_prompt, user_prompt)
VALUES (
Expand Down
18 changes: 18 additions & 0 deletions extension/sql/vectorize--0.18.2--0.18.3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
DROP function vectorize."table";

-- vectorize::api::table
CREATE FUNCTION vectorize."table"(
"table_name" REGCLASS, /* PgOid*/
"columns" TEXT[], /* alloc::vec::Vec<alloc::string::String> */
"job_name" TEXT, /* alloc::string::String */
"primary_key" TEXT, /* alloc::string::String */
"args" json DEFAULT '{}', /* pgrx::datum::json::Json */
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
"transformer" vectorize.Transformer DEFAULT 'openai', /* vectorize::types::Transformer */
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::types::TableMethod */
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'table_wrapper';
104 changes: 99 additions & 5 deletions extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,109 @@ use crate::search::{self, init_table};
use crate::transformers::generic::env_interpolate_string;
use crate::transformers::transform;
use crate::types;
use text_splitter::TextSplitter;

use anyhow::Result;
use pgrx::prelude::*;
use vectorize_core::types::Model;

fn chunk_text(text: &str, max_chunk_size: usize) -> Vec<String> {
let mut chunks = Vec::new();
let mut start = 0;

// Loop through the text and create chunks
while start < text.len() {
let end = (start + max_chunk_size).min(text.len());
let chunk = text[start..end].to_string();
chunks.push(chunk);
start = end;
}

// Remove any trailing empty chunk
if let Some(last_chunk) = chunks.last() {
if last_chunk.is_empty() {
chunks.pop();
}
}

chunks
}

#[pg_extern]
fn chunk_table(
input_table: &str,
column_name: &str,
primary_key: &str, // Add primary_key parameter
max_chunk_size: default!(i32, 1000),
output_table: default!(&str, "'chunked_data'"),
) -> Result<String> {
let max_chunk_size = max_chunk_size as usize;

// Retrieve rows from the input table, ensuring column existence
let query = format!("SELECT {}, {} FROM {}", primary_key, column_name, input_table); // Use primary_key instead of hardcoding "id"

// Reverting back to use get_two
let (id_opt, text_opt): (Option<i32>, Option<String>) = Spi::get_two(&query)?;
let rows = vec![(id_opt, text_opt)]; // Wrap in a vector if needed

// Prepare to hold chunked rows
let mut chunked_rows: Vec<(i32, i32, String)> = Vec::new(); // (original_id, chunk_index, chunk)

// Chunk the data and keep track of the original id and chunk index
for (id_opt, text_opt) in rows {
// Only process rows where both id and text exist
if let (Some(id), Some(text)) = (id_opt, text_opt.map(|s| s.to_string())) {
let chunks = chunk_text(&text, max_chunk_size);
for (index, chunk) in chunks.iter().enumerate() {
chunked_rows.push((id, index as i32, chunk.clone())); // Add chunk index
}
}
}

// Create output table with an additional column for chunk index
let create_table_query = format!(
"CREATE TABLE IF NOT EXISTS {} (id SERIAL PRIMARY KEY, original_id INT, chunk_index INT, chunk TEXT)",
output_table
);
Spi::run(&create_table_query)
.map_err(|e| anyhow::anyhow!("Failed to create table {}: {}", output_table, e))?;

// Insert chunked rows into output table
for (original_id, chunk_index, chunk) in chunked_rows {
let insert_query = format!(
"INSERT INTO {} (original_id, chunk_index, chunk) VALUES ($1, $2, $3)",
output_table
);
Spi::run_with_args(&insert_query, Some(vec![
(pgrx::PgOid::Custom(pgrx::pg_sys::INT4OID), original_id.into_datum()), // OID for integer
(pgrx::PgOid::Custom(pgrx::pg_sys::INT4OID), chunk_index.into_datum()), // OID for integer
(pgrx::PgOid::Custom(pgrx::pg_sys::TEXTOID), chunk.into_datum()), // OID for text
]))?;
}

Ok(format!("Chunked data inserted into table: {}", output_table))
}

#[allow(clippy::too_many_arguments)]
#[pg_extern]
fn table(
table: &str,
table_name: PgOid,
columns: Vec<String>,
job_name: &str,
primary_key: &str,
schema: default!(&str, "'public'"),
update_col: default!(String, "'last_updated_at'"),
index_dist_type: default!(types::IndexDist, "'pgv_hnsw_cosine'"),
transformer: default!(&str, "'sentence-transformers/all-MiniLM-L6-v2'"),
table_method: default!(types::TableMethod, "'join'"),
// cron-like for a cron based update model, or 'realtime' for a trigger-based
schedule: default!(&str, "'* * * * *'"),
) -> Result<String> {

let model = Model::new(transformer)?;

init_table(
job_name,
schema,
table,
table_name,
columns,
primary_key,
Some(update_col),
Expand Down Expand Up @@ -100,7 +178,6 @@ fn init_rag(
let transformer_model = Model::new(transformer)?;
init_table(
agent_name,
schema,
table_name,
columns,
unique_record_id,
Expand Down Expand Up @@ -167,3 +244,20 @@ fn env_interpolate_guc(guc_name: &str) -> Result<String> {
.unwrap_or_else(|| panic!("no value set for guc: {guc_name}"));
env_interpolate_string(&g)
}

/// Splits a document into smaller chunks of text based on a maximum characters
///
/// # Example
///
/// ```sql
/// -- Example usage in PostgreSQL after creating the function:
/// SELECT vectorize.chunk_text('This is a sample text to demonstrate chunking.', 20);
///
/// -- Expected output:
/// -- ["This is a sample tex", "t to demonstrate ch", "unking."]
/// ```
#[pg_extern]
fn chunk_text(document: &str, max_characters: i32) -> Vec<String> {
let splitter = TextSplitter::new(max_characters as usize);
splitter.chunks(document).map(|s| s.to_string()).collect()
}
9 changes: 3 additions & 6 deletions extension/src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ fn append_embedding_column(job_name: &str, schema: &str, table: &str, col_type:
)
}

pub fn get_column_datatype(schema: &str, table: &str, column: &str) -> Result<String> {
pub fn get_column_datatype(table: &str, column: &str) -> Result<String> {
Spi::get_one_with_args(
"
SELECT data_type
Expand All @@ -247,23 +247,20 @@ pub fn get_column_datatype(schema: &str, table: &str, column: &str) -> Result<St
AND column_name = $3
",
vec![
(PgBuiltInOids::TEXTOID.oid(), schema.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), table.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), column.into_datum()),
],
)
.map_err(|_| {
anyhow!(
"One of schema:`{}`, table:`{}`, column:`{}` does not exist.",
schema,
"One of table:`{}`, column:`{}` does not exist.",
table,
column
)
})?
.ok_or_else(|| {
anyhow!(
"An unknown error occurred while fetching the data type for column `{}` in `{}.{}`.",
schema,
"An unknown error occurred while fetching the data type for column `{}` in `{}`.",
table,
column
)
Expand Down
16 changes: 8 additions & 8 deletions extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::init;
use crate::job::{create_event_trigger, create_trigger_handler, initalize_table_job};
use crate::transformers::openai;
use crate::transformers::transform;
use crate::util;
use crate::util::*;

use anyhow::{Context, Result};
use pgrx::prelude::*;
Expand All @@ -15,8 +15,7 @@ use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta
#[allow(clippy::too_many_arguments)]
pub fn init_table(
job_name: &str,
schema: &str,
table: &str,
table_name: PgOid,
columns: Vec<String>,
primary_key: &str,
update_col: Option<String>,
Expand All @@ -26,14 +25,16 @@ pub fn init_table(
// cron-like for a cron based update model, or 'realtime' for a trigger-based
schedule: &str,
) -> Result<String> {
let table_name_str = pg_oid_to_table_name(table_name);

// validate table method
// realtime is only compatible with the join method
if schedule == "realtime" && table_method != TableMethod::join {
error!("realtime schedule is only compatible with the join table method");
}

// get prim key type
let pkey_type = init::get_column_datatype(schema, table, primary_key)?;
let pkey_type = init::get_column_datatype(table_name, primary_key)?;
init::init_pgmq()?;

let guc_configs = get_guc_configs(&transformer.source);
Expand Down Expand Up @@ -99,8 +100,7 @@ pub fn init_table(
};

let valid_params = types::JobParams {
schema: schema.to_string(),
table: table.to_string(),
table: table_name_str.clone(),
columns: columns.clone(),
update_time_col: update_col,
table_method: table_method.clone(),
Expand Down Expand Up @@ -160,8 +160,8 @@ pub fn init_table(
// setup triggers
// create the trigger if not exists
let trigger_handler = create_trigger_handler(job_name, &columns, primary_key);
let insert_trigger = create_event_trigger(job_name, schema, table, "INSERT");
let update_trigger = create_event_trigger(job_name, schema, table, "UPDATE");
let insert_trigger = create_event_trigger(job_name, table_name_str.clone(), "INSERT");
let update_trigger = create_event_trigger(job_name, table_name_str.clone(), "UPDATE");
let _: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(&trigger_handler, None, None)?;
let _r = c.update(&insert_trigger, None, None)?;
Expand Down
12 changes: 12 additions & 0 deletions extension/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use pgrx::pg_sys::{regclassout, Oid};
use pgrx::spi::SpiTupleTable;
use pgrx::*;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
Expand Down Expand Up @@ -195,6 +196,17 @@ pub fn get_pg_options(cfg: Config) -> Result<PgConnectOptions> {
}
}

pub fn pg_oid_to_table_name(oid: PgOid) -> String {
let query = "SELECT relname FROM pg_class WHERE oid = $1";
let table_name: String = Spi::get_one_with_args(
query,
vec![(PgBuiltInOids::REGCLASSOID.oid(), oid.into_datum())]
)
.expect("Failed to fetch table name from oid")
.unwrap_or_else(|| panic!("Table name not found for oid: {}", oid.value()));
table_name
}

pub async fn ready(conn: &Pool<Postgres>) -> bool {
sqlx::query_scalar(
"SELECT EXISTS (
Expand Down
Loading
Loading