Skip to content

Commit f959a73

Browse files
committed
support dlti
1 parent d672629 commit f959a73

File tree

5 files changed

+103
-31
lines changed

5 files changed

+103
-31
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,86 @@
1010
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
1111

1212
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14-
#include <cstring>
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1516

1617
namespace mlir {
1718
namespace gc {
1819

1920
using namespace mlir;
2021

21-
// A mock for the taget information
22-
// TODO: replace it with upstream hardware description model
2322
struct SystemDesc {
24-
25-
static int getPositiveIntFromStr(char *str, int defaultValue = 1) {
26-
if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') {
27-
return defaultValue;
28-
}
29-
auto val = std::stoi(str);
30-
return val > 0 ? val : defaultValue;
31-
}
32-
3323
// get runtime OMP_NUM_THREADS
3424
uint32_t getNumThreads() {
35-
char *numThreads = getenv("OMP_NUM_THREADS");
36-
return getPositiveIntFromStr(numThreads, 1);
25+
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
26+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
27+
Builder(ctx).getStringAttr("num_threads"));
28+
if (numThreads && isa<IntegerAttr>(*numThreads))
29+
{
30+
return dyn_cast<IntegerAttr>(*numThreads).getInt();
31+
}
32+
return 1;
3733
}
3834
// get cache size by cacheLevel
3935
size_t getCacheSize(uint8_t cacheLevel) {
4036
if (cacheLevel == 1) {
41-
char *cacheSize = getenv("L1_CACHE_SIZE");
42-
return getPositiveIntFromStr(cacheSize, 0);
37+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
38+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
39+
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
40+
if (cacheSize && isa<IntegerAttr>(*cacheSize))
41+
{
42+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
43+
}
4344
} else if (cacheLevel == 2) {
44-
char *cacheSize = getenv("L2_CACHE_SIZE");
45-
return getPositiveIntFromStr(cacheSize, 0);
45+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
46+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
47+
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
48+
if (cacheSize && isa<IntegerAttr>(*cacheSize))
49+
{
50+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
51+
}
4652
} else if (cacheLevel == 3) {
47-
char *cacheSize = getenv("L3_CACHE_SIZE");
48-
return getPositiveIntFromStr(cacheSize, 0);
53+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
54+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
55+
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
56+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
57+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
58+
}
4959
}
5060
return 0;
5161
}
5262

5363
// get the maximum vector length in bits
5464
size_t getMaxVectorLength() {
55-
char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH");
56-
return getPositiveIntFromStr(maxVectorLanes, 512);
65+
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
66+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
67+
Builder(ctx).getStringAttr("max_vector_width"));
68+
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength))
69+
{
70+
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
71+
}
72+
return 512;
5773
}
74+
75+
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
76+
77+
private:
78+
DataLayout layout;
79+
MLIRContext *ctx;
5880
};
5981

6082
// The configuration for matmul tiling
6183
// TODO: support batch matmul
6284
struct MatmulConfig {
6385
// The number of threads distributed to M, N, K
6486
uint32_t MThreads, NThreads, KThreads;
65-
// The innermost block size for M, N, K which will be directly converted to
66-
// brgemm.
67-
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
6887
// The outer block size for M, N, K which will be used to decide the loop tile
6988
// size in single thread
7089
uint32_t MBlock, NBlock, KBlock;
90+
// The innermost block size for M, N, K which will be directly converted to
91+
// brgemm.
92+
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
7193
};
7294

