Skip to content

Commit

Permalink
Move QueryData to TimeSeriesId logic out of Res1D into an adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Jan 17, 2024
1 parent a912c4c commit 5540ec7
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 35 deletions.
43 changes: 8 additions & 35 deletions mikeio1d/res1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
from .query import QueryDataStructure # noqa: F401
from .query import QueryDataGlobal # noqa: F401

from .result_query.query_data_adapter import QueryDataAdapter

from .various import mike1d_quantities # noqa: F401
from .various import NAME_DELIMITER
from .various import make_list_if_not_iterable

from .quantities import TimeSeriesId

Expand Down Expand Up @@ -164,44 +167,14 @@ def _get_timeseries_ids_to_read(
if queries is None:
return self.result_network.queue

timeseries_ids = self._validate_queries_as_timeseries_ids(queries)

return timeseries_ids

def _validate_queries_as_timeseries_ids(
self, queries: List[QueryData] | List[TimeSeriesId]
) -> List[TimeSeriesId]:
"""Validates the user supplied query(ies) and converts them to a TimeSeriesId objects.
Parameters
----------
queries : List[QueryData] | List[TimeSeriesId]
List of queries or time series ids supplied in read() method.
Returns
-------
List of TimeSeriesId objects.
"""
try:
iter(queries)
except TypeError:
queries = [queries]
queries = make_list_if_not_iterable(queries)

is_already_timeseries_id = isinstance(queries[0], TimeSeriesId)
if is_already_timeseries_id:
is_already_time_series_ids = isinstance(queries[0], TimeSeriesId)
if is_already_time_series_ids:
return queries

timeseries_ids = []
for q in queries:
q._update_query(self)
q._check_invalid_quantity(self)
tsid = q.to_timeseries_id()
if tsid.is_valid(self):
timeseries_ids.append(tsid)
else:
q._check_invalid_values(None)

return timeseries_ids
queries = QueryDataAdapter.convert_queries_to_time_series_ids(self, queries)
return queries

# endregion Private methods

Expand Down
64 changes: 64 additions & 0 deletions mikeio1d/result_query/query_data_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import List

from ..res1d import Res1D
from .query_data import QueryData

from ..quantities import TimeSeriesId


class QueryDataAdapter:
"""Adapter class for converting QueryData TimeSeriesId.
Parameters
----------
query : QueryData
Query object to adapt.
"""

def __init__(self, query: QueryData):
self._query = query

def to_timeseries_id(self, res1d: Res1D) -> TimeSeriesId:
"""Convert query to timeseries id."""
self._validate_query_with_res1d(res1d)

time_series_id = self._query.to_timeseries_id()

if not time_series_id.is_valid(res1d):
self._query._check_invalid_values(None)

return time_series_id

def _validate_query_with_res1d(self, res1d: Res1D):
"""Validate query with res1d object."""
query = self._query

query._update_query(res1d)
query._check_invalid_quantity(res1d)

@staticmethod
def convert_queries_to_time_series_ids(
res1d: Res1D, queries: list[QueryData]
) -> List[TimeSeriesId]:
"""
Convert queries to TimeSeriesId objects.
Parameters
----------
res1d : Res1D
Res1D object (required for query validation)
queries : list[QueryData]
List of query objects to convert.
Returns
-------
List[TimeSeriesId]
List of timeseries ids.
"""
return [QueryDataAdapter(query).to_timeseries_id(res1d) for query in queries]
20 changes: 20 additions & 0 deletions mikeio1d/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,23 @@ def pyproj_crs_from_projection_string(projection_string: str):
except Exception:
warnings.warn("Could not parse projection string. Returning None.")
return None


def make_list_if_not_iterable(obj) -> list:
"""
Boxes non-iterable objects into a list.
Parameters
----------
obj : object
Object to box.
Returns
-------
list
List with one element if obj is not iterable, otherwise obj.
"""
if not hasattr(obj, "__iter__"):
return [obj]
else:
return obj

0 comments on commit 5540ec7

Please sign in to comment.