@@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
29
29
30
30
template <typename T>
31
31
static llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
32
- std::vector<T> arry ) {
32
+ std::vector<T> array ) {
33
33
ss << " [" ;
34
- for (auto [idx, a] : llvm::enumerate (arry)) {
35
- if (idx != 0 ) {
36
- ss << " , " ;
37
- }
38
- ss << a;
39
- }
34
+ llvm::interleaveComma (array, ss);
40
35
ss << " ]" ;
41
36
return ss;
42
37
}
@@ -174,7 +169,7 @@ std::vector<MatmulConfig>
174
169
filterConfigByCostModel (ArrayRef<MatmulConfig> configs,
175
170
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
176
171
SystemDesc &sysDesc, const CostModelFn &costModel,
177
- float eliminationRatio = 0.5 , float threshold = -1 ) {
172
+ float preserveRatio = 0.5 , float threshold = -1 ) {
178
173
std::vector<MatmulConfig> result;
179
174
std::vector<float > costs;
180
175
std::vector<size_t > idx;
@@ -185,8 +180,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
185
180
std::stable_sort (idx.begin (), idx.end (), [&costs](size_t i1, size_t i2) {
186
181
return costs[i1] < costs[i2];
187
182
});
188
- double thresholdCost =
189
- costs[idx[(size_t )(eliminationRatio * configs.size ())]];
183
+ double thresholdCost = costs[idx[(size_t )(preserveRatio * configs.size ())]];
190
184
thresholdCost =
191
185
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
192
186
for (size_t i = 0 ; i < configs.size (); i++) {
@@ -210,6 +204,11 @@ std::vector<MatmulConfig>
210
204
prepareConfigCandidates (Operation *root, SystemDesc &sysDesc,
211
205
ArrayRef<uint32_t > shape,
212
206
ArrayRef<uint32_t > givenInnermostBlock) {
207
+ if (shape.size () < 3 ) {
208
+ LLVM_DEBUG (llvm::dbgs ()
209
+ << " The shape is invalid, no candidate is generated\n " );
210
+ return {};
211
+ }
213
212
std::vector<MatmulConfig> configs;
214
213
uint32_t threads = sysDesc.getNumThreads ();
215
214
std::vector<uint32_t > MThreadsCandidates =
@@ -290,6 +289,21 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
290
289
return configs;
291
290
}
292
291
292
+ bool validateConfig (const MatmulConfig &cfg) {
293
+ if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294
+ cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295
+ cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296
+ cfg.innerMostKBlock <= 0 ) {
297
+ return false ;
298
+ }
299
+ if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300
+ cfg.NBlock % cfg.innerMostNBlock != 0 ||
301
+ cfg.KBlock % cfg.innerMostKBlock != 0 ) {
302
+ return false ;
303
+ }
304
+ return true ;
305
+ }
306
+
293
307
// read the config from the attributes for tuning
294
308
bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
295
309
size_t cfgItemCnt = 0 ;
@@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
323
337
cfgItemCnt++;
324
338
}
325
339
}
326
- return cfgItemCnt == 9 ;
340
+ if (validateConfig (config)) {
341
+ return cfgItemCnt == 9 ;
342
+ } else {
343
+ LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
344
+ return false ;
345
+ }
327
346
}
328
347
329
348
// Analyze the workload and system description to generate the default config
0 commit comments