Skip to content

Commit 2d27682

Browse files
authored
support reduce op with fast implementation (#314)
1 parent 866d2ee commit 2d27682

18 files changed

+564
-141
lines changed

scripts/correctness.sh

+10-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:32x128xbf16
1010

1111
# f32
1212

13+
# reduce
14+
15+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.add --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
16+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.mul --md 0:128x8xf32 --md 1:128xf32 --dimensions=1 || FAIL=1
17+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.max --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
18+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.min --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
19+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l1 --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
20+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l2_square --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
21+
1322
# misc
1423
python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1
1524
python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1
@@ -92,9 +101,8 @@ python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_max --md 0:4x3
92101
python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_sum --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1
93102
python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_min --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1
94103

95-
# generic / reduce
104+
# generic
96105
python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/generic.mlir || FAIL=1
97-
python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || FAIL=1
98106

99107
# softmax
100108
# python3 -m benchgc --verbose 0 --driver linalg --case softmax --md 0:32x4096xf32 --md 1:32x4096xf32 --dimension 1 || FAIL=1

test/benchgc/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ add_subdirectory("src/benchgc/linalg")
4040
add_subdirectory("src/benchgc/tensor")
4141
add_subdirectory("src/benchgc/arith")
4242
add_subdirectory("src/benchgc/pattern")
43+
add_subdirectory("src/benchgc/math")

test/benchgc/cases/reduce.mlir

-12
This file was deleted.

test/benchgc/src/benchgc/__main__.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def add_pattern_options(parser: argparse.ArgumentParser):
192192
get_pattern_clz(pattern_name).add_args(parser)
193193

194194

195-
def get_module_and_args(flags):
195+
def get_module_and_args(flags: argparse.Namespace):
196196
args: List[Arg] = []
197197
if flags.driver in ["mlir", "pattern"]:
198198
# we need to find all args by reading the entry function
@@ -203,6 +203,8 @@ def get_module_and_args(flags):
203203
elif flags.driver == "pattern":
204204
pattern_clz = get_pattern_clz(flags.case)
205205
module = pattern_clz(ctx, flags).ir_module
206+
else:
207+
raise Exception("unexpected error")
206208

207209
entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry)
208210
idx: int = 0
@@ -235,7 +237,10 @@ def get_module_and_args(flags):
235237

236238
from .linalg import mlir_op
237239

238-
mlir_func = mlir_op[flags.case]
240+
if flags.case.startswith("reduce."):
241+
mlir_func = mlir_op["reduce"]
242+
else:
243+
mlir_func = mlir_op[flags.case]
239244
module = mlir_func(flags, args)
240245
else:
241246
raise Exception(f"unsupported driver {flags.driver}")
@@ -269,7 +274,7 @@ def get_module_and_args(flags):
269274
return module, args
270275

271276

272-
def correctness_testing(flags, module, args):
277+
def correctness_testing(flags: argparse.Namespace, module: ir.Module, args: List[Arg]):
273278
ref_args: List[torch.Tensor] = []
274279
gc_args: List[torch.Tensor | int] = []
275280
ref_tensors: Dict[str, torch.Tensor] = {}
@@ -290,9 +295,8 @@ def correctness_testing(flags, module, args):
290295
ref_out = runner.ref_run(entry, ref_tensors)
291296

292297
# we need to swap the result into the args if some arg is the return value
293-
if ref_out is not None:
294-
for i in range(len(ref_out)):
295-
ref_args[0 - i - 1] = ref_out[0 - i - 1]
298+
for i in range(len(ref_out)):
299+
ref_args[0 - i - 1] = ref_out[0 - i - 1]
296300

297301
mlir_args = get_mlir_args(gc_args)
298302
passes = "any(gc-cpu-pipeline)"
@@ -323,7 +327,7 @@ def correctness_testing(flags, module, args):
323327
print(f"PASSED: {flags.driver}.{flags.case}")
324328

325329

