From 5c185573fce7c16769add7758f95bbfc5fbeb84b Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 11 Sep 2024 20:46:17 -0700 Subject: [PATCH] Fix eqx (#725) * don't hit ray if we don't need to in TreeCache * make ray exit quieter * reduce log spam of wandb * .aider ignore * sigh * make BackgroundIterable work not in the background * fix regression caused by new Equinox --- .gitignore | 1 + pyproject.toml | 4 +- src/levanter/distributed.py | 2 +- src/levanter/store/cache.py | 8 +++- src/levanter/tracker/wandb.py | 10 +++-- src/levanter/utils/background_iterable.py | 48 +++++++++++++++------ src/levanter/utils/thread_utils.py | 29 +++++++++++++ tests/test_background_iterable.py | 51 ++++++++++++++--------- 8 files changed, 114 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 835da2048..8a6acca53 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ ledger.json # local execution commands local_*.sh +.aider* diff --git a/pyproject.toml b/pyproject.toml index e390462da..de85d287e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ "haliax>=1.4.dev307", - "equinox==0.11.5", + "equinox==0.11.3", "jaxtyping>=0.2.20", "tokenizers>=0.15.2", "transformers>=4.41.2", @@ -37,7 +37,7 @@ dependencies = [ "braceexpand>=0.1.7", "jmp>=0.0.3", "fsspec[http]>=2024.2,<2024.10", - "tensorstore>=0.1.62", + "tensorstore==0.1.64", "pytimeparse>=1.1.8", "humanfriendly==10.0", "safetensors[numpy]~=0.4.2", diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index ea0bbb3c7..112409743 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -252,7 +252,7 @@ def _munge_address_port(address: str): logger.info(f"Successfully started ray head on port {ray_port}.") # install an atexit handler to kill the head when we exit - atexit.register(lambda: os.system("ray stop -g 10 --force")) + atexit.register(lambda: os.system("ray stop -g 10 --force &> /dev/null")) elif start_workers: logger.info( f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 85b612f91..6db7693fe 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -54,7 +54,8 @@ LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" # TODO: should probably do this in terms of bytes -MIN_ITEMS_TO_WRITE = 8192 +# this is kinda silly, but the bigger the better. +MIN_ITEMS_TO_WRITE = 32 * 1024 MAX_TIME_BETWEEN_WRITES = 100.0 @@ -883,6 +884,7 @@ def __init__( self.logger = pylogging.getLogger(f"TreeCache.{name}") self._store_future: threading_Future[TreeStore] = threading_Future() self._stop = False + # assert _broker is None if self._broker is not None: self._monitor_thread = threading.Thread(target=self._monitor_metrics, daemon=True) @@ -1078,11 +1080,15 @@ async def _get_start_stops_async(self, slice): return start, step, stop def await_finished(self, timeout: Optional[float] = None): + if self._broker is None: + return x = ray.get(self.finished_sentinel(), timeout=timeout) self._attempt_to_load_store() return x async def finished(self): + if self._broker is None: + return x = await self.finished_sentinel() # TODO: make an async version of this self._attempt_to_load_store() diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 1e95c0d3a..18f0251ec 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -45,6 +45,8 @@ def __init__(self, run: Optional[WandbRun]): else: self.run = run + self._last_warning_step = -500 + def log_hyperparameters(self, hparams: dict[str, Any]): self.run.config.update(hparams, allow_val_change=True) @@ -53,9 +55,11 @@ def log(self, metrics: dict[str, Any], *, step, commit=None): step = self.run.step if step < self.run.step: - logger.warning( - f"Step {step} is less than the current step {self.run.step}. Cowardly refusing to log metrics." - ) + if step - self._last_warning_step > 500: + logger.warning( + f"Step {step} is less than the current step {self.run.step}. Cowardly refusing to log metrics." + ) + self._last_warning_step = step return step = int(step) diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 84c5a7789..4318b3f9b 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -6,6 +6,8 @@ import tblib +from levanter.utils.thread_utils import AsyncIteratorWrapper + Ex = TypeVar("Ex", covariant=True) @@ -36,30 +38,52 @@ def __init__(self, producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[E self.max_capacity = max_capacity self._producer_fn = producer_fn self._stop_event = threading.Event() - self.q: queue.Queue = queue.Queue(self.max_capacity or 0) - self.thread = threading.Thread(target=self._fill_queue_with_batches) - self.thread.daemon = True - self.thread.start() + + if self.max_capacity is None or self.max_capacity >= 0: + self.q: queue.Queue = queue.Queue(self.max_capacity or 0) + self.thread: Optional[threading.Thread] = threading.Thread(target=self._fill_queue_with_batches) + self.thread.daemon = True + self.thread.start() + else: + # No background thread; consume items on demand + self.thread = None + self.iterator = self._producer_fn() + if not isinstance(self.iterator, Iterator): + self.iterator = AsyncIteratorWrapper(self.iterator) def __iter__(self): return self def __next__(self): - while not self._stop_event.is_set(): - batch = self.q.get() - if batch is _SENTINEL: + if self._stop_event.is_set(): + raise StopIteration + if self.thread is not None: + while not self._stop_event.is_set(): + batch = self.q.get() + if batch is _SENTINEL: + raise StopIteration + elif isinstance(batch, _ExceptionWrapper): + batch.reraise() + return batch + else: + # Consume the iterator directly on demand + try: + return next(self.iterator) + except StopIteration: + raise + except StopAsyncIteration: raise StopIteration - elif isinstance(batch, _ExceptionWrapper): - batch.reraise() - return batch - + except Exception as e: + raise e raise StopIteration def __del__(self): self.stop() - def stop(self): + def stop(self, wait: bool = True): self._stop_event.set() + if self.thread is not None and wait: + self.thread.join() def _fill_queue_with_batches(self): try: diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py index 9c6e2ef36..0b4abcdaf 100644 --- a/src/levanter/utils/thread_utils.py +++ b/src/levanter/utils/thread_utils.py @@ -1,5 +1,7 @@ import asyncio +import threading from concurrent.futures import ThreadPoolExecutor +from typing import Iterator # Create a ThreadPoolExecutor @@ -26,3 +28,30 @@ def future_from_value(value): future = asyncio.Future() future.set_result(value) return future + + +class AsyncIteratorWrapper(Iterator): + def __init__(self, async_iter): + self.async_iter = async_iter + self.loop = asyncio.new_event_loop() + self.executor = ThreadPoolExecutor(max_workers=1) + self.thread = threading.Thread(target=self._run_loop, daemon=True) + self.thread.start() + + def _run_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def _run_async_task(self, coro): + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + + def __iter__(self): + return self + + def __next__(self): + try: + return self._run_async_task(self.async_iter.__anext__()) + except StopAsyncIteration: + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() + raise StopIteration diff --git a/tests/test_background_iterable.py b/tests/test_background_iterable.py index 0da8d6ea6..603b01743 100644 --- a/tests/test_background_iterable.py +++ b/tests/test_background_iterable.py @@ -5,9 +5,10 @@ from levanter.utils.background_iterable import BackgroundIterable -def test_reentrancy(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +def test_reentrancy(max_capacity): test_data = list(range(1, 101)) - background_iterable = BackgroundIterable(lambda: iter(test_data), max_capacity=10) + background_iterable = BackgroundIterable(lambda: iter(test_data), max_capacity=max_capacity) iter1 = iter(background_iterable) iter2 = iter(background_iterable) @@ -19,9 +20,10 @@ def test_reentrancy(): assert data1 == test_data -def test_empty_iteration(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +def test_empty_iteration(max_capacity): # Create a BackgroundIterable instance with an empty producer function - background_iterable = BackgroundIterable(lambda: iter([]), max_capacity=10) + background_iterable = BackgroundIterable(lambda: iter([]), max_capacity=max_capacity) # Convert the iterator to a list for comparison data = list(background_iterable) @@ -30,13 +32,14 @@ def test_empty_iteration(): assert data == [] -def test_exception_handling(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +def test_exception_handling(max_capacity): # Create a producer function that raises an exception def producer_with_exception(): raise ValueError("Something went wrong!") # Create a BackgroundIterable instance with the producer function that raises an exception - background_iterable = BackgroundIterable(producer_with_exception, max_capacity=10) + background_iterable = BackgroundIterable(producer_with_exception, max_capacity=max_capacity) # Iterate over the BackgroundIterable and handle the raised exception with pytest.raises(ValueError): @@ -44,13 +47,14 @@ def producer_with_exception(): pass -def test_stop_event(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +def test_stop_event(max_capacity): def ongoing_process(): while True: for item in range(1, 101): yield item - background_iterable = BackgroundIterable(ongoing_process, max_capacity=10) + background_iterable = BackgroundIterable(ongoing_process, max_capacity=max_capacity) iter1 = iter(background_iterable) @@ -67,13 +71,15 @@ def ongoing_process(): @pytest.mark.asyncio -async def test_async_reentrancy(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +async def test_async_reentrancy(max_capacity): async def async_producer(): for i in range(1, 101): yield i - await asyncio.sleep(0.01) + if i % 10 == 0: + await asyncio.sleep(0.001) - background_iterable = BackgroundIterable(async_producer, max_capacity=10) + background_iterable = BackgroundIterable(async_producer, max_capacity=max_capacity) iter1 = iter(background_iterable) iter2 = iter(background_iterable) @@ -86,12 +92,13 @@ async def async_producer(): @pytest.mark.asyncio -async def test_async_empty_iteration(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +async def test_async_empty_iteration(max_capacity): async def async_producer(): if False: yield - background_iterable = BackgroundIterable(async_producer, max_capacity=10) + background_iterable = BackgroundIterable(async_producer, max_capacity=max_capacity) data = list(background_iterable) @@ -99,12 +106,13 @@ async def async_producer(): @pytest.mark.asyncio -async def test_async_exception_handling(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +async def test_async_exception_handling(max_capacity): async def async_producer_with_exception(): raise ValueError("Something went wrong!") yield 0 # have to make sure it's an async coroutine - background_iterable = BackgroundIterable(async_producer_with_exception, max_capacity=10) + background_iterable = BackgroundIterable(async_producer_with_exception, max_capacity=max_capacity) with pytest.raises(ValueError): for _ in background_iterable: @@ -112,21 +120,24 @@ async def async_producer_with_exception(): @pytest.mark.asyncio -async def test_async_stop_event(): +@pytest.mark.parametrize("max_capacity", [-1, None, 10]) +async def test_async_stop_event(max_capacity): async def ongoing_async_process(): while True: for item in range(1, 101): yield item - background_iterable = BackgroundIterable(ongoing_async_process, max_capacity=10) + background_iterable = BackgroundIterable(ongoing_async_process, max_capacity=max_capacity) iter1 = iter(background_iterable) for _ in range(5): - next(iter1) + q = next(iter1) + print(q) iter1.stop() + # this doesn't work b/c pytest is stupid with pytest.raises(StopIteration): - await next(iter1) - await next(iter1) + next(iter1) + next(iter1)