+from typing import Generic, Iterable, Tuple, Type, Callable
+
+import numpy as np
+import pandas as pd
+
+from ..api_types import PeriodObservation
+from .._legacy_dataset import TemporalIndexType, FeaturesT
+from ..datatypes import Location, add_field, remove_field, TimeSeriesArray, TimeSeriesData
+from ..time_period import PeriodRange
+from ..time_period.date_util_wrapper import TimeStamp
+import dataclasses
+
+
+class TemporalDataclass(Generic[FeaturesT]):
+ '''
+ Wraps a dataclass in a object that is can be sliced by time period.
+ Call .data() to get the data back.
+ '''
+
+ def __init__(self, data: FeaturesT):
+ self._data = data
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self._data})'
+
+ def _restrict_by_slice(self, period_range: slice):
+ assert period_range.step is None
+ start, stop = (None, None)
+ if period_range.start is not None:
+ start = self._data.time_period.searchsorted(period_range.start)
+ if period_range.stop is not None:
+ stop = self._data.time_period.searchsorted(period_range.stop, side='right')
+ return self._data[start:stop]
+
+ def fill_to_endpoint(self, end_time_stamp: TimeStamp) -> 'TemporalDataclass[FeaturesT]':
+ if self.end_timestamp == end_time_stamp:
+ return self
+ n_missing = self._data.time_period.delta.n_periods(self.end_timestamp, end_time_stamp)
+ #n_missing = (end_time_stamp - self.end_timestamp) // self._data.time_period.delta
+ assert n_missing >= 0, (f'{n_missing} < 0', end_time_stamp, self.end_timestamp)
+ old_time_period = self._data.time_period
+ new_time_period = PeriodRange(old_time_period.start_timestamp, end_time_stamp, old_time_period.delta)
+ d = {field.name: getattr(self._data, field.name) for field in dataclasses.fields(self._data) if
+ field.name != 'time_period'}
+
+ for name, data in d.items():
+ d[name] = np.pad(data.astype(float), (0, n_missing),
+ constant_values=np.nan)
+ return TemporalDataclass(
+ self._data.__class__(new_time_period, **d))
+
+ def fill_to_range(self, start_timestamp, end_timestamp):
+ if self.end_timestamp == end_timestamp and self.start_timestamp==start_timestamp:
+ return self
+ n_missing_start = self._data.time_period.delta.n_periods(start_timestamp, self.start_timestamp)
+ #n_missing_start = (self.start_timestamp - start_timestamp) // self._data.time_period.delta
+ n_missing = (end_timestamp - self.end_timestamp) // self._data.time_period.delta
+ assert n_missing >= 0, (f'{n_missing} < 0', end_timestamp, self.end_timestamp)
+ assert n_missing_start >= 0, (f'{n_missing} < 0', end_timestamp, self.end_timestamp)
+ old_time_period = self._data.time_period
+ new_time_period = PeriodRange(start_timestamp, end_timestamp, old_time_period.delta)
+ d = {field.name: getattr(self._data, field.name) for field in dataclasses.fields(self._data) if
+ field.name != 'time_period'}
+
+ for name, data in d.items():
+ d[name] = np.pad(data.astype(float), (n_missing_start, n_missing),
+ constant_values=np.nan)
+ return TemporalDataclass(
+ self._data.__class__(new_time_period, **d))
+
+ def restrict_time_period(self, period_range: TemporalIndexType) -> 'TemporalDataclass[FeaturesT]':
+ assert isinstance(period_range, slice)
+ assert period_range.step is None
+ if hasattr(self._data.time_period, 'searchsorted'):
+ return TemporalDataclass(self._restrict_by_slice(period_range))
+ mask = np.full(len(self._data.time_period), True)
+ if period_range.start is not None:
+ mask = mask & (self._data.time_period >= period_range.start)
+ if period_range.stop is not None:
+ mask = mask & (self._data.time_period <= period_range.stop)
+ return TemporalDataclass(self._data[mask])
+
+ def data(self) -> Iterable[FeaturesT]:
+ return self._data
+
+ def to_pandas(self) -> pd.DataFrame:
+ return self._data.to_pandas()
+
+ def join(self, other):
+ return TemporalDataclass(np.concatenate([self._data, other._data]))
+
+ @property
+ def start_timestamp(self) -> pd.Timestamp:
+ return self._data.time_period[0].start_timestamp
+
+ @property
+ def end_timestamp(self) -> pd.Timestamp:
+ return self._data.time_period[-1].end_timestamp
+
+
+
+
[docs]
+
class DataSet(Generic[FeaturesT]):
+
'''
+
Class representing severeal time series at different locations.
+
'''
+
+
def __init__(self, data_dict: dict[str, FeaturesT]):
+
self._data_dict = {loc: TemporalDataclass(data) if not isinstance(data, TemporalDataclass) else data for
+
loc, data in data_dict.items()}
+
+
def __repr__(self):
+
return f'{self.__class__.__name__}({self._data_dict})'
+
+
def __getitem__(self, location: Location) -> TemporalDataclass[FeaturesT]:
+
return self._data_dict[location].data()
+
+
def keys(self):
+
return self._data_dict.keys()
+
+
def items(self):
+
return ((k, d.data()) for k, d in self._data_dict.items())
+
+
def values(self):
+
return (d.data() for d in self._data_dict.values())
+
+
@property
+
def period_range(self) -> PeriodRange:
+
first_period_range = self._data_dict[next(iter(self._data_dict))].data().time_period
+
assert first_period_range.start_timestamp == first_period_range.start_timestamp
+
assert first_period_range.end_timestamp == first_period_range.end_timestamp
+
return first_period_range
+
+
@property
+
def start_timestamp(self) -> pd.Timestamp:
+
return min(data.start_timestamp for data in self.data())
+
+
@property
+
def end_timestamp(self) -> pd.Timestamp:
+
return max(data.end_timestamp for data in self.data())
+
+
def get_locations(self, location: Iterable[Location]) -> 'DataSet[FeaturesT]':
+
return self.__class__({loc: self._data_dict[loc] for loc in location})
+
+
def get_location(self, location: Location) -> FeaturesT:
+
return self._data_dict[location]
+
+
def restrict_time_period(self, period_range: TemporalIndexType) -> 'DataSet[FeaturesT]':
+
return self.__class__(
+
{loc: data.restrict_time_period(period_range) for loc, data in self._data_dict.items()})
+
+
def locations(self) -> Iterable[Location]:
+
return self._data_dict.keys()
+
+
def data(self) -> Iterable[FeaturesT]:
+
return self._data_dict.values()
+
+
#def items(self) -> Iterable[Tuple[Location, FeaturesT]]:
+
# return self._data_dict.items()
+
+
def _add_location_to_dataframe(self, df, location):
+
df['location'] = location
+
return df
+
+
+
[docs]
+
def to_pandas(self) -> pd.DataFrame:
+
''' Join the pandas frame for all locations with locations as column'''
+
tables = [self._add_location_to_dataframe(data.to_pandas(), location) for location, data in
+
self._data_dict.items()]
+
return pd.concat(tables)
+
+
+
def interpolate(self):
+
return self.__class__(
+
{loc: TemporalDataclass(data.data().interpolate()) for loc, data in self.items()})
+
+
@classmethod
+
def _fill_missing(cls, data_dict: dict[str, TemporalDataclass[FeaturesT]]):
+
''' Fill missing values in a dictionary of TemporalDataclasses'''
+
end = max(data.end_timestamp for data in data_dict.values())
+
start = min(data.start_timestamp for data in data_dict.values())
+
for location, data in data_dict.items():
+
data_dict[location] = data.fill_to_range(start, end)
+
return data_dict
+
+
+
[docs]
+
@classmethod
+
def from_pandas(cls, df: pd.DataFrame, dataclass: Type[FeaturesT], fill_missing=False) -> 'DataSet[FeaturesT]':
+
'''
+
Create a SpatioTemporalDict from a pandas dataframe.
+
The dataframe needs to have a 'location' column, and a 'time_period' column.
+
The time_period columnt needs to have strings that can be parsed into a period.
+
All fields in the dataclass needs to be present in the dataframe.
+
If 'fill_missing' is True, missing values will be filled with np.nan. Else all the time series needs to be
+
consecutive.
+
+
+
Parameters
+
----------
+
df : pd.DataFrame
+
The dataframe
+
dataclass : Type[FeaturesT]
+
The dataclass to use for the time series
+
fill_missing : bool, optional
+
If missing values should be filled, by default False
+
+
Returns
+
-------
+
DataSet[FeaturesT]
+
The SpatioTemporalDict
+
+
Examples
+
--------
+
>>> import pandas as pd
+
>>> from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
+
>>> from climate_health.datatypes import HealthData
+
>>> df = pd.DataFrame({'location': ['Oslo', 'Oslo', 'Bergen', 'Bergen'],
+
... 'time_period': ['2020-01', '2020-02', '2020-01', '2020-02'],
+
... 'disease_cases': [10, 20, 30, 40]})
+
>>> DataSet.from_pandas(df, HealthData)
+
'''
+
data_dict = {}
+
for location, data in df.groupby('location'):
+
data_dict[location] = TemporalDataclass(dataclass.from_pandas(data, fill_missing))
+
data_dict = cls._fill_missing(data_dict)
+
+
return cls(data_dict)
+
+
+
def to_csv(self, file_name: str, mode='w'):
+
self.to_pandas().to_csv(file_name, mode=mode)
+
+
@classmethod
+
def df_from_pydantic_observations(cls, observations: list[PeriodObservation])-> TimeSeriesData:
+
df = pd.DataFrame([obs.model_dump() for obs in observations])
+
dataclass = TimeSeriesData.create_class_from_basemodel(type(observations[0]))
+
return dataclass.from_pandas(df)
+
+
+
[docs]
+
@classmethod
+
def from_period_observations(cls, observation_dict: dict[str, list[PeriodObservation]]) -> 'DataSet[TimeSeriesData]':
+
'''
+
Create a SpatioTemporalDict from a dictionary of PeriodObservations.
+
The keys are the location names, and the values are lists of PeriodObservations.
+
+
Parameters
+
----------
+
observation_dict : dict[str, list[PeriodObservation]]
+
The dictionary of observations
+
+
Returns
+
-------
+
DataSet[TimeSeriesData]
+
The SpatioTemporalDict
+
+
Examples
+
--------
+
>>> from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
+
>>> from climate_health.api_types import PeriodObservation
+
>>> class HealthObservation(PeriodObservation):
+
... disease_cases: int
+
>>> observations = {'Oslo': [HealthObservation(time_period='2020-01', disease_cases=10),
+
... HealthObservation(time_period='2020-02', disease_cases=20)]}
+
>>> DataSet.from_period_observations(observations)
+
>>> DataSet.to_pandas()
+
'''
+
data_dict = {}
+
for location, observations in observation_dict.items():
+
data_dict[location] = TemporalDataclass(cls.df_from_pydantic_observations(observations))
+
return cls(data_dict)
+
+
+
@classmethod
+
def from_csv(cls, file_name: str, dataclass: Type[FeaturesT]) -> 'DataSet[FeaturesT]':
+
return cls.from_pandas(pd.read_csv(file_name), dataclass)
+
+
def join_on_time(self, other: 'DataSet[FeaturesT]') -> 'DataSet[Tuple[FeaturesT, FeaturesT]]':
+
''' Join two SpatioTemporalDicts on time. Returns a new SpatioTemporalDict.
+
Assumes other is later in time.
+
'''
+
return self.__class__({loc: self._data_dict[loc].join(other._data_dict[loc]) for loc in self.locations()})
+
+
def add_fields(self, new_type, **kwargs: dict[str, Callable]):
+
return self.__class__({loc: add_field(data.data(), new_type, **{key: func(data.data()) for key, func in kwargs.items()}) for loc, data in self.items()})
+
+
def remove_field(self, field_name, new_class=None):
+
return self.__class__({loc: remove_field(data.data(), field_name, new_class) for loc, data in self.items()})
+
+
@classmethod
+
def from_fields(cls, dataclass: type[TimeSeriesData], fields: dict[str, 'DataSet[TimeSeriesArray]']):
+
start_timestamp = min(data.start_timestamp for data in fields.values())
+
end_timestamp = max(data.end_timestamp for data in fields.values())
+
period_range = PeriodRange(start_timestamp, end_timestamp, fields[next(iter(fields))].period_range.delta)
+
new_dict = {}
+
field_names = list(fields.keys())
+
#all_locations = {location for field in fields.values() for location in field.keys()}
+
common_locations = set.intersection(*[set(field.keys()) for field in fields.values()])
+
#for field, data in fields.items():
+
# assert set(data.keys()) == all_locations, (field, all_locations-set(data.keys()))
+
for location in common_locations:
+
new_dict[location] = dataclass(period_range, **{field: fields[field][location].fill_to_range(start_timestamp, end_timestamp).value for field in field_names})
+
return cls(new_dict)
+
+
+
+
+