diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 4a16d81f91..d23fea0e47 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -6,7 +6,7 @@ import weakref from abc import ABC, abstractmethod from functools import wraps -from typing import Optional, Union +from typing import Literal, Optional, Union import numpy as np from more_itertools import always_iterable @@ -445,6 +445,9 @@ def _get_by_attribute( attribute: str, value: Union[unyt_quantity, tuple[float, str]], tolerance: Union[None, unyt_quantity, tuple[float, str]] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", ) -> "Dataset": r""" Get a dataset at or near to a given value. @@ -462,8 +465,16 @@ def _get_by_attribute( within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. """ + if side not in ("nearest", "smaller", "larger"): + raise ValueError( + f"side must be 'nearest', 'smaller' or 'larger', got {side}" + ) + # Use a binary search to find the closest value iL = 0 iH = len(self._pre_outputs) - 1 @@ -518,7 +529,13 @@ def _get_by_attribute( dsL = dsH = dsM break - if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)): + if side == "smaller": + ds_best = dsL if sign > 0 else dsH + elif side == "larger": + ds_best = dsH if sign > 0 else dsL + elif abs(value - getattr(dsL, attribute)) < abs( + value - getattr(dsH, attribute) + ): ds_best = dsL else: ds_best = dsH @@ -534,6 +551,9 @@ def get_by_time( self, time: Union[unyt_quantity, tuple], tolerance: Union[None, unyt_quantity, tuple] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", ): """ Get a dataset at or near to a given time. @@ -547,6 +567,9 @@ def get_by_time( within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. Examples -------- @@ -554,9 +577,18 @@ def get_by_time( >>> t = ts[0].quan(12, "Gyr") ... ds = ts.get_by_time(t, tolerance=(100, "Myr")) """ - return self._get_by_attribute("current_time", time, tolerance=tolerance) + return self._get_by_attribute( + "current_time", time, tolerance=tolerance, side=side + ) - def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): + def get_by_redshift( + self, + redshift: float, + tolerance: Optional[float] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", + ): """ Get a dataset at or near to a given time. @@ -569,6 +601,9 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. Examples --------