Skip to content

Commit 6d3ebd8

Browse files
[mlir][affine] Allow memref.cast in isDimOpValidSymbol (#74401)
`isDimOpValidSymbol` is used during the verification of `affine.for` ops. It is used to check if LB/UB values are valid symbols. This change adds support for `memref.cast`, which can be skipped over if it is a ranked -> ranked cast. This change fixes `mlir/test/Transforms/canonicalize.mlir`, which used to fail when verifying the IR after each pattern application (#74270). In this test case, a pattern that folds dynamic offsets/sizes/strides to static ones is applied. This pattern inserts a trivial `memref.cast` that can be folded away. This folding happens after the pattern application, so the IR fails to verify after applying the offsets/sizes/strides canonicalization pattern. Note: The verifier of `affine.for` violates MLIR guidelines. Only local properties of an op should be verified. The verifier should not inspect the defining ops of operands. (This would mean that constraints such as "operand is a valid affine symbol" cannot be verified.)
1 parent 5900014 commit 6d3ebd8

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,19 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
354354
if (!index.has_value())
355355
return false;
356356

357+
// Skip over all memref.cast ops (if any).
358+
Operation *op = dimOp.getShapedValue().getDefiningOp();
359+
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
360+
// Bail on unranked memrefs.
361+
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
362+
return false;
363+
op = castOp.getSource().getDefiningOp();
364+
if (!op)
365+
return false;
366+
}
367+
357368
int64_t i = index.value();
358-
return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
369+
return TypeSwitch<Operation *, bool>(op)
359370
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
360371
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
361372
.Default([](Operation *) { return false; });

0 commit comments

Comments
 (0)