Skip to content

Commit

Permalink
Add more airtable logging (#3862)
Browse files Browse the repository at this point in the history
* Add more airtable logging

* Add multithreading

* Remove empty comment
  • Loading branch information
Weves authored Jan 31, 2025
1 parent 5e21dc6 commit 288daa4
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 15 deletions.
6 changes: 6 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)

# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)

# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
Expand Down
51 changes: 42 additions & 9 deletions backend/onyx/connectors/airtable/airtable_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any

Expand Down Expand Up @@ -274,6 +276,11 @@ def _process_record(
field_val = fields.get(field_name)
field_type = field_schema.type

logger.debug(
f"Processing field '{field_name}' of type '{field_type}' "
f"for record '{record_id}'."
)

field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
Expand Down Expand Up @@ -327,19 +334,45 @@ def load_from_state(self) -> GenerateDocumentsOutput:
primary_field_name = field.name
break

record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
if document:
record_documents.append(document)
logger.info(f"Starting to process Airtable records for {table.name}.")

# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 16
max_workers = min(PARALLEL_BATCH_SIZE, len(records))

# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record = {
executor.submit(
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
): record
for record in batch_records
}

# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e

# After batch is complete, yield if we've hit the batch size
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []

# Yield any remaining records
if record_documents:
yield record_documents
13 changes: 12 additions & 1 deletion backend/onyx/connectors/connector_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import time
from datetime import datetime

from onyx.connectors.interfaces import BaseConnector
Expand Down Expand Up @@ -45,7 +46,17 @@ def __init__(
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
yield from self.doc_batch_generator
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)

yield batch

start = time.monotonic()

except Exception:
exc_type, _, exc_traceback = sys.exc_info()

Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/connectors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource

def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length

def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
Expand Down
9 changes: 9 additions & 0 deletions backend/onyx/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ def index_doc_batch(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)

doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.updatable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")

logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)

Expand Down
50 changes: 45 additions & 5 deletions backend/onyx/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any

Expand All @@ -11,6 +13,7 @@
from requests import Response
from retry import retry

from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
Expand Down Expand Up @@ -155,6 +158,7 @@ def _batch_encode_texts(
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)

Expand All @@ -163,12 +167,15 @@ def _batch_encode_texts(
)

embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):

def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")

logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
Expand All @@ -185,10 +192,43 @@ def _batch_encode_texts(
)

response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return batch_idx, response.embeddings

# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}

# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e

# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)

if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings

def encode(
Expand Down

0 comments on commit 288daa4

Please sign in to comment.