Skip to content

Commit

Permalink
Changed the filecount approach
Browse files Browse the repository at this point in the history
  • Loading branch information
Saksham1387 committed Nov 7, 2024
1 parent f2f9128 commit 96ddad4
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions sage/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,20 @@ def __init__(

def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
"""Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
total_chunks = 0
for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
total_chunks += len(chunks)
num_files = len([x for x in self.data_manager.walk(get_content=False)])

batch = []
batch_ids = {} # job_id -> metadata
chunk_count = 0
dataset_name = self.data_manager.dataset_id.replace("/", "_")

pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk")
pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")

for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
chunk_count += len(chunks)
batch.extend(chunks)
pbar.update(len(chunks))
pbar.update(1)

if len(batch) > chunks_per_batch:
for i in range(0, len(batch), chunks_per_batch):
Expand Down Expand Up @@ -229,21 +226,18 @@ def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model:

def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
"""Issues batch embedding jobs for the entire dataset."""
total_chunks = 0
for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
total_chunks += len(chunks)
num_files = len([x for x in self.data_manager.walk(get_content=False)])

batch = []
chunk_count = 0

pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk")
pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")

for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
chunk_count += len(chunks)
batch.extend(chunks)
pbar.update(len(chunks))
pbar.update(1)

token_count = chunk_count * self.chunker.max_tokens
if token_count % 900_000 == 0:
Expand Down Expand Up @@ -313,23 +307,17 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
if chunks_per_batch > 64:
raise ValueError("Marqo enforces a limit of 64 chunks per batch.")

total_chunks = 0
for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
total_chunks += len(chunks)

num_files = len([x for x in self.data_manager.walk(get_content=False)])
chunk_count = 0
batch = []
job_count = 0

pbar = tqdm(total=total_chunks, desc="Processing chunks", unit="chunk")
pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")

for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
chunk_count += len(chunks)
batch.extend(chunks)
pbar.update(len(chunks))

pbar.update(1)
if len(batch) > chunks_per_batch:
for i in range(0, len(batch), chunks_per_batch):
sub_batch = batch[i : i + chunks_per_batch]
Expand All @@ -345,13 +333,12 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
pbar.close()
return
batch = []

if batch:
self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])

pbar.close()
logging.info(f"Successfully embedded {chunk_count} chunks.")


def embeddings_are_ready(self) -> bool:
"""Checks whether the batch embedding jobs are done."""
# Marqo indexes documents synchronously, so once embed_dataset() returns, the embeddings are ready.
Expand Down Expand Up @@ -381,16 +368,18 @@ def _make_batch_request(self, chunks: List[Chunk]) -> Dict:

def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
"""Issues batch embedding jobs for the entire dataset."""
num_files = len([x for x in self.data_manager.walk(get_content=False)])
batch = []
chunk_count = 0

request_count = 0
last_request_time = time.time()

pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")
for content, metadata in self.data_manager.walk():
chunks = self.chunker.chunk(content, metadata)
chunk_count += len(chunks)
batch.extend(chunks)
pbar.update(1)

if len(batch) > chunks_per_batch:
for i in range(0, len(batch), chunks_per_batch):
Expand Down Expand Up @@ -423,7 +412,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
result = self._make_batch_request(batch)
for chunk, embedding in zip(batch, result["embedding"]):
self.embedding_data.append((chunk.metadata, embedding))

pbar.close()
logging.info(f"Successfully embedded {chunk_count} chunks.")

def embeddings_are_ready(self, *args, **kwargs) -> bool:
Expand Down

0 comments on commit 96ddad4

Please sign in to comment.