From 5047243b47c54c10d52c25ad441ac818e409a0d7 Mon Sep 17 00:00:00 2001 From: Knut Rand Date: Thu, 5 Sep 2024 11:07:56 +0200 Subject: [PATCH] feature: samples to/from csv --- climate_health/datatypes.py | 14 ++++++++++++++ tests/test_datatypes.py | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/climate_health/datatypes.py b/climate_health/datatypes.py index da2fda8..2158969 100644 --- a/climate_health/datatypes.py +++ b/climate_health/datatypes.py @@ -260,6 +260,20 @@ class SummaryStatistics(TimeSeriesData): class Samples(TimeSeriesData): samples: float + def topandas(self): + n_samples = self.samples.shape[-1] + df = pd.DataFrame({'time_period': self.time_period.topandas()} | {f'sample_{i}': self.samples[:, i] for i in + range(n_samples)}) + return df + + @classmethod + def from_pandas(cls, data: pd.DataFrame, fill_missing=False) -> 'TimeSeriesData': + ptime = PeriodRange.from_strings(data.time_period.astype(str), fill_missing=fill_missing) + n_samples = sum(1 for col in data.columns if col.startswith('sample_')) + samples = np.array([data[f'sample_{i}'].values for i in range(n_samples)]).T + return cls(ptime, samples) + + to_pandas = topandas @dataclasses.dataclass class Quantile: diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index c423e82..54e7de7 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -3,7 +3,7 @@ import bionumpy as bnp import pytest from bionumpy.util.testing import assert_bnpdataclass_equal -from climate_health.datatypes import ClimateHealthTimeSeries, HealthData +from climate_health.datatypes import ClimateHealthTimeSeries, HealthData, Samples from climate_health.spatio_temporal_data.temporal_dataclass import DataSet from climate_health.time_period import PeriodRange from climate_health.time_period.dataclasses import Year @@ -59,3 +59,17 @@ def test_dataset_with_missing(dataset_with_missing): for location, data in health_data.items(): # assert data.start_timestamp == start assert data.end_timestamp == end + +@pytest.fixture() +def samples(): + time_period = PeriodRange.from_strings(['2010', '2011', '2012']) + return Samples( + time_period=time_period, + samples=np.random.rand(3, 10)) + + +def test_samples(samples, tmp_path): + path = tmp_path/'samples.csv' + samples.to_csv(path) + samples2 = Samples.from_csv(path) + assert_bnpdataclass_equal(samples, samples2)