diff --git a/exo/download/hf/hf_shard_download.py b/exo/download/hf/hf_shard_download.py index eb562c3c9..cc94f4a4a 100644 --- a/exo/download/hf/hf_shard_download.py +++ b/exo/download/hf/hf_shard_download.py @@ -18,8 +18,11 @@ def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4): self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() async def ensure_shard(self, shard: Shard) -> Path: + """Ensure a shard is downloaded and return its path. Downloads are protected from cancellation.""" if shard in self.completed_downloads: + if DEBUG >= 2: print(f"Using completed download for {shard}") return self.completed_downloads[shard] + if self.quick_check: repo_root = get_repo_root(shard.model_id) snapshots_dir = repo_root/"snapshots" @@ -29,10 +32,13 @@ async def ensure_shard(self, shard: Shard) -> Path: most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime) return most_recent_dir - # If a download on this shard is already in progress, keep that one - for active_shard in self.active_downloads: - if active_shard == shard: - if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.") + # Check for active download + if shard in self.active_downloads: + if DEBUG >= 2: print(f"Using existing download for {shard}") + try: + return await self.active_downloads[shard] + except asyncio.CancelledError: + if DEBUG >= 2: print(f"Ignoring cancellation for existing download of {shard}") return await self.active_downloads[shard] # Cancel any downloads for this model_id on a different shard @@ -50,17 +56,42 @@ async def ensure_shard(self, shard: Shard) -> Path: traceback.print_exc() self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id} - # Start new download - download_task = asyncio.create_task(self._download_shard(shard)) + # Start new protected download + event = asyncio.Event() + result = None + error = None + + async def protected_download(): + nonlocal result, error + try: + if DEBUG >= 2: print(f"Starting protected download for {shard}") + result = await self._download_shard(shard) + self.completed_downloads[shard] = result + if DEBUG >= 2: print(f"Download completed for {shard}: {result}") + return result + except Exception as e: + if DEBUG >= 2: print(f"Error in download for {shard}: {e}") + error = e + raise + finally: + event.set() + + download_task = asyncio.create_task(protected_download()) self.active_downloads[shard] = download_task + try: - path = await download_task - self.completed_downloads[shard] = path - return path + if DEBUG >= 2: print(f"Waiting for download to complete for {shard}") + try: + return await download_task + except asyncio.CancelledError: + if DEBUG >= 2: print(f"Ignoring cancellation and waiting for download to complete for {shard}") + await event.wait() + if error: + raise error + return result finally: - # Ensure the task is removed even if an exception occurs - print(f"Removing download task for {shard}: {shard in self.active_downloads}") - if shard in self.active_downloads: + if DEBUG >= 2: print(f"Cleaning up download task for {shard}") + if shard in self.active_downloads and self.active_downloads[shard] is download_task: self.active_downloads.pop(shard) async def _download_shard(self, shard: Shard) -> Path: diff --git a/test/test_hf_shard_download.py b/test/test_hf_shard_download.py new file mode 100644 index 000000000..b121c8865 --- /dev/null +++ b/test/test_hf_shard_download.py @@ -0,0 +1,135 @@ +import pytest +import asyncio +from pathlib import Path +from unittest import mock +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo import DEBUG + +class MockPath: + def __init__(self, exists=True): + self.exists = lambda: exists + self.iterdir = lambda: [] + + def __truediv__(self, other): + return self + +mock_was_cancelled = False + +async def mock_download_shard(self, shard): + """Mock the _download_shard method to simulate a long download""" + global mock_was_cancelled + try: + if DEBUG >= 2: print(f"Starting mock download for {shard}") + await asyncio.sleep(0.5) # Simulate download time + if DEBUG >= 2: print(f"Mock download completed for {shard}") + return Path("/mock/download/path") + except asyncio.CancelledError: + mock_was_cancelled = True + if DEBUG >= 2: print(f"Mock download was cancelled for {shard}, completing anyway") + # Continue despite cancellation + await asyncio.sleep(0.5) + if DEBUG >= 2: print(f"Mock download completed after cancellation for {shard}") + return Path("/mock/download/path") + +@pytest.mark.asyncio +async def test_download_protection(): + """Test that downloads are protected from cancellation""" + global mock_was_cancelled + mock_was_cancelled = False + + with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \ + mock.patch.object(HFShardDownloader, '_download_shard', mock_download_shard): + + downloader = HFShardDownloader() + shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1) + + # Create a future we'll use to control the test flow + download_complete = asyncio.Future() + + async def do_download(): + try: + result = await downloader.ensure_shard(shard) + download_complete.set_result(result) + except Exception as e: + if not download_complete.done(): + download_complete.set_exception(e) + + # Start the download + task = asyncio.create_task(do_download()) + + # Give it a moment to start + await asyncio.sleep(0.1) + + # Try to cancel the task + if DEBUG >= 2: print("Attempting to cancel download task") + task.cancel() + + # Wait for result with timeout + try: + result = await asyncio.wait_for(download_complete, timeout=2.0) + assert isinstance(result, Path), "Should return a Path" + assert shard not in downloader.active_downloads, "Download task should be cleaned up" + assert shard in downloader.completed_downloads, "Download should be marked as completed" + if DEBUG >= 2: print("Download completed successfully despite cancellation") + except asyncio.TimeoutError: + pytest.fail("Download did not complete in time") + except Exception as e: + pytest.fail(f"Download failed with error: {e}") + +@pytest.mark.asyncio +async def test_multiple_downloads(): + """Test handling multiple downloads for the same shard""" + with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \ + mock.patch.object(HFShardDownloader, '_download_shard', mock_download_shard): + + downloader = HFShardDownloader() + shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1) + + # Start both downloads with a small delay between them + if DEBUG >= 2: print("Starting first download") + download1 = asyncio.create_task(downloader.ensure_shard(shard)) + + await asyncio.sleep(0.2) # Give first download time to start + + if DEBUG >= 2: print("Starting second download") + download2 = asyncio.create_task(downloader.ensure_shard(shard)) + + # Wait for both downloads to complete + if DEBUG >= 2: print("Waiting for downloads to complete") + + path1 = await download1 + if DEBUG >= 2: print(f"First download completed with path: {path1}") + + path2 = await download2 + if DEBUG >= 2: print(f"Second download completed with path: {path2}") + + # Verify results + assert isinstance(path1, Path), "First download should return a Path" + assert isinstance(path2, Path), "Second download should return a Path" + assert path1 == path2, "Multiple downloads should return same path" + assert shard not in downloader.active_downloads, "Download tasks should be cleaned up" + assert shard in downloader.completed_downloads, "Download should be marked as completed" + +@pytest.mark.asyncio +async def test_download_error_handling(): + """Test that errors during download are handled properly""" + + async def mock_download_error(self, shard): + await asyncio.sleep(0.1) # Simulate some work + raise Exception("Download failed") + + with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \ + mock.patch.object(HFShardDownloader, '_download_shard', mock_download_error): + + downloader = HFShardDownloader() + shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1) + + with pytest.raises(Exception) as exc_info: + await downloader.ensure_shard(shard) + + assert str(exc_info.value) == "Download failed" + assert shard not in downloader.active_downloads, "Failed download should be cleaned up" + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file