Skip to content

Commit 21b4dfe

Browse files
committed
format code
1 parent 8e5d071 commit 21b4dfe

File tree

5 files changed

+79
-125
lines changed

5 files changed

+79
-125
lines changed

lib/gc/Analysis/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
22
MLIRIR
33
MLIRSupport)
44

5-
add_mlir_library(GCAnalysis
5+
gc_add_mlir_library(GCAnalysis
66
MatmulConfigAnalysis.cpp
77

88
ADDITIONAL_HEADER_DIRS
@@ -14,6 +14,5 @@ add_mlir_library(GCAnalysis
1414
LINK_LIBS PUBLIC
1515
${mlir_dialect_libs}
1616
${MLIR_LINK_COMPONENTS}
17-
)
18-
19-
set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)
17+
GcInterface
18+
)

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ getCandidate(uint32_t num, uint32_t floor,
4444
// factor
4545
std::vector<uint32_t> candidates;
4646
uint32_t upperbound = std::min(num, ceil);
47-
for (uint32_t i = floor; i <= upperbound; i++) {
48-
if (num % i == 0) {
47+
for (uint32_t i = floor; i <= upperbound; i++)
48+
if (num % i == 0)
4949
candidates.push_back(i);
50-
}
51-
}
50+
5251
// the pow of 2
5352
uint32_t candidate = 1U;
5453
while (candidate < floor)
@@ -68,9 +67,8 @@ getCandidate(uint32_t num, uint32_t floor,
6867
bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
6968
uint32_t numThreads = sysDesc.getNumThreads();
7069
uint32_t actualThreads = 1U;
71-
for (uint32_t t : threads) {
70+
for (uint32_t t : threads)
7271
actualThreads *= t;
73-
}
7472
return actualThreads == numThreads;
7573
}
7674

@@ -154,9 +152,8 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
154152
config.NBlock * config.KBlock +
155153
config.MBlock * config.KBlock;
156154
double computationIntensity = FLOPS / memoryConsumption;
157-
if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) {
155+
if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio)
158156
computationIntensity /= outOfCachePenalty;
159-
}
160157
return 1 / computationIntensity;
161158
}
162159

@@ -183,19 +180,17 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
183180
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
184181
thresholdCost =
185182
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
186-
for (const auto &i : idx) {
187-
if (costs[i] <= thresholdCost) {
183+
for (const auto &i : idx)
184+
if (costs[i] <= thresholdCost)
188185
result.push_back(configs[i]);
189-
}
190-
}
186+
191187
LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost
192188
<< "\nbest with cost: " << costs[idx[0]] << "\n"
193189
<< configs[idx[0]] << "\n worst with cost: "
194190
<< costs[idx[configs.size() - 1]] << "\n"
195191
<< configs[idx[configs.size() - 1]] << "\n");
196-
if (result.empty()) {
192+
if (result.empty())
197193
result = configs;
198-
}
199194
return result;
200195
}
201196

@@ -248,27 +243,23 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
248243
for (uint32_t MThreads : MThreadsCandidates) {
249244
for (uint32_t NThreads : NThreadsCandidates) {
250245
for (uint32_t KThreads : KThreadsCandidates) {
251-
if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) {
246+
if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc))
252247
continue;
253-
}
254248
for (uint32_t MBlock : MBlockCandidates) {
255249
for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
256250
if (MBlock % innerMostMBlock != 0 ||
257-
shape[0] % innerMostMBlock != 0) {
251+
shape[0] % innerMostMBlock != 0)
258252
continue;
259-
}
260253
for (uint32_t NBlock : NBlockCandidates) {
261254
for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
262255
if (NBlock % innerMostNBlock != 0 ||
263-
shape[1] % innerMostNBlock != 0) {
256+
shape[1] % innerMostNBlock != 0)
264257
continue;
265-
}
266258
for (uint32_t KBlock : KBlockCandidates) {
267259
for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
268260
if (KBlock % innerMostKBlock != 0 ||
269-
shape[2] % innerMostKBlock != 0) {
261+
shape[2] % innerMostKBlock != 0)
270262
continue;
271-
}
272263
MatmulConfig config{
273264
MThreads, NThreads, KThreads,
274265
MBlock, NBlock, KBlock,
@@ -293,14 +284,12 @@ bool validateConfig(const MatmulConfig &cfg) {
293284
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294285
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295286
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296-
cfg.innerMostKBlock <= 0) {
287+
cfg.innerMostKBlock <= 0)
297288
return false;
298-
}
299289
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300290
cfg.NBlock % cfg.innerMostNBlock != 0 ||
301-
cfg.KBlock % cfg.innerMostKBlock != 0) {
291+
cfg.KBlock % cfg.innerMostKBlock != 0)
302292
return false;
303-
}
304293
return true;
305294
}
306295

