diff --git a/fsspec/caching.py b/fsspec/caching.py index 511de1dee..ec8206768 100644 --- a/fsspec/caching.py +++ b/fsspec/caching.py @@ -50,14 +50,35 @@ class MMapCache(BaseCache): Ensure there is enough disc space in the temporary location. This cache method might only work on posix + + Parameters + ---------- + blocksize: int + How far to read ahead in numbers of bytes + fetcher: func + Function of the form f(start, end) which gets bytes from remote as + specified + size: int + How big this file is + location: str + Where to create the temporary file. If None, a temporary file is + created using tempfile.TemporaryFile(). + blocks: set + Set of block numbers that have already been fetched. If None, an empty + set is created. + multi_fetcher: func + Function of the form f([(start, end)]) which gets bytes from remote + as specified. This function is used to fetch multiple blocks at once. + If not specified, the fetcher function is used instead. """ name = "mmap" - def __init__(self, blocksize, fetcher, size, location=None, blocks=None): + def __init__(self, blocksize, fetcher, size, location=None, blocks=None, multi_fetcher=None): super().__init__(blocksize, fetcher, size) self.blocks = set() if blocks is None else blocks self.location = location + self.multi_fetcher = multi_fetcher self.cache = self._makefile() def _makefile(self): @@ -93,16 +114,30 @@ def _fetch(self, start, end): start_block = start // self.blocksize end_block = end // self.blocksize need = [i for i in range(start_block, end_block + 1) if i not in self.blocks] + ranges = [] while need: # TODO: not a for loop so we can consolidate blocks later to - # make fewer fetch calls; this could be parallel + # make fewer fetch calls i = need.pop(0) sstart = i * self.blocksize send = min(sstart + self.blocksize, self.size) - logger.debug(f"MMap get block #{i} ({sstart}-{send}") - self.cache[sstart:send] = self.fetcher(sstart, send) + ranges.append((sstart, send)) self.blocks.add(i) + if not ranges: + return self.cache[start:end] + + if self.multi_fetcher: + logger.debug(f"MMap get blocks {ranges}") + for idx, r in enumerate(self.multi_fetcher(ranges)): + (sstart, send) = ranges[idx] + logger.debug(f"MMap get block ({sstart}-{send}") + self.cache[sstart:send] = r + else: + for (sstart, send) in ranges: + logger.debug(f"MMap get block ({sstart}-{send}") + self.cache[sstart:send] = self.fetcher(sstart, send) + return self.cache[start:end] def __getstate__(self): diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index 30aeb119d..0a2633db8 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -400,7 +400,17 @@ def _open( ) else: detail["blocksize"] = f.blocksize - f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks) + + def _fetch_ranges(ranges): + return self.fs.cat_ranges( + [path] * len(ranges), + [r[0] for r in ranges], + [r[1] for r in ranges], + **kwargs, + ) + + multi_fetcher = None if not self.fs.async_impl or self.compression else _fetch_ranges + f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher) close = f.close f.close = lambda: self.close_and_update(f, close) self.save_cache() diff --git a/fsspec/tests/test_caches.py b/fsspec/tests/test_caches.py index 65c80cdb4..75db97dcb 100644 --- a/fsspec/tests/test_caches.py +++ b/fsspec/tests/test_caches.py @@ -3,7 +3,7 @@ import pytest -from fsspec.caching import BlockCache, FirstChunkCache, caches, register_cache +from fsspec.caching import BlockCache, FirstChunkCache, MMapCache, caches, register_cache def test_cache_getitem(Cache_imp): @@ -43,6 +43,10 @@ def _fetcher(start, end): def letters_fetcher(start, end): return string.ascii_letters[start:end].encode() +def multi_letters_fetcher(ranges): + print(ranges) + return [string.ascii_letters[start:end].encode() for start, end in ranges] + not_parts_caches = {k: v for k, v in caches.items() if k != "parts"} @@ -81,6 +85,26 @@ def test_first_cache(): c.fetcher = None assert c._fetch(1, 4) == letters_fetcher(1, 4) +def test_mmap_cache(mocker): + fetcher = mocker.Mock(wraps=letters_fetcher) + + c = MMapCache(5, fetcher, 52) + assert c._fetch(12, 15) == letters_fetcher(12, 15) + assert fetcher.call_count == 2 + assert c._fetch(3, 10) == letters_fetcher(3, 10) + assert fetcher.call_count == 4 + assert c._fetch(1, 4) == letters_fetcher(1, 4) + assert fetcher.call_count == 4 + + multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher) + m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher) + assert m._fetch(12, 15) == letters_fetcher(12, 15) + assert multi_fetcher.call_count == 1 + assert m._fetch(3, 10) == letters_fetcher(3, 10) + assert multi_fetcher.call_count == 2 + assert m._fetch(1, 4) == letters_fetcher(1, 4) + assert multi_fetcher.call_count == 2 + assert fetcher.call_count == 4 @pytest.mark.parametrize( "size_requests",