Skip to content

Commit

Permalink
feature: samples to/from csv
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Sep 5, 2024
1 parent eea872f commit 5047243
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
14 changes: 14 additions & 0 deletions climate_health/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,20 @@ class SummaryStatistics(TimeSeriesData):
class Samples(TimeSeriesData):
samples: float

def topandas(self):
n_samples = self.samples.shape[-1]
df = pd.DataFrame({'time_period': self.time_period.topandas()} | {f'sample_{i}': self.samples[:, i] for i in
range(n_samples)})
return df

@classmethod
def from_pandas(cls, data: pd.DataFrame, fill_missing=False) -> 'TimeSeriesData':
ptime = PeriodRange.from_strings(data.time_period.astype(str), fill_missing=fill_missing)
n_samples = sum(1 for col in data.columns if col.startswith('sample_'))
samples = np.array([data[f'sample_{i}'].values for i in range(n_samples)]).T
return cls(ptime, samples)

to_pandas = topandas

@dataclasses.dataclass
class Quantile:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import bionumpy as bnp
import pytest
from bionumpy.util.testing import assert_bnpdataclass_equal
from climate_health.datatypes import ClimateHealthTimeSeries, HealthData
from climate_health.datatypes import ClimateHealthTimeSeries, HealthData, Samples
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period import PeriodRange
from climate_health.time_period.dataclasses import Year
Expand Down Expand Up @@ -59,3 +59,17 @@ def test_dataset_with_missing(dataset_with_missing):
for location, data in health_data.items():
# assert data.start_timestamp == start
assert data.end_timestamp == end

@pytest.fixture()
def samples():
time_period = PeriodRange.from_strings(['2010', '2011', '2012'])
return Samples(
time_period=time_period,
samples=np.random.rand(3, 10))


def test_samples(samples, tmp_path):
path = tmp_path/'samples.csv'
samples.to_csv(path)
samples2 = Samples.from_csv(path)
assert_bnpdataclass_equal(samples, samples2)

0 comments on commit 5047243

Please sign in to comment.