25
25
module {{
26
26
27
27
util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_ {n }_ {k }_ {sup_count }_ {sub_count }_ {bs }_ {a_type }(
28
- %a: !a_tensor_type ,
29
- %d: !d_tensor_type ,
28
+ %a: !a_tensor_type ,
29
+ %d: !d_tensor_type ,
30
30
%dmin: !dmin_tensor_type ,
31
31
%sb_scales_hi_i8: !sb_hi_i8_type ,
32
32
%sb_scales_low_i8: !sb_low_i8_type ,
@@ -59,11 +59,11 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
59
59
affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>, // sb_mins_hi[n, sup, sub]
60
60
affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>, // sb_mins_low[n, sup, sub]
61
61
affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )> // out b_grouped[n, sup, sub, bs]
62
- ],
62
+ ],
63
63
iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ] }}
64
64
ins (
65
65
%qs , %d , %dmin , %sb_scales_hi , %sb_scales_low , %sb_mins_hi , %sb_mins_low :
66
- !qs_tensor_type , !d_tensor_type , !dmin_tensor_type ,
66
+ !qs_tensor_type , !d_tensor_type , !dmin_tensor_type ,
67
67
!sb_hi_i2_type , !sb_low_i4_type , !sb_hi_i2_type , !sb_low_i4_type
68
68
)
69
69
outs (%b_grouped : !b_grouped_tensor_type ) {{
@@ -74,7 +74,7 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
74
74
%shift_4 = arith.constant 4 : i32
75
75
%d_element_ext = arith.extf %d_element : !scale_type to !a_type
76
76
%dmin_element_ext = arith.extf %dmin_element : !scale_type to !a_type
77
-
77
+
78
78
// Combine sub-block scale.
79
79
%sb_scale_low_i32 = arith.extui %sb_scales_low_element : i4 to i32
80
80
%sb_scale_hi_i32 = arith.extui %sb_scales_hi_element : i2 to i32
@@ -111,8 +111,8 @@ util.func private @mmt_super_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{sup_cou
111
111
affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d4 , d5 )>, // aexp
112
112
affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d2 , d3 , d4 , d5 )>, // b_grouped_dequant
113
113
affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 )> // out
114
- ],
115
- iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" , " reduction" , " reduction" ] }}
114
+ ],
115
+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" , " reduction" , " reduction" ] }}
116
116
ins (%aexp , %b_grouped_dequant : !aexp_tensor_type , !b_grouped_tensor_type )
117
117
outs (%result_fill : !c_tensor_type ) {{
118
118
^bb0 (%a_element: !a_type , %b_element: !a_type , %out: !a_type ):
0 commit comments