Skip to content

Commit 5718460

Browse files
authored
[mlir][vector] Relax constraints on shape_cast (#136587)
`vector.shape_cast` was initially designed to be the union of collapse_shape and expand_shape. There was an inconsistency in the verifier that allowed any shape casts when the rank did not change, which led to a strange middle ground where you could cast from shape (4,3) to (3,4) but not from (4,3) to (2,3,2). That issue was fixed (verifier made stricter) in #135855, but further feedback there (and polling) suggests that vector.shape_cast should rather allow all shape casts (so more like tensor.reshape than tensor.collapse_shape/tensor.expand_shape). This PR makes this simplification by relaxing the verifier.
1 parent 526ae7f commit 5718460

File tree

6 files changed

+51
-136
lines changed

6 files changed

+51
-136
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
22442244
Results<(outs AnyVectorOfAnyRank:$result)> {
22452245
let summary = "shape_cast casts between vector shapes";
22462246
let description = [{
2247-
The shape_cast operation casts between an n-D source vector shape and
2248-
a k-D result vector shape (the element type remains the same).
2249-
2250-
If reducing rank (n > k), result dimension sizes must be a product
2251-
of contiguous source dimension sizes.
2252-
If expanding rank (n < k), source dimensions must factor into a
2253-
contiguous sequence of destination dimension sizes.
2254-
Each source dim is expanded (or contiguous sequence of source dims combined)
2255-
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
2256-
sequence of result dims (or a single result dim), in result dimension list
2257-
order (i.e. 0 <= j < k). The product of all source dimension sizes and all
2258-
result dimension sizes must match.
2247+
Casts to a vector with the same number of elements, element type, and
2248+
number of scalable dimensions.
22592249

22602250
It is currently assumed that this operation does not require moving data,
22612251
and that it will be folded away before lowering vector operations.
@@ -2265,15 +2255,13 @@ def Vector_ShapeCastOp :
22652255
2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
22662256
is supported in that particular case, for now.
22672257

2268-
Example:
2258+
Examples:
22692259

22702260
```mlir
2271-
// Example casting to a lower vector rank.
2272-
%1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
2273-
2274-
// Example casting to a higher vector rank.
2275-
%3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
2261+
%1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
22762262

2263+
// with 2 scalable dimensions (number of which must be preserved).
2264+
%3 = vector.shape_cast %2 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
22772265
```
22782266
}];
22792267
let extraClassDeclaration = [{

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 26 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5546,124 +5546,56 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55465546
setResultRanges(getResult(), argRanges.front());
55475547
}
55485548

5549-
/// Returns true if each element of 'a' is equal to the product of a contiguous
5550-
/// sequence of the elements of 'b'. Returns false otherwise.
5551-
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5552-
unsigned rankA = a.size();
5553-
unsigned rankB = b.size();
5554-
assert(rankA < rankB);
5555-
5556-
auto isOne = [](int64_t v) { return v == 1; };
5557-
5558-
// Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5559-
// casted to a 0-d vector.
5560-
if (rankA == 0 && llvm::all_of(b, isOne))
5561-
return true;
5549+
LogicalResult ShapeCastOp::verify() {
55625550

5563-
unsigned i = 0;
5564-
unsigned j = 0;
5565-
while (i < rankA && j < rankB) {
5566-
int64_t dimA = a[i];
5567-
int64_t dimB = 1;
5568-
while (dimB < dimA && j < rankB)
5569-
dimB *= b[j++];
5570-
if (dimA != dimB)
5571-
break;
5572-
++i;
5551+
VectorType sourceType = getSourceVectorType();
5552+
VectorType resultType = getResultVectorType();
55735553

5574-
// Handle the case when trailing dimensions are of size 1.
5575-
// Include them into the contiguous sequence.
5576-
if (i < rankA && llvm::all_of(a.slice(i), isOne))
5577-
i = rankA;
5578-
if (j < rankB && llvm::all_of(b.slice(j), isOne))
5579-
j = rankB;
5580-
}
5554+
// Check that element type is preserved
5555+
if (sourceType.getElementType() != resultType.getElementType())
5556+
return emitOpError("has different source and result element types");
55815557

5582-
return i == rankA && j == rankB;
5583-
}
5584-
5585-
static LogicalResult verifyVectorShapeCast(Operation *op,
5586-
VectorType sourceVectorType,
5587-
VectorType resultVectorType) {
5588-
// Check that element type is the same.
5589-
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5590-
return op->emitOpError("source/result vectors must have same element type");
5591-
auto sourceShape = sourceVectorType.getShape();
5592-
auto resultShape = resultVectorType.getShape();
5593-
5594-
// Check that product of source dim sizes matches product of result dim sizes.
5595-
int64_t sourceDimProduct = std::accumulate(
5596-
sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5597-
int64_t resultDimProduct = std::accumulate(
5598-
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5599-
if (sourceDimProduct != resultDimProduct)
5600-
return op->emitOpError("source/result number of elements must match");
5601-
5602-
// Check that expanding/contracting rank cases.
5603-
unsigned sourceRank = sourceVectorType.getRank();
5604-
unsigned resultRank = resultVectorType.getRank();
5605-
if (sourceRank < resultRank) {
5606-
if (!isValidShapeCast(sourceShape, resultShape))
5607-
return op->emitOpError("invalid shape cast");
5608-
} else if (sourceRank > resultRank) {
5609-
if (!isValidShapeCast(resultShape, sourceShape))
5610-
return op->emitOpError("invalid shape cast");
5558+
// Check that number of elements is preserved
5559+
int64_t sourceNElms = sourceType.getNumElements();
5560+
int64_t resultNElms = resultType.getNumElements();
5561+
if (sourceNElms != resultNElms) {
5562+
return emitOpError() << "has different number of elements at source ("
5563+
<< sourceNElms << ") and result (" << resultNElms
5564+
<< ")";
56115565
}
56125566

56135567
// Check that (non-)scalability is preserved
5614-
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5615-
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5568+
int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5569+
int64_t resultNScalableDims = resultType.getNumScalableDims();
56165570
if (sourceNScalableDims != resultNScalableDims)
5617-
return op->emitOpError("different number of scalable dims at source (")
5618-
<< sourceNScalableDims << ") and result (" << resultNScalableDims
5619-
<< ")";
5620-
sourceVectorType.getNumDynamicDims();
5621-
5622-
return success();
5623-
}
5624-
5625-
LogicalResult ShapeCastOp::verify() {
5626-
auto sourceVectorType =
5627-
llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5628-
auto resultVectorType =
5629-
llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5630-
5631-
// Check if source/result are of vector type.
5632-
if (sourceVectorType && resultVectorType)
5633-
return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5571+
return emitOpError() << "has different number of scalable dims at source ("
5572+
<< sourceNScalableDims << ") and result ("
5573+
<< resultNScalableDims << ")";
56345574

56355575
return success();
56365576
}
56375577

56385578
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56395579

5580+
VectorType resultType = getType();
5581+
56405582
// No-op shape cast.
5641-
if (getSource().getType() == getType())
5583+
if (getSource().getType() == resultType)
56425584
return getSource();
56435585

5644-
VectorType resultType = getType();
5645-
5646-
// Canceling shape casts.
5586+
// Y = shape_cast(shape_cast(X)))
5587+
// -> X, if X and Y have same type
5588+
// -> shape_cast(X) otherwise.
56475589
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5648-
5649-
// Only allows valid transitive folding (expand/collapse dimensions).
56505590
VectorType srcType = otherOp.getSource().getType();
56515591
if (resultType == srcType)
56525592
return otherOp.getSource();
5653-
if (srcType.getRank() < resultType.getRank()) {
5654-
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5655-
return {};
5656-
} else if (srcType.getRank() > resultType.getRank()) {
5657-
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5658-
return {};
5659-
} else {
5660-
return {};
5661-
}
56625593
setOperand(otherOp.getSource());
56635594
return getResult();
56645595
}
56655596

5666-
// Cancelling broadcast and shape cast ops.
5597+
// Y = shape_cast(broadcast(X))
5598+
// -> X, if X and Y have same type
56675599
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
56685600
if (bcastOp.getSourceType() == resultType)
56695601
return bcastOp.getSource();

mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
2626
// CHECK-NEXT: vector.insert {{.*}}[1]
2727
// CHECK-NEXT: vector.insert {{.*}}[2]
2828
// CHECK-NEXT: vector.insert {{.*}}[3]
29-
// CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
30-
// CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
29+
// CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32>
3130
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
3231
return %0 : vector<8x4xf32>
3332
}
@@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32>
5453
// CHECK-NEXT: vector.insert {{.*}}[1]
5554
// CHECK-NEXT: vector.insert {{.*}}[2]
5655
// CHECK-NEXT: vector.insert {{.*}}[3]
57-
// CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
58-
// CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
56+
// CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32>
5957
%0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
6058
return %0 : vector<1x8x4xf32>
6159
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -977,10 +977,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
977977

978978
// -----
979979

980-
// CHECK-LABEL: dont_fold_expand_collapse
981-
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
982-
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
983-
// CHECK: return %[[B]] : vector<8x8xf32>
980+
// CHECK-LABEL: fold_expand_collapse
981+
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
982+
// CHECK: return %[[A]] : vector<8x8xf32>
984983
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
985984
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
986985
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,34 +1165,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
11651165

11661166
// -----
11671167

1168+
11681169
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
1169-
// expected-error@+1 {{op source/result vectors must have same element type}}
1170+
// expected-error@+1 {{'vector.shape_cast' op has different source and result element types}}
11701171
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
11711172
}
11721173

11731174
// -----
11741175

11751176
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
1176-
// expected-error@+1 {{op source/result number of elements must match}}
1177+
// expected-error@+1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}}
11771178
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
11781179
}
11791180

11801181
// -----
11811182

1182-
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
1183-
// expected-error@+1 {{invalid shape cast}}
1184-
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
1185-
}
1186-
1187-
// -----
1188-
1189-
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
1190-
// expected-error@+1 {{invalid shape cast}}
1191-
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
1192-
}
1193-
1194-
// -----
1195-
11961183
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
11971184
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}
11981185
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,17 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
564564
return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
565565
}
566566

567+
// A vector.shape_cast can cast between any 2 shapes as long as the
568+
// number of elements is preserved. For those familiar with the tensor
569+
// dialect: this behaviour is like the tensor.reshape operation, i.e.
570+
// less restrictive than tensor.collapse_shape and tensor.expand_shape
571+
// CHECK-LABEL: @shape_cast_general_reshape
572+
func.func @shape_cast_general_reshape(%arg0 : vector<2x3xf32>) -> (vector<3x1x2xf32>) {
573+
// CHECK: vector.shape_cast %{{.*}} : vector<2x3xf32> to vector<3x1x2xf32>
574+
%0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<3x1x2xf32>
575+
return %0 : vector<3x1x2xf32>
576+
}
577+
567578
// CHECK-LABEL: @shape_cast_0d
568579
func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
569580

0 commit comments

Comments
 (0)