Skip to content

Commit b5a2101

Browse files
[mlir][Tensor] Rework ReifyRankedShapedTypeInterface implementation for tensor.expand_shape op.
The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector<OpFoldResult>`. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e2766b2 commit b5a2101

File tree

9 files changed

+43
-29
lines changed

9 files changed

+43
-29
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11601160
let extraClassDeclaration = commonExtraClassDeclaration # [{
11611161
int64_t getCorrespondingSourceDim(int64_t resultDim);
11621162

1163+
// Return output shape as mixes static/dynamic shapes.
1164+
SmallVector<OpFoldResult> getMixedOutputShape();
1165+
11631166
// Infer the output shape for a tensor.expand_shape when it is possible
11641167
// to do so.
11651168
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
125125
/// Return a vector of OpFoldResults with the same size a staticValues, but
126126
/// all elements for which ShapedType::isDynamic is true, will be replaced by
127127
/// dynamicValues.
128+
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
129+
ValueRange dynamicValues,
130+
MLIRContext *context);
128131
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
129132
ValueRange dynamicValues, Builder &b);
130133

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/Tensor/IR/Tensor.h"
23+
#include "mlir/Interfaces/CastInterfaces.h"
2324
#include "mlir/Interfaces/InferTypeOpInterface.h"
2425
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526

mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
144144
builder, loc, src, dstStaticShape, reassocation);
145145
}
146146

147-
template <typename OpTy>
148-
struct ReifyExpandOrCollapseShapeOp
147+
struct ReifyCollapseShapeOp
149148
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
150-
ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
149+
ReifyCollapseShapeOp, CollapseShapeOp> {
151150
LogicalResult
152151
reifyResultShapes(Operation *op, OpBuilder &b,
153152
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154153
auto loc = op->getLoc();
155-
auto reshapeOp = cast<OpTy>(op);
154+
auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
156155
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
157156
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
158157
reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
162161

163162
namespace {
164163

164+
struct ReifyExpandShapeOp
165+
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
166+
ExpandShapeOp> {
167+
LogicalResult
168+
reifyResultShapes(Operation *op, OpBuilder &b,
169+
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
170+
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
171+
SmallVector<OpFoldResult> resultShapes =
172+
expandShapeOp.getMixedOutputShape();
173+
reifyResultShapes.emplace_back(std::move(resultShapes));
174+
return success();
175+
}
176+
};
177+
165178
struct ReifyPadOp
166179
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
167180
PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
202215
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
203216
DialectRegistry &registry) {
204217
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
205-
ExpandShapeOp::attachInterface<
206-
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207-
CollapseShapeOp::attachInterface<
208-
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
218+
ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
219+
CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209220
PadOp::attachInterface<ReifyPadOp>(*ctx);
210221
});
211222
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
16751675
return *outputShape;
16761676
}
16771677

1678+
SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1679+
return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1680+
}
1681+
16781682
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
16791683
Type resultType, Value src,
16801684
ArrayRef<ReassociationIndices> reassociation,

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
177177
/// elements for which ShapedType::isDynamic is true, will be replaced by
178178
/// dynamicValues.
179179
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
180-
ValueRange dynamicValues, Builder &b) {
180+
ValueRange dynamicValues,
181+
MLIRContext *context) {
181182
SmallVector<OpFoldResult> res;
182183
res.reserve(staticValues.size());
183184
unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
186187
int64_t value = staticValues[idx];
187188
res.push_back(ShapedType::isDynamic(value)
188189
? OpFoldResult{dynamicValues[numDynamic++]}
189-
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
190+
: OpFoldResult{IntegerAttr::get(
191+
IntegerType::get(context, 64), staticValues[idx])});
190192
}
191193
return res;
192194
}
195+
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
196+
ValueRange dynamicValues, Builder &b) {
197+
return getMixedValues(staticValues, dynamicValues, b.getContext());
198+
}
193199

194200
/// Decompose a vector of mixed static or dynamic values into the corresponding
195201
/// pair of arrays. This is the inverse function of `getMixedValues`.

mlir/lib/Interfaces/InferTypeOpInterface.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
4848
assert(shapedType.getRank() ==
4949
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
5050
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
51-
for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
52-
// reifyResultShapes must return:
53-
// * Attribute for static dimensions
54-
// * Value for dynamic dimensions
55-
assert(shapedType.isDynamicDim(dim) ==
56-
reifiedReturnShapes[resultIdx][dim].is<Value>() &&
57-
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
58-
}
5951
++resultIdx;
6052
}
6153
// Assert that every shaped value result was reified.

mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
210210
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
211211
return %1, %2, %3 : index, index, index
212212
}
213-
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
214213
// CHECK: func @dim_reshape_expansion
215214
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
216-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
215+
// CHECK-SAME: %[[ARG1:.+]]: index
217216
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
218217
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
219-
// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
220-
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
221-
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
218+
// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
222219

223220
// -----
224221

mlir/test/Dialect/Tensor/fold-empty-op.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
1010
}
1111
}
1212

13-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
1413
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
1514

1615
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
1918
return %1 : tensor<2x3x5x4x?x7xf32>
2019
}
2120
// CHECK-LABEL: func @empty_reshape_expansion
22-
// CHECK-SAME: %[[ARG0:.+]]: index
23-
// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
24-
// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
25-
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
26-
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
21+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
22+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
23+
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
2724
// CHECK-NEXT: return %[[INIT]]
2825

2926
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

0 commit comments

Comments
 (0)