Skip to content

Commit

Permalink
Awkward Arrays of momentum vectors in Numba. (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored Jun 11, 2021
1 parent 6dd9462 commit d9cd956
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 44 deletions.
221 changes: 177 additions & 44 deletions src/vector/_backends/awkward_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,26 +1012,49 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record):
# implementation of behaviors in Numba ########################################


def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
def _aztype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
typing.Type[AzimuthalObjectXY],
typing.Type[AzimuthalObjectRhoPhi],
]

try:
x_index = recordarraytype.recordlookup.index("x")
except ValueError:
x_index = None
try:
y_index = recordarraytype.recordlookup.index("y")
except ValueError:
y_index = None
try:
rho_index = recordarraytype.recordlookup.index("rho")
except ValueError:
rho_index = None
x_index = None
y_index = None
rho_index = None
phi_index = None

if is_momentum:
try:
x_index = recordarraytype.recordlookup.index("px")
except ValueError:
x_index = None
if x_index is None:
try:
x_index = recordarraytype.recordlookup.index("x")
except ValueError:
x_index = None
if is_momentum:
try:
y_index = recordarraytype.recordlookup.index("py")
except ValueError:
y_index = None
if y_index is None:
try:
y_index = recordarraytype.recordlookup.index("y")
except ValueError:
y_index = None
if is_momentum:
try:
rho_index = recordarraytype.recordlookup.index("pt")
except ValueError:
rho_index = None
if rho_index is None:
try:
rho_index = recordarraytype.recordlookup.index("rho")
except ValueError:
rho_index = None
try:
phi_index = recordarraytype.recordlookup.index("phi")
except ValueError:
Expand All @@ -1047,6 +1070,11 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
coord2 = recordarraytype.contenttypes[phi_index].arraytype.dtype
cls = AzimuthalObjectRhoPhi

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing azimuthal fields: px/py (x/y) or pt/phi (rho/phi)"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing azimuthal fields: x/y or rho/phi"
Expand All @@ -1055,7 +1083,7 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
return numba.typeof(cls(coord1.cast_python_value(0), coord2.cast_python_value(0)))


def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
def _ltype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
Expand All @@ -1064,10 +1092,20 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
typing.Type[LongitudinalObjectEta],
]

try:
z_index = recordarraytype.recordlookup.index("z")
except ValueError:
z_index = None
z_index = None
theta_index = None
eta_index = None

if is_momentum:
try:
z_index = recordarraytype.recordlookup.index("pz")
except ValueError:
z_index = None
if z_index is None:
try:
z_index = recordarraytype.recordlookup.index("z")
except ValueError:
z_index = None
try:
theta_index = recordarraytype.recordlookup.index("theta")
except ValueError:
Expand All @@ -1089,6 +1127,11 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
coord1 = recordarraytype.contenttypes[eta_index].arraytype.dtype
cls = LongitudinalObjectEta

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing longitudinal fields: pz (z) or theta or eta"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing longitudinal fields: z or theta or eta"
Expand All @@ -1097,22 +1140,57 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
return numba.typeof(cls(coord1.cast_python_value(0)))


def _ttype_of(recordarraytype: typing.Any) -> typing.Any:
def _ttype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
typing.Type[TemporalObjectT],
typing.Type[TemporalObjectTau],
]

try:
t_index = recordarraytype.recordlookup.index("t")
except ValueError:
t_index = None
try:
tau_index = recordarraytype.recordlookup.index("tau")
except ValueError:
tau_index = None
t_index = None
tau_index = None

if is_momentum:
try:
t_index = recordarraytype.recordlookup.index("E")
except ValueError:
t_index = None
if is_momentum and t_index is None:
try:
t_index = recordarraytype.recordlookup.index("e")
except ValueError:
t_index = None
if is_momentum and t_index is None:
try:
t_index = recordarraytype.recordlookup.index("energy")
except ValueError:
t_index = None
if t_index is None:
try:
t_index = recordarraytype.recordlookup.index("t")
except ValueError:
t_index = None
if is_momentum:
try:
tau_index = recordarraytype.recordlookup.index("M")
except ValueError:
tau_index = None
if is_momentum and tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("m")
except ValueError:
tau_index = None
if is_momentum and tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("mass")
except ValueError:
tau_index = None
if tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("tau")
except ValueError:
tau_index = None

