Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dvcfs: optimize get() by reducing index.info calls() #10540

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def download(self, to: "Output", jobs: Optional[int] = None):

files = super().download(to=to, jobs=jobs)
if not isinstance(to.fs, LocalFileSystem):
return files
return

hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
for src_path, dest_path in files:
for src_path, dest_path, *rest in files:
try:
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
info = rest[0] if rest else self.fs.info(src_path)
hash_info = info["dvc_info"]["entry"].hash_info
dest_info = to.fs.info(dest_path)
except (KeyError, AttributeError):
# If no hash info found, just keep going and output will be hashed later
Expand All @@ -112,7 +113,6 @@ def download(self, to: "Output", jobs: Optional[int] = None):
hashes.append((dest_path, hash_info, dest_info))
cache = to.cache if to.use_cache else to.local_cache
cache.state.save_many(hashes, to.fs)
return files

def update(self, rev: Optional[str] = None):
if rev:
Expand Down
25 changes: 14 additions & 11 deletions dvc/fs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import glob
from typing import Optional
from typing import Optional, Union
from urllib.parse import urlparse

from dvc.config import ConfigError as RepoConfigError
Expand Down Expand Up @@ -47,12 +47,24 @@

def download(
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
) -> list[tuple[str, str]]:
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
from dvc.scm import lfs_prefetch

from .callbacks import TqdmCallback

with TqdmCallback(desc=f"Downloading {fs.name(fs_path)}", unit="files") as cb:
if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
if not glob.has_magic(fs_path):
return fs._get(fs_path, to, batch_size=jobs, callback=cb)
skshetry marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: We use dvc-objects generic.copy over fs.get since it makes file
# download atomic and avoids fsspec glob/regex path expansion.
if fs.isdir(fs_path):
Expand All @@ -69,15 +81,6 @@ def download(
from_infos = [fs_path]
to_infos = [to]

if isinstance(fs, DVCFileSystem):
lfs_prefetch(
fs,
[
f"{fs.normpath(glob.escape(fs_path))}/**"
if fs.isdir(fs_path)
else glob.escape(fs_path)
],
)
cb.set_size(len(from_infos))
jobs = jobs or fs.jobs
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)
Expand Down
148 changes: 146 additions & 2 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
import threading
from collections import deque
from contextlib import ExitStack, suppress
from glob import has_magic
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from fsspec.spec import AbstractFileSystem
from fsspec.spec import DEFAULT_CALLBACK, AbstractFileSystem
from funcy import wrap_with

from dvc.log import logger
from dvc_objects.fs.base import FileSystem
from dvc.utils.threadpool import ThreadPoolExecutor
from dvc_objects.fs.base import AnyFSPath, FileSystem

from .data import DataFileSystem

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.types import DictStrAny, StrPath

from .callbacks import Callback

logger = logger.getChild(__name__)

RepoFactory = Union[Callable[..., "Repo"], type["Repo"]]
Expand Down Expand Up @@ -474,9 +478,110 @@ def _info( # noqa: C901
info["name"] = path
return info

def get(
skshetry marked this conversation as resolved.
Show resolved Hide resolved
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
):
self._get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
batch_size=batch_size,
**kwargs,
)

def _get( # noqa: C901
self,
rpath,
lpath,
recursive=False,
callback=DEFAULT_CALLBACK,
maxdepth=None,
batch_size=None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
if (
isinstance(rpath, list)
or isinstance(lpath, list)
or has_magic(rpath)
or not self.exists(rpath)
or not recursive
):
super().get(
rpath,
lpath,
recursive=recursive,
callback=callback,
maxdepth=maxdepth,
**kwargs,
)
return []

if os.path.isdir(lpath) or lpath.endswith(os.path.sep):
lpath = self.join(lpath, os.path.basename(rpath))

if self.isfile(rpath):
with callback.branched(rpath, lpath) as child:
self.get_file(rpath, lpath, callback=child, **kwargs)
return [(rpath, lpath)]

_files = []
_dirs: list[str] = []
for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True):
if files:
callback.set_size((callback.size or 0) + len(files))

