From 96ddad44d7ff81dc9a04c847f78b0d81a39aaf58 Mon Sep 17 00:00:00 2001 From: saksham56 Date: Thu, 7 Nov 2024 15:59:10 +0530 Subject: [PATCH] Changed the filecount approach --- sage/embedder.py | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/sage/embedder.py b/sage/embedder.py index b01d21e..234eb33 100644 --- a/sage/embedder.py +++ b/sage/embedder.py @@ -52,23 +52,20 @@ def __init__( def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str: """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs.""" - total_chunks = 0 - for content, metadata in self.data_manager.walk(): - chunks = self.chunker.chunk(content, metadata) - total_chunks += len(chunks) + num_files = len([x for x in self.data_manager.walk(get_content=False)]) batch = [] batch_ids = {} # job_id -> metadata chunk_count = 0 dataset_name = self.data_manager.dataset_id.replace("/", "_") - pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk") + pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk") for content, metadata in self.data_manager.walk(): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) - pbar.update(len(chunks)) + pbar.update(1) if len(batch) > chunks_per_batch: for i in range(0, len(batch), chunks_per_batch): @@ -229,21 +226,18 @@ def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): """Issues batch embedding jobs for the entire dataset.""" - total_chunks = 0 - for content, metadata in self.data_manager.walk(): - chunks = self.chunker.chunk(content, metadata) - total_chunks += len(chunks) + num_files = len([x for x in self.data_manager.walk(get_content=False)]) batch = [] chunk_count = 0 - pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk") + pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk") for content, metadata in self.data_manager.walk(): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) - pbar.update(len(chunks)) + pbar.update(1) token_count = chunk_count * self.chunker.max_tokens if token_count % 900_000 == 0: @@ -313,23 +307,17 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): if chunks_per_batch > 64: raise ValueError("Marqo enforces a limit of 64 chunks per batch.") - total_chunks = 0 - for content, metadata in self.data_manager.walk(): - chunks = self.chunker.chunk(content, metadata) - total_chunks += len(chunks) - + num_files = len([x for x in self.data_manager.walk(get_content=False)]) chunk_count = 0 batch = [] job_count = 0 - - pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk") + pbar = tqdm(total=num_files, desc="Processing chunks", unit="file") for content, metadata in self.data_manager.walk(): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) - pbar.update(len(chunks)) - + pbar.update(1) if len(batch) > chunks_per_batch: for i in range(0, len(batch), chunks_per_batch): sub_batch = batch[i : i + chunks_per_batch] @@ -345,13 +333,12 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): pbar.close() return batch = [] - if batch: self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD]) + pbar.close() logging.info(f"Successfully embedded {chunk_count} chunks.") - def embeddings_are_ready(self) -> bool: """Checks whether the batch embedding jobs are done.""" # Marqo indexes documents synchronously, so once embed_dataset() returns, the embeddings are ready. @@ -381,16 +368,18 @@ def _make_batch_request(self, chunks: List[Chunk]) -> Dict: def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): """Issues batch embedding jobs for the entire dataset.""" + num_files = len([x for x in self.data_manager.walk(get_content=False)]) batch = [] chunk_count = 0 request_count = 0 last_request_time = time.time() - + pbar = tqdm(total=num_files, desc="Processing chunks", unit="file") for content, metadata in self.data_manager.walk(): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) + pbar.update(1) if len(batch) > chunks_per_batch: for i in range(0, len(batch), chunks_per_batch): @@ -423,7 +412,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): result = self._make_batch_request(batch) for chunk, embedding in zip(batch, result["embedding"]): self.embedding_data.append((chunk.metadata, embedding)) - + pbar.close() logging.info(f"Successfully embedded {chunk_count} chunks.") def embeddings_are_ready(self, *args, **kwargs) -> bool: