From 1c957896af02afa375ea8bef94a4f732faf7363d Mon Sep 17 00:00:00 2001
From: Moritz Onken <moritz.onken@emdgroup.com>
Date: Sat, 19 Aug 2023 16:05:58 -0400
Subject: [PATCH] Call cat_ranges in blockcache for async filesystems

---
 fsspec/caching.py                | 43 +++++++++++++++++++++++++++++---
 fsspec/implementations/cached.py | 12 ++++++++-
 fsspec/tests/test_caches.py      | 26 ++++++++++++++++++-
 3 files changed, 75 insertions(+), 6 deletions(-)

diff --git a/fsspec/caching.py b/fsspec/caching.py
index 511de1dee..ec8206768 100644
--- a/fsspec/caching.py
+++ b/fsspec/caching.py
@@ -50,14 +50,35 @@ class MMapCache(BaseCache):
     Ensure there is enough disc space in the temporary location.
 
     This cache method might only work on posix
+
+    Parameters
+    ----------
+    blocksize: int
+        How far to read ahead in numbers of bytes
+    fetcher: func
+        Function of the form f(start, end) which gets bytes from remote as
+        specified
+    size: int
+        How big this file is
+    location: str
+        Where to create the temporary file. If None, a temporary file is
+        created using tempfile.TemporaryFile().
+    blocks: set
+        Set of block numbers that have already been fetched. If None, an empty
+        set is created.
+    multi_fetcher: func
+        Function of the form f([(start, end)]) which gets bytes from remote
+        as specified. This function is used to fetch multiple blocks at once.
+        If not specified, the fetcher function is used instead.
     """
 
     name = "mmap"
 
-    def __init__(self, blocksize, fetcher, size, location=None, blocks=None):
+    def __init__(self, blocksize, fetcher, size, location=None, blocks=None, multi_fetcher=None):
         super().__init__(blocksize, fetcher, size)
         self.blocks = set() if blocks is None else blocks
         self.location = location
+        self.multi_fetcher = multi_fetcher
         self.cache = self._makefile()
 
     def _makefile(self):
@@ -93,16 +114,30 @@ def _fetch(self, start, end):
         start_block = start // self.blocksize
         end_block = end // self.blocksize
         need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
+        ranges = []
         while need:
             # TODO: not a for loop so we can consolidate blocks later to
-            # make fewer fetch calls; this could be parallel
+            # make fewer fetch calls
             i = need.pop(0)
             sstart = i * self.blocksize
             send = min(sstart + self.blocksize, self.size)
-            logger.debug(f"MMap get block #{i} ({sstart}-{send}")
-            self.cache[sstart:send] = self.fetcher(sstart, send)
+            ranges.append((sstart, send))
             self.blocks.add(i)
 
+        if not ranges:
+            return self.cache[start:end]
+
+        if self.multi_fetcher:
+            logger.debug(f"MMap get blocks {ranges}")
+            for idx, r in enumerate(self.multi_fetcher(ranges)):
+                (sstart, send) = ranges[idx]
+                logger.debug(f"MMap get block ({sstart}-{send}")
+                self.cache[sstart:send] = r
+        else:
+            for (sstart, send) in ranges:
+                logger.debug(f"MMap get block ({sstart}-{send}")
+                self.cache[sstart:send] = self.fetcher(sstart, send)
+
         return self.cache[start:end]
 
     def __getstate__(self):
diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py
index 30aeb119d..0a2633db8 100644
--- a/fsspec/implementations/cached.py
+++ b/fsspec/implementations/cached.py
@@ -400,7 +400,17 @@ def _open(
                 )
         else:
             detail["blocksize"] = f.blocksize
-        f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)
+
+        def _fetch_ranges(ranges):
+            return self.fs.cat_ranges(
+                [path] * len(ranges),
+                [r[0] for r in ranges],
+                [r[1] for r in ranges],
+                **kwargs,
+            )
+
+        multi_fetcher = None if not self.fs.async_impl or self.compression else _fetch_ranges
+        f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher)
         close = f.close
         f.close = lambda: self.close_and_update(f, close)
         self.save_cache()
diff --git a/fsspec/tests/test_caches.py b/fsspec/tests/test_caches.py
index 65c80cdb4..75db97dcb 100644
--- a/fsspec/tests/test_caches.py
+++ b/fsspec/tests/test_caches.py
@@ -3,7 +3,7 @@
 
 import pytest
 
-from fsspec.caching import BlockCache, FirstChunkCache, caches, register_cache
+from fsspec.caching import BlockCache, FirstChunkCache, MMapCache, caches, register_cache
 
 
 def test_cache_getitem(Cache_imp):
@@ -43,6 +43,10 @@ def _fetcher(start, end):
 def letters_fetcher(start, end):
     return string.ascii_letters[start:end].encode()
 
+def multi_letters_fetcher(ranges):
+    print(ranges)
+    return [string.ascii_letters[start:end].encode() for start, end in ranges]
+
 
 not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
 
@@ -81,6 +85,26 @@ def test_first_cache():
     c.fetcher = None
     assert c._fetch(1, 4) == letters_fetcher(1, 4)
 
+def test_mmap_cache(mocker):
+    fetcher = mocker.Mock(wraps=letters_fetcher)
+
+    c = MMapCache(5, fetcher, 52)
+    assert c._fetch(12, 15) == letters_fetcher(12, 15)
+    assert fetcher.call_count == 2
+    assert c._fetch(3, 10) == letters_fetcher(3, 10)
+    assert fetcher.call_count == 4
+    assert c._fetch(1, 4) == letters_fetcher(1, 4)
+    assert fetcher.call_count == 4
+
+    multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
+    m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
+    assert m._fetch(12, 15) == letters_fetcher(12, 15)
+    assert multi_fetcher.call_count == 1
+    assert m._fetch(3, 10) == letters_fetcher(3, 10)
+    assert multi_fetcher.call_count == 2
+    assert m._fetch(1, 4) == letters_fetcher(1, 4)
+    assert multi_fetcher.call_count == 2
+    assert fetcher.call_count == 4
 
 @pytest.mark.parametrize(
     "size_requests",