Skip to content

Commit 4060bbe

Browse files
antiagainstGroverkss
authored andcommitted
[mlir][vector] Enable transfer op hoisting with dynamic indices (llvm#68500)
Recent changes (llvm#66930) disabled vector transfer ops hoisting with view-like intermediate ops. The recommended way is to fold subview ops into transfer op indices before invoking hoisting. That would mean now we see transfer op indices involving dynamic values, instead of static constant values before with subview ops. Therefore hoisting won't kick in anymore. This breaks downstream users. To fix it, this commit enables hoisting transfer ops with dynamic indices by using `ValueBoundsConstraintSet` to prove ranges are disjoint in `isDisjointTransferIndices`. Given that utility is used in many places including op folders, right now we introduce a flag to it and only set as true for "heavy" transforms in hoisting and load-store forwarding.
1 parent 51fdabc commit 4060bbe

File tree

13 files changed

+370
-60
lines changed

13 files changed

+370
-60
lines changed

mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ class Value;
1818
namespace affine {
1919
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
2020

21-
/// Compute whether the given values are equal. Return "failure" if equality
22-
/// could not be determined. `value1`/`value2` must be index-typed.
21+
/// Compute a constant delta of the given two values. Return "failure" if we
22+
/// cannot determine a constant delta. `value1`/`value2` must be index-typed.
2323
///
24-
/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
25-
/// around limitations in `FlatLinearConstraints`, this function fully composes
24+
/// This function is similar to
25+
/// `ValueBoundsConstraintSet::computeConstantDistance`. To work around
26+
/// limitations in `FlatLinearConstraints`, this function fully composes
2627
/// `value1` and `value2` (if they are the result of affine.apply ops) before
2728
/// populating the constraint set. The folding/composing logic can see
2829
/// opportunities for simplifications that the constraint set implementation
2930
/// cannot see.
30-
FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
31+
FailureOr<int64_t> fullyComposeAndComputeConstantDelta(Value value1,
32+
Value value2);
3133
} // namespace affine
3234
} // namespace mlir
3335

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
105105
/// op.
106106
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
107107

108-
/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
109-
/// to have the same tensor/memref. This allows comparing operations accessing
110-
/// different tensors.
108+
/// Return true if we can prove that the transfer operations access disjoint
109+
/// memory, without requring the accessed tensor/memref to be the same.
110+
///
111+
/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
112+
/// via ValueBoundsOpInterface.
111113
bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
112-
VectorTransferOpInterface transferB);
114+
VectorTransferOpInterface transferB,
115+
bool testDynamicValueUsingBounds = false);
113116

114117
/// Return true if we can prove that the transfer operations access disjoint
115-
/// memory.
118+
/// memory, requiring the operations to access the same tensor/memref.
119+
///
120+
/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
121+
/// via ValueBoundsOpInterface.
116122
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
117-
VectorTransferOpInterface transferB);
123+
VectorTransferOpInterface transferB,
124+
bool testDynamicValueUsingBounds = false);
118125

119126
/// Return the result value of reducing two scalar/vector values with the
120127
/// corresponding arith operation.

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ class ValueBoundsConstraintSet {
176176
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
177177
StopConditionFn stopCondition = nullptr, bool closedUB = false);
178178

179+
/// Compute a constant delta between the given two values. Return "failure"
180+
/// if a constant delta could not be determined.
181+
///
182+
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
183+
/// index-typed.
184+
static FailureOr<int64_t>
185+
computeConstantDelta(Value value1, Value value2,
186+
std::optional<int64_t> dim1 = std::nullopt,
187+
std::optional<int64_t> dim2 = std::nullopt);
188+
179189
/// Compute whether the given values/dimensions are equal. Return "failure" if
180190
/// equality could not be determined.
181191
///

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
103103
});
104104
}
105105