if t_index is not None:
coord1 = recordarraytype.contenttypes[t_index].arraytype.dtype
Expand All @@ -1122,6 +1200,11 @@ def _ttype_of(recordarraytype: typing.Any) -> typing.Any:
coord1 = recordarraytype.contenttypes[tau_index].arraytype.dtype
cls = TemporalObjectTau

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing temporal fields: E/e/energy (t) or M/m/mass (tau)"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing temporal fields: t or tau"
Expand All @@ -1135,80 +1218,112 @@ def _numba_typer_Vector2D(viewtype: typing.Any) -> typing.Any:

# These clearly exist, a bug somewhere, but ignoring them for now
return vector._backends.numba_object.VectorObject2DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type)
_aztype_of(viewtype.arrayviewtype.type, False)
)


def _numba_typer_Vector3D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.VectorObject3DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, False),
_ltype_of(viewtype.arrayviewtype.type, False),
)


def _numba_typer_Vector4D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.VectorObject4DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_ttype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, False),
_ltype_of(viewtype.arrayviewtype.type, False),
_ttype_of(viewtype.arrayviewtype.type, False),
)


def _numba_typer_Momentum2D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject2DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type)
_aztype_of(viewtype.arrayviewtype.type, True)
)


def _numba_typer_Momentum3D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject3DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, True),
_ltype_of(viewtype.arrayviewtype.type, True),
)


def _numba_typer_Momentum4D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject4DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_ttype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, True),
_ltype_of(viewtype.arrayviewtype.type, True),
_ttype_of(viewtype.arrayviewtype.type, True),
)


def _numba_lower(
context: typing.Any, builder: typing.Any, sig: typing.Any, args: typing.Any
) -> typing.Any:
from vector._backends.numba_object import ( # type: ignore
_awkward_numba_E,
_awkward_numba_e,
_awkward_numba_energy,
_awkward_numba_eta,
_awkward_numba_M,
_awkward_numba_m,
_awkward_numba_mass,
_awkward_numba_ptphi,
_awkward_numba_pxpy,
_awkward_numba_pxy,
_awkward_numba_pz,
_awkward_numba_rhophi,
_awkward_numba_t,
_awkward_numba_tau,
_awkward_numba_theta,
_awkward_numba_xpy,
_awkward_numba_xy,
_awkward_numba_z,
)

vectorcls = sig.return_type.instance_class

fields = sig.args[0].arrayviewtype.type.recordlookup

if issubclass(vectorcls, (VectorObject2D, VectorObject3D, VectorObject4D)):
if issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalXY):
azimuthal = _awkward_numba_xy
if "x" in fields and "y" in fields:
azimuthal = _awkward_numba_xy
elif "x" in fields and "py" in fields:
azimuthal = _awkward_numba_xpy
elif "px" in fields and "y" in fields:
azimuthal = _awkward_numba_pxy
elif "px" in fields and "py" in fields:
azimuthal = _awkward_numba_pxpy
else:
raise AssertionError
elif issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalRhoPhi):
azimuthal = _awkward_numba_rhophi
if "rho" in fields and "phi" in fields:
azimuthal = _awkward_numba_rhophi
elif "pt" in fields and "phi" in fields:
azimuthal = _awkward_numba_ptphi
else:
raise AssertionError

if issubclass(vectorcls, (VectorObject3D, VectorObject4D)):
if issubclass(sig.return_type.longitudinaltype.instance_class, LongitudinalZ):
longitudinal = _awkward_numba_z
if "z" in fields:
longitudinal = _awkward_numba_z
elif "pz" in fields:
longitudinal = _awkward_numba_pz
else:
raise AssertionError
elif issubclass(
sig.return_type.longitudinaltype.instance_class, LongitudinalTheta
):
Expand All @@ -1220,9 +1335,27 @@ def _numba_lower(

if issubclass(vectorcls, VectorObject4D):
if issubclass(sig.return_type.temporaltype.instance_class, TemporalT):
temporal = _awkward_numba_t
if "t" in fields:
temporal = _awkward_numba_t
elif "E" in fields:
temporal = _awkward_numba_E
elif "e" in fields:
temporal = _awkward_numba_e
elif "energy" in fields:
temporal = _awkward_numba_energy
else:
raise AssertionError
elif issubclass(sig.return_type.temporaltype.instance_class, TemporalTau):
temporal = _awkward_numba_tau
if "tau" in fields:
temporal = _awkward_numba_tau
elif "M" in fields:
temporal = _awkward_numba_M
elif "m" in fields:
temporal = _awkward_numba_m
elif "mass" in fields:
temporal = _awkward_numba_mass
else:
raise AssertionError

if issubclass(vectorcls, VectorObject2D):

Expand Down
Loading

0 comments on commit d9cd956

Please sign in to comment.