parts = self.relparts(root, rpath)
if parts in ((os.curdir,), ("",)):
parts = ()
dest_root = os.path.join(lpath, *parts)
if not maxdepth or len(parts) < maxdepth - 1:
_dirs.extend(f"{dest_root}{os.path.sep}{d}" for d in dirs)

key = self._get_key_from_relative(root)
_, dvc_fs, _ = self._get_subrepo_info(key)

for name, info in files.items():
src_path = f"{root}{self.sep}{name}"
dest_path = f"{dest_root}{os.path.sep}{name}"
_files.append((dvc_fs, src_path, dest_path, info))

os.makedirs(lpath, exist_ok=True)
for d in _dirs:
os.mkdir(d)

def _get_file(arg):
dvc_fs, src, dest, info = arg
dvc_info = info.get("dvc_info")
if dvc_info and dvc_fs:
dvc_path = dvc_info["name"]
dvc_fs.get_file(
dvc_path, dest, callback=callback, info=dvc_info, **kwargs
)
else:
self.get_file(src, dest, callback=callback, **kwargs)
return src, dest, info

with ThreadPoolExecutor(max_workers=batch_size) as executor:
return list(executor.imap_unordered(_get_file, _files))

def get_file(self, rpath, lpath, **kwargs):
key = self._get_key_from_relative(rpath)
fs_path = self._from_key(key)

dirpath = os.path.dirname(lpath)
if dirpath:
# makedirs raises error if the string is empty
os.makedirs(dirpath, exist_ok=True)

try:
return self.repo.fs.get_file(fs_path, lpath, **kwargs)
except FileNotFoundError:
Expand Down Expand Up @@ -553,6 +658,45 @@ def immutable(self):
def getcwd(self):
return self.fs.getcwd()

def _get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
# FileSystem.get is non-recursive by default if arguments are lists
# otherwise, it's recursive.
recursive = not (isinstance(from_info, list) and isinstance(to_info, list))
return self.fs._get(
from_info,
to_info,
callback=callback,
recursive=recursive,
batch_size=batch_size,
**kwargs,
)

def get(
self,
from_info: Union[AnyFSPath, list[AnyFSPath]],
to_info: Union[AnyFSPath, list[AnyFSPath]],
callback: "Callback" = DEFAULT_CALLBACK,
recursive: bool = False,
batch_size: Optional[int] = None,
**kwargs,
) -> None:
self._get(
from_info,
to_info,
callback=callback,
batch_size=batch_size,
recursive=recursive,
**kwargs,
)

@property
def fsid(self) -> str:
return self.fs.fsid
Expand Down
47 changes: 38 additions & 9 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils.fs import remove
from dvc_data.hashfile import hash
from dvc_data.index.index import DataIndexDirError
from dvc_data.index.index import DataIndex, DataIndexDirError


def test_import(tmp_dir, scm, dvc, erepo_dir):
Expand Down Expand Up @@ -725,12 +725,41 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
)


def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
@pytest.mark.parametrize(
"files,expected_info_calls",
[
({"foo": "foo"}, {("foo",)}),
(
{
"dir": {
"bar": "bar",
"subdir": {"lorem": "ipsum", "nested": {"lorem": "lorem"}},
}
},
# info calls should be made for only directories
{("dir",), ("dir", "subdir"), ("dir", "subdir", "nested")},
),
],
)
def test_import_no_hash(
tmp_dir, scm, dvc, erepo_dir, mocker, files, expected_info_calls
):
with erepo_dir.chdir():
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")

spy = mocker.spy(hash, "file_md5")
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
assert spy.call_count == 1
for call in spy.call_args_list:
assert stage.outs[0].fs_path != call.args[0]
erepo_dir.dvc_gen(files, commit="create foo")

file_md5_spy = mocker.spy(hash, "file_md5")
index_info_spy = mocker.spy(DataIndex, "info")
name = next(iter(files))

dvc.imp(os.fspath(erepo_dir), name, "out")

local_hashes = [
call.args[0]
for call in file_md5_spy.call_args_list
if call.args[1].protocol == "local"
]
# no files should be hashed, should use existing metadata
assert not local_hashes
assert {
call.args[1] for call in index_info_spy.call_args_list
} == expected_info_calls
Loading
Loading