diff --git a/extension/src/api.rs b/extension/src/api.rs index 700cc6b..eb41aa0 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -10,6 +10,77 @@ use anyhow::Result; use pgrx::prelude::*; use vectorize_core::types::Model; +fn chunk_text(text: &str, max_chunk_size: usize) -> Vec { + let mut chunks = Vec::new(); + let mut start = 0; + + // Loop through the text and create chunks + while start < text.len() { + let end = (start + max_chunk_size).min(text.len()); + let chunk = text[start..end].to_string(); + chunks.push(chunk); + start = end; + } + + chunks +} + +#[pg_extern] +fn chunk_table( + input_table: &str, + column_name: &str, + max_chunk_size: default!(i32, 1000), + output_table: default!(&str, "'chunked_data'"), +) -> Result { + let max_chunk_size = max_chunk_size as usize; + + // Retrieve rows from the input table, ensuring column existence + let query = format!("SELECT id, {} FROM {}", column_name, input_table); + + // Reverting back to use get_two + let (id_opt, text_opt): (Option, Option) = Spi::get_two(&query)?; + let rows = vec![(id_opt, text_opt)]; // Wrap in a vector if needed + + + // Prepare to hold chunked rows + let mut chunked_rows: Vec<(i32, i32, String)> = Vec::new(); // (original_id, chunk_index, chunk) + + // Chunk the data and keep track of the original id and chunk index + for (id_opt, text_opt) in rows { + // Only process rows where both id and text exist + if let (Some(id), Some(text)) = (id_opt, text_opt.map(|s| s.to_string())) { + let chunks = chunk_text(&text, max_chunk_size); + for (index, chunk) in chunks.iter().enumerate() { + chunked_rows.push((id, index as i32, chunk.clone())); // Add chunk index + } + } + + } + + // Create output table with an additional column for chunk index + let create_table_query = format!( + "CREATE TABLE IF NOT EXISTS {} (id SERIAL PRIMARY KEY, original_id INT, chunk_index INT, chunk TEXT)", + output_table + ); + Spi::run(&create_table_query) + .map_err(|e| anyhow::anyhow!("Failed to create table {}: {}", output_table, e))?; + + // Insert chunked rows into output table + for (original_id, chunk_index, chunk) in chunked_rows { + let insert_query = format!( + "INSERT INTO {} (original_id, chunk_index, chunk) VALUES ($1, $2, $3)", + output_table + ); + Spi::run_with_args(&insert_query, Some(vec![ + (pgrx::PgOid::Custom(pgrx::pg_sys::INT4OID), original_id.into_datum()), // OID for integer + (pgrx::PgOid::Custom(pgrx::pg_sys::INT4OID), chunk_index.into_datum()), // OID for integer + (pgrx::PgOid::Custom(pgrx::pg_sys::TEXTOID), chunk.into_datum()), // OID for text + ]))?; + } + + Ok(format!("Chunked data inserted into table: {}", output_table)) +} + #[allow(clippy::too_many_arguments)] #[pg_extern] fn table( @@ -26,7 +97,15 @@ fn table( table_method: default!(types::TableMethod, "'join'"), // cron-like for a cron based update model, or 'realtime' for a trigger-based schedule: default!(&str, "'* * * * *'"), + chunk_input: default!(bool, false), // New parameter to enable chunking + max_chunk_size: default!(i32, 1000), // New parameter for chunk size ) -> Result { + if chunk_input { + // Call chunk_table if chunking is enabled + chunk_table(table, &columns[0], max_chunk_size, "'chunked_data'")?; + } + + // Proceed with the original table initialization logic let model = Model::new(transformer)?; init_table( job_name, diff --git a/vector-serve/app/routes/transform.py b/vector-serve/app/routes/transform.py index 8b5527d..b146073 100644 --- a/vector-serve/app/routes/transform.py +++ b/vector-serve/app/routes/transform.py @@ -6,6 +6,34 @@ from fastapi import APIRouter, Header, HTTPException, Request from pydantic import BaseModel, conlist +# Chunking functions +def chunk_text(text, max_length): + """Splits text into smaller chunks based on a maximum character length.""" + import re + sentences = re.split(r'(?<=[.!?])\s+', text) # Split by sentence or paragraph boundaries + chunks = [] + current_chunk = "" + + for sentence in sentences: + if len(current_chunk) + len(sentence) <= max_length: + current_chunk += sentence + " " + else: + chunks.append(current_chunk.strip()) + current_chunk = sentence + " " + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + +def chunk_table(input_data, max_length): + """Chunk each item in a list of texts into smaller pieces.""" + chunked_data = [] + for text in input_data: + chunks = chunk_text(text, max_length) + chunked_data.extend(chunks) + return chunked_data + + router = APIRouter(tags=["transform"]) logging.basicConfig(level=logging.DEBUG) @@ -40,13 +68,26 @@ def batch_transform( request: Request, payload: Batch, authorization: str = Header(None) ) -> ResponseModel: logging.info({"batch-predict-len": len(payload.input)}) - batches = chunk_list(payload.input, BATCH_SIZE) + + # Check if the input is empty or contains only empty strings + if all(not text.strip() for text in payload.input): + logging.warning("Received empty input.") + return ResponseModel(data=[], model=payload.model) # Return empty response + + # Preprocess by chunking large texts in payload.input + chunked_input = chunk_table(payload.input, max_length=500) # You can adjust the max_length as needed + if not chunked_input: + logging.warning("No valid chunks created from input.") + return ResponseModel(data=[], model=payload.model) # Return empty response if chunking results in no data + + batches = chunk_list(chunked_input, BATCH_SIZE) + num_batches = len(batches) responses: list[list[float]] = [] requested_model = model_org_name(payload.model) - api_key = parse_header(authorization) + api_key = parse_header(authorization) try: model = get_model( model_name=requested_model, @@ -66,10 +107,14 @@ def batch_transform( sentences=batch, normalize_embeddings=payload.normalize ).tolist() ) + logging.info("Completed %s batches", num_batches) + + # Construct the embedding response embeds = [ Embedding(embedding=embedding, index=i) for i, embedding in enumerate(responses) ] + return ResponseModel( data=embeds, model=requested_model, diff --git a/vector-serve/tests/test_endpoints.py b/vector-serve/tests/test_endpoints.py index e1f733c..9975ba2 100644 --- a/vector-serve/tests/test_endpoints.py +++ b/vector-serve/tests/test_endpoints.py @@ -20,3 +20,70 @@ def test_metrics_endpoint(test_client): response = test_client.get("/metrics") assert response.status_code == 200 assert "all-MiniLM-L6-v2" in response.text + +def test_long_text_endpoint(test_client): + long_text = "This is a very long document. " * 1000 # Create a long document + + payload = { + "input": [long_text], + "model": "all-MiniLM-L6-v2", + "normalize": False + } + + response = test_client.post("/v1/embeddings", json=payload) + assert response.status_code == 200 + response_data = response.json() + + # Verify that chunking occurred correctly + assert len(response_data["data"]) > 1 # More than one chunk returned + + # Validate that each chunk is of appropriate length + for chunk in response_data["data"]: + assert len(chunk['embedding']) > 0 # Check that each chunk has an embedding + + +def test_small_input(test_client): + small_text = "Short text." + payload = { + "input": [small_text], + "model": "all-MiniLM-L6-v2", + "normalize": False + } + + response = test_client.post("/v1/embeddings", json=payload) + assert response.status_code == 200 + response_data = response.json() + + assert len(response_data["data"]) == 1 # Should return one chunk for small input + assert response_data["data"][0]['embedding'] is not None # Check that the embedding exists + + +def test_empty_input(test_client): + payload = { + "input": [""], + "model": "all-MiniLM-L6-v2", + "normalize": False + } + + response = test_client.post("/v1/embeddings", json=payload) + assert response.status_code == 200 + response_data = response.json() + + # Expect no chunks for empty input + assert len(response_data["data"]) == 0 # No chunks should be created + + +def test_boundary_chunking(test_client): + boundary_text = "A" * 500 # Exactly at the chunk size + payload = { + "input": [boundary_text], + "model": "all-MiniLM-L6-v2", + "normalize": False + } + + response = test_client.post("/v1/embeddings", json=payload) + assert response.status_code == 200 + response_data = response.json() + + assert len(response_data["data"]) == 1 # Should return one chunk + assert response_data["data"][0]['embedding'] is not None # Check that the embedding exists \ No newline at end of file