7395
enum DimType { Batch, M, N, K };

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp,
8888
size_t dtypeSize = DataLayout().getTypeSizeInBits(
8989
ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType());
9090
size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize;
91+
// TODO: take matrix register like amx into account
9192
double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) %
9293
maxVectorLength * 1.0 / config.innerMostMBlock +
9394
(maxVectorLength - config.innerMostKBlock % maxVectorLength) %
@@ -270,8 +271,8 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
270271
continue;
271272
}
272273
MatmulConfig config{
273-
MBlock, NBlock, KBlock,
274274
MThreads, NThreads, KThreads,
275+
MBlock, NBlock, KBlock,
275276
innerMostMBlock, innerMostNBlock, innerMostKBlock};
276277
configs.push_back(config);
277278
}
@@ -311,13 +312,13 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
311312
} else if (attr.getName() == "MThreads") {
312313
config.MThreads = cast<IntegerAttr>(attr.getValue()).getInt();
313314
cfgItemCnt++;
314-
} else if (attr.getName() == "innerMostMBlock") {
315+
} else if (attr.getName() == "innermostMBlock") {
315316
config.innerMostMBlock = cast<IntegerAttr>(attr.getValue()).getInt();
316317
cfgItemCnt++;
317-
} else if (attr.getName() == "innerMostNBlock") {
318+
} else if (attr.getName() == "innermostNBlock") {
318319
config.innerMostNBlock = cast<IntegerAttr>(attr.getValue()).getInt();
319320
cfgItemCnt++;
320-
} else if (attr.getName() == "innerMostKBlock") {
321+
} else if (attr.getName() == "innermostKBlock") {
321322
config.innerMostKBlock = cast<IntegerAttr>(attr.getValue()).getInt();
322323
cfgItemCnt++;
323324
}
@@ -338,7 +339,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
338339
// previous matmul
339340
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
340341
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
341-
SystemDesc sysDesc;
342+
SystemDesc sysDesc(root->getParentOfType<ModuleOp>());
342343
SmallVector<SmallVector<DimType>> oprandDimType =
343344
*getOprandDimType(linalgOp);
344345
// get the origin M,N,K size

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ static Operation *findParentFillOp(Value val) {
243243
llvm::find(skipOpList, currentOp->getName().getStringRef()) !=
244244
skipOpList.end() &&
245245
!isa<linalg::FillOp>(currentOp)) {
246-
currentOp = currentOp->getResult(0).getDefiningOp();
246+
currentOp = currentOp->getOperand(0).getDefiningOp();
247247
}
248248
if (currentOp && isa<linalg::FillOp>(currentOp)) {
249249
return currentOp;

lib/gc/Transforms/TilingUtil.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
namespace mlir {
1717
namespace linalgX {
1818

19+
// An enahncement for the upstream pass to support tiling reduction for MKmk
20+
// like cases(with multiple reduction iterators).
1921
FailureOr<linalg::ForallReductionTilingResult> tileReductionUsingForall(
2022
RewriterBase &b, PartialReductionOpInterface op,
2123
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,

test/gc/Transform/deepTileContractionNamedOp.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,50 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
108108
return %2 : tensor<4096x4096xbf16>
109109
}
110110

111+
// -----
112+
113+
module attributes {
114+
dlti.target_system_spec = #dlti.target_system_spec<
115+
"CPU": #dlti.target_device_spec<
116+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
117+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
118+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
119+
#dlti.dl_entry<"num_threads", 56 : i32>,
120+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
121+
>} {
122+
/// CHECK-LABEL: @matmul_2Dx4D_bf16_with_dlti
123+
func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> {
124+
%cst_0 = arith.constant 0.000000e+00 : bf16
125+
%0 = tensor.empty() : tensor<4096x4096xbf16>
126+
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
127+
// CHECK: scf.forall
128+
// CHECK: tensor.extract_slice
129+
// CHECK: scf.forall
130+
// CHECK: tensor.extract_slice
131+
// CHECK: scf.forall
132+
// CHECK: tensor.extract_slice
133+
// CHECK: scf.for
134+
// CHECK: tensor.extract_slice
135+
// CHECK: scf.for
136+
// CHECK: scf.for
137+
// CHECK: tensor.extract_slice
138+
// CHECK: tensor.extract_slice
139+
// CHECK: scf.for
140+
// CHECK: tensor.extract_slice
141+
// CHECK: tensor.extract_slice
142+
// CHECK: linalg.transpose
143+
// CHECK: scf.if
144+
// CHECK: linalg.fill
145+
// CHECK: linalgx.batch_reduce_matmul_vnni
146+
// CHECK: else
147+
// CHECK: linalgx.batch_reduce_matmul_vnni
148+
// CHECK: scf.forall.in_parallel
149+
// CHECK: scf.forall.in_parallel
150+
// CHECK: scf.forall.in_parallel
151+
// CHECK: linalg.reduce
152+
// CHECK: linalg.copy
153+
%2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
154+
return %2 : tensor<4096x4096xbf16>
155+
}
156+
157+
}

0 commit comments

Comments
 (0)