Skip to content

Commit

Permalink
feat: make nonexact slices always expand
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 14, 2025
1 parent 35674b5 commit 5c89534
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions zetta_utils/geometry/bbox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-docstring, no-else-raise
from __future__ import annotations

import math
from itertools import product
from math import floor
from typing import Literal, Optional, Sequence, Union, cast
Expand Down Expand Up @@ -205,24 +206,23 @@ def get_slice(
return slice(dim_range_start_raw, dim_range_end_raw)

if not allow_slice_rounding:
if dim_range_start_raw != round(dim_range_start_raw):
if dim_range_start_raw != math.floor(dim_range_start_raw):
raise ValueError(
f"{self} results in slice_start == "
f"{dim_range_start_raw} along dimension {dim} "
f"at resolution == {resolution} while "
"`allow_slice_rounding` == False."
)
if dim_range_end_raw != round(dim_range_end_raw):
if dim_range_end_raw != math.ceil(dim_range_end_raw):
raise ValueError(
f"{self} results in slice_end == "
f"{dim_range_end_raw} along dimension {dim} "
f"at resolution == {resolution} while "
"`allow_slice_rounding` == False."
)
slice_length = int(round(dim_range_end_raw - dim_range_start_raw))
result = slice(
floor(dim_range_start_raw),
floor(dim_range_start_raw) + slice_length,
math.floor(dim_range_start_raw),
math.ceil(dim_range_end_raw),
)
return result

Expand Down Expand Up @@ -576,9 +576,10 @@ def pformat(self, resolution: Optional[Sequence[float]] = None) -> str: # pragm
f"({s.join([str(slice.stop) for slice in slices])})"
)

def get_size(self) -> Union[int, float]: # pragma: no cover
def get_size(
self, resolution: Sequence[float] = (1, 1, 1)
) -> Union[int, float]: # pragma: no cover
"""Returns the volume of the box, in base units (i.e. `nm^3`)."""
resolution = (1, 1, 1)
slices = self.to_slices(resolution, round_to_int=False)
size = 1
for _, slc in enumerate(slices):
Expand Down

0 comments on commit 5c89534

Please sign in to comment.