7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " gc/Analysis/MatmulConfigAnalysis.h"
10
+ #include " gc/Analysis/TargetDescriptionAnalysis.h"
10
11
#include < limits>
11
12
#include < llvm/Support/Debug.h>
12
13
@@ -64,7 +65,8 @@ getCandidate(uint32_t num, uint32_t floor,
64
65
}
65
66
66
67
// check if the threads are valid
67
- bool validateThreads (ArrayRef<uint32_t > threads, SystemDesc &sysDesc) {
68
+ bool validateThreads (ArrayRef<uint32_t > threads,
69
+ CPUTargetDescriptionAnalysis &sysDesc) {
68
70
uint32_t numThreads = sysDesc.getNumThreads ();
69
71
uint32_t actualThreads = 1U ;
70
72
for (uint32_t t : threads)
@@ -77,24 +79,25 @@ bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
77
79
double vectorRegEfficiencyCost (linalg::LinalgOp &linalgOp,
78
80
ArrayRef<uint32_t > shape,
79
81
const MatmulConfig &config,
80
- SystemDesc &sysDesc) {
82
+ CPUTargetDescriptionAnalysis &sysDesc) {
81
83
size_t dtypeSize = DataLayout ().getTypeSizeInBits (
82
84
ShapeAdaptor (linalgOp.getDpsInputs ()[1 ].getType ()).getElementType ());
83
- size_t maxVectorLength = sysDesc.getMaxVectorLength () / dtypeSize;
85
+ size_t maxVectorWidth = sysDesc.getMaxVectorWidth () / dtypeSize;
84
86
// TODO: take matrix register like amx into account
85
- double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength ) %
86
- maxVectorLength * 1.0 / config.innerMostMBlock +
87
- (maxVectorLength - config.innerMostKBlock % maxVectorLength ) %
88
- maxVectorLength * 1.0 / config.innerMostKBlock +
89
- (maxVectorLength - config.innerMostNBlock % maxVectorLength ) %
90
- maxVectorLength * 1.0 / config.innerMostNBlock ;
87
+ double cost = (maxVectorWidth - config.innerMostMBlock % maxVectorWidth ) %
88
+ maxVectorWidth * 1.0 / config.innerMostMBlock +
89
+ (maxVectorWidth - config.innerMostKBlock % maxVectorWidth ) %
90
+ maxVectorWidth * 1.0 / config.innerMostKBlock +
91
+ (maxVectorWidth - config.innerMostNBlock % maxVectorWidth ) %
92
+ maxVectorWidth * 1.0 / config.innerMostNBlock ;
91
93
return cost;
92
94
}
93
95
94
96
// calculate the cost of the workload balance
95
97
double workloadBalancedCost (linalg::LinalgOp &linalgOp,
96
98
ArrayRef<uint32_t > shape,
97
- const MatmulConfig &config, SystemDesc &sysDesc) {
99
+ const MatmulConfig &config,
100
+ CPUTargetDescriptionAnalysis &sysDesc) {
98
101
if (shape.size () < 3 ) {
99
102
// Has an invalid shape
100
103
return 0 ;
@@ -118,7 +121,7 @@ double workloadBalancedCost(linalg::LinalgOp &linalgOp,
118
121
double memoryConsumptionOnThreadCost (linalg::LinalgOp &linalgOp,
119
122
ArrayRef<uint32_t > shape,
120
123
const MatmulConfig &config,
121
- SystemDesc &sysDesc) {
124
+ CPUTargetDescriptionAnalysis &sysDesc) {
122
125
if (shape.size () < 3 ) {
123
126
// Has an invalid shape
124
127
return 0 ;
@@ -141,7 +144,7 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp,
141
144
double computationIntensityOnL2Cache (linalg::LinalgOp &linalgOp,
142
145
ArrayRef<uint32_t > shape,
143
146
const MatmulConfig &config,
144
- SystemDesc &sysDesc) {
147
+ CPUTargetDescriptionAnalysis &sysDesc) {
145
148
double fullLoadRatio = 0.7 ;
146
149
uint32_t L2Cache = sysDesc.getCacheSize (2 );
147
150
size_t dtypeSize = DataLayout ().getTypeSize (
@@ -157,16 +160,17 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
157
160
return 1 / computationIntensity;
158
161
}
159
162
160
- using CostModelFn =
161
- std::function< double ( linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
162
- MatmulConfig cfg, SystemDesc &sysDesc)>;
163
+ using CostModelFn = std::function< double (
164
+ linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape, MatmulConfig cfg ,
165
+ CPUTargetDescriptionAnalysis &sysDesc)>;
163
166
164
167
// filter the config by the cost model
165
168
std::vector<MatmulConfig>
166
169
filterConfigByCostModel (ArrayRef<MatmulConfig> configs,
167
170
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
168
- SystemDesc &sysDesc, const CostModelFn &costModel,
169
- float preserveRatio = 0.5 , float threshold = -1 ) {
171
+ CPUTargetDescriptionAnalysis &sysDesc,
172
+ const CostModelFn &costModel, float preserveRatio = 0.5 ,
173
+ float threshold = -1 ) {
170
174
std::vector<MatmulConfig> result;
171
175
std::vector<float > costs;
172
176
std::vector<size_t > idx;
@@ -196,7 +200,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
196
200
197
201
// prepare the config candidates
198
202
std::vector<MatmulConfig>
199
- prepareConfigCandidates (Operation *root, SystemDesc &sysDesc,
203
+ prepareConfigCandidates (Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
200
204
ArrayRef<uint32_t > shape,
201
205
ArrayRef<uint32_t > givenInnermostBlock) {
202
206
if (shape.size () < 3 ) {
@@ -347,7 +351,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
347
351
// previous matmul
348
352
MatmulConfigAnalysis::MatmulConfigAnalysis (Operation *root) {
349
353
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
350
- SystemDesc sysDesc (root-> getParentOfType <ModuleOp>() );
354
+ CPUTargetDescriptionAnalysis sysDesc (root);
351
355
SmallVector<SmallVector<DimType>> oprandDimType =
352
356
*getOprandDimType (linalgOp);
353
357
// get the origin M,N,K size
0 commit comments