diff --git a/mikeio1d/res1d.py b/mikeio1d/res1d.py index e9fcd39f..4a0f6353 100644 --- a/mikeio1d/res1d.py +++ b/mikeio1d/res1d.py @@ -32,8 +32,11 @@ from .query import QueryDataStructure # noqa: F401 from .query import QueryDataGlobal # noqa: F401 +from .result_query.query_data_adapter import QueryDataAdapter + from .various import mike1d_quantities # noqa: F401 from .various import NAME_DELIMITER +from .various import make_list_if_not_iterable from .quantities import TimeSeriesId @@ -164,44 +167,14 @@ def _get_timeseries_ids_to_read( if queries is None: return self.result_network.queue - timeseries_ids = self._validate_queries_as_timeseries_ids(queries) - - return timeseries_ids - - def _validate_queries_as_timeseries_ids( - self, queries: List[QueryData] | List[TimeSeriesId] - ) -> List[TimeSeriesId]: - """Validates the user supplied query(ies) and converts them to a TimeSeriesId objects. - - Parameters - ---------- - queries : List[QueryData] | List[TimeSeriesId] - List of queries or time series ids supplied in read() method. - - Returns - ------- - List of TimeSeriesId objects. - """ - try: - iter(queries) - except TypeError: - queries = [queries] + queries = make_list_if_not_iterable(queries) - is_already_timeseries_id = isinstance(queries[0], TimeSeriesId) - if is_already_timeseries_id: + is_already_time_series_ids = isinstance(queries[0], TimeSeriesId) + if is_already_time_series_ids: return queries - timeseries_ids = [] - for q in queries: - q._update_query(self) - q._check_invalid_quantity(self) - tsid = q.to_timeseries_id() - if tsid.is_valid(self): - timeseries_ids.append(tsid) - else: - q._check_invalid_values(None) - - return timeseries_ids + queries = QueryDataAdapter.convert_queries_to_time_series_ids(self, queries) + return queries # endregion Private methods diff --git a/mikeio1d/result_query/query_data_adapter.py b/mikeio1d/result_query/query_data_adapter.py new file mode 100644 index 00000000..14235c13 --- /dev/null +++ b/mikeio1d/result_query/query_data_adapter.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import List + + from ..res1d import Res1D + from .query_data import QueryData + +from ..quantities import TimeSeriesId + + +class QueryDataAdapter: + """Adapter class for converting QueryData TimeSeriesId. + + Parameters + ---------- + query : QueryData + Query object to adapt. + + """ + + def __init__(self, query: QueryData): + self._query = query + + def to_timeseries_id(self, res1d: Res1D) -> TimeSeriesId: + """Convert query to timeseries id.""" + self._validate_query_with_res1d(res1d) + + time_series_id = self._query.to_timeseries_id() + + if not time_series_id.is_valid(res1d): + self._query._check_invalid_values(None) + + return time_series_id + + def _validate_query_with_res1d(self, res1d: Res1D): + """Validate query with res1d object.""" + query = self._query + + query._update_query(res1d) + query._check_invalid_quantity(res1d) + + @staticmethod + def convert_queries_to_time_series_ids( + res1d: Res1D, queries: list[QueryData] + ) -> List[TimeSeriesId]: + """ + Convert queries to TimeSeriesId objects. + + Parameters + ---------- + res1d : Res1D + Res1D object (required for query validation) + queries : list[QueryData] + List of query objects to convert. + + Returns + ------- + List[TimeSeriesId] + List of timeseries ids. + """ + return [QueryDataAdapter(query).to_timeseries_id(res1d) for query in queries] diff --git a/mikeio1d/various.py b/mikeio1d/various.py index a84b0b80..fe1a7df0 100644 --- a/mikeio1d/various.py +++ b/mikeio1d/various.py @@ -50,3 +50,23 @@ def pyproj_crs_from_projection_string(projection_string: str): except Exception: warnings.warn("Could not parse projection string. Returning None.") return None + + +def make_list_if_not_iterable(obj) -> list: + """ + Boxes non-iterable objects into a list. + + Parameters + ---------- + obj : object + Object to box. + + Returns + ------- + list + List with one element if obj is not iterable, otherwise obj. + """ + if not hasattr(obj, "__iter__"): + return [obj] + else: + return obj