326-
def performance_testing(flags, module, args):
330+
def performance_testing(flags: argparse.Namespace, module: ir.Module, args: List[Arg]):
327331
gc_args: List[torch.Tensor | int] = []
328332
gc_tensors: Dict[str, torch.Tensor] = {}
329333
for i in range(len(args)):

test/benchgc/src/benchgc/arg/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import benchgc.arg.eltwise as eltwise
2424
import benchgc.arg.matmul as matmul
2525
import benchgc.arg.pool as pool
26+
import benchgc.arg.reduce as reduce
2627
import benchgc.arg.softmax as softmax
2728
import benchgc.util
2829
import torch
@@ -36,6 +37,7 @@
3637
"softmax": softmax,
3738
"conv": conv,
3839
"pool": pool,
40+
"reduce": reduce,
3941
}
4042

4143

test/benchgc/src/benchgc/arg/reduce.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,42 @@
1414
# limitations under the License.
1515
################################################################################
1616

17-
from typing import List, Tuple
17+
import argparse
18+
from typing import List, Set, Tuple
1819

1920
import benchgc.arg
2021
import benchgc.util
2122
import torch
22-
23+
from benchgc.arg.arg import Arg
24+
from benchgc.arg.compare import p2p
25+
26+
op: Set[str] = set(
27+
[
28+
"linalg.reduce.add",
29+
"linalg.reduce.mul",
30+
"linalg.reduce.max",
31+
"linalg.reduce.min",
32+
"linalg.reduce.l1",
33+
"linalg.reduce.l2_square",
34+
]
35+
)
36+
37+
38+
def default_fill(
39+
flags: argparse.Namespace,
40+
arg: Arg,
41+
arglist: List[Arg],
42+
):
43+
if arg.index > 0:
44+
raise Exception("reduce fill: dst filling is not allowed")
45+
arg.fill_param = [
46+
"reduce",
47+
flags.case,
48+
arglist[0].dtype,
49+
arglist[1].dtype,
50+
str(arglist[0].nelem() // arglist[1].nelem()),
51+
]
52+
arg.fill_type = "D"
2353

2454
def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor:
2555

@@ -30,22 +60,17 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso
3060

3161
safe_to_reduce_elems: int = benchgc.util.get_problem_bounds(op, sdtype)[0]
3262

33-
neutral_value: float = 1.0 if op == "mul" else 0.0
63+
neutral_value: float = 1.0 if op == "reduce.mul" else 0.0
3464

3565
shift: float = (
3666
1.0
37-
if (
38-
op == "mean"
39-
or op == "min"
40-
and not sdtype.is_signed
41-
and not ddtype.is_signed
42-
)
67+
if (op == "reduce.min" and not sdtype.is_signed and not ddtype.is_signed)
4368
else 0.0
4469
)
4570

4671
value_range: int = benchgc.util.get_problem_bounds(op, sdtype)[1]
4772

48-
is_mul_fp: bool = op == "mul" and sdtype.is_floating_point
73+
is_mul_fp: bool = op == "reduce.mul" and sdtype.is_floating_point
4974
min_range: int = -value_range if is_mul_fp else 1
5075

5176
index = torch.arange(benchgc.util.nelem(shape)).reshape(shape)
@@ -69,10 +94,18 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso
6994
return value.to(dtype)
7095

7196

97+
def default_compare(
98+
flags: argparse.Namespace,
99+
arg: Arg,
100+
arglist: List[Arg],
101+
):
102+
arg.cmp_type = "D"
103+
arg.cmp_param = ["reduce", arg.dtype, flags.case]
104+
72105
def compare(
73-
ref: torch.Tensor, res: torch.Tensor, verbose: int
106+
param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int
74107
) -> Tuple[bool, bool | None]:
75108
dtype = ref.dtype
76109
ref = ref.to(torch.float)
77110
res = res.to(torch.float)
78-
return benchgc.arg.p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)
111+
return p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)

