diff --git a/zetta_utils/geometry/bbox.py b/zetta_utils/geometry/bbox.py index 8a8fa7a2c..f01a13dd3 100644 --- a/zetta_utils/geometry/bbox.py +++ b/zetta_utils/geometry/bbox.py @@ -1,6 +1,7 @@ # pylint: disable=missing-docstring, no-else-raise from __future__ import annotations +import math from itertools import product from math import floor from typing import Literal, Optional, Sequence, Union, cast @@ -205,24 +206,23 @@ def get_slice( return slice(dim_range_start_raw, dim_range_end_raw) if not allow_slice_rounding: - if dim_range_start_raw != round(dim_range_start_raw): + if dim_range_start_raw != math.floor(dim_range_start_raw): raise ValueError( f"{self} results in slice_start == " f"{dim_range_start_raw} along dimension {dim} " f"at resolution == {resolution} while " "`allow_slice_rounding` == False." ) - if dim_range_end_raw != round(dim_range_end_raw): + if dim_range_end_raw != math.ceil(dim_range_end_raw): raise ValueError( f"{self} results in slice_end == " f"{dim_range_end_raw} along dimension {dim} " f"at resolution == {resolution} while " "`allow_slice_rounding` == False." ) - slice_length = int(round(dim_range_end_raw - dim_range_start_raw)) result = slice( - floor(dim_range_start_raw), - floor(dim_range_start_raw) + slice_length, + math.floor(dim_range_start_raw), + math.ceil(dim_range_end_raw), ) return result @@ -576,9 +576,10 @@ def pformat(self, resolution: Optional[Sequence[float]] = None) -> str: # pragm f"({s.join([str(slice.stop) for slice in slices])})" ) - def get_size(self) -> Union[int, float]: # pragma: no cover + def get_size( + self, resolution: Sequence[float] = (1, 1, 1) + ) -> Union[int, float]: # pragma: no cover """Returns the volume of the box, in base units (i.e. `nm^3`).""" - resolution = (1, 1, 1) slices = self.to_slices(resolution, round_to_int=False) size = 1 for _, slc in enumerate(slices):