Skip to content

Commit 092372d

Browse files
[mlir][Tensor] Rework ReifyRankedShapedTypeInterface implementation for tensor.expand_shape op. (#113501)
The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector<OpFoldResult>`. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 27c9173 commit 092372d

File tree

8 files changed

+43
-116
lines changed

8 files changed

+43
-116
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11651165
let extraClassDeclaration = commonExtraClassDeclaration # [{
11661166
int64_t getCorrespondingSourceDim(int64_t resultDim);
11671167

1168+
// Return output shape as mixes static/dynamic shapes.
1169+
SmallVector<OpFoldResult> getMixedOutputShape();
1170+
11681171
// Infer the output shape for a tensor.expand_shape when it is possible
11691172
// to do so.
11701173
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
144144
/// Return a vector of OpFoldResults with the same size a staticValues, but
145145
/// all elements for which ShapedType::isDynamic is true, will be replaced by
146146
/// dynamicValues.
147+
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
148+
ValueRange dynamicValues,
149+
MLIRContext *context);
147150
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
148151
ValueRange dynamicValues, Builder &b);
149152

mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp

Lines changed: 20 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,6 @@
1616
using namespace mlir;
1717
using namespace mlir::tensor;
1818

19-
/// Compute a map that for a given dimension of the expanded type gives the
20-
/// dimension in the collapsed type it maps to. Essentially its the inverse of
21-
/// the `reassocation` maps.
22-
static llvm::DenseMap<int64_t, int64_t>
23-
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
24-
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
25-
for (const auto &map : enumerate(reassociation)) {
26-
unsigned startPos =
27-
cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
28-
unsigned endPos =
29-
cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
30-
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
31-
expandedDimToCollapsedDim[dim] = map.index();
32-
}
33-
}
34-
return expandedDimToCollapsedDim;
35-
}
36-
3719
/// For reshape op compute the shape at dimension `dimIndex` of the output in
3820
/// terms of shape of the `src`, when the reshape op is a collapsing
3921
/// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,84 +58,15 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
7658
}));
7759
}
7860

79-
/// For an expanding reshape op, compute the value for a dimension of the output
80-
/// from the shape of the input.
81-
static OpFoldResult getExpandedOutputDimFromInputShape(
82-
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83-
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
84-
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
85-
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
86-
// Static dimension: return Attribute.
87-
return builder.getIndexAttr(dstStaticShape[dimIndex]);
88-
}
89-
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90-
unsigned startPos =
91-
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
92-
.getPosition();
93-
unsigned endPos =
94-
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
95-
.getPosition();
96-
int64_t linearizedStaticDim = 1;
97-
for (auto d :
98-
llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
99-
if (d.index() + startPos == static_cast<unsigned>(dimIndex))
100-
continue;
101-
assert(!ShapedType::isDynamic(d.value()) &&
102-
"single dimension cannot be expanded into multiple dynamic "
103-
"dimensions");
104-
linearizedStaticDim *= d.value();
105-
}
106-
OpFoldResult sourceDim =
107-
builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
108-
109-
// Dynamic dimension: return Value.
110-
return affine::makeComposedAffineApply(
111-
builder, loc,
112-
AffineMap::get(
113-
0, 1,
114-
builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
115-
sourceDim)
116-
->getResult(0);
117-
}
118-
119-
/// Given the `src` of an expanding reshape op, the reassociation maps and the
120-
/// result type, compute the shape of the result of the reshape.
121-
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
122-
OpBuilder &builder, Location loc, Value src,
123-
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
124-
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
125-
getExpandedDimToCollapsedDimMap(reassociation);
126-
return llvm::to_vector<4>(llvm::map_range(
127-
llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
128-
return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
129-
dstStaticShape, reassociation,
130-
expandedDimToCollapsedDim);
131-
}));
132-
}
133-
134-
static SmallVector<OpFoldResult, 4>
135-
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
136-
ArrayRef<int64_t> dstStaticShape,
137-
ArrayRef<AffineMap> reassocation) {
138-
return dstStaticShape.size() >
139-
static_cast<size_t>(
140-
llvm::cast<ShapedType>(src.getType()).getRank())
141-
? getExpandedOutputShapeFromInputShape(
142-
builder, loc, src, dstStaticShape, reassocation)
143-
: getCollapsedOutputShapeFromInputShape(
144-
builder, loc, src, dstStaticShape, reassocation);
145-
}
146-
147-
template <typename OpTy>
148-
struct ReifyExpandOrCollapseShapeOp
61+
struct ReifyCollapseShapeOp
14962
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
150-
ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
63+
ReifyCollapseShapeOp, CollapseShapeOp> {
15164
LogicalResult
15265
reifyResultShapes(Operation *op, OpBuilder &b,
15366
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
15467
auto loc = op->getLoc();
155-
auto reshapeOp = cast<OpTy>(op);
156-
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
68+
auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
69+
reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
15770
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
15871
reshapeOp.getReassociationMaps()));
15972
return success();
@@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp
16275

16376
namespace {
16477

78+
struct ReifyExpandShapeOp
79+
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
80+
ExpandShapeOp> {
81+
LogicalResult
82+
reifyResultShapes(Operation *op, OpBuilder &b,
83+
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
84+
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
85+
SmallVector<OpFoldResult> resultShapes =
86+
expandShapeOp.getMixedOutputShape();
87+
reifyResultShapes.emplace_back(std::move(resultShapes));
88+
return success();
89+
}
90+
};
91+
16592
struct ReifyPadOp
16693
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
16794
PadOp> {
@@ -202,10 +129,8 @@ struct ReifyPadOp
202129
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
203130
DialectRegistry &registry) {
204131
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
205-
ExpandShapeOp::attachInterface<
206-
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207-
CollapseShapeOp::attachInterface<
208-
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
132+
ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
133+
CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209134
PadOp::attachInterface<ReifyPadOp>(*ctx);
210135
});
211136
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
17321732
return *outputShape;
17331733
}
17341734

1735+
SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1736+
return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1737+
}
1738+
17351739
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
17361740
Type resultType, Value src,
17371741
ArrayRef<ReassociationIndices> reassociation,

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
191191
/// elements for which ShapedType::isDynamic is true, will be replaced by
192192
/// dynamicValues.
193193
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
194-
ValueRange dynamicValues, Builder &b) {
194+
ValueRange dynamicValues,
195+
MLIRContext *context) {
195196
SmallVector<OpFoldResult> res;
196197
res.reserve(staticValues.size());
197198
unsigned numDynamic = 0;
@@ -200,10 +201,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
200201
int64_t value = staticValues[idx];
201202
res.push_back(ShapedType::isDynamic(value)
202203
? OpFoldResult{dynamicValues[numDynamic++]}
203-
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
204+
: OpFoldResult{IntegerAttr::get(
205+
IntegerType::get(context, 64), staticValues[idx])});
204206
}
205207
return res;
206208
}
209+
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
210+
ValueRange dynamicValues, Builder &b) {
211+
return getMixedValues(staticValues, dynamicValues, b.getContext());
212+
}
207213

208214
/// Decompose a vector of mixed static or dynamic values into the corresponding
209215
/// pair of arrays. This is the inverse function of `getMixedValues`.

mlir/lib/Interfaces/InferTypeOpInterface.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
4848
assert(shapedType.getRank() ==
4949
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
5050
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
51-
for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
52-
// reifyResultShapes must return:
53-
// * Attribute for static dimensions
54-
// * Value for dynamic dimensions
55-
assert(shapedType.isDynamicDim(dim) ==
56-
isa<Value>(reifiedReturnShapes[resultIdx][dim]) &&
57-
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
58-
}
5951
++resultIdx;
6052
}
6153
// Assert that every shaped value result was reified.

mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
210210
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
211211
return %1, %2, %3 : index, index, index
212212
}
213-
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
214213
// CHECK: func @dim_reshape_expansion
215214
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
216-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
215+
// CHECK-SAME: %[[ARG1:.+]]: index
217216
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
218217
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
219-
// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
220-
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
221-
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
218+
// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
222219

223220
// -----
224221

mlir/test/Dialect/Tensor/fold-empty-op.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
1010
}
1111
}
1212

13-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
1413
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
1514

1615
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
1918
return %1 : tensor<2x3x5x4x?x7xf32>
2019
}
2120
// CHECK-LABEL: func @empty_reshape_expansion
22-
// CHECK-SAME: %[[ARG0:.+]]: index
23-
// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
24-
// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
25-
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
26-
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
21+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
22+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
23+
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
2724
// CHECK-NEXT: return %[[INIT]]
2825

2926
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

0 commit comments

Comments
 (0)