Skip to content

Commit

Permalink
Generate call to dpnp instead of np for divide
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 29, 2024
1 parent a8ec606 commit 328fb38
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
4 changes: 0 additions & 4 deletions numba_dpex/dpnp_iface/dpnp_ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
op.types = npop.types
op.is_dpnp_ufunc = True
cp = copy.copy(_ufunc_db[npop])
if "'divide'" in str(npop):
# TODO: why do we need to do it only for divide?
# https://github.com/IntelPython/numba-dpex/issues/1270
ufunc_db.update({npop: cp})
ufunc_db.update({op: cp})
for key in list(ufunc_db[op].keys()):
if (
Expand Down
2 changes: 0 additions & 2 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ def load_additional_registries(self):
self.install_registry(ocldecl.registry)
self.install_registry(mathdecl.registry)
self.install_registry(cmathdecl.registry)
# TODO: https://github.com/IntelPython/numba-dpex/issues/1270
self.install_registry(npydecl.registry)
self.install_registry(dpnpdecl.registry)
self.install_registry(enumdecl.registry)

Expand Down
39 changes: 38 additions & 1 deletion numba_dpex/numba_patches/patch_arrayexpr_tree_to_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import math
import operator

import dpnp
from numba.core import errors, ir, types, typing
from numba.core.ir_utils import mk_unique_var
from numba.core.typing import npydecl
from numba.parfors import array_analysis, parfor

from numba_dpex.core.typing import dpnpdecl


def _ufunc_to_parfor_instr(
typemap,
Expand Down Expand Up @@ -53,6 +56,39 @@ def _ufunc_to_parfor_instr(
return el_typ


def get_dpnp_ufunc_typ(func):
"""get type of the incoming function from builtin registry"""
for k, v in dpnpdecl.registry.globals:
if k == func:
return v
raise RuntimeError("type for func ", func, " not found")


def _gen_dpnp_divide(arg1, arg2, out_ir, typemap):
"""generate np.divide() instead of / for array_expr to get numpy error model
like inf for division by zero (test_division_by_zero).
"""
scope = arg1.scope
loc = arg1.loc
g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
typemap[g_np_var.name] = types.misc.Module(dpnp)
g_np = ir.Global("dpnp", dpnp, loc)
g_np_assign = ir.Assign(g_np, g_np_var, loc)
# attr call: div_attr = getattr(g_np_var, divide)
div_attr_call = ir.Expr.getattr(g_np_var, "divide", loc)
attr_var = ir.Var(scope, mk_unique_var("$div_attr"), loc)
func_var_typ = get_dpnp_ufunc_typ(dpnp.divide)
typemap[attr_var.name] = func_var_typ
attr_assign = ir.Assign(div_attr_call, attr_var, loc)
# divide call: div_attr(arg1, arg2)
div_call = ir.Expr.call(attr_var, [arg1, arg2], (), loc)
func_typ = func_var_typ.get_call_type(
typing.Context(), [typemap[arg1.name], typemap[arg2.name]], {}
)
out_ir.extend([g_np_assign, attr_assign])
return func_typ, div_call


def _arrayexpr_tree_to_ir(
func_ir,
typingctx,
Expand Down Expand Up @@ -103,7 +139,8 @@ def _arrayexpr_tree_to_ir(
)
ir_expr = ir.Expr.binop(op, arg_vars[0], arg_vars[1], loc)
if op == operator.truediv:
func_typ, ir_expr = parfor._gen_np_divide(
# NUMBA_DPEX: is_dpnp_func check was added
func_typ, ir_expr = _gen_dpnp_divide(
arg_vars[0], arg_vars[1], out_ir, typemap
)
else:
Expand Down

0 comments on commit 328fb38

Please sign in to comment.