Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Models in dbutils.fs operations #750

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __repr__(self) -> str:
return f"<_DbfsIO {self._path} {'read' if self.readable() else 'write'}=True>"


class _VolumesIO(BinaryIO):
class _FilesIO(BinaryIO):

def __init__(self, api: files.FilesAPI, path: str, *, read: bool, write: bool, overwrite: bool):
self._buffer = []
Expand Down Expand Up @@ -262,7 +262,7 @@ def __exit__(self, __t, __value, __traceback):
self.close()

def __repr__(self) -> str:
return f"<_VolumesIO {self._path} {'read' if self.readable() else 'write'}=True>"
return f"<_FilesIO {self._path} {'read' if self.readable() else 'write'}=True>"


class _Path(ABC):
Expand Down Expand Up @@ -398,7 +398,7 @@ def __repr__(self) -> str:
return f'<_LocalPath {self._path}>'


class _VolumesPath(_Path):
class _FilesPath(_Path):

def __init__(self, api: files.FilesAPI, src: Union[str, pathlib.Path]):
self._path = pathlib.PurePosixPath(str(src).replace('dbfs:', '').replace('file:', ''))
Expand All @@ -411,7 +411,7 @@ def _is_dbfs(self) -> bool:
return False

def child(self, path: str) -> Self:
return _VolumesPath(self._api, str(self._path / path))
return _FilesPath(self._api, str(self._path / path))

def _is_dir(self) -> bool:
try:
Expand All @@ -431,7 +431,7 @@ def exists(self) -> bool:
return self.is_dir

def open(self, *, read=False, write=False, overwrite=False) -> BinaryIO:
return _VolumesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite)
return _FilesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite)

def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]:
if not self.is_dir:
Expand All @@ -458,13 +458,13 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]:
def delete(self, *, recursive=False):
if self.is_dir:
for entry in self.list(recursive=False):
_VolumesPath(self._api, entry.path).delete(recursive=True)
_FilesPath(self._api, entry.path).delete(recursive=True)
self._api.delete_directory(self.as_string)
else:
self._api.delete(self.as_string)

def __repr__(self) -> str:
return f'<_VolumesPath {self._path}>'
return f'<_FilesPath {self._path}>'


class _DbfsPath(_Path):
Expand Down Expand Up @@ -589,8 +589,8 @@ def _path(self, src):
'UC Volumes paths, not external locations or DBFS mount points.')
if src.scheme == 'file':
return _LocalPath(src.geturl())
if src.path.startswith('/Volumes'):
return _VolumesPath(self._files_api, src.geturl())
if src.path.startswith(('/Volumes', '/Models')):
return _FilesPath(self._files_api, src.geturl())
return _DbfsPath(self._dbfs_api, src.geturl())

def copy(self, src: str, dst: str, *, recursive=False, overwrite=False):
Expand Down
13 changes: 8 additions & 5 deletions tests/test_dbfs_mixins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

from databricks.sdk.errors import NotFound
from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _LocalPath,
_VolumesPath)
from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _FilesPath,
_LocalPath)


def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker):
Expand Down Expand Up @@ -55,11 +55,14 @@ def test_moving_local_dir_to_dbfs(config, tmp_path, mocker):


@pytest.mark.parametrize('path,expected_type', [('/path/to/file', _DbfsPath),
('/Volumes/path/to/file', _VolumesPath),
('/Volumes/path/to/file', _FilesPath),
('/Models/path/to/file', _FilesPath),
('dbfs:/path/to/file', _DbfsPath),
('dbfs:/Volumes/path/to/file', _VolumesPath),
('dbfs:/Volumes/path/to/file', _FilesPath),
('dbfs:/Models/path/to/file', _FilesPath),
('file:/path/to/file', _LocalPath),
('file:/Volumes/path/to/file', _LocalPath), ])
('file:/Volumes/path/to/file', _LocalPath),
('file:/Models/path/to/file', _LocalPath), ])
def test_fs_path(config, path, expected_type):
dbfs_ext = DbfsExt(config)
assert isinstance(dbfs_ext._path(path), expected_type)
Expand Down
Loading