Skip to content

Commit 974b8ca

Browse files
committed
fix comments
1 parent ed5180d commit 974b8ca

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
2929

3030
template <typename T>
3131
static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
32-
std::vector<T> arry) {
32+
std::vector<T> array) {
3333
ss << "[";
34-
for (auto [idx, a] : llvm::enumerate(arry)) {
35-
if (idx != 0) {
36-
ss << ", ";
37-
}
38-
ss << a;
39-
}
34+
llvm::interleaveComma(array, ss);
4035
ss << "]";
4136
return ss;
4237
}
@@ -174,7 +169,7 @@ std::vector<MatmulConfig>
174169
filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
175170
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
176171
SystemDesc &sysDesc, const CostModelFn &costModel,
177-
float eliminationRatio = 0.5, float threshold = -1) {
172+
float preserveRatio = 0.5, float threshold = -1) {
178173
std::vector<MatmulConfig> result;
179174
std::vector<float> costs;
180175
std::vector<size_t> idx;
@@ -185,8 +180,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
185180
std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) {
186181
return costs[i1] < costs[i2];
187182
});
188-
double thresholdCost =
189-
costs[idx[(size_t)(eliminationRatio * configs.size())]];
183+
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
190184
thresholdCost =
191185
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
192186
for (size_t i = 0; i < configs.size(); i++) {
@@ -210,6 +204,11 @@ std::vector<MatmulConfig>
210204
prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
211205
ArrayRef<uint32_t> shape,
212206
ArrayRef<uint32_t> givenInnermostBlock) {
207+
if (shape.size() < 3) {
208+
LLVM_DEBUG(llvm::dbgs()
209+
<< "The shape is invalid, no candidate is generated\n");
210+
return {};
211+
}
213212
std::vector<MatmulConfig> configs;
214213
uint32_t threads = sysDesc.getNumThreads();
215214
std::vector<uint32_t> MThreadsCandidates =
@@ -290,6 +289,21 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
290289
return configs;
291290
}
292291

292+
bool validateConfig(const MatmulConfig &cfg) {
293+
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294+
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295+
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296+
cfg.innerMostKBlock <= 0) {
297+
return false;
298+
}
299+
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300+
cfg.NBlock % cfg.innerMostNBlock != 0 ||
301+
cfg.KBlock % cfg.innerMostKBlock != 0) {
302+
return false;
303+
}
304+
return true;
305+
}
306+
293307
// read the config from the attributes for tuning
294308
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
295309
size_t cfgItemCnt = 0;
@@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
323337
cfgItemCnt++;
324338
}
325339
}
326-
return cfgItemCnt == 9;
340+
if (validateConfig(config)) {
341+
return cfgItemCnt == 9;
342+
} else {
343+
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
344+
return false;
345+
}
327346
}
328347

329348
// Analyze the workload and system description to generate the default config

0 commit comments

Comments
 (0)