@@ -44,11 +44,10 @@ getCandidate(uint32_t num, uint32_t floor,
44
44
// factor
45
45
std::vector<uint32_t > candidates;
46
46
uint32_t upperbound = std::min (num, ceil);
47
- for (uint32_t i = floor; i <= upperbound; i++) {
48
- if (num % i == 0 ) {
47
+ for (uint32_t i = floor; i <= upperbound; i++)
48
+ if (num % i == 0 )
49
49
candidates.push_back (i);
50
- }
51
- }
50
+
52
51
// the pow of 2
53
52
uint32_t candidate = 1U ;
54
53
while (candidate < floor)
@@ -68,9 +67,8 @@ getCandidate(uint32_t num, uint32_t floor,
68
67
bool validateThreads (ArrayRef<uint32_t > threads, SystemDesc &sysDesc) {
69
68
uint32_t numThreads = sysDesc.getNumThreads ();
70
69
uint32_t actualThreads = 1U ;
71
- for (uint32_t t : threads) {
70
+ for (uint32_t t : threads)
72
71
actualThreads *= t;
73
- }
74
72
return actualThreads == numThreads;
75
73
}
76
74
@@ -154,9 +152,8 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
154
152
config.NBlock * config.KBlock +
155
153
config.MBlock * config.KBlock ;
156
154
double computationIntensity = FLOPS / memoryConsumption;
157
- if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) {
155
+ if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio)
158
156
computationIntensity /= outOfCachePenalty;
159
- }
160
157
return 1 / computationIntensity;
161
158
}
162
159
@@ -183,19 +180,17 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
183
180
double thresholdCost = costs[idx[(size_t )(preserveRatio * configs.size ())]];
184
181
thresholdCost =
185
182
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
186
- for (const auto &i : idx) {
187
- if (costs[i] <= thresholdCost) {
183
+ for (const auto &i : idx)
184
+ if (costs[i] <= thresholdCost)
188
185
result.push_back (configs[i]);
189
- }
190
- }
186
+
191
187
LLVM_DEBUG (llvm::dbgs () << " thresholdCost is: " << thresholdCost
192
188
<< " \n best with cost: " << costs[idx[0 ]] << " \n "
193
189
<< configs[idx[0 ]] << " \n worst with cost: "
194
190
<< costs[idx[configs.size () - 1 ]] << " \n "
195
191
<< configs[idx[configs.size () - 1 ]] << " \n " );
196
- if (result.empty ()) {
192
+ if (result.empty ())
197
193
result = configs;
198
- }
199
194
return result;
200
195
}
201
196
@@ -248,27 +243,23 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
248
243
for (uint32_t MThreads : MThreadsCandidates) {
249
244
for (uint32_t NThreads : NThreadsCandidates) {
250
245
for (uint32_t KThreads : KThreadsCandidates) {
251
- if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc)) {
246
+ if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc))
252
247
continue ;
253
- }
254
248
for (uint32_t MBlock : MBlockCandidates) {
255
249
for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
256
250
if (MBlock % innerMostMBlock != 0 ||
257
- shape[0 ] % innerMostMBlock != 0 ) {
251
+ shape[0 ] % innerMostMBlock != 0 )
258
252
continue ;
259
- }
260
253
for (uint32_t NBlock : NBlockCandidates) {
261
254
for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
262
255
if (NBlock % innerMostNBlock != 0 ||
263
- shape[1 ] % innerMostNBlock != 0 ) {
256
+ shape[1 ] % innerMostNBlock != 0 )
264
257
continue ;
265
- }
266
258
for (uint32_t KBlock : KBlockCandidates) {
267
259
for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
268
260
if (KBlock % innerMostKBlock != 0 ||
269
- shape[2 ] % innerMostKBlock != 0 ) {
261
+ shape[2 ] % innerMostKBlock != 0 )
270
262
continue ;
271
- }
272
263
MatmulConfig config{
273
264
MThreads, NThreads, KThreads,
274
265
MBlock, NBlock, KBlock,
@@ -293,14 +284,12 @@ bool validateConfig(const MatmulConfig &cfg) {
293
284
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294
285
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295
286
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296
- cfg.innerMostKBlock <= 0 ) {
287
+ cfg.innerMostKBlock <= 0 )
297
288
return false ;
298
- }
299
289
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300
290
cfg.NBlock % cfg.innerMostNBlock != 0 ||
301
- cfg.KBlock % cfg.innerMostKBlock != 0 ) {
291
+ cfg.KBlock % cfg.innerMostKBlock != 0 )
302
292
return false ;
303
- }
304
293
return true ;
305
294
}
306
295
@@ -371,19 +360,16 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
371
360
uint32_t M = 1U , N = 1U , K = 1U ;
372
361
for (auto &&[s, dimType] :
373
362
llvm::zip (linalgOp.getShape (linalgOp.getDpsInputOperand (0 )),
374
- oprandDimType[0 ])) {
375
- if (dimType == DimType::M) {
363
+ oprandDimType[0 ]))
364
+ if (dimType == DimType::M)
376
365
M *= s;
377
- }
378
- }
379
366
for (auto &&[s, dimType] :
380
367
llvm::zip (linalgOp.getShape (linalgOp.getDpsInputOperand (1 )),
381
368
oprandDimType[1 ])) {
382
- if (dimType == DimType::N) {
369
+ if (dimType == DimType::N)
383
370
N *= s;
384
- } else if (dimType == DimType::K) {
371
+ else if (dimType == DimType::K)
385
372
K *= s;
386
- }
387
373
}
388
374
389
375
// innermost Block, if the layout is blockied layout, the innermost block
@@ -395,30 +381,30 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
395
381
SmallVector<uint32_t > givenInnermostBlock;
396
382
if (MDimTypeIdx.size () > 1 ) {
397
383
config.innerMostMBlock = 1 ;
398
- for (size_t i = 1UL ; i < MDimTypeIdx. size (); i++) {
399
- config. innerMostMBlock *=
400
- linalgOp. getShape (linalgOp. getDpsInputOperand ( 0 ))[MDimTypeIdx[i]];
401
- }
384
+ for (auto &&[i, d] : llvm::enumerate (MDimTypeIdx))
385
+ if (i != 0 )
386
+ config. innerMostMBlock *=
387
+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 0 ))[d];
402
388
givenInnermostBlock.push_back (config.innerMostMBlock );
403
389
} else {
404
390
givenInnermostBlock.push_back (0 );
405
391
}
406
392
if (NDimTypeIdx.size () > 1 ) {
407
393
config.innerMostNBlock = 1 ;
408
- for (size_t i = 1UL ; i < NDimTypeIdx. size (); i++) {
409
- config. innerMostNBlock *=
410
- linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[NDimTypeIdx[i]];
411
- }
394
+ for (auto &&[i, d] : llvm::enumerate (NDimTypeIdx))
395
+ if (i != 0 )
396
+ config. innerMostNBlock *=
397
+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[d];
412
398
givenInnermostBlock.push_back (config.innerMostNBlock );
413
399
} else {
414
400
givenInnermostBlock.push_back (0 );
415
401
}
416
402
if (KDimTypeIdx.size () > 1 ) {
417
403
config.innerMostKBlock = 1 ;
418
- for (size_t i = 1UL ; i < KDimTypeIdx. size (); i++) {
419
- config. innerMostKBlock *=
420
- linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[KDimTypeIdx[i]];
421
- }
404
+ for (auto &&[i, d] : llvm::enumerate (KDimTypeIdx))
405
+ if (i != 0 )
406
+ config. innerMostKBlock *=
407
+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[d];
422
408
givenInnermostBlock.push_back (config.innerMostKBlock );
423
409
} else {
424
410
givenInnermostBlock.push_back (0 );
@@ -444,13 +430,11 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
444
430
SmallVector<uint32_t > shape = {M, N, K};
445
431
std::vector<MatmulConfig> configCandidates =
446
432
prepareConfigCandidates (root, sysDesc, shape, givenInnermostBlock);
447
- for (auto &&[fn, name, threshold] : costModelList) {
433
+ for (auto &&[fn, name, threshold] : costModelList)
448
434
configCandidates = filterConfigByCostModel (
449
435
configCandidates, linalgOp, shape, sysDesc, fn, 0.5 , threshold);
450
- }
451
- if (!configCandidates.empty ()) {
436
+ if (!configCandidates.empty ())
452
437
config = configCandidates[0 ];
453
- }
454
438
}
455
439
456
440
LLVM_DEBUG (llvm::dbgs ()
0 commit comments