Skip to content

Commit 51527c0

Browse files
committed
replace sysDesc with target info
1 parent efc2d86 commit 51527c0

File tree

3 files changed

+24
-75
lines changed

3 files changed

+24
-75
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,62 +19,6 @@ namespace gc {
1919

2020
using namespace mlir;
2121

22-
struct SystemDesc {
23-
// get runtime OMP_NUM_THREADS
24-
uint32_t getNumThreads() {
25-
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
26-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
27-
Builder(ctx).getStringAttr("num_threads"));
28-
if (numThreads && isa<IntegerAttr>(*numThreads)) {
29-
return dyn_cast<IntegerAttr>(*numThreads).getInt();
30-
}
31-
return 1;
32-
}
33-
// get cache size by cacheLevel
34-
size_t getCacheSize(uint8_t cacheLevel) {
35-
if (cacheLevel == 1) {
36-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
37-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
38-
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
39-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
40-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
41-
}
42-
} else if (cacheLevel == 2) {
43-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
44-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
45-
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
46-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
47-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
48-
}
49-
} else if (cacheLevel == 3) {
50-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
51-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
52-
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
53-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
54-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
55-
}
56-
}
57-
return 0;
58-
}
59-
60-
// get the maximum vector length in bits
61-
size_t getMaxVectorLength() {
62-
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
63-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
64-
Builder(ctx).getStringAttr("max_vector_width"));
65-
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
66-
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
67-
}
68-
return 512;
69-
}
70-
71-
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
72-
73-
private:
74-
DataLayout layout;
75-
MLIRContext *ctx;
76-
};
77-
7822
// The configuration for matmul tiling
7923
// TODO: support batch matmul
8024
struct MatmulConfig {

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Analysis/MatmulConfigAnalysis.h"
10+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
1011
#include <limits>
1112
#include <llvm/Support/Debug.h>
1213

@@ -64,7 +65,8 @@ getCandidate(uint32_t num, uint32_t floor,
6465
}
6566

6667
// check if the threads are valid
67-
bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
68+
bool validateThreads(ArrayRef<uint32_t> threads,
69+
CPUTargetDescriptionAnalysis &sysDesc) {
6870
uint32_t numThreads = sysDesc.getNumThreads();
6971
uint32_t actualThreads = 1U;
7072
for (uint32_t t : threads)
@@ -77,24 +79,25 @@ bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
7779
double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp,
7880
ArrayRef<uint32_t> shape,
7981
const MatmulConfig &config,
80-
SystemDesc &sysDesc) {
82+
CPUTargetDescriptionAnalysis &sysDesc) {
8183
size_t dtypeSize = DataLayout().getTypeSizeInBits(
8284
ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType());
83-
size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize;
85+
size_t maxVectorWidth = sysDesc.getMaxVectorWidth() / dtypeSize;
8486
// TODO: take matrix register like amx into account
85-
double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) %
86-
maxVectorLength * 1.0 / config.innerMostMBlock +
87-
(maxVectorLength - config.innerMostKBlock % maxVectorLength) %
88-
maxVectorLength * 1.0 / config.innerMostKBlock +
89-
(maxVectorLength - config.innerMostNBlock % maxVectorLength) %
90-
maxVectorLength * 1.0 / config.innerMostNBlock;
87+
double cost = (maxVectorWidth - config.innerMostMBlock % maxVectorWidth) %
88+
maxVectorWidth * 1.0 / config.innerMostMBlock +
89+
(maxVectorWidth - config.innerMostKBlock % maxVectorWidth) %
90+
maxVectorWidth * 1.0 / config.innerMostKBlock +
91+
(maxVectorWidth - config.innerMostNBlock % maxVectorWidth) %
92+
maxVectorWidth * 1.0 / config.innerMostNBlock;
9193
return cost;
9294
}
9395

9496
// calculate the cost of the workload balance
9597
double workloadBalancedCost(linalg::LinalgOp &linalgOp,
9698
ArrayRef<uint32_t> shape,
97-
const MatmulConfig &config, SystemDesc &sysDesc) {
99+
const MatmulConfig &config,
100+
CPUTargetDescriptionAnalysis &sysDesc) {
98101
if (shape.size() < 3) {
99102
// Has an invalid shape
100103
return 0;
@@ -118,7 +121,7 @@ double workloadBalancedCost(linalg::LinalgOp &linalgOp,
118121
double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp,
119122
ArrayRef<uint32_t> shape,
120123
const MatmulConfig &config,
121-
SystemDesc &sysDesc) {
124+
CPUTargetDescriptionAnalysis &sysDesc) {
122125
if (shape.size() < 3) {
123126
// Has an invalid shape
124127
return 0;
@@ -141,7 +144,7 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp,
141144
double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
142145
ArrayRef<uint32_t> shape,
143146
const MatmulConfig &config,
144-
SystemDesc &sysDesc) {
147+
CPUTargetDescriptionAnalysis &sysDesc) {
145148
double fullLoadRatio = 0.7;
146149
uint32_t L2Cache = sysDesc.getCacheSize(2);
147150
size_t dtypeSize = DataLayout().getTypeSize(
@@ -157,16 +160,17 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
157160
return 1 / computationIntensity;
158161
}
159162

160-
using CostModelFn =
161-
std::function<double(linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
162-
MatmulConfig cfg, SystemDesc &sysDesc)>;
163+
using CostModelFn = std::function<double(
164+
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape, MatmulConfig cfg,
165+
CPUTargetDescriptionAnalysis &sysDesc)>;
163166

164167
// filter the config by the cost model
165168
std::vector<MatmulConfig>
166169
filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
167170
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
168-
SystemDesc &sysDesc, const CostModelFn &costModel,
169-
float preserveRatio = 0.5, float threshold = -1) {
171+
CPUTargetDescriptionAnalysis &sysDesc,
172+
const CostModelFn &costModel, float preserveRatio = 0.5,
173+
float threshold = -1) {
170174
std::vector<MatmulConfig> result;
171175
std::vector<float> costs;
172176
std::vector<size_t> idx;
@@ -196,7 +200,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
196200

197201
// prepare the config candidates
198202
std::vector<MatmulConfig>
199-
prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
203+
prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
200204
ArrayRef<uint32_t> shape,
201205
ArrayRef<uint32_t> givenInnermostBlock) {
202206
if (shape.size() < 3) {
@@ -347,7 +351,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
347351
// previous matmul
348352
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
349353
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
350-
SystemDesc sysDesc(root->getParentOfType<ModuleOp>());
354+
CPUTargetDescriptionAnalysis sysDesc(root);
351355
SmallVector<SmallVector<DimType>> oprandDimType =
352356
*getOprandDimType(linalgOp);
353357
// get the origin M,N,K size

test/mlir/unittests/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ add_mlir_unittest(GCAnalysisTests
33
)
44
target_link_libraries(GCAnalysisTests
55
PRIVATE
6+
GcPasses
67
GcAnalysis
78
GcJitWrapper)

0 commit comments

Comments
 (0)