Skip to content

Commit

Permalink
Merge branch 'main' into pfackeldey/weakref_Array_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey authored Dec 17, 2024
2 parents 7b28ed8 + 564126d commit 662b8cb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 43 deletions.
62 changes: 19 additions & 43 deletions src/awkward/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import builtins
import sys
import threading
import warnings
from collections.abc import Callable, Collection, Iterable, Mapping
Expand Down Expand Up @@ -51,11 +49,6 @@ def __call__(self):
return self.func(*self.args, **self.kwargs)


class KeyError(builtins.KeyError):
def __str__(self):
return super(Exception, self).__str__()


class ErrorContext:
# Any other threads should get a completely independent _slate.
_slate = threading.local()
Expand All @@ -75,50 +68,33 @@ def __enter__(self):
self._slate.__dict__["__primary_context__"] = self

def __exit__(self, exception_type, exception_value, traceback):
try:
if (
exception_type is not None
and issubclass(exception_type, Exception)
and self.primary() is self
):
# Step out of the way so that another ErrorContext can become primary.
# Is this necessary to do here? (We're about to raise an exception anyway)
self._slate.__dict__.clear()
# Handle caught exception
if (
exception_type is not None
and issubclass(exception_type, Exception)
and self.primary() is self
):
self.handle_exception(exception_type, exception_value)
finally:
raise self.decorate_exception(exception_type, exception_value)
else:
# Step out of the way so that another ErrorContext can become primary.
if self.primary() is self:
self._slate.__dict__.clear()

def handle_exception(self, cls: type[E], exception: E):
if sys.version_info >= (3, 11, 0, "final"):
self.decorate_exception(cls, exception)
else:
raise self.decorate_exception(cls, exception)

def decorate_exception(self, cls: type[E], exception: E) -> Exception:
if sys.version_info >= (3, 11, 0, "final"):
if issubclass(cls, (NotImplementedError, AssertionError)):
exception.add_note(
"\n\nSee if this has been reported at https://github.com/scikit-hep/awkward/issues"
)
def _add_note(exception: E, note: str) -> E:
if hasattr(exception, "add_note"):
exception.add_note(note)
else:
exception.add_note(self.note)
exception.__notes__ = [note]
return exception
else:
new_exception: Exception
if issubclass(cls, (NotImplementedError, AssertionError)):
# Raise modified exception
new_exception = cls(
str(exception)
+ "\n\nSee if this has been reported at https://github.com/scikit-hep/awkward/issues"
)
new_exception.__cause__ = exception
elif issubclass(cls, builtins.KeyError):
new_exception = KeyError(self.format_exception(exception))
new_exception.__cause__ = exception
else:
new_exception = cls(self.format_exception(exception))
new_exception.__cause__ = exception
return new_exception

note = self.note
if issubclass(cls, (NotImplementedError, AssertionError)):
note = "\n\nSee if this has been reported at https://github.com/scikit-hep/awkward/issues"
return _add_note(exception, note)

def format_argument(self, width, value):
from awkward import contents, highlevel, record
Expand Down
3 changes: 3 additions & 0 deletions src/awkward/prettyprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def get_at(data: Content, index: int):


def get_field(data: Content, field: str):
if isinstance(data._layout, ak.record.Record):
if data._layout._array.content(field)._is_getitem_at_placeholder():
return PlaceholderValue()
out = data._layout._getitem_field(field)
if isinstance(out, ak.contents.NumpyArray):
array_param = out.parameter("__array__")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_1447_jax_autodiff_slices_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

import numpy as np
import pytest
from packaging.version import parse as parse_version

import awkward as ak

jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
if parse_version(jax.__version__) >= parse_version("0.4.36"):
jax.config.update("jax_data_dependent_tracing_fallback", True)

ak.jax.register_and_check()

Expand Down

0 comments on commit 662b8cb

Please sign in to comment.