diff --git a/giskard/datasets/base/__init__.py b/giskard/datasets/base/__init__.py index 233cba66b6..4d7a48b4f3 100644 --- a/giskard/datasets/base/__init__.py +++ b/giskard/datasets/base/__init__.py @@ -34,8 +34,6 @@ if TYPE_CHECKING: from mlflow import MlflowClient -SAMPLE_SIZE = 1000 - logger = logging.getLogger(__name__) @@ -526,10 +524,22 @@ def cast_column_to_dtypes(df, column_dtypes): @classmethod def load(cls, local_path: str): - with open(local_path, "rb") as ds_stream: - return pd.read_csv( - ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"] - ) + # load metadata + with open(Path(local_path) / "giskard-dataset-meta.yaml", "r") as meta_f: + meta = yaml.safe_load(meta_f) + + # load data + with open(Path(local_path) / "data.csv.zst", "rb") as ds_stream: + df = pd.read_csv(ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"]) + + return cls( + df, + name=meta.get("name"), + target=meta.get("target"), + cat_columns=[k for k in meta["category_features"].keys()], + column_types=meta.get("column_types"), + original_id=meta.get("id"), + ) @staticmethod def _cat_columns(meta): @@ -543,21 +553,17 @@ def _cat_columns(meta): def cat_columns(self): return self._cat_columns(self.meta) - def save(self, local_path: Path, dataset_id): - with open(local_path / "data.csv.zst", "wb") as f, open(local_path / "data.sample.csv.zst", "wb") as f_sample: + def save(self, local_path: str): + with (open(Path(local_path) / "data.csv.zst", "wb") as f,): uncompressed_bytes = save_df(self.df) compressed_bytes = compress(uncompressed_bytes) f.write(compressed_bytes) original_size_bytes, compressed_size_bytes = len(uncompressed_bytes), len(compressed_bytes) - uncompressed_bytes = save_df(self.df.sample(min(SAMPLE_SIZE, len(self.df.index)))) - compressed_bytes = compress(uncompressed_bytes) - f_sample.write(compressed_bytes) - with open(Path(local_path) / "giskard-dataset-meta.yaml", "w") as meta_f: yaml.dump( { - "id": dataset_id, + "id": str(self.id), "name": self.meta.name, "target": self.meta.target, "column_types": self.meta.column_types, diff --git a/tests/datasets/test_dataset_serialization.py b/tests/datasets/test_dataset_serialization.py new file mode 100644 index 0000000000..c572fb77bb --- /dev/null +++ b/tests/datasets/test_dataset_serialization.py @@ -0,0 +1,57 @@ +import tempfile + +import pandas as pd +import pytest + +from giskard.datasets import Dataset + + +@pytest.mark.parametrize( + "dataset", + [ + Dataset( + pd.DataFrame( + { + "question": [ + "What is the capital of France?", + "What is the capital of Germany?", + ] + } + ), + column_types={"question": "text"}, + target=None, + ), + Dataset( + pd.DataFrame( + { + "country": ["France", "Germany", "France", "Germany", "France"], + "capital": ["Paris", "Berlin", "Paris", "Berlin", "Paris"], + } + ), + column_types={"country": "category", "capital": "category"}, + cat_columns=["country", "capital"], + target=None, + ), + Dataset( + pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "y": [2, 4, 6, 8, 10], + } + ), + column_types={"x": "numeric", "y": "numeric"}, + target="y", + ), + ], + ids=["text", "category", "numeric"], +) +def test_save_and_load_dataset(dataset: Dataset): + with tempfile.TemporaryDirectory() as tmp_test_folder: + dataset.save(tmp_test_folder) + + loaded_dataset = Dataset.load(tmp_test_folder) + + assert loaded_dataset.id != dataset.id + assert loaded_dataset.original_id == dataset.id + assert pd.DataFrame.equals(loaded_dataset.df, dataset.df) + assert loaded_dataset.meta == dataset.meta