13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/Func/IR/FuncOps.h"
15
15
#include " mlir/Dialect/Linalg/IR/Linalg.h"
16
+ #include " mlir/Dialect/Linalg/Transforms/Transforms.h"
16
17
#include " mlir/Dialect/Tensor/IR/Tensor.h"
17
18
#include " mlir/Dialect/Utils/StaticValueUtils.h"
18
19
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -30,6 +31,130 @@ using namespace mlir;
30
31
using namespace mlir ::arith;
31
32
using namespace mlir ::tensor;
32
33
34
+ static FailureOr<linalg::PackResult> packNamedOp (RewriterBase &rewriter,
35
+ linalg::LinalgOp linalgOp,
36
+ OperatorLayout opLayout) {
37
+ std::cout << " ----------------------------------" << std::endl;
38
+ std::cout << " Visiting op in packNamedOp " ;
39
+ linalgOp->getName ().print (llvm::errs ());
40
+ std::cout << std::endl;
41
+ std::cout << " ----------------------------------" << std::endl;
42
+ Location loc = linalgOp->getLoc ();
43
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
44
+ SmallVector<utils::IteratorType> iteratorTypes =
45
+ linalgOp.getIteratorTypesArray ();
46
+
47
+ SmallVector<tensor::PackOp> packOps;
48
+ SmallVector<tensor::UnPackOp> unPackOps;
49
+ SmallVector<Value> inputsAndInits, results;
50
+ SmallVector<OpOperand *> initOperands = llvm::to_vector (llvm::map_range (
51
+ linalgOp.getDpsInitsMutable (), [](OpOperand &o) { return &o; }));
52
+ SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands ();
53
+ std::cout << " Num of input operands: " << inputOperands.size () << std::endl;
54
+ std::cout << " Num of init operands: " << initOperands.size () << std::endl;
55
+ SmallVector<TensorLayout> inputLayouts = opLayout.getSupportedInputLayouts ();
56
+ SmallVector<TensorLayout> initLayouts = opLayout.getSupportedOutputLayouts ();
57
+ std::cout << " Num of input layouts: " << inputLayouts.size () << std::endl;
58
+ std::cout << " Num of init layouts: " << initLayouts.size () << std::endl;
59
+
60
+ // check all inputs and inits are tensor, otherwise no need for layout
61
+ // propagation
62
+ bool allTensor =
63
+ llvm::all_of (inputOperands,
64
+ [](OpOperand *opOperand) {
65
+ return opOperand->get ().getType ().isa <TensorType>();
66
+ }) &&
67
+ llvm::all_of (initOperands, [](OpOperand *opOperand) {
68
+ return opOperand->get ().getType ().isa <TensorType>();
69
+ });
70
+ std::cout << " The op's input is all tensor?" << allTensor << std::endl;
71
+ if (!allTensor) {
72
+ return failure (" the op does not need packing." );
73
+ }
74
+ for (const auto &operandsList : {inputOperands, initOperands}) {
75
+ for (OpOperand *opOperand : operandsList) {
76
+ int64_t pos = opOperand->getOperandNumber ();
77
+ std::cout << " pos: " << pos << std::endl;
78
+ Value operand = opOperand->get ();
79
+ TensorLayout targetLayout = pos >= inputLayouts.size ()
80
+ ? initLayouts[pos - inputLayouts.size ()]
81
+ : inputLayouts[pos];
82
+ SmallVector<int64_t > outerPerm = targetLayout.getOuterAxis ();
83
+ SmallVector<int64_t > innerPos = targetLayout.getInnerAxis ();
84
+ SmallVector<OpFoldResult> innerPackSizes = targetLayout.getTileSizes ();
85
+
86
+ std::cout << " Suggested layout: " << targetLayout << std::endl;
87
+
88
+ std::cout << " Operand shape: " ;
89
+ for (auto dim :
90
+ llvm::cast<RankedTensorType>(operand.getType ()).getShape ()) {
91
+ std::cout << dim << " , " ;
92
+ }
93
+ std::cout << std::endl;
94
+
95
+ Value dest = tensor::PackOp::createDestinationTensor (
96
+ rewriter, loc, operand, innerPackSizes, innerPos, outerPerm);
97
+ ShapedType operandType = cast<ShapedType>(operand.getType ());
98
+ bool areConstantTiles =
99
+ llvm::all_of (innerPackSizes, [](OpFoldResult tile) {
100
+ return getConstantIntValue (tile).has_value ();
101
+ });
102
+ if (areConstantTiles && operandType.hasStaticShape () &&
103
+ !tensor::PackOp::requirePaddingValue (
104
+ operandType.getShape (), innerPos,
105
+ cast<ShapedType>(dest.getType ()).getShape (), {},
106
+ innerPackSizes)) {
107
+ packOps.push_back (rewriter.create <tensor::PackOp>(
108
+ loc, operand, dest, innerPos, innerPackSizes, std::nullopt,
109
+ outerPerm));
110
+ } else {
111
+ // TODO: value of the padding attribute should be determined by
112
+ // consumers.
113
+ auto zeroAttr =
114
+ rewriter.getZeroAttr (getElementTypeOrSelf (dest.getType ()));
115
+ Value zero = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
116
+ packOps.push_back (rewriter.create <tensor::PackOp>(
117
+ loc, operand, dest, innerPos, innerPackSizes, zero, outerPerm));
118
+ }
119
+ inputsAndInits.push_back (packOps.back ());
120
+ }
121
+ }
122
+
123
+ // Step 3. Build the packed op, use the type of `inits` as result types.
124
+ ValueRange inputs =
125
+ ValueRange{inputsAndInits}.take_front (linalgOp.getNumDpsInputs ());
126
+ ValueRange inits =
127
+ ValueRange{inputsAndInits}.take_back (linalgOp.getNumDpsInits ());
128
+ // TODO(yifei): the axis info of reduce/broadcast/transpose may change
129
+ auto packedLinalgOp = mlir::clone (
130
+ rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back ().getType ()},
131
+ inputsAndInits);
132
+
133
+ // Step 4. Unpack all the op results.
134
+ for (OpResult result : packedLinalgOp->getResults ()) {
135
+ int64_t resultNum = result.getResultNumber ();
136
+ tensor::PackOp maybePackedInit =
137
+ inits[resultNum].getDefiningOp <tensor::PackOp>();
138
+ if (!maybePackedInit) {
139
+ results.push_back (result);
140
+ continue ;
141
+ }
142
+ // Build the symmetrical UnPackOp to the existing PackOp.
143
+ unPackOps.push_back (rewriter.create <tensor::UnPackOp>(
144
+ packedLinalgOp->getLoc (), result, maybePackedInit.getSource (),
145
+ maybePackedInit.getInnerDimsPos (), maybePackedInit.getMixedTiles ()));
146
+ results.push_back (unPackOps.back ());
147
+ }
148
+
149
+ // Step 5. Replace `linalgOp`.
150
+ rewriter.replaceOp (linalgOp, results);
151
+
152
+ // Return packedLinalgOp.
153
+ return linalg::PackResult{
154
+ packOps, cast<linalg::LinalgOp>(packedLinalgOp.getOperation ()),
155
+ unPackOps};
156
+ }
157
+
33
158
class PropagateLayout : public impl ::PropagateLayoutBase<PropagateLayout> {
34
159
public:
35
160
using impl::PropagateLayoutBase<PropagateLayout>::PropagateLayoutBase;
@@ -42,24 +167,37 @@ void PropagateLayout::runOnOperation() {
42
167
IRRewriter rewriter (ctx);
43
168
// walk the entire graph
44
169
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
45
- graph->walk ([&](linalg::LinalgOp linalgOp) {
170
+ SmallVector<Operation *> packTODOList;
171
+ graph->walk ([&](Operation *op) {
172
+ if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface (
173
+ dyn_cast<linalg::LinalgOp>(op))) {
174
+ packTODOList.push_back (op);
175
+ }
176
+ });
177
+ for (auto op : packTODOList) {
46
178
std::cout << std::endl;
47
179
std::cout << " ----------------------------------" << std::endl;
48
180
std::cout << " Visiting op " ;
49
- linalgOp. getOperation () ->getName ().print (llvm::errs ());
181
+ op ->getName ().print (llvm::errs ());
50
182
std::cout << std::endl;
51
183
std::cout << " ----------------------------------" << std::endl;
52
- FailureOr<OperatorLayout> opLayout =
53
- layoutAnalysisResult.getOpLayout (linalgOp);
184
+ FailureOr<OperatorLayout> opLayout = layoutAnalysisResult.getOpLayout (op);
54
185
if (failed (opLayout)) {
55
186
std::cout << " infer failed" << std::endl;
56
187
} else {
57
188
// pack op into ideal layout
58
189
std::cout << " -------- supported layouts -------" << std::endl;
59
190
std::cout << *opLayout << std::endl;
60
191
// insert pack
192
+ OpBuilder::InsertionGuard guard (rewriter);
193
+ rewriter.setInsertionPoint (op);
194
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
195
+ FailureOr<linalg::PackResult> packedOp =
196
+ packNamedOp (rewriter, linalgOp, *opLayout);
197
+ }
198
+ graph->dump ();
61
199
}
62
- });
200
+ }
63
201
graph->dump ();
64
202
}
65
203
0 commit comments