Skip to content

Commit

Permalink
Provide 'side' to pick whether we want the closest, smaller or larger…
Browse files Browse the repository at this point in the history
… value
  • Loading branch information
cphyc committed Nov 1, 2023
1 parent d50bd5d commit a41c86a
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions yt/data_objects/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import weakref
from abc import ABC, abstractmethod
from functools import wraps
from typing import Optional, Union
from typing import Literal, Optional, Union

import numpy as np
from more_itertools import always_iterable
Expand Down Expand Up @@ -445,6 +445,9 @@ def _get_by_attribute(
attribute: str,
value: Union[unyt_quantity, tuple[float, str]],
tolerance: Union[None, unyt_quantity, tuple[float, str]] = None,
side: Union[
Literal["nearest"], Literal["smaller"], Literal["larger"]
] = "nearest",
) -> "Dataset":
r"""
Get a dataset at or near to a given value.
Expand All @@ -462,8 +465,16 @@ def _get_by_attribute(
within the tolerance value. If None, simply return the
nearest dataset.
Default: None.
side : str
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
Default: 'nearest'.
"""

if side not in ("nearest", "smaller", "larger"):
raise ValueError(
f"side must be 'nearest', 'smaller' or 'larger', got {side}"
)

# Use a binary search to find the closest value
iL = 0
iH = len(self._pre_outputs) - 1
Expand Down Expand Up @@ -518,7 +529,13 @@ def _get_by_attribute(
dsL = dsH = dsM
break

if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)):
if side == "smaller":
ds_best = dsL if sign > 0 else dsH
elif side == "larger":
ds_best = dsH if sign > 0 else dsL
elif abs(value - getattr(dsL, attribute)) < abs(
value - getattr(dsH, attribute)
):
ds_best = dsL
else:
ds_best = dsH
Expand All @@ -534,6 +551,9 @@ def get_by_time(
self,
time: Union[unyt_quantity, tuple],
tolerance: Union[None, unyt_quantity, tuple] = None,
side: Union[
Literal["nearest"], Literal["smaller"], Literal["larger"]
] = "nearest",
):
"""
Get a dataset at or near to a given time.
Expand All @@ -547,16 +567,28 @@ def get_by_time(
within the tolerance value. If None, simply return the
nearest dataset.
Default: None.
side : str
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
Default: 'nearest'.
Examples
--------
>>> ds = ts.get_by_time((12, "Gyr"))
>>> t = ts[0].quan(12, "Gyr")
... ds = ts.get_by_time(t, tolerance=(100, "Myr"))
"""
return self._get_by_attribute("current_time", time, tolerance=tolerance)
return self._get_by_attribute(
"current_time", time, tolerance=tolerance, side=side
)

def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
def get_by_redshift(
self,
redshift: float,
tolerance: Optional[float] = None,
side: Union[
Literal["nearest"], Literal["smaller"], Literal["larger"]
] = "nearest",
):
"""
Get a dataset at or near to a given time.
Expand All @@ -569,12 +601,17 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None):
within the tolerance value. If None, simply return the
nearest dataset.
Default: None.
side : str
The side of the value to return. Can be 'nearest', 'smaller' or 'larger'.
Default: 'nearest'.
Examples
--------
>>> ds = ts.get_by_redshift(0.0)
"""
return self._get_by_attribute("current_redshift", redshift, tolerance=tolerance)
return self._get_by_attribute(
"current_redshift", redshift, tolerance=tolerance, side=side
)


class TimeSeriesQuantitiesContainer:
Expand Down

0 comments on commit a41c86a

Please sign in to comment.