@@ -371,19 +360,16 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
371360
uint32_t M = 1U, N = 1U, K = 1U;
372361
for (auto &&[s, dimType] :
373362
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)),
374-
oprandDimType[0])) {
375-
if (dimType == DimType::M) {
363+
oprandDimType[0]))
364+
if (dimType == DimType::M)
376365
M *= s;
377-
}
378-
}
379366
for (auto &&[s, dimType] :
380367
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)),
381368
oprandDimType[1])) {
382-
if (dimType == DimType::N) {
369+
if (dimType == DimType::N)
383370
N *= s;
384-
} else if (dimType == DimType::K) {
371+
else if (dimType == DimType::K)
385372
K *= s;
386-
}
387373
}
388374

389375
// innermost Block, if the layout is blockied layout, the innermost block
@@ -395,30 +381,30 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
395381
SmallVector<uint32_t> givenInnermostBlock;
396382
if (MDimTypeIdx.size() > 1) {
397383
config.innerMostMBlock = 1;
398-
for (size_t i = 1UL; i < MDimTypeIdx.size(); i++) {
399-
config.innerMostMBlock *=
400-
linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]];
401-
}
384+
for (auto &&[i, d] : llvm::enumerate(MDimTypeIdx))
385+
if (i != 0)
386+
config.innerMostMBlock *=
387+
linalgOp.getShape(linalgOp.getDpsInputOperand(0))[d];
402388
givenInnermostBlock.push_back(config.innerMostMBlock);
403389
} else {
404390
givenInnermostBlock.push_back(0);
405391
}
406392
if (NDimTypeIdx.size() > 1) {
407393
config.innerMostNBlock = 1;
408-
for (size_t i = 1UL; i < NDimTypeIdx.size(); i++) {
409-
config.innerMostNBlock *=
410-
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]];
411-
}
394+
for (auto &&[i, d] : llvm::enumerate(NDimTypeIdx))
395+
if (i != 0)
396+
config.innerMostNBlock *=
397+
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d];
412398
givenInnermostBlock.push_back(config.innerMostNBlock);
413399
} else {
414400
givenInnermostBlock.push_back(0);
415401
}
416402
if (KDimTypeIdx.size() > 1) {
417403
config.innerMostKBlock = 1;
418-
for (size_t i = 1UL; i < KDimTypeIdx.size(); i++) {
419-
config.innerMostKBlock *=
420-
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]];
421-
}
404+
for (auto &&[i, d] : llvm::enumerate(KDimTypeIdx))
405+
if (i != 0)
406+
config.innerMostKBlock *=
407+
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d];
422408
givenInnermostBlock.push_back(config.innerMostKBlock);
423409
} else {
424410
givenInnermostBlock.push_back(0);
@@ -444,13 +430,11 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
444430
SmallVector<uint32_t> shape = {M, N, K};
445431
std::vector<MatmulConfig> configCandidates =
446432
prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock);
447-
for (auto &&[fn, name, threshold] : costModelList) {
433+
for (auto &&[fn, name, threshold] : costModelList)
448434
configCandidates = filterConfigByCostModel(
449435
configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold);
450-
}
451-
if (!configCandidates.empty()) {
436+
if (!configCandidates.empty())
452437
config = configCandidates[0];
453-
}
454438
}
455439

456440
LLVM_DEBUG(llvm::dbgs()

0 commit comments

Comments
 (0)