Skip to content

Commit

Permalink
dvcfs: optimize get() by reducing index.info calls()
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Aug 26, 2024
1 parent 3f2584d commit 309631c
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 24 deletions.
8 changes: 4 additions & 4 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def download(self, to: "Output", jobs: Optional[int] = None):

files = super().download(to=to, jobs=jobs)
if not isinstance(to.fs, LocalFileSystem):
return files
return

hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
for src_path, dest_path in files:
for src_path, dest_path, *rest in files:
try:
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
info = rest[0] if rest else self.fs.info(src_path)
hash_info = info["dvc_info"]["entry"].hash_info
dest_info = to.fs.info(dest_path)
except (KeyError, AttributeError):
# If no hash info found, just keep going and output will be hashed later
Expand All @@ -112,7 +113,6 @@ def download(self, to: "Output", jobs: Optional[int] = None):
hashes.append((dest_path, hash_info, dest_info))
cache = to.cache if to.use_cache else to.local_cache
cache.state.save_many(hashes, to.fs)
return files

def update(self, rev: Optional[str] = None):
if rev:
Expand Down
25 changes: 14 additions & 11 deletions dvc/fs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import glob
from typing import Optional
from typing import Optional, Union
from urllib.parse import urlparse

from dvc.config import ConfigError as RepoConfigError
Expand Down Expand Up @@ -47,12 +47,24 @@

def download(
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
) -> list[tuple[str, str]]:
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
from dvc.scm import lfs_prefetch

from .callbacks import TqdmCallback

with TqdmCallback(desc=f"Downloading {fs.name(fs_path)}", unit="files") as cb:
if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
if not glob.has_magic(fs_path):
return fs._get(fs_path, to, batch_size=jobs, callback=cb)

# NOTE: We use dvc-objects generic.copy over fs.get since it makes file
# download atomic and avoids fsspec glob/regex path expansion.
if fs.isdir(fs_path):
Expand All @@ -69,15 +81,6 @@ def download(
from_infos = [fs_path]
to_infos = [to]

if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
cb.set_size(len(from_infos))
jobs = jobs or fs.jobs
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)
Expand Down
141 changes: 139 additions & 2 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
import threading
from collections import deque
from contextlib import ExitStack, suppress
from glob import has_magic
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from fsspec.spec import AbstractFileSystem
from fsspec.spec import DEFAULT_CALLBACK, AbstractFileSystem
from funcy import wrap_with

from dvc.log import logger
from dvc_objects.fs.base import FileSystem
from dvc.utils.threadpool import ThreadPoolExecutor
from dvc_objects.fs.base import AnyFSPath, FileSystem

from .data import DataFileSystem

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.types import DictStrAny, StrPath

from .callbacks import Callback

logger = logger.getChild(__name__)

RepoFactory = Union[Callable[..., "Repo"], type["Repo"]]
Expand Down Expand Up @@ -474,9 +478,103 @@ def _info( # noqa: C901
info["name"] = path
return info

def get(
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
):
self._get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
batch_size=batch_size,
**kwargs,
)

def _get(
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
if (
isinstance(rpath, list)
or isinstance(lpath, list)
or has_magic(rpath)
or not self.exists(rpath)
or not recursive
):
super().get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
**kwargs,
)
return []

if os.path.isdir(lpath) or lpath.endswith(os.path.sep):
lpath = self.join(lpath, os.path.basename(rpath))

if self.isfile(rpath):
os.makedirs(os.path.dirname(lpath), exist_ok=True)
with callback.branched(rpath, lpath) as child:
self.get_file(rpath, lpath, callback=child, **kwargs)
return [(rpath, lpath)]

_files = []
_dirs: list[str] = []
for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True):
if files:
callback.set_size((callback.size or 0) + len(files))

rel = self.relpath(root, rpath)
dest_root = lpath if rel in ("", os.curdir) else os.path.join(lpath, rel)
_dirs.extend(f"{dest_root}{os.path.sep}{d}" for d in dirs)

key = self._get_key_from_relative(root)
_, dvc_fs, _ = self._get_subrepo_info(key)

for name, info in files.items():
src_path = f"{root}{self.sep}{name}"
dest_path = f"{dest_root}{os.path.sep}{name}"
_files.append((dvc_fs, src_path, dest_path, info))

os.makedirs(lpath, exist_ok=True)
for d in _dirs:
os.mkdir(d)

def _get_file(arg):
dvc_fs, src, dest, info = arg
dvc_info = info.get("dvc_info")
if dvc_info and dvc_fs:
dvc_path = dvc_info["name"]
dvc_fs.get_file(
dvc_path, dest, callback=callback, info=dvc_info, **kwargs
)
else:
self.get_file(src, dest, callback=callback, **kwargs)
return src, dest, info

with ThreadPoolExecutor(max_workers=batch_size) as executor:
return list(executor.imap_unordered(_get_file, _files))

def get_file(self, rpath, lpath, **kwargs):
key = self._get_key_from_relative(rpath)
fs_path = self._from_key(key)

try:
return self.repo.fs.get_file(fs_path, lpath, **kwargs)
except FileNotFoundError:
Expand Down Expand Up @@ -553,6 +651,45 @@ def immutable(self):
def getcwd(self):
return self.fs.getcwd()

def _get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
# FileSystem.get is non-recursive by default if arguments are lists
# otherwise, it's recursive.
recursive = not (isinstance(from_info, list) and isinstance(to_info, list))
return self.fs._get(
from_info,
to_info,
callback=callback,
recursive=recursive,
batch_size=batch_size,
**kwargs,
)

def get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> None:
self._get(
from_info,
to_info,
callback=callback,
batch_size=batch_size,
recursive=recursive,
**kwargs,
)

@property
def fsid(self) -> str:
return self.fs.fsid
Expand Down
42 changes: 35 additions & 7 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils.fs import remove
from dvc_data.hashfile import hash
from dvc_data.index.index import DataIndexDirError
from dvc_data.index.index import DataIndex, DataIndexDirError


def test_import(tmp_dir, scm, dvc, erepo_dir):
Expand Down Expand Up @@ -725,12 +725,40 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
)


def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
@pytest.mark.parametrize(
"files",
[
{"foo": "foo"},
{
"dir": {
"bar": "bar",
"subdir": {"lorem": "ipsum", "nested": {"lorem": "lorem"}},
}
},
],
)
def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker, files):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")
erepo_dir.dvc_gen(files, commit="create foo")

spy = mocker.spy(hash, "file_md5")
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.call_count == 1
for call in spy.call_args_list:
assert stage.outs[0].fs_path != call.args[0]
index_info_spy = mocker.spy(DataIndex, "info")
name = next(iter(files))
dvc.imp(os.fspath(erepo_dir), name, "out")

local_hashes = [
call.args[0] for call in spy.call_args_list if call.args[1].protocol == "local"
]
assert not local_hashes

expected_info_calls = {(name,)}
if isinstance(files[name], dict):
dirs = {
(name, *d.relative_to(tmp_dir / "out").parts)
for d in (tmp_dir / "out").rglob("*")
if d.is_dir()
}
expected_info_calls.update(dirs)
assert {
call.args[1] for call in index_info_spy.call_args_list
} == expected_info_calls

0 comments on commit 309631c

Please sign in to comment.