Skip to content

Commit 9af3f96

Browse files
committed
replace generic op with named op
1 parent 3390b61 commit 9af3f96

File tree

2 files changed

+78
-49
lines changed

2 files changed

+78
-49
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "./Tiling.hpp"
1010
#include "gc/Dialect/Arith/Utils/EasyBuild.h"
11+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1112
#include "gc/IR/EasyBuild.h"
1213
#include "gc/IR/EasyBuildSCF.h"
1314
#include "mlir/AsmParser/AsmParser.h"
@@ -68,24 +69,23 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
6869
SmallVector<DimType>{DimType::M, DimType::K},
6970
SmallVector<DimType>{DimType::K, DimType::N},
7071
SmallVector<DimType>{DimType::M, DimType::N}};
71-
} else if (isa<linalg::GenericOp>(linalgOp)) {
72-
auto iteratorTypes = linalgOp.getIteratorTypesArray();
73-
if (iteratorTypes.size() == 7UL) {
74-
// 4Dx5D, brgemm vnni
75-
return SmallVector<SmallVector<DimType>>{
76-
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
77-
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
78-
DimType::K},
79-
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
80-
} else if (iteratorTypes.size() == 6UL) {
81-
// 4Dx4D
82-
return SmallVector<SmallVector<DimType>>{
83-
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
84-
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N},
85-
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
86-
}
87-
} else {
88-
return failure();
72+
} else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
73+
return SmallVector<SmallVector<DimType>>{
74+
SmallVector<DimType>{DimType::M, DimType::K},
75+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
76+
DimType::K},
77+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
78+
} else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
79+
return SmallVector<SmallVector<DimType>>{
80+
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
81+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
82+
DimType::K},
83+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
84+
} else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
85+
return SmallVector<SmallVector<DimType>>{
86+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
87+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
88+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
8989
}
9090
return failure();
9191
}
@@ -136,7 +136,7 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
136136
cfg.KBlock = 64;
137137
cfg.MThreads = 2;
138138
cfg.NThreads = 2;
139-
cfg.KThreads = 1;
139+
cfg.KThreads = 2;
140140
return cfg;
141141
}
142142

@@ -784,8 +784,9 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
784784
ValueRange{dataOprand, weightOprand}, resultOprand);
785785
} else {
786786
IRMapping mapping;
787-
matmul = dyn_cast<linalg::LinalgOp>(
788-
*rewriter.clone(*(currentOp.getOperation())));
787+
matmul = rewriter.create<linalgx::BatchReduceMatmulVnniOp>(
788+
resultOprand.getLoc(), resultOprand.getType(),
789+
ValueRange{dataOprand, weightOprand}, resultOprand);
789790
}
790791
Value result = matmul.getOperation()->getResult(0);
791792

@@ -830,18 +831,32 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
830831
return success();
831832
}
832833

834+
bool checkLinalgMatmulType(linalg::LinalgOp linalgOp) const {
835+
return llvm::isa<linalg::MatmulOp>(linalgOp) ||
836+
llvm::isa<linalgx::Mm2DVnniOp>(linalgOp) ||
837+
llvm::isa<linalgx::Mm4DVnniOp>(linalgOp) ||
838+
llvm::isa<linalgx::MultiBatchMatmulOp>(linalgOp) ||
839+
llvm::isa<linalg::BatchMatmulOp>(linalgOp);
840+
}
841+
833842
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
834843
PatternRewriter &rewriter) const override {
844+
if (!checkLinalgMatmulType(linalgOp))
845+
return failure();
835846
if (linalgOp.hasPureBufferSemantics())
836847
return failure();
837-
OpBuilder::InsertionGuard guard(rewriter);
838-
rewriter.setInsertionPoint(linalgOp);
848+
839849
if (linalgOp.getOperation()->getParentOfType<scf::ForallOp>() ||
840850
!linalgOp || linalgOp.getNumDpsInputs() != 2)
841851
return failure();
842-
Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]);
852+
853+
OpBuilder::InsertionGuard guard(rewriter);
854+
rewriter.setInsertionPoint(linalgOp);
843855
linalg::LinalgOp originOp =
844856
dyn_cast<linalg::LinalgOp>(*rewriter.clone(*(linalgOp.getOperation())));
857+
linalgOp = *linalg::generalizeNamedOp(rewriter, linalgOp);
858+
Operation *fillOp = findParentFillOp(linalgOp.getDpsInits()[0]);
859+
845860
// Step 1. generate the outer loop
846861
MatmulConfig cfg = getDefaultMatmulConfig(linalgOp);
847862
auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg,
Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s
22

