17
17
from typing import Dict , Tuple
18
18
19
19
import benchgc .util
20
- import gc_mlir ._mlir_libs ._mlir .ir
21
20
import torch
22
21
from benchgc .mlir .util import MLIRCache
23
22
from gc_mlir import ir
@@ -42,6 +41,19 @@ def ref_constant(
42
41
)
43
42
else :
44
43
raise Exception ("only support splat value now" )
44
+ elif isinstance (value , ir .IntegerAttr ):
45
+ return (torch .full (size = tuple (), fill_value = value .__int__ (), dtype = torch .int ),)
46
+ elif isinstance (value , ir .DenseIntElementsAttr ):
47
+ if value .is_splat :
48
+ return (
49
+ torch .full (
50
+ size = tuple (value .type .shape ),
51
+ fill_value = value .get_splat_value ().value ,
52
+ dtype = benchgc .util .get_dtype (str (value .get_splat_value ().type )),
53
+ ),
54
+ )
55
+ else :
56
+ raise Exception ("only support splat value now" )
45
57
else :
46
58
raise Exception ("Not support constant type %s" , type (value ))
47
59
@@ -56,3 +68,39 @@ def ref_addf(
56
68
cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
57
69
) -> Tuple [torch .Tensor , ...]:
58
70
return (var [cache .opr [0 ]] + var [cache .opr [1 ]],)
71
+
72
+
73
+ def ref_maxf (
74
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
75
+ ) -> Tuple [torch .Tensor , ...]:
76
+ return (torch .max (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
77
+
78
+
79
+ def ref_minf (
80
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
81
+ ) -> Tuple [torch .Tensor , ...]:
82
+ return (torch .min (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
83
+
84
+
85
+ def ref_muli (
86
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
87
+ ) -> Tuple [torch .Tensor , ...]:
88
+ return (var [cache .opr [0 ]] * var [cache .opr [1 ]],)
89
+
90
+
91
+ def ref_addi (
92
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
93
+ ) -> Tuple [torch .Tensor , ...]:
94
+ return (var [cache .opr [0 ]] + var [cache .opr [1 ]],)
95
+
96
+
97
+ def ref_maxsi (
98
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
99
+ ) -> Tuple [torch .Tensor , ...]:
100
+ return (torch .max (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
101
+
102
+
103
+ def ref_minsi (
104
+ cache : MLIRCache , op : ir .OpView , var : Dict [str , torch .Tensor ]
105
+ ) -> Tuple [torch .Tensor , ...]:
106
+ return (torch .min (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
0 commit comments