Skip to content

Commit

Permalink
feat: expose attrs in typetracer (#2806)
Browse files Browse the repository at this point in the history
* feat: expose attrs in typetracer

* feat: add attrs to ak.typetracer

* test: add simple test
  • Loading branch information
agoose77 authored Nov 8, 2023
1 parent 8a2fa20 commit ec261c2
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 23 deletions.
108 changes: 85 additions & 23 deletions src/awkward/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

from collections.abc import Callable, Mapping

import awkward.forms
from awkward._backends.typetracer import TypeTracerBackend
from awkward._behavior import behavior_of
from awkward._do import touch_data as _touch_data
from awkward._errors import deprecate
from awkward._layout import wrap_layout
from awkward._layout import HighLevelContext, wrap_layout
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.typetracer import (
Expand All @@ -18,13 +19,12 @@
from awkward._nplikes.typetracer import (
typetracer_with_report as _typetracer_with_report,
)
from awkward._typing import TypeVar
from awkward._typing import Any, TypeVar
from awkward._util import UNSET
from awkward.contents import Content
from awkward.forms import Form
from awkward.forms.form import regularize_buffer_key
from awkward.highlevel import Array, Record
from awkward.operations.ak_to_layout import to_layout
from awkward.types.numpytype import is_primitive

__all__ = [
Expand All @@ -40,75 +40,125 @@
T = TypeVar("T", Array, Record)


def _length_0_1_if_typetracer(array, function, highlevel: bool, behavior) -> T:
def _length_0_1_if_typetracer(
array,
function,
highlevel: bool,
behavior: Mapping | None,
attrs: Mapping[str, Any],
) -> T:
typetracer_backend = TypeTracerBackend.instance()

layout = to_layout(array, primitive_policy="error")
behavior = behavior_of(array, behavior=behavior)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(
array,
allow_unknown=False,
allow_record=True,
primitive_policy="error",
none_policy="error",
string_policy="as-characters",
)

if layout.backend is typetracer_backend:
_touch_data(layout)
layout = function(layout.form, highlevel=False)

return wrap_layout(layout, behavior=behavior, highlevel=highlevel)
return ctx.wrap(layout, highlevel=highlevel)


def length_zero_if_typetracer(array, *, highlevel: bool = True, behavior=None) -> T:
def length_zero_if_typetracer(
array: Any,
*,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> T:
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Recursively touches the data of an array, before returning a length-zero
NumPy-backed iff. the given array has a typetracer backend; otherwise, a
shallow copy of the original array is returned.
"""
return _length_0_1_if_typetracer(array, Form.length_zero_array, highlevel, behavior)
return _length_0_1_if_typetracer(
array, Form.length_zero_array, highlevel, behavior, attrs
)


def length_one_if_typetracer(array, *, highlevel: bool = True, behavior=None) -> T:
def length_one_if_typetracer(
array: Any,
*,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array | Record:
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Recursively touches the data of an array, before returning a length-one
NumPy-backed iff. the given array has a typetracer backend; otherwise, a
shallow copy of the original array is returned.
"""
return _length_0_1_if_typetracer(array, Form.length_one_array, highlevel, behavior)
return _length_0_1_if_typetracer(
array, Form.length_one_array, highlevel, behavior, attrs
)


def touch_data(array, *, highlevel: bool = True, behavior=None) -> T:
def touch_data(
array: Any,
*,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array | Record:
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Recursively touches the data and returns a shall copy of the given array.
"""
behavior = behavior_of(array, behavior=behavior)
layout = to_layout(array, primitive_policy="error")
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(
array,
allow_unknown=False,
allow_record=True,
primitive_policy="error",
none_policy="error",
string_policy="as-characters",
)

_touch_data(layout)
return wrap_layout(layout, behavior=behavior, highlevel=highlevel)
return ctx.wrap(layout, highlevel=highlevel)


def typetracer_with_report(
form,
form: Form | str | Mapping,
forget_length: bool = UNSET,
*,
buffer_key="{form_key}",
buffer_key: str | Callable = "{form_key}",
highlevel: bool = False,
behavior=None,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> tuple[Content, TypeTracerReport]:
"""
Args:
Expand All @@ -122,6 +172,8 @@ def typetracer_with_report(
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Returns a typetracer array and associated report object built from a form
with labelled form keys.
Expand Down Expand Up @@ -153,10 +205,18 @@ def typetracer_with_report(
layout, report = _typetracer_with_report(
form, getkey=getkey, forget_length=forget_length
)
return wrap_layout(layout, behavior=behavior, highlevel=highlevel), report
return wrap_layout(
layout, behavior=behavior, highlevel=highlevel, attrs=attrs
), report


def typetracer_from_form(form, *, highlevel: bool = True, behavior=None):
def typetracer_from_form(
form: Form | str | Mapping,
*,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array | Content:
"""
Args:
form (#ak.forms.Form or str/dict equivalent): The form of the Awkward
Expand All @@ -165,6 +225,8 @@ def typetracer_from_form(form, *, highlevel: bool = True, behavior=None):
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
attrs (None or dict): Custom attributes for the output array, if
high-level.
Returns a typetracer array built from a form.
"""
Expand All @@ -173,7 +235,7 @@ def typetracer_from_form(form, *, highlevel: bool = True, behavior=None):
form = awkward.forms.NumpyForm(form)
else:
form = awkward.forms.from_json(form)
elif isinstance(form, dict):
elif isinstance(form, Mapping):
form = awkward.forms.from_dict(form)
elif not isinstance(form, awkward.forms.Form):
raise TypeError(
Expand All @@ -182,4 +244,4 @@ def typetracer_from_form(form, *, highlevel: bool = True, behavior=None):

layout = form.length_zero_array(highlevel=False).to_typetracer(forget_length=True)

return wrap_layout(layout, behavior=behavior, highlevel=highlevel)
return wrap_layout(layout, behavior=behavior, highlevel=highlevel, attrs=attrs)
50 changes: 50 additions & 0 deletions tests/test_2806_attrs_typetracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import pytest

import awkward as ak
from awkward.typetracer import (
typetracer_with_report,
)

SOME_ATTRS = {"foo": "FOO"}


def test_typetracer_with_report():
array = ak.zip(
{
"x": [[0.2, 0.3, 0.4], [1, 2, 3], [1, 1, 2]],
"y": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
"z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
}
)
layout = ak.to_layout(array)
form = layout.form_with_key("node{id}")

meta, report = typetracer_with_report(form, highlevel=True, attrs=SOME_ATTRS)
assert meta.attrs is SOME_ATTRS

meta, report = typetracer_with_report(form, highlevel=True, attrs=None)
assert meta._attrs is None


@pytest.mark.parametrize(
"function",
[
ak.typetracer.touch_data,
ak.typetracer.length_zero_if_typetracer,
ak.typetracer.length_one_if_typetracer,
],
)
def test_function(function):
array = ak.zip(
{
"x": [[0.2, 0.3, 0.4], [1, 2, 3], [1, 1, 2]],
"y": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
"z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
}
)
assert function(array, attrs=SOME_ATTRS).attrs is SOME_ATTRS
assert function(array)._attrs is None

0 comments on commit ec261c2

Please sign in to comment.