Skip to content

Commit

Permalink
chore: combine axestuple
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii committed Jan 29, 2025
1 parent 44d8ca9 commit 87c0d3a
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 129 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def pylint(session: nox.Session) -> None:
Run pylint.
"""

session.install("pylint==3.2.*")
session.install("pylint==3.3.*")
session.install("-e.")
session.run("pylint", "boost_histogram", *session.posargs)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ messages_control.disable = [
"too-many-locals",
"too-many-return-statements",
"too-many-statements",
"too-many-positional-arguments",
"wrong-import-position",
]

Expand Down
131 changes: 127 additions & 4 deletions src/boost_histogram/axis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from __future__ import annotations

import copy
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
from functools import partial
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Iterator,
Literal,
TypedDict,
TypeVar,
Union,
)

import numpy as np # pylint: disable=unused-import

import boost_histogram

from .._core import axis as ca
from .._internal.traits import Traits
from .._internal.utils import cast, register
from .._internal.utils import cast, register, zip_strict
from . import transform
from ._axes_tuple import ArrayTuple, AxesTuple
from .transform import AxisTransform

__all__ = [
Expand Down Expand Up @@ -297,7 +307,7 @@ def __init__(
overflow: bool = True,
growth: bool = False,
circular: bool = False,
transform: AxisTransform | None = None,
transform: AxisTransform | None = None, # pylint: disable=redefined-outer-name
__dict__: dict[str, Any] | None = None,
):
"""
Expand Down Expand Up @@ -769,3 +779,116 @@ def _repr_args_(self) -> list[str]:

ret += super()._repr_args_()
return ret


class MGridOpts(TypedDict):
sparse: bool
indexing: Literal["ij", "xy"]


A = TypeVar("A", bound="ArrayTuple")


class ArrayTuple(tuple): # type: ignore[type-arg]
__slots__ = ()
# This is an exhaustive list as of NumPy 1.19
_REDUCTIONS = frozenset(("sum", "any", "all", "min", "max", "prod"))

def __getattr__(self, name: str) -> Any:
if name in self._REDUCTIONS:
return partial(getattr(np, name), np.broadcast_arrays(*self))

return self.__class__(getattr(a, name) for a in self)

def __dir__(self) -> list[str]:
names = dir(self.__class__) + dir("np.typing.NDArray[Any]")
return sorted(n for n in names if not n.startswith("_"))

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.__class__(a(*args, **kwargs) for a in self)

def broadcast(self: A) -> A:
"""
The arrays in this tuple will be compressed if possible to save memory.
Use this method to broadcast them out into their full memory
representation.
"""
return self.__class__(np.broadcast_arrays(*self))


B = TypeVar("B", bound="AxesTuple")


class AxesTuple(tuple): # type: ignore[type-arg]
__slots__ = ()
_MGRIDOPTS: ClassVar[MGridOpts] = {"sparse": True, "indexing": "ij"}

def __init__(self, __iterable: Iterable[Axis]) -> None:
for item in self:
if not isinstance(item, Axis):
raise TypeError(
f"Only an iterable of Axis supported in AxesTuple, got {item}"
)
super().__init__()

@property
def size(self) -> tuple[int, ...]:
return tuple(s.size for s in self)

@property
def extent(self) -> tuple[int, ...]:
return tuple(s.extent for s in self)

@property
def centers(self) -> ArrayTuple:
gen = (s.centers for s in self)
return ArrayTuple(np.meshgrid(*gen, **self._MGRIDOPTS))

@property
def edges(self) -> ArrayTuple:
gen = (s.edges for s in self)
return ArrayTuple(np.meshgrid(*gen, **self._MGRIDOPTS))

@property
def widths(self) -> ArrayTuple:
gen = (s.widths for s in self)
return ArrayTuple(np.meshgrid(*gen, **self._MGRIDOPTS))

def value(self, *indexes: float) -> tuple[float, ...]:
if len(indexes) != len(self):
raise IndexError(
"Must have the same number of arguments as the number of axes"
)
return tuple(self[i].value(indexes[i]) for i in range(len(indexes)))

def bin(self, *indexes: float) -> tuple[float, ...]:
if len(indexes) != len(self):
raise IndexError(
"Must have the same number of arguments as the number of axes"
)
return tuple(self[i].bin(indexes[i]) for i in range(len(indexes)))

def index(self, *values: float) -> tuple[float, ...]: # type: ignore[override, override]
if len(values) != len(self):
raise IndexError(
"Must have the same number of arguments as the number of axes"
)
return tuple(self[i].index(values[i]) for i in range(len(values)))

def __getitem__(self, item: Any) -> Any:
result = super().__getitem__(item)
return self.__class__(result) if isinstance(result, tuple) else result

def __getattr__(self, attr: str) -> tuple[Any, ...]:
return tuple(getattr(s, attr) for s in self)

def __setattr__(self, attr: str, values: Any) -> None:
try:
super().__setattr__(attr, values)
except AttributeError:
for s, v in zip_strict(self, values):
s.__setattr__(attr, v)

value.__doc__ = Axis.value.__doc__
index.__doc__ = Axis.index.__doc__
bin.__doc__ = Axis.bin.__doc__
121 changes: 0 additions & 121 deletions src/boost_histogram/axis/_axes_tuple.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/boost_histogram/axis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __dir__() -> list[str]:

class AxisTransform:
__slots__ = ("_this",)
_family: object
_family: ClassVar[object] # pylint: disable=declare-non-slot
_this: ca.transform._BaseTransform

def __init_subclass__(cls, *, family: object) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_histogram_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_pick_flowbin(ax):

def test_axes_tuple():
h = bh.Histogram(bh.axis.Regular(10, 0, 1))
assert isinstance(h.axes[:1], bh._internal.axestuple.AxesTuple)
assert isinstance(h.axes[:1], bh.axis.AxesTuple)
assert isinstance(h.axes[0], bh.axis.Regular)

(before,) = h.axes.centers[:1]
Expand All @@ -390,7 +390,7 @@ def test_axes_tuple_Nd():
h = bh.Histogram(
bh.axis.Integer(0, 5), bh.axis.Integer(0, 4), bh.axis.Integer(0, 6)
)
assert isinstance(h.axes[:2], bh._internal.axestuple.AxesTuple)
assert isinstance(h.axes[:2], bh.axis.AxesTuple)
assert isinstance(h.axes[1], bh.axis.Integer)

b1, b2 = h.axes.centers[1:3]
Expand Down

0 comments on commit 87c0d3a

Please sign in to comment.