diff --git a/src/vector/_backends/awkward_.py b/src/vector/_backends/awkward_.py index 2442ee07..8cfcfc58 100644 --- a/src/vector/_backends/awkward_.py +++ b/src/vector/_backends/awkward_.py @@ -1012,7 +1012,7 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record): # implementation of behaviors in Numba ######################################## -def _aztype_of(recordarraytype: typing.Any) -> typing.Any: +def _aztype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any: import numba cls: typing.Union[ @@ -1020,18 +1020,41 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any: typing.Type[AzimuthalObjectRhoPhi], ] - try: - x_index = recordarraytype.recordlookup.index("x") - except ValueError: - x_index = None - try: - y_index = recordarraytype.recordlookup.index("y") - except ValueError: - y_index = None - try: - rho_index = recordarraytype.recordlookup.index("rho") - except ValueError: - rho_index = None + x_index = None + y_index = None + rho_index = None + phi_index = None + + if is_momentum: + try: + x_index = recordarraytype.recordlookup.index("px") + except ValueError: + x_index = None + if x_index is None: + try: + x_index = recordarraytype.recordlookup.index("x") + except ValueError: + x_index = None + if is_momentum: + try: + y_index = recordarraytype.recordlookup.index("py") + except ValueError: + y_index = None + if y_index is None: + try: + y_index = recordarraytype.recordlookup.index("y") + except ValueError: + y_index = None + if is_momentum: + try: + rho_index = recordarraytype.recordlookup.index("pt") + except ValueError: + rho_index = None + if rho_index is None: + try: + rho_index = recordarraytype.recordlookup.index("rho") + except ValueError: + rho_index = None try: phi_index = recordarraytype.recordlookup.index("phi") except ValueError: @@ -1047,6 +1070,11 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any: coord2 = recordarraytype.contenttypes[phi_index].arraytype.dtype cls = AzimuthalObjectRhoPhi + elif is_momentum: + raise numba.TypingError( + f"{recordarraytype} is missing azimuthal fields: px/py (x/y) or pt/phi (rho/phi)" + ) + else: raise numba.TypingError( f"{recordarraytype} is missing azimuthal fields: x/y or rho/phi" @@ -1055,7 +1083,7 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any: return numba.typeof(cls(coord1.cast_python_value(0), coord2.cast_python_value(0))) -def _ltype_of(recordarraytype: typing.Any) -> typing.Any: +def _ltype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any: import numba cls: typing.Union[ @@ -1064,10 +1092,20 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any: typing.Type[LongitudinalObjectEta], ] - try: - z_index = recordarraytype.recordlookup.index("z") - except ValueError: - z_index = None + z_index = None + theta_index = None + eta_index = None + + if is_momentum: + try: + z_index = recordarraytype.recordlookup.index("pz") + except ValueError: + z_index = None + if z_index is None: + try: + z_index = recordarraytype.recordlookup.index("z") + except ValueError: + z_index = None try: theta_index = recordarraytype.recordlookup.index("theta") except ValueError: @@ -1089,6 +1127,11 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any: coord1 = recordarraytype.contenttypes[eta_index].arraytype.dtype cls = LongitudinalObjectEta + elif is_momentum: + raise numba.TypingError( + f"{recordarraytype} is missing longitudinal fields: pz (z) or theta or eta" + ) + else: raise numba.TypingError( f"{recordarraytype} is missing longitudinal fields: z or theta or eta" @@ -1097,7 +1140,7 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any: return numba.typeof(cls(coord1.cast_python_value(0))) -def _ttype_of(recordarraytype: typing.Any) -> typing.Any: +def _ttype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any: import numba cls: typing.Union[ @@ -1105,14 +1148,49 @@ def _ttype_of(recordarraytype: typing.Any) -> typing.Any: typing.Type[TemporalObjectTau], ] - try: - t_index = recordarraytype.recordlookup.index("t") - except ValueError: - t_index = None - try: - tau_index = recordarraytype.recordlookup.index("tau") - except ValueError: - tau_index = None + t_index = None + tau_index = None + + if is_momentum: + try: + t_index = recordarraytype.recordlookup.index("E") + except ValueError: + t_index = None + if is_momentum and t_index is None: + try: + t_index = recordarraytype.recordlookup.index("e") + except ValueError: + t_index = None + if is_momentum and t_index is None: + try: + t_index = recordarraytype.recordlookup.index("energy") + except ValueError: + t_index = None + if t_index is None: + try: + t_index = recordarraytype.recordlookup.index("t") + except ValueError: + t_index = None + if is_momentum: + try: + tau_index = recordarraytype.recordlookup.index("M") + except ValueError: + tau_index = None + if is_momentum and tau_index is None: + try: + tau_index = recordarraytype.recordlookup.index("m") + except ValueError: + tau_index = None + if is_momentum and tau_index is None: + try: + tau_index = recordarraytype.recordlookup.index("mass") + except ValueError: + tau_index = None + if tau_index is None: + try: + tau_index = recordarraytype.recordlookup.index("tau") + except ValueError: + tau_index = None if t_index is not None: coord1 = recordarraytype.contenttypes[t_index].arraytype.dtype @@ -1122,6 +1200,11 @@ def _ttype_of(recordarraytype: typing.Any) -> typing.Any: coord1 = recordarraytype.contenttypes[tau_index].arraytype.dtype cls = TemporalObjectTau + elif is_momentum: + raise numba.TypingError( + f"{recordarraytype} is missing temporal fields: E/e/energy (t) or M/m/mass (tau)" + ) + else: raise numba.TypingError( f"{recordarraytype} is missing temporal fields: t or tau" @@ -1135,7 +1218,7 @@ def _numba_typer_Vector2D(viewtype: typing.Any) -> typing.Any: # These clearly exist, a bug somewhere, but ignoring them for now return vector._backends.numba_object.VectorObject2DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type) + _aztype_of(viewtype.arrayviewtype.type, False) ) @@ -1143,8 +1226,8 @@ def _numba_typer_Vector3D(viewtype: typing.Any) -> typing.Any: import vector._backends.numba_object return vector._backends.numba_object.VectorObject3DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type), - _ltype_of(viewtype.arrayviewtype.type), + _aztype_of(viewtype.arrayviewtype.type, False), + _ltype_of(viewtype.arrayviewtype.type, False), ) @@ -1152,9 +1235,9 @@ def _numba_typer_Vector4D(viewtype: typing.Any) -> typing.Any: import vector._backends.numba_object return vector._backends.numba_object.VectorObject4DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type), - _ltype_of(viewtype.arrayviewtype.type), - _ttype_of(viewtype.arrayviewtype.type), + _aztype_of(viewtype.arrayviewtype.type, False), + _ltype_of(viewtype.arrayviewtype.type, False), + _ttype_of(viewtype.arrayviewtype.type, False), ) @@ -1162,7 +1245,7 @@ def _numba_typer_Momentum2D(viewtype: typing.Any) -> typing.Any: import vector._backends.numba_object return vector._backends.numba_object.MomentumObject2DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type) + _aztype_of(viewtype.arrayviewtype.type, True) ) @@ -1170,8 +1253,8 @@ def _numba_typer_Momentum3D(viewtype: typing.Any) -> typing.Any: import vector._backends.numba_object return vector._backends.numba_object.MomentumObject3DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type), - _ltype_of(viewtype.arrayviewtype.type), + _aztype_of(viewtype.arrayviewtype.type, True), + _ltype_of(viewtype.arrayviewtype.type, True), ) @@ -1179,9 +1262,9 @@ def _numba_typer_Momentum4D(viewtype: typing.Any) -> typing.Any: import vector._backends.numba_object return vector._backends.numba_object.MomentumObject4DType( # type: ignore - _aztype_of(viewtype.arrayviewtype.type), - _ltype_of(viewtype.arrayviewtype.type), - _ttype_of(viewtype.arrayviewtype.type), + _aztype_of(viewtype.arrayviewtype.type, True), + _ltype_of(viewtype.arrayviewtype.type, True), + _ttype_of(viewtype.arrayviewtype.type, True), ) @@ -1189,26 +1272,58 @@ def _numba_lower( context: typing.Any, builder: typing.Any, sig: typing.Any, args: typing.Any ) -> typing.Any: from vector._backends.numba_object import ( # type: ignore + _awkward_numba_E, + _awkward_numba_e, + _awkward_numba_energy, _awkward_numba_eta, + _awkward_numba_M, + _awkward_numba_m, + _awkward_numba_mass, + _awkward_numba_ptphi, + _awkward_numba_pxpy, + _awkward_numba_pxy, + _awkward_numba_pz, _awkward_numba_rhophi, _awkward_numba_t, _awkward_numba_tau, _awkward_numba_theta, + _awkward_numba_xpy, _awkward_numba_xy, _awkward_numba_z, ) vectorcls = sig.return_type.instance_class + fields = sig.args[0].arrayviewtype.type.recordlookup + if issubclass(vectorcls, (VectorObject2D, VectorObject3D, VectorObject4D)): if issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalXY): - azimuthal = _awkward_numba_xy + if "x" in fields and "y" in fields: + azimuthal = _awkward_numba_xy + elif "x" in fields and "py" in fields: + azimuthal = _awkward_numba_xpy + elif "px" in fields and "y" in fields: + azimuthal = _awkward_numba_pxy + elif "px" in fields and "py" in fields: + azimuthal = _awkward_numba_pxpy + else: + raise AssertionError elif issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalRhoPhi): - azimuthal = _awkward_numba_rhophi + if "rho" in fields and "phi" in fields: + azimuthal = _awkward_numba_rhophi + elif "pt" in fields and "phi" in fields: + azimuthal = _awkward_numba_ptphi + else: + raise AssertionError if issubclass(vectorcls, (VectorObject3D, VectorObject4D)): if issubclass(sig.return_type.longitudinaltype.instance_class, LongitudinalZ): - longitudinal = _awkward_numba_z + if "z" in fields: + longitudinal = _awkward_numba_z + elif "pz" in fields: + longitudinal = _awkward_numba_pz + else: + raise AssertionError elif issubclass( sig.return_type.longitudinaltype.instance_class, LongitudinalTheta ): @@ -1220,9 +1335,27 @@ def _numba_lower( if issubclass(vectorcls, VectorObject4D): if issubclass(sig.return_type.temporaltype.instance_class, TemporalT): - temporal = _awkward_numba_t + if "t" in fields: + temporal = _awkward_numba_t + elif "E" in fields: + temporal = _awkward_numba_E + elif "e" in fields: + temporal = _awkward_numba_e + elif "energy" in fields: + temporal = _awkward_numba_energy + else: + raise AssertionError elif issubclass(sig.return_type.temporaltype.instance_class, TemporalTau): - temporal = _awkward_numba_tau + if "tau" in fields: + temporal = _awkward_numba_tau + elif "M" in fields: + temporal = _awkward_numba_M + elif "m" in fields: + temporal = _awkward_numba_m + elif "mass" in fields: + temporal = _awkward_numba_mass + else: + raise AssertionError if issubclass(vectorcls, VectorObject2D): diff --git a/src/vector/_backends/numba_object.py b/src/vector/_backends/numba_object.py index 7e8e1e9f..48017b84 100644 --- a/src/vector/_backends/numba_object.py +++ b/src/vector/_backends/numba_object.py @@ -3460,16 +3460,41 @@ def _awkward_numba_xy(record): return AzimuthalObjectXY(record["x"], record["y"]) +@numba.jit(nopython=True) +def _awkward_numba_pxy(record): + return AzimuthalObjectXY(record["px"], record["y"]) + + +@numba.jit(nopython=True) +def _awkward_numba_xpy(record): + return AzimuthalObjectXY(record["x"], record["py"]) + + +@numba.jit(nopython=True) +def _awkward_numba_pxpy(record): + return AzimuthalObjectXY(record["px"], record["py"]) + + @numba.jit(nopython=True) def _awkward_numba_rhophi(record): return AzimuthalObjectRhoPhi(record["rho"], record["phi"]) +@numba.jit(nopython=True) +def _awkward_numba_ptphi(record): + return AzimuthalObjectRhoPhi(record["pt"], record["phi"]) + + @numba.jit(nopython=True) def _awkward_numba_z(record): return LongitudinalObjectZ(record["z"]) +@numba.jit(nopython=True) +def _awkward_numba_pz(record): + return LongitudinalObjectZ(record["pz"]) + + @numba.jit(nopython=True) def _awkward_numba_theta(record): return LongitudinalObjectTheta(record["theta"]) @@ -3485,6 +3510,36 @@ def _awkward_numba_t(record): return TemporalObjectT(record["t"]) +@numba.jit(nopython=True) +def _awkward_numba_E(record): + return TemporalObjectT(record["E"]) + + +@numba.jit(nopython=True) +def _awkward_numba_e(record): + return TemporalObjectT(record["e"]) + + +@numba.jit(nopython=True) +def _awkward_numba_energy(record): + return TemporalObjectT(record["energy"]) + + @numba.jit(nopython=True) def _awkward_numba_tau(record): return TemporalObjectTau(record["tau"]) + + +@numba.jit(nopython=True) +def _awkward_numba_M(record): + return TemporalObjectTau(record["M"]) + + +@numba.jit(nopython=True) +def _awkward_numba_m(record): + return TemporalObjectTau(record["m"]) + + +@numba.jit(nopython=True) +def _awkward_numba_mass(record): + return TemporalObjectTau(record["mass"])