test/benchgc/src/benchgc/arith/basic.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Dict, Tuple
1818

1919
import benchgc.util
20-
import gc_mlir._mlir_libs._mlir.ir
2120
import torch
2221
from benchgc.mlir.util import MLIRCache
2322
from gc_mlir import ir
@@ -42,6 +41,19 @@ def ref_constant(
4241
)
4342
else:
4443
raise Exception("only support splat value now")
44+
elif isinstance(value, ir.IntegerAttr):
45+
return (torch.full(size=tuple(), fill_value=value.__int__(), dtype=torch.int),)
46+
elif isinstance(value, ir.DenseIntElementsAttr):
47+
if value.is_splat:
48+
return (
49+
torch.full(
50+
size=tuple(value.type.shape),
51+
fill_value=value.get_splat_value().value,
52+
dtype=benchgc.util.get_dtype(str(value.get_splat_value().type)),
53+
),
54+
)
55+
else:
56+
raise Exception("only support splat value now")
4557
else:
4658
raise Exception("Not support constant type %s", type(value))
4759

@@ -56,3 +68,39 @@ def ref_addf(
5668
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
5769
) -> Tuple[torch.Tensor, ...]:
5870
return (var[cache.opr[0]] + var[cache.opr[1]],)
71+
72+
73+
def ref_maxf(
74+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
75+
) -> Tuple[torch.Tensor, ...]:
76+
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)
77+
78+
79+
def ref_minf(
80+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
81+
) -> Tuple[torch.Tensor, ...]:
82+
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)
83+
84+
85+
def ref_muli(
86+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
87+
) -> Tuple[torch.Tensor, ...]:
88+
return (var[cache.opr[0]] * var[cache.opr[1]],)
89+
90+
91+
def ref_addi(
92+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
93+
) -> Tuple[torch.Tensor, ...]:
94+
return (var[cache.opr[0]] + var[cache.opr[1]],)
95+
96+
97+
def ref_maxsi(
98+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
99+
) -> Tuple[torch.Tensor, ...]:
100+
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)
101+
102+
103+
def ref_minsi(
104+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
105+
) -> Tuple[torch.Tensor, ...]:
106+
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)

test/benchgc/src/benchgc/bench.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import ctypes
1919
import random
2020
import timeit
21-
from typing import List, Tuple
21+
from typing import Any, List, Tuple
2222

2323
import numpy as np
2424
from benchgc.mlir.util import (
@@ -34,10 +34,10 @@ def py_timeit_bench(
3434
ir_module: ir.Module,
3535
entry_name: str,
3636
pipeline: str,
37-
mlir_args: list,
38-
ir_printing=False,
39-
repeat_time=100,
40-
warm_up=20,
37+
mlir_args: List[Any],
38+
ir_printing: bool = False,
39+
repeat_time: int = 100,
40+
warm_up: int = 20,
4141
) -> Tuple[float, float]:
4242
"""benchmark mlir with python timeit."""
4343
compiler = GraphCompiler(pipeline)
@@ -64,10 +64,10 @@ def mlir_wrapper_bench(
6464
ir_module: ir.Module,
6565
entry_name: str,
6666
pipeline: str,
67-
mlir_args: list,
68-
ir_printing=False,
69-
repeat_time=100,
70-
warm_up=20,
67+
mlir_args: List[Any],
68+
ir_printing: bool = False,
69+
repeat_time: int = 100,
70+
warm_up: int = 20,
7171
) -> Tuple[float, float]:
7272
"""benchmark mlir with a wrapper func."""
7373
kernel_func = get_kernel_func_from_module(ir_module, entry_name)

test/benchgc/src/benchgc/linalg/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"softmax",
4242
"conv",
4343
"pool",
44+
"reduce",
4445
]:
4546
mod = importlib.import_module(f"benchgc.linalg.{dri}")
4647
for key in mod.__dict__:

0 commit comments

Comments
 (0)