106-
FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
107-
Value value2) {
106+
FailureOr<int64_t>
107+
mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
108108
assert(value1.getType().isIndex() && "expected index type");
109109
assert(value2.getType().isIndex() && "expected index type");
110110

@@ -123,9 +123,6 @@ FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
123123
ValueDimList valueDims;
124124
for (Value v : mapOperands)
125125
valueDims.push_back({v, std::nullopt});
126-
FailureOr<int64_t> bound = ValueBoundsConstraintSet::computeConstantBound(
126+
return ValueBoundsConstraintSet::computeConstantBound(
127127
presburger::BoundType::EQ, map, valueDims);
128-
if (failed(bound))
129-
return failure();
130-
return *bound == 0;
131128
}

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
173173
if (auto transferWriteUse =
174174
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
175175
if (!vector::isDisjointTransferSet(
176-
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
177-
cast<VectorTransferOpInterface>(
178-
transferWriteUse.getOperation())))
176+
cast<VectorTransferOpInterface>(*transferWrite),
177+
cast<VectorTransferOpInterface>(*transferWriteUse),
178+
/*testDynamicValueUsingBounds=*/true))
179179
return WalkResult::advance();
180180
} else if (auto transferReadUse =
181181
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
182182
if (!vector::isDisjointTransferSet(
183-
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
184-
cast<VectorTransferOpInterface>(
185-
transferReadUse.getOperation())))
183+
cast<VectorTransferOpInterface>(*transferWrite),
184+
cast<VectorTransferOpInterface>(*transferReadUse),
185+
/*testDynamicValueUsingBounds=*/true))
186186
return WalkResult::advance();
187187
} else {
188188
// Unknown use, we cannot prove that it doesn't alias with the

mlir/lib/Dialect/Vector/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorDialect
1111
MLIRVectorAttributesIncGen
1212

1313
LINK_LIBS PUBLIC
14+
MLIRAffineDialect
1415
MLIRArithDialect
1516
MLIRControlFlowInterfaces
1617
MLIRDataLayoutInterfaces
@@ -22,5 +23,6 @@ add_mlir_dialect_library(MLIRVectorDialect
2223
MLIRMemRefDialect
2324
MLIRSideEffectInterfaces
2425
MLIRTensorDialect
26+
MLIRValueBoundsOpInterface
2527
MLIRVectorInterfaces
2628
)

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1515

16+
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
1617
#include "mlir/Dialect/Arith/IR/Arith.h"
1718
#include "mlir/Dialect/Arith/Utils/Utils.h"
1819
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -30,6 +31,7 @@
3031
#include "mlir/IR/OpImplementation.h"
3132
#include "mlir/IR/PatternMatch.h"
3233
#include "mlir/IR/TypeUtilities.h"
34+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
3335
#include "mlir/Support/LLVM.h"
3436
#include "llvm/ADT/ArrayRef.h"
3537
#include "llvm/ADT/STLExtras.h"
@@ -168,39 +170,76 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
168170
}
169171

170172
bool mlir::vector::isDisjointTransferIndices(
171-
VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
173+
VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
174+
bool testDynamicValueUsingBounds) {
172175
// For simplicity only look at transfer of same type.
173176
if (transferA.getVectorType() != transferB.getVectorType())
174177
return false;
175178
unsigned rankOffset = transferA.getLeadingShapedRank();
176179
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
177-
auto indexA = getConstantIntValue(transferA.indices()[i]);
178-
auto indexB = getConstantIntValue(transferB.indices()[i]);
179-
// If any of the indices are dynamic we cannot prove anything.
180-
if (!indexA.has_value() || !indexB.has_value())
181-
continue;
180+
Value indexA = transferA.indices()[i];
181+
Value indexB = transferB.indices()[i];
182+
std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
183+
std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
182184

183185
if (i < rankOffset) {
184186
// For leading dimensions, if we can prove that index are different we
185187
// know we are accessing disjoint slices.
186-
if (*indexA != *indexB)
187-
return true;
188+
if (cstIndexA.has_value() && cstIndexB.has_value()) {
189+
if (*cstIndexA != *cstIndexB)
190+
return true;
191+
continue;
192+
}
193+
if (testDynamicValueUsingBounds) {
194+
// First try to see if we can fully compose and simplify the affine
195+
// expression as a fast track.
196+
FailureOr<uint64_t> delta =
197+
affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
198+
if (succeeded(delta) && *delta != 0)
199+
return true;
200+
201+
FailureOr<bool> testEqual =
202+
ValueBoundsConstraintSet::areEqual(indexA, indexB);
203+
if (succeeded(testEqual) && !testEqual.value())
204+
return true;
205+
}
188206
} else {
189207
// For this dimension, we slice a part of the memref we need to make sure
190208
// the intervals accessed don't overlap.
191-
int64_t distance = std::abs(*indexA - *indexB);
192-
if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
193-
return true;
209+
int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
210+
if (cstIndexA.has_value() && cstIndexB.has_value()) {
211+
int64_t distance = std::abs(*cstIndexA - *cstIndexB);
212+
if (distance >= vectorDim)
213+
return true;
214+
continue;
215+
}
216+
if (testDynamicValueUsingBounds) {
217+
// First try to see if we can fully compose and simplify the affine
218+
// expression as a fast track.
219+
FailureOr<int64_t> delta =
220+
affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
221+
if (succeeded(delta) && std::abs(*delta) >= vectorDim)
222+
return true;
223+
224+
FailureOr<int64_t> computeDelta =
225+
ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
226+
if (succeeded(computeDelta)) {
227+
if (std::abs(computeDelta.value()) >= vectorDim)
228+
return true;
229+
}
230+
}
194231
}
195232
}
196233
return false;
197234
}
198235

