Skip to content

Commit 11fd6c3

Browse files
committed
squash commits to simplify rebase
1 parent deae5ee commit 11fd6c3

File tree

3 files changed

+191
-12
lines changed

3 files changed

+191
-12
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() {
55755575
return success();
55765576
}
55775577

5578+
namespace {
5579+
5580+
/// Return true if `transpose` does not permute a pair of non-unit dims.
5581+
/// By `order preserving` we mean that the flattened versions of the input and
5582+
/// output vectors are (numerically) identical. In other words `transpose` is
5583+
/// effectively a shape cast.
5584+
bool isOrderPreserving(TransposeOp transpose) {
5585+
ArrayRef<int64_t> permutation = transpose.getPermutation();
5586+
VectorType sourceType = transpose.getSourceVectorType();
5587+
ArrayRef<int64_t> inShape = sourceType.getShape();
5588+
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5589+
auto isNonScalableUnitDim = [&](int64_t dim) {
5590+
return inShape[dim] == 1 && !inDimIsScalable[dim];
5591+
};
5592+
int64_t current = 0;
5593+
for (auto p : permutation) {
5594+
if (!isNonScalableUnitDim(p)) {
5595+
if (p < current) {
5596+
return false;
5597+
}
5598+
current = p;
5599+
}
5600+
}
5601+
return true;
5602+
}
5603+
5604+
} // namespace
5605+
55785606
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
55795607

55805608
VectorType resultType = getType();
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
55835611
if (getSource().getType() == resultType)
55845612
return getSource();
55855613

5586-
// Y = shape_cast(shape_cast(X)))
5587-
// -> X, if X and Y have same type
5588-
// -> shape_cast(X) otherwise.
5589-
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5590-
VectorType srcType = otherOp.getSource().getType();
5591-
if (resultType == srcType)
5592-
return otherOp.getSource();
5593-
setOperand(otherOp.getSource());
5614+
// shape_cast(shape_cast(x)) -> shape_cast(x)
5615+
if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5616+
setOperand(precedingShapeCast.getSource());
55945617
return getResult();
55955618
}
55965619

5620+
// shape_cast(transpose(x)) -> shape_cast(x)
5621+
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5622+
// This folder does
5623+
// shape_cast(transpose) -> shape_cast
5624+
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5625+
// shape_cast -> shape_cast(transpose)
5626+
// i.e. the complete opposite. When paired, these 2 patterns can cause
5627+
// infinite cycles in pattern rewriting.
5628+
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5629+
// vectors, so by disabling this folder for scalable vectors the
5630+
// cycle is avoided.
5631+
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5632+
// still needed. If it's not, then we can fold here.
5633+
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5634+
setOperand(transpose.getVector());
5635+
return getResult();
5636+
}
5637+
return {};
5638+
}
5639+
55975640
// Y = shape_cast(broadcast(X))
55985641
// -> X, if X and Y have same type
55995642
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
@@ -5619,7 +5662,7 @@ namespace {
56195662
/// Helper function that computes a new vector type based on the input vector
56205663
/// type by removing the trailing one dims:
56215664
///
5622-
/// vector<4x1x1xi1> --> vector<4x1>
5665+
/// vector<4x1x1xi1> --> vector<4x1xi1>
56235666
///
56245667
static VectorType trimTrailingOneDims(VectorType oldType) {
56255668
ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6086,6 +6129,32 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
60866129
}
60876130
};
60886131

6132+
/// Folds transpose(shape_cast) into a new shape_cast.
6133+
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6134+
public:
6135+
using OpRewritePattern::OpRewritePattern;
6136+
6137+
LogicalResult matchAndRewrite(TransposeOp transposeOp,
6138+
PatternRewriter &rewriter) const override {
6139+
auto shapeCastOp =
6140+
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6141+
if (!shapeCastOp)
6142+
return failure();
6143+
if (!isOrderPreserving(transposeOp))
6144+
return failure();
6145+
6146+
VectorType resultType = transposeOp.getType();
6147+
6148+
// We don't need to check isValidShapeCast at this point, because it is
6149+
// guaranteed that merging the transpose into the the shape_cast is a valid
6150+
// shape_cast, because the transpose just inserts/removes ones.
6151+
6152+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6153+
shapeCastOp.getSource());
6154+
return success();
6155+
}
6156+
};
6157+
60896158
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
60906159
/// 'order preserving', where 'order preserving' means the flattened
60916160
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6184,8 +6253,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
61846253

