Skip to content

Commit

Permalink
refactor: introduce parametrised nplike types (#2795)
Browse files Browse the repository at this point in the history
* chore: enable mypy for more modules

* refactor: allow numpylike to be parametrised by the arraylike

* fix: add missing type hint

* fix: use NDArray instead of ArrayLike

* refactor: rename `numpylike` -> `numpy_like`
  • Loading branch information
agoose77 authored Nov 3, 2023
1 parent c572295 commit 4ed2397
Show file tree
Hide file tree
Showing 154 changed files with 741 additions and 607 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ module = [
'awkward._backends.*',
'awkward.forms.*',
'awkward.types.*',
'awkward._errors',
'awkward._dispatch',
'awkward.index',
]
ignore_errors = false
ignore_missing_imports = true
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import awkward as ak
from awkward._kernels import KernelError
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._singleton import PublicSingleton
from awkward._typing import Callable, Tuple, TypeAlias, TypeVar

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from awkward._kernels import CupyKernel, NumpyKernel
from awkward._nplikes.cupy import Cupy
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Final

np = NumpyMetadata.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from awkward._backends.backend import Backend
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._typing import Callable, TypeAlias, TypeVar, cast
from awkward._util import UNSET, Sentinel

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from awkward._kernels import JaxKernel
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Final

np = NumpyMetadata.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from awkward._backends.dispatch import register_backend
from awkward._kernels import NumpyKernel
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Final

np = NumpyMetadata.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_backends/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from awkward._backends.dispatch import register_backend
from awkward._kernels import TypeTracerKernel
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.typetracer import MaybeNone, TypeTracer, TypeTracerArray
from awkward._typing import Final

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from awkward._typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from awkward._nplikes.numpylike import UfuncLike
from awkward._nplikes.numpy_like import UfuncLike
from awkward._reducers import Reducer
from awkward.contents.content import Content
from awkward.highlevel import Array
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from awkward._backends.backend import Backend
from awkward._backends.dispatch import backend_of
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._parameters import (
parameters_are_empty,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_categorical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()
numpy = Numpy.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/cling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import awkward as ak
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()
numpy = Numpy.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import awkward as ak
from awkward import _reducers
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._reducers import Reducer
from awkward._typing import Final, Self, TypeVar
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/jax/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from awkward._layout import wrap_layout
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Generic, TypeVar, Union
from awkward.contents import Content
from awkward.record import Record
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/numba/arrayview.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import awkward as ak
from awkward._behavior import behavior_of, overlay_behavior
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import awkward as ak
from awkward._backends.numpy import NumpyBackend
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import parameters_union

np = NumpyMetadata.instance()
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_connect/rdataframe/from_rdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from awkward._backends.numpy import NumpyBackend
from awkward._layout import wrap_layout
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward.types.numpytype import primitive_to_dtype

cpp_type_of = {
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import awkward as ak
from awkward._backends.backend import Backend
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Any, AxisMaybeNone, Literal
from awkward.contents.content import ActionType, Content
from awkward.errors import AxisError
Expand Down
18 changes: 11 additions & 7 deletions src/awkward/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy

from awkward._nplikes.numpylike import NumpyMetadata
from awkward._typing import Any, TypeVar
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import Any, ParamSpec, TypeVar

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -75,13 +75,13 @@ def __exit__(self, exception_type, exception_value, traceback):
if self.primary() is self:
self._slate.__dict__.clear()

def handle_exception(self, cls: type[E], exception: E) -> E:
def handle_exception(self, cls: type[E], exception: E):
if sys.version_info >= (3, 11, 0, "final"):
self.decorate_exception(cls, exception)
else:
raise self.decorate_exception(cls, exception)

def decorate_exception(self, cls: type[E], exception: E) -> E:
def decorate_exception(self, cls: type[E], exception: E) -> Exception:
if sys.version_info >= (3, 11, 0, "final"):
if issubclass(cls, (NotImplementedError, AssertionError)):
exception.add_note(
Expand All @@ -91,6 +91,7 @@ def decorate_exception(self, cls: type[E], exception: E) -> E:
exception.add_note(self.note)
return exception
else:
new_exception: Exception
if issubclass(cls, (NotImplementedError, AssertionError)):
# Raise modified exception
new_exception = cls(
Expand Down Expand Up @@ -212,6 +213,8 @@ def any_backend_is_delayed(
return False

def __init__(self, name, args: Iterable[Any], kwargs: Mapping[str, Any]):
string_args: list[str] | PartialFunction
string_kwargs: dict[str, str] | PartialFunction
if self.primary() is None and (
self.any_backend_is_delayed(args)
or self.any_backend_is_delayed(kwargs.values())
Expand Down Expand Up @@ -419,12 +422,13 @@ def deprecate(
warnings.warn(warning, category, stacklevel=stacklevel + 1)


T = TypeVar("T", bound=Callable)
T = TypeVar("T")
P = ParamSpec("P")


def with_operation_context(func: T) -> T:
def with_operation_context(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# NOTE: this decorator assumes that the operation is exposed under `ak.`
with OperationErrorContext(f"ak.{func.__qualname__}", args, kwargs):
return func(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from awkward._nplikes.cupy import Cupy
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.typetracer import try_touch_data
from awkward._typing import Protocol, TypeAlias

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from awkward._nplikes.dispatch import nplike_of_obj
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import TYPE_CHECKING
from awkward.errors import AxisError

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_lookup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()

Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_nplikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from awkward._typing import TYPE_CHECKING

if TYPE_CHECKING:
from awkward._nplikes.numpylike import ArrayLike, NumpyLike
from awkward._nplikes.numpy_like import NumpyLike
from awkward._nplikes.array_like import ArrayLike


def to_nplike(
Expand Down
Loading

0 comments on commit 4ed2397

Please sign in to comment.