199236
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
200-
VectorTransferOpInterface transferB) {
237+
VectorTransferOpInterface transferB,
238+
bool testDynamicValueUsingBounds) {
201239
if (transferA.source() != transferB.source())
202240
return false;
203-
return isDisjointTransferIndices(transferA, transferB);
241+
return isDisjointTransferIndices(transferA, transferB,
242+
testDynamicValueUsingBounds);
204243
}
205244

206245
// Helper to iterate over n-D vector slice elements. Calculate the next

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
142142
// Don't need to consider disjoint accesses.
143143
if (vector::isDisjointTransferSet(
144144
cast<VectorTransferOpInterface>(write.getOperation()),
145-
cast<VectorTransferOpInterface>(transferOp.getOperation())))
145+
cast<VectorTransferOpInterface>(transferOp.getOperation()),
146+
/*testDynamicValueUsingBounds=*/true))
146147
continue;
147148
}
148149
blockingAccesses.push_back(user);
@@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
217218
// the write.
218219
if (vector::isDisjointTransferSet(
219220
cast<VectorTransferOpInterface>(write.getOperation()),
220-
cast<VectorTransferOpInterface>(read.getOperation())))
221+
cast<VectorTransferOpInterface>(read.getOperation()),
222+
/*testDynamicValueUsingBounds=*/true))
221223
continue;
222224
if (write.getSource() == read.getSource() &&
223225
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -484,25 +484,32 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
484484
return failure();
485485
}
486486

487-
FailureOr<bool>
488-
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
489-
std::optional<int64_t> dim1,
490-
std::optional<int64_t> dim2) {
487+
FailureOr<int64_t>
488+
ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
489+
std::optional<int64_t> dim1,
490+
std::optional<int64_t> dim2) {
491491
#ifndef NDEBUG
492492
assertValidValueDim(value1, dim1);
493493
assertValidValueDim(value2, dim2);
494494
#endif // NDEBUG
495495

496-
// Subtract the two values/dimensions from each other. If the result is 0,
497-
// both are equal.
498496
Builder b(value1.getContext());
499497
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
500498
b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
501-
FailureOr<int64_t> bound = computeConstantBound(
502-
presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
503-
if (failed(bound))
499+
return computeConstantBound(presburger::BoundType::EQ, map,
500+
{{value1, dim1}, {value2, dim2}});
501+
}
502+
503+
FailureOr<bool>
504+
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
505+
std::optional<int64_t> dim1,
506+
std::optional<int64_t> dim2) {
507+
// Subtract the two values/dimensions from each other. If the result is 0,
508+
// both are equal.
509+
FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
510+
if (failed(delta))
504511
return failure();
505-
return *bound == 0;
512+
return *delta == 0;
506513
}
507514

508515
ValueBoundsConstraintSet::BoundBuilder &

0 commit comments

Comments
 (0)