61856254
void vector::TransposeOp::getCanonicalizationPatterns(
61866255
RewritePatternSet &results, MLIRContext *context) {
6187-
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6188-
FoldTransposeBroadcast>(context);
6256+
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6257+
FoldTransposeSplat, FoldTransposeBroadcast>(context);
61896258
}
61906259

61916260
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
88
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
99
return %0 : vector<4x3xi1>
1010
}
11+
1112
// -----
1213

1314
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
@@ -3061,7 +3062,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
30613062
return %1 : vector<4x8xf32>
30623063
}
30633064

3064-
30653065
// -----
30663066

30673067
// CHECK-LABEL: @insert_scalar_poison_idx

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,113 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
137137
return %1 : vector<3x3x3xi8>
138138
}
139139

140+
141+
// -----
142+
143+
// Test of FoldTransposeShapeCast
144+
// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
145+
// 1 -> 0
146+
// 2 -> 4
147+
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
148+
// CHECK-LABEL: @transpose_shape_cast
149+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
150+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
151+
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
152+
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
153+
func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
154+
%0 = vector.transpose %arg, [1, 0, 3, 4, 2]
155+
: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
156+
%1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
157+
return %1 : vector<4x4xi8>
158+
}
159+
160+
// -----
161+
162+
// Test of FoldTransposeShapeCast
163+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
164+
// 1 -> 2
165+
// 2 -> 1
166+
// As this is not increasing (2 > 1), this transpose is not order
167+
// preserving and cannot be treated as a shape_cast.
168+
// CHECK-LABEL: @negative_transpose_shape_cast
169+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
170+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
171+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
172+
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
173+
func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
174+
%0 = vector.transpose %arg, [0, 2, 1, 3]
175+
: vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
176+
%1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
177+
return %1 : vector<4x4xi8>
178+
}
179+
180+
// -----
181+
182+
// Test of FoldTransposeShapeCast
183+
// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
184+
// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
185+
// CHECK-LABEL: @negative_transpose_shape_cast_scalable
186+
// CHECK: vector.transpose
187+
// CHECK: vector.shape_cast
188+
func.func @negative_transpose_shape_cast_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
189+
%0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
190+
%1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
191+
return %1 : vector<[4]xi8>
192+
}
193+
194+
// -----
195+
196+
// Test of shape_cast folding.
197+
// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
198+
// vectors.
199+
// CHECK-LABEL: @shape_cast_transpose_scalable
200+
// CHECK: vector.shape_cast
201+
// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
202+
func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
203+
%0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
204+
%1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
205+
return %1 : vector<[4]x1xi8>
206+
}
207+
208+
// -----
209+
210+
// Test of shape_cast folding.
211+
// A transpose that is 'order preserving' can be treated like a shape_cast.
212+
// CHECK-LABEL: @shape_cast_transpose
213+
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
214+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
215+
// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
216+
// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
217+
func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
218+
%0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
219+
%1 = vector.transpose %0, [0, 2, 1]
220+
: vector<6x1x1xi8> to vector<6x1x1xi8>
221+
return %1 : vector<6x1x1xi8>
222+
}
223+
224+
// -----
225+
226+
// Test of shape_cast folding.
227+
// Scalable dimensions should be treated as non-unit dimensions.
228+
// CHECK-LABEL: @shape_cast_transpose_scalable
229+
// CHECK: vector.shape_cast
230+
// CHECK: vector.transpose
231+
func.func @shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
232+
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
233+
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
234+
return %1 : vector<4x[1]xi8>
235+
}
236+
237+
// -----
238+
239+
// Test of shape_cast (not) folding.
240+
// CHECK-LABEL: @negative_shape_cast_transpose
241+
// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
242+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
243+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
244+
// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
245+
func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
246+
%0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
247+
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
248+
return %1 : vector<2x3xi8>
249+
}

0 commit comments

Comments
 (0)