Skip to content

Commit 469b9cb

Browse files
authored
[mlir][VectorOps] Don't fold extract chains that include dynamic indices (#68333)
This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.
1 parent 4cb6c1c commit 469b9cb

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
12441244
}
12451245

12461246
LogicalResult vector::ExtractOp::verify() {
1247+
// Note: This check must come before getMixedPosition() to prevent a crash.
1248+
auto dynamicMarkersCount =
1249+
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1250+
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1251+
return emitOpError(
1252+
"mismatch between dynamic and static positions (kDynamic marker but no "
1253+
"corresponding dynamic position) -- this can only happen due to an "
1254+
"incorrect fold/rewrite");
12471255
auto position = getMixedPosition();
12481256
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
12491257
return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
12851293
globalPosition.append(extrPos.rbegin(), extrPos.rend());
12861294
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
12871295
currentOp = nextOp;
1296+
// TODO: Canonicalization for dynamic position not implemented yet.
1297+
if (currentOp.hasDynamicPosition())
1298+
return failure();
12881299
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
12891300
globalPosition.append(extrPos.rbegin(), extrPos.rend());
12901301
}

mlir/test/Dialect/Vector/canonicalize.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
16931693

16941694
// -----
16951695

1696+
// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
1697+
// CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
1698+
// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
1699+
// CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
1700+
func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
1701+
%0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
1702+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
1703+
return %1 : f32
1704+
}
1705+
1706+
// -----
1707+
16961708
// CHECK-LABEL: extract_extract_strided2
16971709
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
16981710
// CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>

0 commit comments

Comments
 (0)