Skip to content

Commit 1617753

Browse files
committed
Add dpnp/dpjit specific parfor
1 parent e0639f0 commit 1617753

File tree

2 files changed

+215
-1
lines changed

2 files changed

+215
-1
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import warnings
2+
3+
from numba.core import config, errors, ir, types
4+
from numba.core.compiler_machinery import register_pass
5+
from numba.core.ir_utils import (
6+
dprint_func_ir,
7+
mk_alloc,
8+
mk_unique_var,
9+
next_label,
10+
)
11+
from numba.core.typed_passes import ParforPass as NumpyParforPass
12+
from numba.core.typed_passes import _reload_parfors
13+
from numba.parfors.parfor import (
14+
ConvertInplaceBinop,
15+
ConvertLoopPass,
16+
ConvertNumpyPass,
17+
ConvertReducePass,
18+
ConvertSetItemPass,
19+
Parfor,
20+
)
21+
from numba.parfors.parfor import ParforPass as _NumpyParforPass
22+
from numba.parfors.parfor import (
23+
_make_index_var,
24+
_mk_parfor_loops,
25+
repr_arrayexpr,
26+
signature,
27+
)
28+
from numba.stencils.stencilparfor import StencilPass
29+
30+
from numba_dpex.numba_patches.patch_arrayexpr_tree_to_ir import (
31+
_arrayexpr_tree_to_ir,
32+
)
33+
34+
35+
class ConvertDPNPPass(ConvertNumpyPass):
36+
def __init__(self, pass_states):
37+
super().__init__(pass_states)
38+
39+
def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars):
40+
"""generate parfor from arrayexpr node, which is essentially a
41+
map with recursive tree.
42+
"""
43+
pass_states = self.pass_states
44+
scope = lhs.scope
45+
loc = lhs.loc
46+
expr = arrayexpr.expr
47+
arr_typ = pass_states.typemap[lhs.name]
48+
el_typ = arr_typ.dtype
49+
50+
# generate loopnests and size variables from lhs correlations
51+
size_vars = equiv_set.get_shape(lhs)
52+
index_vars, loopnests = _mk_parfor_loops(
53+
pass_states.typemap, size_vars, scope, loc
54+
)
55+
56+
# generate init block and body
57+
init_block = ir.Block(scope, loc)
58+
init_block.body = mk_alloc(
59+
pass_states.typingctx,
60+
pass_states.typemap,
61+
pass_states.calltypes,
62+
lhs,
63+
tuple(size_vars),
64+
el_typ,
65+
scope,
66+
loc,
67+
pass_states.typemap[lhs.name],
68+
)
69+
body_label = next_label()
70+
body_block = ir.Block(scope, loc)
71+
expr_out_var = ir.Var(scope, mk_unique_var("$expr_out_var"), loc)
72+
pass_states.typemap[expr_out_var.name] = el_typ
73+
74+
index_var, index_var_typ = _make_index_var(
75+
pass_states.typemap, scope, index_vars, body_block
76+
)
77+
78+
body_block.body.extend(
79+
_arrayexpr_tree_to_ir(
80+
pass_states.func_ir,
81+
pass_states.typingctx,
82+
pass_states.typemap,
83+
pass_states.calltypes,
84+
equiv_set,
85+
init_block,
86+
expr_out_var,
87+
expr,
88+
index_var,
89+
index_vars,
90+
avail_vars,
91+
)
92+
)
93+
94+
pat = ("array expression {}".format(repr_arrayexpr(arrayexpr.expr)),)
95+
96+
parfor = Parfor(
97+
loopnests,
98+
init_block,
99+
{},
100+
loc,
101+
index_var,
102+
equiv_set,
103+
pat[0],
104+
pass_states.flags,
105+
)
106+
107+
setitem_node = ir.SetItem(lhs, index_var, expr_out_var, loc)
108+
pass_states.calltypes[setitem_node] = signature(
109+
types.none, pass_states.typemap[lhs.name], index_var_typ, el_typ
110+
)
111+
body_block.body.append(setitem_node)
112+
parfor.loop_body = {body_label: body_block}
113+
if config.DEBUG_ARRAY_OPT >= 1:
114+
print("parfor from arrayexpr")
115+
parfor.dump()
116+
return parfor
117+
118+
119+
class _ParforPass(_NumpyParforPass):
120+
def run(self):
121+
"""run parfor conversion pass: replace Numpy calls
122+
with Parfors when possible and optimize the IR."""
123+
self._pre_run()
124+
# run stencil translation to parfor
125+
if self.options.stencil:
126+
stencil_pass = StencilPass(
127+
self.func_ir,
128+
self.typemap,
129+
self.calltypes,
130+
self.array_analysis,
131+
self.typingctx,
132+
self.targetctx,
133+
self.flags,
134+
)
135+
stencil_pass.run()
136+
if self.options.setitem:
137+
ConvertSetItemPass(self).run(self.func_ir.blocks)
138+
if self.options.numpy:
139+
ConvertDPNPPass(self).run(self.func_ir.blocks)
140+
if self.options.reduction:
141+
ConvertReducePass(self).run(self.func_ir.blocks)
142+
if self.options.prange:
143+
ConvertLoopPass(self).run(self.func_ir.blocks)
144+
if self.options.inplace_binop:
145+
ConvertInplaceBinop(self).run(self.func_ir.blocks)
146+
147+
# setup diagnostics now parfors are found
148+
self.diagnostics.setup(self.func_ir, self.options.fusion)
149+
150+
dprint_func_ir(self.func_ir, "after parfor pass")
151+
152+
153+
@register_pass(mutates_CFG=True, analysis_only=False)
154+
class ParforPass(NumpyParforPass):
155+
# TODO: do we care about name?
156+
_name = "dpnp_parfor_pass"
157+
158+
def __init__(self):
159+
NumpyParforPass.__init__(self)
160+
161+
def run_pass(self, state):
162+
"""
163+
Convert data-parallel computations into Parfor nodes
164+
"""
165+
# Ensure we have an IR and type information.
166+
assert state.func_ir
167+
parfor_pass = _ParforPass(
168+
state.func_ir,
169+
state.typemap,
170+
state.calltypes,
171+
state.return_type,
172+
state.typingctx,
173+
state.targetctx,
174+
state.flags.auto_parallel,
175+
state.flags,
176+
state.metadata,
177+
state.parfor_diagnostics,
178+
)
179+
parfor_pass.run()
180+
181+
# check the parfor pass worked and warn if it didn't
182+
has_parfor = False
183+
for blk in state.func_ir.blocks.values():
184+
for stmnt in blk.body:
185+
if isinstance(stmnt, Parfor):
186+
has_parfor = True
187+
break
188+
else:
189+
continue
190+
break
191+
192+
if not has_parfor:
193+
# parfor calls the compiler chain again with a string
194+
if not (
195+
config.DISABLE_PERFORMANCE_WARNINGS
196+
or state.func_ir.loc.filename == "<string>"
197+
):
198+
url = (
199+
"https://numba.readthedocs.io/en/stable/user/"
200+
"parallel.html#diagnostics"
201+
)
202+
msg = (
203+
"\nThe keyword argument 'parallel=True' was specified "
204+
"but no transformation for parallel execution was "
205+
"possible.\n\nTo find out why, try turning on parallel "
206+
"diagnostics, see %s for help." % url
207+
)
208+
warnings.warn(
209+
errors.NumbaPerformanceWarning(msg, state.func_ir.loc)
210+
)
211+
212+
# Add reload function to initialize the parallel backend.
213+
state.reload_init.append(_reload_parfors)
214+
return True

numba_dpex/core/pipelines/dpjit_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
NoPythonSupportedFeatureValidation,
1414
NopythonTypeInference,
1515
ParforFusionPass,
16-
ParforPass,
1716
ParforPreLoweringPass,
1817
PreLowerStripPhis,
1918
PreParforPass,
2019
)
2120

2221
from numba_dpex.core.exceptions import UnsupportedCompilationModeError
22+
from numba_dpex.core.parfors.parfor_pass import ParforPass
2323
from numba_dpex.core.passes import (
2424
DumpParforDiagnostics,
2525
NoPythonBackend,

0 commit comments

Comments
 (0)