Skip to content

Commit 339753d

Browse files
[mlir][vector][NFC] isDisjointTransferIndices: Use getConstantIntValue (#65931)
Use `getConstantIntValue` instead of matching for `arith::ConstantOp`.
1 parent 2484678 commit 339753d

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,21 @@ bool mlir::vector::isDisjointTransferIndices(
172172
return false;
173173
unsigned rankOffset = transferA.getLeadingShapedRank();
174174
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
175-
auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
176-
auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
175+
auto indexA = getConstantIntValue(transferA.indices()[i]);
176+
auto indexB = getConstantIntValue(transferB.indices()[i]);
177177
// If any of the indices are dynamic we cannot prove anything.
178-
if (!indexA || !indexB)
178+
if (!indexA.has_value() || !indexB.has_value())
179179
continue;
180180

181181
if (i < rankOffset) {
182182
// For leading dimensions, if we can prove that index are different we
183183
// know we are accessing disjoint slices.
184-
if (llvm::cast<IntegerAttr>(indexA.getValue()).getInt() !=
185-
llvm::cast<IntegerAttr>(indexB.getValue()).getInt())
184+
if (*indexA != *indexB)
186185
return true;
187186
} else {
188187
// For this dimension, we slice a part of the memref we need to make sure
189188
// the intervals accessed don't overlap.
190-
int64_t distance =
191-
std::abs(llvm::cast<IntegerAttr>(indexA.getValue()).getInt() -
192-
llvm::cast<IntegerAttr>(indexB.getValue()).getInt());
189+
int64_t distance = std::abs(*indexA - *indexB);
193190
if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
194191
return true;
195192
}

0 commit comments

Comments
 (0)