Skip to content

Commit fcea2d2

Browse files
committed
feat: add option to keep type parameters
1 parent 2ee6f00 commit fcea2d2

File tree

1 file changed

+49
-14
lines changed

1 file changed

+49
-14
lines changed

src/awkward/operations/ak_enforce_type.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,29 @@ def enforce_type(
4242

4343
def _impl(array, type, highlevel, behavior):
4444
layout = ak.to_layout(array)
45-
type_ = (
46-
ak.types.from_datashape(type, highlevel=False)
47-
if isinstance(type, str)
48-
else type
49-
)
45+
46+
if isinstance(type, str):
47+
type_ = ak.types.from_datashape(type, highlevel=False)
48+
49+
def select_parameters(type_: ak.types.Type, layout: ak.contents.Content):
50+
return layout.parameters
51+
52+
else:
53+
54+
def select_parameters(type_: ak.types.Type, layout: ak.contents.Content):
55+
return type_.parameters
56+
57+
type_ = type
5058

5159
def recurse(
5260
type: ak.types.Type, layout: ak.contents.Content
5361
) -> ak.contents.Content:
5462
# Early exit - unknown layouts take the form of the type.
5563
if layout.is_unknown:
5664
type_form = ak.forms.from_type(type)
57-
return type_form.length_zero_array()
65+
return type_form.length_zero_array(highlevel=False).copy(
66+
parameters=select_parameters(type, layout)
67+
)
5868

5969
# If we want to lose the option
6070
elif layout.is_option and not isinstance(type, ak.types.OptionType):
@@ -67,11 +77,15 @@ def recurse(
6777

6878
# Indexed nodes are invisible to layouts
6979
elif layout.is_indexed:
70-
return recurse(type, layout.content)
80+
return recurse(type, layout.content).copy(
81+
parameters=select_parameters(type, layout)
82+
)
7183

7284
if isinstance(type, ak.types.NumpyType):
7385
assert layout.is_numpy
7486

87+
layout = layout.copy(parameters=select_parameters(type, layout))
88+
7589
dtype = primitive_to_dtype(type.primitive)
7690
if np.issubdtype(layout.dtype, dtype):
7791
return layout
@@ -88,9 +102,15 @@ def recurse(
88102
layout.is_numpy and layout.inner_shape[0] == type.size
89103
):
90104
layout = layout.to_ListOffsetArray64(True)
91-
return layout.copy(content=recurse(type.content, layout.content))
105+
return layout.copy(
106+
content=recurse(type.content, layout.content),
107+
parameters=select_parameters(type, layout),
108+
)
92109
elif layout.is_list:
93-
return layout.copy(content=recurse(type.content, layout.content))
110+
return layout.copy(
111+
content=recurse(type.content, layout.content),
112+
parameters=select_parameters(type, layout),
113+
)
94114
else:
95115
raise wrap_error(
96116
AssertionError(f"expected list type, found {type(layout)!r}")
@@ -105,10 +125,16 @@ def recurse(
105125
if (layout.is_regular and layout.size == type.size) or (
106126
layout.is_numpy and layout.inner_shape[0] == type.size
107127
):
108-
return layout.copy(content=recurse(type.content, layout.content))
128+
return layout.copy(
129+
content=recurse(type.content, layout.content),
130+
parameters=select_parameters(type, layout),
131+
)
109132
elif layout.is_list:
110133
layout = layout.to_RegularArray()
111-
return layout.copy(content=recurse(type.content, layout.content))
134+
return layout.copy(
135+
content=recurse(type.content, layout.content),
136+
parameters=select_parameters(type, layout),
137+
)
112138
else:
113139
raise wrap_error(
114140
AssertionError(f"expected list type, found {type(layout)!r}")
@@ -118,18 +144,27 @@ def recurse(
118144
assert layout.is_record
119145
assert layout.fields == type.fields # TODO: do we care about order?
120146
return layout.copy(
121-
contents=[recurse(x, y) for x, y in zip(type.contents, layout.contents)]
147+
contents=[
148+
recurse(x, y) for x, y in zip(type.contents, layout.contents)
149+
],
150+
parameters=select_parameters(type, layout),
122151
)
123152

124153
elif isinstance(type, ak.types.UnionType):
125154
assert layout.is_union
126155
return layout.copy(
127-
contents=[recurse(x, y) for x, y in zip(type.contents, layout.contents)]
156+
contents=[
157+
recurse(x, y) for x, y in zip(type.contents, layout.contents)
158+
],
159+
parameters=select_parameters(type, layout),
128160
)
129161

130162
elif isinstance(type, ak.types.OptionType):
131163
if layout.is_option:
132-
return layout.copy(content=recurse(type.content, layout.content))
164+
return layout.copy(
165+
content=recurse(type.content, layout.content),
166+
parameters=select_parameters(type, layout),
167+
)
133168
else:
134169
return ak.contents.UnmaskedArray(recurse(type.content, layout))
135170

0 commit comments

Comments
 (0)