Skip to content

Commit

Permalink
Retry the file if parser fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ankush-cohere committed Dec 20, 2024
1 parent c3b7d9a commit ba777c0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
5 changes: 4 additions & 1 deletion cohere/compass/clients/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
username: Optional[str] = None,
password: Optional[str] = None,
num_workers: int = 4,
retries: int = 3,
):
"""
Initialize the CompassParserClient.
Expand All @@ -79,8 +80,9 @@ def __init__(
self.username = username or os.getenv("COHERE_COMPASS_USERNAME")
self.password = password or os.getenv("COHERE_COMPASS_PASSWORD")
self.session = requests.Session()
self.thread_pool = ThreadPoolExecutor(num_workers)
self.thread_pool = ThreadPoolExecutor(num_workers * 2)
self.num_workers = num_workers
self.retries = retries

self.metadata_config = metadata_config
logger.info(
Expand Down Expand Up @@ -181,6 +183,7 @@ def process_file(i: int) -> list[CompassDocument]:
process_file,
range(len(filenames)),
max_queued=self.num_workers,
retries=self.retries,
):
yield from results

Expand Down
31 changes: 28 additions & 3 deletions cohere/compass/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Python imports
import base64
import glob
import logging
import os
import uuid
from collections.abc import Iterable, Iterator
Expand All @@ -23,9 +24,15 @@
T = TypeVar("T")
U = TypeVar("U")

logger = logging.getLogger(__name__)


def imap_queued(
executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int
executor: Executor,
f: Callable[[T], U],
it: Iterable[T],
max_queued: int,
retries: int = 3,
) -> Iterator[U]:
"""
Similar to Python's `map`, but uses an executor to parallelize the calls.
Expand All @@ -34,20 +41,38 @@ def imap_queued(
:param f: the function to call.
:param it: the iterable to map over.
:param max_queued: the maximum number of futures to keep in flight.
:param retries: maximum number of retries to make in case of failure
:returns: an iterator over the results.
"""

def execute_with_retry(f: Callable[[T], U], x: T):
"""Executes a function with retries on failure."""
for attempt in range(retries):
future = executor.submit(f, x)
try:
return future.result() # Attempt to get the result
except Exception as e:
logger.info(f"Attempt {attempt + 1} failed for input {x}: {e}")
if (
attempt == retries - 1
): # If it's the last attempt, re-raise the exception
logger.error(f"Cannot process file {x} after {retries} attempts")

assert max_queued >= 1
futures_set: set[futures.Future[U]] = set()

for x in it:
futures_set.add(executor.submit(f, x))
futures_set.add(executor.submit(execute_with_retry, f, x))
while len(futures_set) > max_queued:
done, futures_set = futures.wait(
futures_set, return_when=futures.FIRST_COMPLETED
)
for future in done:
yield future.result()
try:
yield future.result()
except Exception:
logger.error(f"Cannot process file {x} after {retries} attempts")

for future in futures.as_completed(futures_set):
yield future.result()
Expand Down

0 comments on commit ba777c0

Please sign in to comment.