Skip to content

Commit 58f363e

Browse files
committed
support simple strctured ops in layout propagation
1 parent a9c4ff3 commit 58f363e

File tree

3 files changed

+167
-16
lines changed

3 files changed

+167
-16
lines changed

include/gc/Analysis/GlobalAnalysis.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2525
#include "mlir/Dialect/Tensor/IR/Tensor.h"
26+
#include "mlir/Pass/Pass.h"
2627
#include "mlir/Support/LLVM.h"
2728
#include "llvm/ADT/DenseMap.h"
2829

@@ -34,7 +35,7 @@ using namespace mlir;
3435
class TensorLayout {
3536
public:
3637
TensorLayout(ArrayRef<int64_t> outerAxis, ArrayRef<int64_t> innerAxis,
37-
ArrayRef<int64_t> tileSizes) {
38+
ArrayRef<OpFoldResult> tileSizes) {
3839
assert(innerAxis.size() == tileSizes.size());
3940
for (auto oa : outerAxis) {
4041
OuterAxis.push_back(oa);
@@ -59,7 +60,7 @@ class TensorLayout {
5960
SmallVector<int64_t> outerAxis(rank, 0);
6061
std::iota(outerAxis.begin(), outerAxis.end(), 0);
6162
return TensorLayout(outerAxis, SmallVector<int64_t>{},
62-
SmallVector<int64_t>{});
63+
SmallVector<OpFoldResult>{});
6364
}
6465

6566
size_t getTensorRank() const { return OuterAxis.size(); }
@@ -68,7 +69,7 @@ class TensorLayout {
6869

6970
SmallVector<int64_t> getInnerAxis() const { return InnerAxis; }
7071

71-
SmallVector<int64_t> getTileSizes() const { return TileSizes; }
72+
SmallVector<OpFoldResult> getTileSizes() const { return TileSizes; }
7273

7374
friend std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout);
7475

@@ -77,7 +78,7 @@ class TensorLayout {
7778
private:
7879
SmallVector<int64_t> OuterAxis;
7980
SmallVector<int64_t> InnerAxis;
80-
SmallVector<int64_t> TileSizes;
81+
SmallVector<OpFoldResult> TileSizes;
8182
};
8283

8384
class OperatorLayout {

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace gc {
1616
std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
1717
SmallVector<int64_t> outerAxis = layout.getOuterAxis();
1818
SmallVector<int64_t> innerAxis = layout.getInnerAxis();
19-
SmallVector<int64_t> tileSizes = layout.getTileSizes();
19+
SmallVector<OpFoldResult> tileSizes = layout.getTileSizes();
2020
ss << "[";
2121
for (size_t i = 0; i < outerAxis.size(); ++i) {
2222
if (i != 0) {
@@ -35,7 +35,9 @@ std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
3535
if (i != 0) {
3636
ss << ", ";
3737
}
38-
ss << tileSizes[i];
38+
if (getConstantIntValue(tileSizes[i]).has_value()) {
39+
ss << *getConstantIntValue(tileSizes[i]);
40+
}
3941
}
4042
ss << "}";
4143
}
@@ -120,10 +122,10 @@ inferTargetLayout(TensorLayout layoutBase,
120122
int64_t dimDifference = indexMap.size() - layoutBase.getTensorRank();
121123
SmallVector<int64_t> baseOuterAxis = layoutBase.getOuterAxis();
122124
SmallVector<int64_t> baseInnerAxis = layoutBase.getInnerAxis();
123-
SmallVector<int64_t> baseTileSizes = layoutBase.getTileSizes();
125+
SmallVector<OpFoldResult> baseTileSizes = layoutBase.getTileSizes();
124126
SmallVector<int64_t> targetOuterAxis;
125127
SmallVector<int64_t> targetInnerAxis;
126-
SmallVector<int64_t> targetTileSizes;
128+
SmallVector<OpFoldResult> targetTileSizes;
127129
DenseMap<int64_t, int64_t> reverseIndexMap =
128130
getReversedIndexMap(indexMap, layoutBase.getTensorRank());
129131
for (auto oa : baseOuterAxis) {
@@ -184,18 +186,28 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
184186
}
185187

186188
// ------ Get Current Op's Suggested Layout & Do Propagation ------
189+
IRRewriter rewriter(linalgOp);
187190
if (mlir::linalg::isaContractionOpInterface(linalgOp)) {
188191
// query the cost model
189192
// OperatorLayout suggestedLayout = costModel->queryLayout(linalgOp,
190193
// curInputLayouts);
191194

192195
// hardcode one for now
193196
// A side layout, [0, 1, 0, 1]; {32, 32}
194-
TensorLayout A_layout({0, 1}, {0, 1}, {32, 32});
197+
TensorLayout A_layout(
198+
{0, 1}, {0, 1},
199+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
200+
rewriter.getIndexAttr(32)});
195201
// B side layout, [1, 0, 0, 1]; {32, 32}
196-
TensorLayout B_layout({1, 0}, {0, 1}, {32, 32});
202+
TensorLayout B_layout(
203+
{1, 0}, {0, 1},
204+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
205+
rewriter.getIndexAttr(32)});
197206
// C side layout, [0, 1, 0, 1]; {32, 32}
198-
TensorLayout C_layout({0, 1}, {0, 1}, {32, 32});
207+
TensorLayout C_layout(
208+
{0, 1}, {0, 1},
209+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
210+
rewriter.getIndexAttr(32)});
199211
OperatorLayout suggestedLayout({A_layout, B_layout}, {C_layout});
200212
layout[linalgOp] = suggestedLayout;
201213
} else {

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1819
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -30,6 +31,130 @@ using namespace mlir;
3031
using namespace mlir::arith;
3132
using namespace mlir::tensor;
3233

34+
static FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
35+
linalg::LinalgOp linalgOp,
36+
OperatorLayout opLayout) {
37+
std::cout << "----------------------------------" << std::endl;
38+
std::cout << " Visiting op in packNamedOp ";
39+
linalgOp->getName().print(llvm::errs());
40+
std::cout << std::endl;
41+
std::cout << "----------------------------------" << std::endl;
42+
Location loc = linalgOp->getLoc();
43+
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
44+
SmallVector<utils::IteratorType> iteratorTypes =
45+
linalgOp.getIteratorTypesArray();
46+
47+
SmallVector<tensor::PackOp> packOps;
48+
SmallVector<tensor::UnPackOp> unPackOps;
49+
SmallVector<Value> inputsAndInits, results;
50+
SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
51+
linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
52+
SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
53+
std::cout << "Num of input operands: " << inputOperands.size() << std::endl;
54+
std::cout << "Num of init operands: " << initOperands.size() << std::endl;
55+
SmallVector<TensorLayout> inputLayouts = opLayout.getSupportedInputLayouts();
56+
SmallVector<TensorLayout> initLayouts = opLayout.getSupportedOutputLayouts();
57+
std::cout << "Num of input layouts: " << inputLayouts.size() << std::endl;
58+
std::cout << "Num of init layouts: " << initLayouts.size() << std::endl;
59+
60+
// check all inputs and inits are tensor, otherwise no need for layout
61+
// propagation
62+
bool allTensor =
63+
llvm::all_of(inputOperands,
64+
[](OpOperand *opOperand) {
65+
return opOperand->get().getType().isa<TensorType>();
66+
}) &&
67+
llvm::all_of(initOperands, [](OpOperand *opOperand) {
68+
return opOperand->get().getType().isa<TensorType>();
69+
});
70+
std::cout << "The op's input is all tensor?" << allTensor << std::endl;
71+
if (!allTensor) {
72+
return failure("the op does not need packing.");
73+
}
74+
for (const auto &operandsList : {inputOperands, initOperands}) {
75+
for (OpOperand *opOperand : operandsList) {
76+
int64_t pos = opOperand->getOperandNumber();
77+
std::cout << "pos: " << pos << std::endl;
78+
Value operand = opOperand->get();
79+
TensorLayout targetLayout = pos >= inputLayouts.size()
80+
? initLayouts[pos - inputLayouts.size()]
81+
: inputLayouts[pos];
82+
SmallVector<int64_t> outerPerm = targetLayout.getOuterAxis();
83+
SmallVector<int64_t> innerPos = targetLayout.getInnerAxis();
84+
SmallVector<OpFoldResult> innerPackSizes = targetLayout.getTileSizes();
85+
86+
std::cout << "Suggested layout: " << targetLayout << std::endl;
87+
88+
std::cout << "Operand shape: ";
89+
for (auto dim :
90+
llvm::cast<RankedTensorType>(operand.getType()).getShape()) {
91+
std::cout << dim << ", ";
92+
}
93+
std::cout << std::endl;
94+
95+
Value dest = tensor::PackOp::createDestinationTensor(
96+
rewriter, loc, operand, innerPackSizes, innerPos, outerPerm);
97+
ShapedType operandType = cast<ShapedType>(operand.getType());
98+
bool areConstantTiles =
99+
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
100+
return getConstantIntValue(tile).has_value();
101+
});
102+
if (areConstantTiles && operandType.hasStaticShape() &&
103+
!tensor::PackOp::requirePaddingValue(
104+
operandType.getShape(), innerPos,
105+
cast<ShapedType>(dest.getType()).getShape(), {},
106+
innerPackSizes)) {
107+
packOps.push_back(rewriter.create<tensor::PackOp>(
108+
loc, operand, dest, innerPos, innerPackSizes, std::nullopt,
109+
outerPerm));
110+
} else {
111+
// TODO: value of the padding attribute should be determined by
112+
// consumers.
113+
auto zeroAttr =
114+
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
115+
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
116+
packOps.push_back(rewriter.create<tensor::PackOp>(
117+
loc, operand, dest, innerPos, innerPackSizes, zero, outerPerm));
118+
}
119+
inputsAndInits.push_back(packOps.back());
120+
}
121+
}
122+
123+
// Step 3. Build the packed op, use the type of `inits` as result types.
124+
ValueRange inputs =
125+
ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
126+
ValueRange inits =
127+
ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
128+
// TODO(yifei): the axis info of reduce/broadcast/transpose may change
129+
auto packedLinalgOp = mlir::clone(
130+
rewriter, linalgOp, SmallVector<Type>{inputsAndInits.back().getType()},
131+
inputsAndInits);
132+
133+
// Step 4. Unpack all the op results.
134+
for (OpResult result : packedLinalgOp->getResults()) {
135+
int64_t resultNum = result.getResultNumber();
136+
tensor::PackOp maybePackedInit =
137+
inits[resultNum].getDefiningOp<tensor::PackOp>();
138+
if (!maybePackedInit) {
139+
results.push_back(result);
140+
continue;
141+
}
142+
// Build the symmetrical UnPackOp to the existing PackOp.
143+
unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
144+
packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
145+
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
146+
results.push_back(unPackOps.back());
147+
}
148+
149+
// Step 5. Replace `linalgOp`.
150+
rewriter.replaceOp(linalgOp, results);
151+
152+
// Return packedLinalgOp.
153+
return linalg::PackResult{
154+
packOps, cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
155+
unPackOps};
156+
}
157+
33158
class PropagateLayout : public impl::PropagateLayoutBase<PropagateLayout> {
34159
public:
35160
using impl::PropagateLayoutBase<PropagateLayout>::PropagateLayoutBase;
@@ -42,24 +167,37 @@ void PropagateLayout::runOnOperation() {
42167
IRRewriter rewriter(ctx);
43168
// walk the entire graph
44169
auto &layoutAnalysisResult = getAnalysis<GlobalAnalysis>();
45-
graph->walk([&](linalg::LinalgOp linalgOp) {
170+
SmallVector<Operation *> packTODOList;
171+
graph->walk([&](Operation *op) {
172+
if (isa<linalg::LinalgOp>(op) && !mlir::linalg::isaContractionOpInterface(
173+
dyn_cast<linalg::LinalgOp>(op))) {
174+
packTODOList.push_back(op);
175+
}
176+
});
177+
for (auto op : packTODOList) {
46178
std::cout << std::endl;
47179
std::cout << "----------------------------------" << std::endl;
48180
std::cout << "Visiting op ";
49-
linalgOp.getOperation()->getName().print(llvm::errs());
181+
op->getName().print(llvm::errs());
50182
std::cout << std::endl;
51183
std::cout << "----------------------------------" << std::endl;
52-
FailureOr<OperatorLayout> opLayout =
53-
layoutAnalysisResult.getOpLayout(linalgOp);
184+
FailureOr<OperatorLayout> opLayout = layoutAnalysisResult.getOpLayout(op);
54185
if (failed(opLayout)) {
55186
std::cout << "infer failed" << std::endl;
56187
} else {
57188
// pack op into ideal layout
58189
std::cout << "-------- supported layouts -------" << std::endl;
59190
std::cout << *opLayout << std::endl;
60191
// insert pack
192+
OpBuilder::InsertionGuard guard(rewriter);
193+
rewriter.setInsertionPoint(op);
194+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
195+
FailureOr<linalg::PackResult> packedOp =
196+
packNamedOp(rewriter, linalgOp, *opLayout);
197+
}
198+
graph->dump();
61199
}
62-
});
200+
}
63201
graph->dump();
64202
}
65203

0 commit comments

Comments
 (0)