Skip to content

Commit 80949fe

Browse files
[mlir][Interfaces][WIP] Expose public compare API
Also use `compare` API for `areEqual` etc.
1 parent 40327a6 commit 80949fe

File tree

7 files changed

+361
-114
lines changed

7 files changed

+361
-114
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
211211
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
212212
enum ComparisonOperator { LT, LE, EQ, GT, GE };
213213

214-
/// Try to prove that, based on the current state of this constraint set
214+
/// Populate constraints for lhs/rhs (until the stop condition is met). Then,
215+
/// try to prove that, based on the current state of this constraint set
215216
/// (i.e., without analyzing additional IR or adding new constraints), the
216217
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
217218
///
@@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
220221
/// proven. This could be because the specified relation does in fact not hold
221222
/// or because there is not enough information in the constraint set. In other
222223
/// words, if we do not know for sure, this function returns "false".
223-
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
224-
Value rhs, std::optional<int64_t> rhsDim);
224+
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
225+
ComparisonOperator cmp, OpFoldResult rhs,
226+
std::optional<int64_t> rhsDim);
227+
228+
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
229+
/// specified relation could not be proven. This could be because the
230+
/// specified relation does in fact not hold or because there is not enough
231+
/// information in the constraint set. In other words, if we do not know for
232+
/// sure, this function returns "false".
233+
///
234+
/// This function keeps traversing the backward slice of lhs/rhs until could
235+
/// prove the relation or until it ran out of IR.
236+
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
237+
ComparisonOperator cmp, OpFoldResult rhs,
238+
std::optional<int64_t> rhsDim);
239+
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
240+
ComparisonOperator cmp, AffineMap rhs,
241+
ValueDimList rhsOperands);
242+
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
243+
ComparisonOperator cmp, AffineMap rhs,
244+
ArrayRef<Value> rhsOperands);
225245

226246
/// Compute whether the given values/dimensions are equal. Return "failure" if
227247
/// equality could not be determined.
228248
///
229249
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
230250
/// index-typed.
231-
static FailureOr<bool> areEqual(Value value1, Value value2,
251+
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
232252
std::optional<int64_t> dim1 = std::nullopt,
233253
std::optional<int64_t> dim2 = std::nullopt);
234254

235-
/// Compute whether the given values/attributes are equal. Return "failure" if
236-
/// equality could not be determined.
237-
///
238-
/// `ofr1`/`ofr2` must be of index type.
239-
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
240-
241255
/// Return "true" if the given slices are guaranteed to be overlapping.
242256
/// Return "false" if the given slices are guaranteed to be non-overlapping.
243257
/// Return "failure" if unknown.
@@ -294,6 +308,20 @@ class ValueBoundsConstraintSet
294308

295309
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
296310

311+
/// Return "true" if, based on the current state of the constraint system,
312+
/// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
313+
/// could not be proven. This could be because the specified relation does in
314+
/// fact not hold or because there is not enough information in the constraint
315+
/// set. In other words, if we do not know for sure, this function returns
316+
/// "false".
317+
///
318+
/// This function does not analyze any IR and does not populate any additional
319+
/// constraints.
320+
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
321+
ComparisonOperator cmp, OpFoldResult rhs,
322+
std::optional<int64_t> rhsDim);
323+
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
324+
297325
/// Given an affine map with a single result (and map operands), add a new
298326
/// column to the constraint set that represents the result of the map.
299327
/// Traverse additional IR starting from the map operands as needed (as long
@@ -319,6 +347,10 @@ class ValueBoundsConstraintSet
319347
/// set.
320348
AffineExpr getPosExpr(int64_t pos);
321349

350+
/// Return "true" if the given value/dim is mapped (i.e., has a corresponding
351+
/// column in the constraint system).
352+
bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
353+
322354
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
323355
/// "false", a dimension is added. The value/dimension is added to the
324356
/// worklist if `addToWorklist` is set.
@@ -338,6 +370,11 @@ class ValueBoundsConstraintSet
338370
/// dimensions but not for symbols.
339371
int64_t insert(bool isSymbol = true);
340372

373+
/// Insert the given affine map and its bound operands as a new column in the
374+
/// constraint system. Return the position of the new column. Any operands
375+
/// that were not analyzed yet are put on the worklist.
376+
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
377+
341378
/// Project out the given column in the constraint set.
342379
void projectOut(int64_t pos);
343380

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,11 @@ struct ForOpInterface
5858
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
5959
Value initArg = forOp.getInitArgs()[iterArgIdx];
6060

61-
// Populate constraints for the yielded value.
62-
cstr.populateConstraints(yieldedValue, dim);
63-
// Populate constraints for the iter_arg. This is just to ensure that the
64-
// iter_arg is mapped in the constraint set, which is a prerequisite for
65-
// `compare`. It may lead to a recursive call to this function in case the
66-
// iter_arg was not visited when the constraints for the yielded value were
67-
// populated, but no additional work is done.
68-
cstr.populateConstraints(iterArg, dim);
69-
7061
// An EQ constraint can be added if the yielded value (dimension size)
7162
// equals the corresponding block argument (dimension size).
72-
if (cstr.compare(yieldedValue, dim,
73-
ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74-
dim)) {
63+
if (cstr.populateAndCompare(
64+
yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
65+
iterArg, dim)) {
7566
if (dim.has_value()) {
7667
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
7768
} else {
@@ -113,10 +104,6 @@ struct IfOpInterface
113104
Value thenValue = ifOp.thenYield().getResults()[resultNum];
114105
Value elseValue = ifOp.elseYield().getResults()[resultNum];
115106

116-
// Populate constraints for the yielded value (and all values on the
117-
// backward slice, as long as the current stop condition is not satisfied).
118-
cstr.populateConstraints(thenValue, dim);
119-
cstr.populateConstraints(elseValue, dim);
120107
auto boundsBuilder = cstr.bound(value);
121108
if (dim)
122109
boundsBuilder[*dim];
@@ -125,9 +112,9 @@ struct IfOpInterface
125112
// If thenValue <= elseValue:
126113
// * result <= elseValue
127114
// * result >= thenValue
128-
if (cstr.compare(thenValue, dim,
129-
ValueBoundsConstraintSet::ComparisonOperator::LE,
130-
elseValue, dim)) {
115+
if (cstr.populateAndCompare(
116+
thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
117+
elseValue, dim)) {
131118
if (dim) {
132119
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
133120
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -139,9 +126,9 @@ struct IfOpInterface
139126
// If elseValue <= thenValue:
140127
// * result <= thenValue
141128
// * result >= elseValue
142-
if (cstr.compare(elseValue, dim,
143-
ValueBoundsConstraintSet::ComparisonOperator::LE,
144-
thenValue, dim)) {
129+
if (cstr.populateAndCompare(
130+
elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
131+
thenValue, dim)) {
145132
if (dim) {
146133
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
147134
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);

0 commit comments

Comments
 (0)