@@ -53,101 +53,101 @@ class MatMulConfig(Config):
53
53
def __init__ (
54
54
self ,
55
55
op : OpView ,
56
- M_threads : int = 1 ,
57
- K_threads : int = 1 ,
58
- N_threads : int = 1 ,
59
- M_block : int = 1 ,
60
- K_block : int = 1 ,
61
- N_block : int = 1 ,
62
- innermostM_block : int = 1 ,
63
- innermostK_block : int = 1 ,
64
- innermostN_block : int = 1 ,
56
+ MThreads : int = 1 ,
57
+ KThreads : int = 1 ,
58
+ NThreads : int = 1 ,
59
+ MBlock : int = 1 ,
60
+ KBlock : int = 1 ,
61
+ NBlock : int = 1 ,
62
+ innerMostMBlock : int = 1 ,
63
+ innerMostKBlock : int = 1 ,
64
+ innerMostNBlock : int = 1 ,
65
65
):
66
66
# you can set the default value and candidates by info from matmul_op
67
- self .M = op .inputs [0 ].type .shape [0 ]
68
- self .K = op .inputs [0 ].type .shape [1 ]
69
- self .N = op .inputs [1 ].type .shape [1 ]
67
+ self .m = op .inputs [0 ].type .shape [0 ]
68
+ self .k = op .inputs [0 ].type .shape [1 ]
69
+ self .n = op .inputs [1 ].type .shape [1 ]
70
70
# self.input_a_dtype = str(op.inputs[0].type.element_type)
71
71
self .num_threads = int (os .environ .get ("OMP_NUM_THREADS" , 1 ))
72
- self .M_threads = M_threads
73
- self .K_threads = K_threads
74
- self .N_threads = N_threads
75
- self .M_block = M_block
76
- self .K_block = K_block
77
- self .N_block = N_block
78
- self .innermostM_block = innermostM_block
79
- self .innermostK_block = innermostK_block
80
- self .innermostN_block = innermostN_block
72
+ self .m_threads = MThreads
73
+ self .k_threads = KThreads
74
+ self .n_threads = NThreads
75
+ self .m_block = MBlock
76
+ self .k_block = KBlock
77
+ self .n_block = NBlock
78
+ self .innermost_m_block = innerMostMBlock
79
+ self .innermost_k_block = innerMostKBlock
80
+ self .innermost_n_block = innerMostNBlock
81
81
super ().__init__ ()
82
82
83
83
def init_candidates (self ):
84
84
default_blocks = [16 , 32 , 64 , 128 , 256 , 512 ]
85
85
default_innermost_blocks = [16 , 32 ]
86
- self .field_candidates ["M_threads " ] = find_factors (self .num_threads )
87
- self .field_candidates ["K_threads " ] = find_factors (self .num_threads )
88
- self .field_candidates ["N_threads " ] = find_factors (self .num_threads )
89
- self .field_candidates ["M_block " ] = [
90
- block for block in default_blocks if self .M >= block
86
+ self .field_candidates ["m_threads " ] = find_factors (self .num_threads )
87
+ self .field_candidates ["k_threads " ] = find_factors (self .num_threads )
88
+ self .field_candidates ["n_threads " ] = find_factors (self .num_threads )
89
+ self .field_candidates ["m_block " ] = [
90
+ block for block in default_blocks if self .m >= block
91
91
]
92
- self .field_candidates ["K_block " ] = [
93
- block for block in default_blocks if self .K >= block
92
+ self .field_candidates ["k_block " ] = [
93
+ block for block in default_blocks if self .k >= block
94
94
]
95
- self .field_candidates ["N_block " ] = [
96
- block for block in default_blocks if self .N >= block
95
+ self .field_candidates ["n_block " ] = [
96
+ block for block in default_blocks if self .n >= block
97
97
]
98
- self .field_candidates ["innermostM_block " ] = [
99
- block for block in default_innermost_blocks if self .M >= block
98
+ self .field_candidates ["innermost_m_block " ] = [
99
+ block for block in default_innermost_blocks if self .m >= block
100
100
]
101
- self .field_candidates ["innermostK_block " ] = [
102
- block for block in default_innermost_blocks if self .K >= block
101
+ self .field_candidates ["innermost_k_block " ] = [
102
+ block for block in default_innermost_blocks if self .k >= block
103
103
]
104
- self .field_candidates ["innermostN_block " ] = [
105
- block for block in default_innermost_blocks if self .N >= block
104
+ self .field_candidates ["innermost_n_block " ] = [
105
+ block for block in default_innermost_blocks if self .n >= block
106
106
]
107
107
108
108
def init_constraints (self ):
109
109
# example: using lambda to add constraints, adding constraints by the order of the fields
110
- self .field_constraints ["M_threads " ] = None
111
- self .field_constraints ["K_threads " ] = (
112
- lambda MatMulConfig , K_threads : self .num_threads
113
- % (MatMulConfig .M_threads * K_threads )
110
+ self .field_constraints ["m_threads " ] = None
111
+ self .field_constraints ["k_threads " ] = (
112
+ lambda MatMulConfig , k_threads : self .num_threads
113
+ % (MatMulConfig .m_threads * k_threads )
114
114
== 0
115
115
)
116
- self .field_constraints ["N_threads " ] = (
117
- lambda MatMulConfig , N_threads : self .num_threads
118
- % (MatMulConfig .M_threads * MatMulConfig .K_threads * N_threads )
116
+ self .field_constraints ["n_threads " ] = (
117
+ lambda MatMulConfig , n_threads : self .num_threads
118
+ % (MatMulConfig .m_threads * MatMulConfig .k_threads * n_threads )
119
119
== 0
120
120
)
121
- self .field_constraints ["M_block " ] = None
122
- self .field_constraints ["K_block " ] = None
123
- self .field_constraints ["N_block " ] = None
124
- self .field_constraints ["innermostM_block " ] = (
125
- lambda MatMulConfig , innermostM_block : MatMulConfig .M_block
126
- % innermostM_block
121
+ self .field_constraints ["m_block " ] = None
122
+ self .field_constraints ["k_block " ] = None
123
+ self .field_constraints ["n_block " ] = None
124
+ self .field_constraints ["innermost_m_block " ] = (
125
+ lambda MatMulConfig , innermost_m_block : MatMulConfig .m_block
126
+ % innermost_m_block
127
127
== 0
128
128
)
129
- self .field_constraints ["innermostK_block " ] = (
130
- lambda MatMulConfig , innermostK_block : MatMulConfig .K_block
131
- % innermostK_block
129
+ self .field_constraints ["innermost_k_block " ] = (
130
+ lambda MatMulConfig , innermost_k_block : MatMulConfig .k_block
131
+ % innermost_k_block
132
132
== 0
133
133
)
134
- self .field_constraints ["innermostN_block " ] = (
135
- lambda MatMulConfig , innermostN_block : MatMulConfig .N_block
136
- % innermostN_block
134
+ self .field_constraints ["innermost_n_block " ] = (
135
+ lambda MatMulConfig , innermost_n_block : MatMulConfig .n_block
136
+ % innermost_n_block
137
137
== 0
138
138
)
139
139
140
140
def attach_to_ir (self , op : OpView ):
141
141
attr_to_field = {
142
- "Mthreads " : self .M_threads ,
143
- "Kthreads " : self .K_threads ,
144
- "Nthreads " : self .N_threads ,
145
- "MBlock" : self .M_block ,
146
- "KBlock" : self .K_block ,
147
- "NBlock" : self .N_block ,
148
- "innermostMBlock " : self .innermostM_block ,
149
- "innermostKBlock " : self .innermostK_block ,
150
- "innermostNBlock " : self .innermostN_block ,
142
+ "MThreads " : self .m_threads ,
143
+ "KThreads " : self .k_threads ,
144
+ "NThreads " : self .n_threads ,
145
+ "MBlock" : self .m_block ,
146
+ "KBlock" : self .k_block ,
147
+ "NBlock" : self .n_block ,
148
+ "innerMostMBlock " : self .innermost_m_block ,
149
+ "innerMostKBlock " : self .innermost_k_block ,
150
+ "innerMostNBlock " : self .innermost_n_block ,
151
151
}
152
152
for name , value in attr_to_field .items ():
153
153
op .attributes [name ] = IntegerAttr .get (T .i32 (), value )
@@ -158,15 +158,15 @@ def __repr__(self) -> str:
158
158
def __str__ (self ) -> str :
159
159
obj_dict = {
160
160
"MatMulConfig" : {
161
- "M_threads " : self .M_threads ,
162
- "K_threads " : self .K_threads ,
163
- "N_threads " : self .N_threads ,
164
- "M_block " : self .M_block ,
165
- "K_block " : self .K_block ,
166
- "N_block " : self .N_block ,
167
- "innermostM_block " : self .innermostM_block ,
168
- "innermostK_block " : self .innermostK_block ,
169
- "innermostN_block " : self .innermostN_block ,
161
+ "MThreads " : self .m_threads ,
162
+ "KThreads " : self .k_threads ,
163
+ "NThreads " : self .n_threads ,
164
+ "MBlock " : self .m_block ,
165
+ "KBlock " : self .k_block ,
166
+ "NBlock " : self .n_block ,
167
+ "innerMostMBlock " : self .innermost_m_block ,
168
+ "innerMostKBlock " : self .innermost_k_block ,
169
+ "innerMostNBlock " : self .innermost_n_block ,
170
170
}
171
171
}
172
172
return json .dumps (obj_dict , indent = 4 )
0 commit comments