diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index d4b1c3da..6617a29d 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -43,6 +43,8 @@ class BaseRepository: """ def __init__(self, root_dir: str, **storage_options): + self._df_storage_options = {} # should only be non-empty for S3 logging + self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options) self.root_dir = root_dir.rstrip("/") @@ -614,7 +616,7 @@ def _persist_dataframe( df.write_parquet(path) else: # Dask or pandas - df.to_parquet(path, engine="pyarrow") + df.to_parquet(path, engine="pyarrow", storage_options=self._df_storage_options) def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "pandas"): """Reads the dataframe `df` from the configured filesystem.""" @@ -623,7 +625,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = " if df_type == "pandas": path = f"{path}/data.parquet" - df = pd.read_parquet(path, engine="pyarrow") + df = pd.read_parquet(path, engine="pyarrow", storage_options=self._df_storage_options) elif df_type == "polars": try: from polars import read_parquet @@ -633,7 +635,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = " "to read dataframes with `df_type`='polars'. `pip install polars` " "or `conda install polars` to continue." ) - df = read_parquet(path) + df = read_parquet(path, storage_options=self._df_storage_options) elif df_type == "dask": try: @@ -645,7 +647,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = " "or `conda install dask` to continue." ) - df = dd.read_parquet(path, engine="pyarrow") + df = dd.read_parquet(path, engine="pyarrow", storage_options=self._df_storage_options) else: raise ValueError(f"`df_type` must be one of {acceptable_types}") diff --git a/rubicon_ml/repository/memory.py b/rubicon_ml/repository/memory.py index 9a7ea714..5416ef80 100644 --- a/rubicon_ml/repository/memory.py +++ b/rubicon_ml/repository/memory.py @@ -28,6 +28,8 @@ class MemoryRepository(LocalRepository): PROTOCOL = "memory" def __init__(self, root_dir=None, **storage_options): + self._df_storage_options = {} # should only be non-empty for S3 logging + self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options) self.root_dir = root_dir.rstrip("/") if root_dir is not None else "/root" diff --git a/rubicon_ml/repository/s3.py b/rubicon_ml/repository/s3.py index 76bf3e7d..e37fa008 100644 --- a/rubicon_ml/repository/s3.py +++ b/rubicon_ml/repository/s3.py @@ -1,3 +1,5 @@ +import fsspec + from rubicon_ml.repository import BaseRepository from rubicon_ml.repository.utils import json @@ -18,6 +20,12 @@ class S3Repository(BaseRepository): PROTOCOL = "s3" + def __init__(self, root_dir: str, **storage_options): + self._df_storage_options = storage_options + + self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options) + self.root_dir = root_dir.rstrip("/") + def _persist_bytes(self, bytes_data, path): """Persists the raw bytes `bytes_data` to the S3 bucket defined by `path`. diff --git a/tests/fixtures.py b/tests/fixtures.py index 87136d7a..2d0fad81 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -119,6 +119,7 @@ def rubicon_local_filesystem_client(): rubicon = Rubicon( persistence="filesystem", root_dir=os.path.join(os.path.dirname(os.path.realpath(__file__)), "rubicon"), + storage_option_a="test", # should be ignored when logging local dfs ) # teardown after yield @@ -221,7 +222,7 @@ def test_dataframe(): def memory_repository(): """Setup an in-memory repository and clean it up afterwards.""" root_dir = "/in-memory-root" - repository = MemoryRepository(root_dir) + repository = MemoryRepository(root_dir, storage_option_a="test") yield repository repository.filesystem.rm(root_dir, recursive=True) diff --git a/tests/unit/repository/test_base_repo.py b/tests/unit/repository/test_base_repo.py index df3b9c88..6156fed2 100644 --- a/tests/unit/repository/test_base_repo.py +++ b/tests/unit/repository/test_base_repo.py @@ -403,7 +403,11 @@ def test_persist_dataframe(mock_to_parquet, memory_repository): # calls `BaseRepository._persist_dataframe` despite class using `MemoryRepository` super(MemoryRepository, repository)._persist_dataframe(df, path) - mock_to_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow") + mock_to_parquet.assert_called_once_with( + f"{path}/data.parquet", + engine="pyarrow", + storage_options={}, + ) @patch("polars.DataFrame.write_parquet") @@ -426,7 +430,11 @@ def test_read_dataframe(mock_read_parquet, memory_repository): # calls `BaseRepository._read_dataframe` despite class using `MemoryRepository` super(MemoryRepository, repository)._read_dataframe(path) - mock_read_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow") + mock_read_parquet.assert_called_once_with( + f"{path}/data.parquet", + engine="pyarrow", + storage_options={}, + ) def test_read_dataframe_value_error(memory_repository): diff --git a/tests/unit/repository/test_s3_repo.py b/tests/unit/repository/test_s3_repo.py index 33a660a1..75cb55d8 100644 --- a/tests/unit/repository/test_s3_repo.py +++ b/tests/unit/repository/test_s3_repo.py @@ -1,6 +1,7 @@ import uuid from unittest.mock import patch +import pandas as pd import pytest import s3fs @@ -49,3 +50,30 @@ def test_persist_domain_throws_error(mock_open): s3_repo._persist_domain(project, project_metadata_path) mock_open.assert_not_called() + + +@patch("s3fs.core.S3FileSystem.mkdirs") +@patch("pandas.DataFrame.to_parquet") +def test_persist_dataframe(mock_to_parquet, mock_mkdirs): + s3_repo = S3Repository(root_dir="s3://bucket/root", storage_option_a="test") + df = pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"]) + + s3_repo._persist_dataframe(df, s3_repo.root_dir) + + mock_to_parquet.assert_called_once_with( + f"{s3_repo.root_dir}/data.parquet", + engine="pyarrow", + storage_options={"storage_option_a": "test"}, + ) + + +@patch("pandas.read_parquet") +def test_read_dataframe(mock_read_parquet): + s3_repo = S3Repository(root_dir="s3://bucket/root", storage_option_a="test") + s3_repo._read_dataframe(s3_repo.root_dir) + + mock_read_parquet.assert_called_once_with( + f"{s3_repo.root_dir}/data.parquet", + engine="pyarrow", + storage_options={"storage_option_a": "test"}, + )