Skip to content

Commit

Permalink
Fix eqx (#725)
Browse files Browse the repository at this point in the history
* don't hit ray if we don't need to in TreeCache

* make ray exit quieter

* reduce log spam of wandb

* .aider ignore

* sigh

* make BackgroundIterable work not in the background

* fix regression caused by new Equinox
  • Loading branch information
dlwh committed Sep 12, 2024
1 parent a91ef81 commit 5c18557
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ ledger.json

# local execution commands
local_*.sh
.aider*
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
]
dependencies = [
"haliax>=1.4.dev307",
"equinox==0.11.5",
"equinox==0.11.3",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
"transformers>=4.41.2",
Expand All @@ -37,7 +37,7 @@ dependencies = [
"braceexpand>=0.1.7",
"jmp>=0.0.3",
"fsspec[http]>=2024.2,<2024.10",
"tensorstore>=0.1.62",
"tensorstore==0.1.64",
"pytimeparse>=1.1.8",
"humanfriendly==10.0",
"safetensors[numpy]~=0.4.2",
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _munge_address_port(address: str):
logger.info(f"Successfully started ray head on port {ray_port}.")

# install an atexit handler to kill the head when we exit
atexit.register(lambda: os.system("ray stop -g 10 --force"))
atexit.register(lambda: os.system("ray stop -g 10 --force &> /dev/null"))
elif start_workers:
logger.info(
f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}."
Expand Down
8 changes: 7 additions & 1 deletion src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

# TODO: should probably do this in terms of bytes
MIN_ITEMS_TO_WRITE = 8192
# this is kinda silly, but the bigger the better.
MIN_ITEMS_TO_WRITE = 32 * 1024
MAX_TIME_BETWEEN_WRITES = 100.0


Expand Down Expand Up @@ -883,6 +884,7 @@ def __init__(
self.logger = pylogging.getLogger(f"TreeCache.{name}")
self._store_future: threading_Future[TreeStore] = threading_Future()
self._stop = False
# assert _broker is None

if self._broker is not None:
self._monitor_thread = threading.Thread(target=self._monitor_metrics, daemon=True)
Expand Down Expand Up @@ -1078,11 +1080,15 @@ async def _get_start_stops_async(self, slice):
return start, step, stop

def await_finished(self, timeout: Optional[float] = None):
if self._broker is None:
return
x = ray.get(self.finished_sentinel(), timeout=timeout)
self._attempt_to_load_store()
return x

async def finished(self):
if self._broker is None:
return
x = await self.finished_sentinel()
# TODO: make an async version of this
self._attempt_to_load_store()
Expand Down
10 changes: 7 additions & 3 deletions src/levanter/tracker/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, run: Optional[WandbRun]):
else:
self.run = run

self._last_warning_step = -500

def log_hyperparameters(self, hparams: dict[str, Any]):
self.run.config.update(hparams, allow_val_change=True)

Expand All @@ -53,9 +55,11 @@ def log(self, metrics: dict[str, Any], *, step, commit=None):
step = self.run.step

if step < self.run.step:
logger.warning(
f"Step {step} is less than the current step {self.run.step}. Cowardly refusing to log metrics."
)
if step - self._last_warning_step > 500:
logger.warning(
f"Step {step} is less than the current step {self.run.step}. Cowardly refusing to log metrics."
)
self._last_warning_step = step
return

step = int(step)
Expand Down
48 changes: 36 additions & 12 deletions src/levanter/utils/background_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import tblib

from levanter.utils.thread_utils import AsyncIteratorWrapper


Ex = TypeVar("Ex", covariant=True)

Expand Down Expand Up @@ -36,30 +38,52 @@ def __init__(self, producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[E
self.max_capacity = max_capacity
self._producer_fn = producer_fn
self._stop_event = threading.Event()
self.q: queue.Queue = queue.Queue(self.max_capacity or 0)
self.thread = threading.Thread(target=self._fill_queue_with_batches)
self.thread.daemon = True
self.thread.start()

if self.max_capacity is None or self.max_capacity >= 0:
self.q: queue.Queue = queue.Queue(self.max_capacity or 0)
self.thread: Optional[threading.Thread] = threading.Thread(target=self._fill_queue_with_batches)
self.thread.daemon = True
self.thread.start()
else:
# No background thread; consume items on demand
self.thread = None
self.iterator = self._producer_fn()
if not isinstance(self.iterator, Iterator):
self.iterator = AsyncIteratorWrapper(self.iterator)

def __iter__(self):
return self

def __next__(self):
while not self._stop_event.is_set():
batch = self.q.get()
if batch is _SENTINEL:
if self._stop_event.is_set():
raise StopIteration
if self.thread is not None:
while not self._stop_event.is_set():
batch = self.q.get()
if batch is _SENTINEL:
raise StopIteration
elif isinstance(batch, _ExceptionWrapper):
batch.reraise()
return batch
else:
# Consume the iterator directly on demand
try:
return next(self.iterator)
except StopIteration:
raise
except StopAsyncIteration:
raise StopIteration
elif isinstance(batch, _ExceptionWrapper):
batch.reraise()
return batch

except Exception as e:
raise e
raise StopIteration

def __del__(self):
self.stop()

def stop(self):
def stop(self, wait: bool = True):
self._stop_event.set()
if self.thread is not None and wait:
self.thread.join()

def _fill_queue_with_batches(self):
try:
Expand Down
29 changes: 29 additions & 0 deletions src/levanter/utils/thread_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator


# Create a ThreadPoolExecutor
Expand All @@ -26,3 +28,30 @@ def future_from_value(value):
future = asyncio.Future()
future.set_result(value)
return future


class AsyncIteratorWrapper(Iterator):
def __init__(self, async_iter):
self.async_iter = async_iter
self.loop = asyncio.new_event_loop()
self.executor = ThreadPoolExecutor(max_workers=1)
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()

def _run_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

def _run_async_task(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()

def __iter__(self):
return self

def __next__(self):
try:
return self._run_async_task(self.async_iter.__anext__())
except StopAsyncIteration:
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()
raise StopIteration
51 changes: 31 additions & 20 deletions tests/test_background_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from levanter.utils.background_iterable import BackgroundIterable


def test_reentrancy():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
def test_reentrancy(max_capacity):
test_data = list(range(1, 101))
background_iterable = BackgroundIterable(lambda: iter(test_data), max_capacity=10)
background_iterable = BackgroundIterable(lambda: iter(test_data), max_capacity=max_capacity)

iter1 = iter(background_iterable)
iter2 = iter(background_iterable)
Expand All @@ -19,9 +20,10 @@ def test_reentrancy():
assert data1 == test_data


def test_empty_iteration():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
def test_empty_iteration(max_capacity):
# Create a BackgroundIterable instance with an empty producer function
background_iterable = BackgroundIterable(lambda: iter([]), max_capacity=10)
background_iterable = BackgroundIterable(lambda: iter([]), max_capacity=max_capacity)

# Convert the iterator to a list for comparison
data = list(background_iterable)
Expand All @@ -30,27 +32,29 @@ def test_empty_iteration():
assert data == []


def test_exception_handling():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
def test_exception_handling(max_capacity):
# Create a producer function that raises an exception
def producer_with_exception():
raise ValueError("Something went wrong!")

# Create a BackgroundIterable instance with the producer function that raises an exception
background_iterable = BackgroundIterable(producer_with_exception, max_capacity=10)
background_iterable = BackgroundIterable(producer_with_exception, max_capacity=max_capacity)

# Iterate over the BackgroundIterable and handle the raised exception
with pytest.raises(ValueError):
for _ in background_iterable:
pass


def test_stop_event():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
def test_stop_event(max_capacity):
def ongoing_process():
while True:
for item in range(1, 101):
yield item

background_iterable = BackgroundIterable(ongoing_process, max_capacity=10)
background_iterable = BackgroundIterable(ongoing_process, max_capacity=max_capacity)

iter1 = iter(background_iterable)

Expand All @@ -67,13 +71,15 @@ def ongoing_process():


@pytest.mark.asyncio
async def test_async_reentrancy():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
async def test_async_reentrancy(max_capacity):
async def async_producer():
for i in range(1, 101):
yield i
await asyncio.sleep(0.01)
if i % 10 == 0:
await asyncio.sleep(0.001)

background_iterable = BackgroundIterable(async_producer, max_capacity=10)
background_iterable = BackgroundIterable(async_producer, max_capacity=max_capacity)

iter1 = iter(background_iterable)
iter2 = iter(background_iterable)
Expand All @@ -86,47 +92,52 @@ async def async_producer():


@pytest.mark.asyncio
async def test_async_empty_iteration():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
async def test_async_empty_iteration(max_capacity):
async def async_producer():
if False:
yield

background_iterable = BackgroundIterable(async_producer, max_capacity=10)
background_iterable = BackgroundIterable(async_producer, max_capacity=max_capacity)

data = list(background_iterable)

assert data == []


@pytest.mark.asyncio
async def test_async_exception_handling():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
async def test_async_exception_handling(max_capacity):
async def async_producer_with_exception():
raise ValueError("Something went wrong!")
yield 0 # have to make sure it's an async coroutine

background_iterable = BackgroundIterable(async_producer_with_exception, max_capacity=10)
background_iterable = BackgroundIterable(async_producer_with_exception, max_capacity=max_capacity)

with pytest.raises(ValueError):
for _ in background_iterable:
pass


@pytest.mark.asyncio
async def test_async_stop_event():
@pytest.mark.parametrize("max_capacity", [-1, None, 10])
async def test_async_stop_event(max_capacity):
async def ongoing_async_process():
while True:
for item in range(1, 101):
yield item

background_iterable = BackgroundIterable(ongoing_async_process, max_capacity=10)
background_iterable = BackgroundIterable(ongoing_async_process, max_capacity=max_capacity)

iter1 = iter(background_iterable)

for _ in range(5):
next(iter1)
q = next(iter1)
print(q)

iter1.stop()

# this doesn't work b/c pytest is stupid
with pytest.raises(StopIteration):
await next(iter1)
await next(iter1)
next(iter1)
next(iter1)

0 comments on commit 5c18557

Please sign in to comment.