Skip to content

Commit 328fb38

Browse files
committed
Generate call to dpnp instead of np for divide
1 parent a8ec606 commit 328fb38

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
8181
op.types = npop.types
8282
op.is_dpnp_ufunc = True
8383
cp = copy.copy(_ufunc_db[npop])
84-
if "'divide'" in str(npop):
85-
# TODO: why do we need to do it only for divide?
86-
# https://github.com/IntelPython/numba-dpex/issues/1270
87-
ufunc_db.update({npop: cp})
8884
ufunc_db.update({op: cp})
8985
for key in list(ufunc_db[op].keys()):
9086
if (

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def load_additional_registries(self):
108108
self.install_registry(ocldecl.registry)
109109
self.install_registry(mathdecl.registry)
110110
self.install_registry(cmathdecl.registry)
111-
# TODO: https://github.com/IntelPython/numba-dpex/issues/1270
112-
self.install_registry(npydecl.registry)
113111
self.install_registry(dpnpdecl.registry)
114112
self.install_registry(enumdecl.registry)
115113

numba_dpex/numba_patches/patch_arrayexpr_tree_to_ir.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
import math
88
import operator
99

10+
import dpnp
1011
from numba.core import errors, ir, types, typing
1112
from numba.core.ir_utils import mk_unique_var
1213
from numba.core.typing import npydecl
1314
from numba.parfors import array_analysis, parfor
1415

16+
from numba_dpex.core.typing import dpnpdecl
17+
1518

1619
def _ufunc_to_parfor_instr(
1720
typemap,
@@ -53,6 +56,39 @@ def _ufunc_to_parfor_instr(
5356
return el_typ
5457

5558

59+
def get_dpnp_ufunc_typ(func):
60+
"""get type of the incoming function from builtin registry"""
61+
for k, v in dpnpdecl.registry.globals:
62+
if k == func:
63+
return v
64+
raise RuntimeError("type for func ", func, " not found")
65+
66+
67+
def _gen_dpnp_divide(arg1, arg2, out_ir, typemap):
68+
"""generate np.divide() instead of / for array_expr to get numpy error model
69+
like inf for division by zero (test_division_by_zero).
70+
"""
71+
scope = arg1.scope
72+
loc = arg1.loc
73+
g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
74+
typemap[g_np_var.name] = types.misc.Module(dpnp)
75+
g_np = ir.Global("dpnp", dpnp, loc)
76+
g_np_assign = ir.Assign(g_np, g_np_var, loc)
77+
# attr call: div_attr = getattr(g_np_var, divide)
78+
div_attr_call = ir.Expr.getattr(g_np_var, "divide", loc)
79+
attr_var = ir.Var(scope, mk_unique_var("$div_attr"), loc)
80+
func_var_typ = get_dpnp_ufunc_typ(dpnp.divide)
81+
typemap[attr_var.name] = func_var_typ
82+
attr_assign = ir.Assign(div_attr_call, attr_var, loc)
83+
# divide call: div_attr(arg1, arg2)
84+
div_call = ir.Expr.call(attr_var, [arg1, arg2], (), loc)
85+
func_typ = func_var_typ.get_call_type(
86+
typing.Context(), [typemap[arg1.name], typemap[arg2.name]], {}
87+
)
88+
out_ir.extend([g_np_assign, attr_assign])
89+
return func_typ, div_call
90+
91+
5692
def _arrayexpr_tree_to_ir(
5793
func_ir,
5894
typingctx,
@@ -103,7 +139,8 @@ def _arrayexpr_tree_to_ir(
103139
)
104140
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
105141
if op == operator.truediv:
106-
func_typ, ir_expr = parfor._gen_np_divide(
142+
# NUMBA_DPEX: is_dpnp_func check was added
143+
func_typ, ir_expr = _gen_dpnp_divide(
107144
arg_vars[0], arg_vars[1], out_ir, typemap
108145
)
109146
else:

0 commit comments

Comments
 (0)