3-
// -----
3+
// // -----
44

5-
/// CHECK-LABEL: @blocked_matmul_f32
6-
func.func @blocked_matmul_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7-
%cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32xf32>
8-
%cst_0 = arith.constant 0.000000e+00 : f32
9-
%0 = tensor.empty() : tensor<128x128x32x32xf32>
10-
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11-
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32xf32>) outs(%1 : tensor<128x128x32x32xf32>) {
12-
^bb0(%in: f32, %in_1: f32, %out: f32):
13-
%3 = arith.mulf %in, %in_1 : f32
14-
%4 = arith.addf %out, %3 : f32
15-
linalg.yield %4 : f32
16-
} -> tensor<128x128x32x32xf32>
17-
return %2 : tensor<128x128x32x32xf32>
18-
}
5+
// /// CHECK-LABEL: @matmul_4Dx4D_f32
6+
// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7+
// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
8+
// %cst_0 = arith.constant 0.000000e+00 : f32
9+
// %0 = tensor.empty() : tensor<128x128x32x32xf32>
10+
// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11+
// %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
12+
// return %2 : tensor<128x128x32x32xf32>
13+
// }
1914

2015
// -----
2116

22-
/// CHECK-LABEL: @plain_matmul_f32
23-
func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
17+
/// CHECK-LABEL: @matmul_2Dx2D_f32
18+
func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
2419
%cst = arith.constant dense<1.000000e+00> : tensor<4096x4096xf32>
2520
%cst_0 = arith.constant 0.000000e+00 : f32
2621
%0 = tensor.empty() : tensor<4096x4096xf32>
@@ -29,20 +24,39 @@ func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf3
2924
return %2 : tensor<4096x4096xf32>
3025
}
3126

27+
// // -----
28+
29+
// /// CHECK-LABEL: @matmul_2Dx4D_f32
30+
// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
31+
// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
32+
// %cst_0 = arith.constant 0.000000e+00 : f32
33+
// %0 = tensor.empty() : tensor<4096x4096xf32>
34+
// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
35+
// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
36+
// return %2 : tensor<4096x4096xf32>
37+
// }
38+
3239
// -----
3340

34-
/// CHECK-LABEL: @blocked_matmul_bf16
35-
func.func @blocked_matmul_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> {
41+
/// CHECK-LABEL: @matmul_4Dx4D_bf16
42+
func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16> {
3643
%cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
3744
%cst_0 = arith.constant 0.000000e+00 : bf16
3845
%0 = tensor.empty() : tensor<128x128x32x32xbf16>
3946
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
40-
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) {
41-
^bb0(%in: bf16, %in_1: bf16, %out: bf16):
42-
%3 = arith.mulf %in, %in_1 : bf16
43-
%4 = arith.addf %out, %3 : bf16
44-
linalg.yield %4 : bf16
45-
} -> tensor<128x128x32x32xbf16>
47+
%2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
4648
return %2 : tensor<128x128x32x32xbf16>
4749
}
4850

51+
// // -----
52+
53+
// /// CHECK-LABEL: @matmul_2Dx4D_bf16
54+
// func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> {
55+
// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
56+
// %cst_0 = arith.constant 0.000000e+00 : bf16
57+
// %0 = tensor.empty() : tensor<4096x4096xbf16>
58+
// %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
59+
// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
60+
// return %2 : tensor<4096x4096xbf16>
61+
// }
62+

0 commit comments

Comments
 (0)