From e991dd07e3dab4c9a91b4e53b8c43bce00ef5a37 Mon Sep 17 00:00:00 2001 From: Shicheng Zhou Date: Wed, 11 Sep 2024 16:30:13 -0700 Subject: [PATCH] support Models in files operations --- databricks/sdk/mixins/files.py | 18 +++++++++--------- tests/test_dbfs_mixins.py | 13 ++++++++----- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 1e109a1a7..47c11747d 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -167,7 +167,7 @@ def __repr__(self) -> str: return f"<_DbfsIO {self._path} {'read' if self.readable() else 'write'}=True>" -class _VolumesIO(BinaryIO): +class _FilesIO(BinaryIO): def __init__(self, api: files.FilesAPI, path: str, *, read: bool, write: bool, overwrite: bool): self._buffer = [] @@ -262,7 +262,7 @@ def __exit__(self, __t, __value, __traceback): self.close() def __repr__(self) -> str: - return f"<_VolumesIO {self._path} {'read' if self.readable() else 'write'}=True>" + return f"<_FilesIO {self._path} {'read' if self.readable() else 'write'}=True>" class _Path(ABC): @@ -398,7 +398,7 @@ def __repr__(self) -> str: return f'<_LocalPath {self._path}>' -class _VolumesPath(_Path): +class _FilesPath(_Path): def __init__(self, api: files.FilesAPI, src: Union[str, pathlib.Path]): self._path = pathlib.PurePosixPath(str(src).replace('dbfs:', '').replace('file:', '')) @@ -411,7 +411,7 @@ def _is_dbfs(self) -> bool: return False def child(self, path: str) -> Self: - return _VolumesPath(self._api, str(self._path / path)) + return _FilesPath(self._api, str(self._path / path)) def _is_dir(self) -> bool: try: @@ -431,7 +431,7 @@ def exists(self) -> bool: return self.is_dir def open(self, *, read=False, write=False, overwrite=False) -> BinaryIO: - return _VolumesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite) + return _FilesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite) def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: if not self.is_dir: @@ -458,13 +458,13 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: def delete(self, *, recursive=False): if self.is_dir: for entry in self.list(recursive=False): - _VolumesPath(self._api, entry.path).delete(recursive=True) + _FilesPath(self._api, entry.path).delete(recursive=True) self._api.delete_directory(self.as_string) else: self._api.delete(self.as_string) def __repr__(self) -> str: - return f'<_VolumesPath {self._path}>' + return f'<_FilesPath {self._path}>' class _DbfsPath(_Path): @@ -589,8 +589,8 @@ def _path(self, src): 'UC Volumes paths, not external locations or DBFS mount points.') if src.scheme == 'file': return _LocalPath(src.geturl()) - if src.path.startswith('/Volumes'): - return _VolumesPath(self._files_api, src.geturl()) + if src.path.startswith(('/Volumes', '/Models')): + return _FilesPath(self._files_api, src.geturl()) return _DbfsPath(self._dbfs_api, src.geturl()) def copy(self, src: str, dst: str, *, recursive=False, overwrite=False): diff --git a/tests/test_dbfs_mixins.py b/tests/test_dbfs_mixins.py index 6bbaca7a2..ce86a2a80 100644 --- a/tests/test_dbfs_mixins.py +++ b/tests/test_dbfs_mixins.py @@ -1,8 +1,8 @@ import pytest from databricks.sdk.errors import NotFound -from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _LocalPath, - _VolumesPath) +from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _FilesPath, + _LocalPath) def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker): @@ -55,11 +55,14 @@ def test_moving_local_dir_to_dbfs(config, tmp_path, mocker): @pytest.mark.parametrize('path,expected_type', [('/path/to/file', _DbfsPath), - ('/Volumes/path/to/file', _VolumesPath), + ('/Volumes/path/to/file', _FilesPath), + ('/Models/path/to/file', _FilesPath), ('dbfs:/path/to/file', _DbfsPath), - ('dbfs:/Volumes/path/to/file', _VolumesPath), + ('dbfs:/Volumes/path/to/file', _FilesPath), + ('dbfs:/Models/path/to/file', _FilesPath), ('file:/path/to/file', _LocalPath), - ('file:/Volumes/path/to/file', _LocalPath), ]) + ('file:/Volumes/path/to/file', _LocalPath), + ('file:/Models/path/to/file', _LocalPath), ]) def test_fs_path(config, path, expected_type): dbfs_ext = DbfsExt(config) assert isinstance(dbfs_ext._path(path), expected_type)