Skip to content

Commit 051612c

Browse files
authored
[mlir][ValueBounds] memref.dim and tensor.dim are always positive (#122804)
Add the constraint that the length of a memref or tensor dimension is always non-negative (at least 0) even if we don't know which dimension we're querying the length of.
1 parent 8ce81f1 commit 051612c

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct DimOpInterface
5151
auto dimOp = cast<DimOp>(op);
5252
assert(value == dimOp.getResult() && "invalid value");
5353

54+
cstr.bound(value) >= 0;
5455
auto constIndex = dimOp.getConstantIndex();
5556
if (!constIndex.has_value())
5657
return;

mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct DimOpInterface
3838
auto dimOp = cast<DimOp>(op);
3939
assert(value == dimOp.getResult() && "invalid value");
4040

41+
cstr.bound(value) >= 0;
4142
auto constIndex = dimOp.getConstantIndex();
4243
if (!constIndex.has_value())
4344
return;

mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ func.func @memref_dim(%m: memref<?xf32>) -> index {
5252

5353
// -----
5454

55+
// CHECK-LABEL: func @memref_dim_all_positive(
56+
func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
57+
%c0 = arith.constant 0 : index
58+
%0 = memref.dim %m, %x : memref<?xf32>
59+
// expected-remark @below{{true}}
60+
"test.compare"(%0, %c0) {cmp = "GE"} : (index, index) -> ()
61+
return
62+
}
63+
64+
// -----
65+
5566
// CHECK-LABEL: func @memref_get_global(
5667
// CHECK: %[[c4:.*]] = arith.constant 4 : index
5768
// CHECK: return %[[c4]]

mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ func.func @dim(%t: tensor<?xf32>) -> index {
4444

4545
// -----
4646

47+
// CHECK-LABEL: func @dim_all_positive(
48+
func.func @dim_all_positive(%t: tensor<?xf32>, %x: index) {
49+
%c0 = arith.constant 0 : index
50+
%0 = tensor.dim %t, %x : tensor<?xf32>
51+
// expected-remark @below{{true}}
52+
"test.compare"(%0, %c0) {cmp = "GE" } : (index, index) -> ()
53+
return
54+
}
55+
56+
// -----
57+
4758
// CHECK-LABEL: func @empty(
4859
// CHECK-SAME: %[[sz:.*]]: index
4960
// CHECK: %[[c6:.*]] = arith.constant 6 : index

0 commit comments

Comments
 (0)