Skip to content

Commit 5586541

Browse files
authored
[mlir][tensor] Make useful Tensor utilities public (#126802)
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the corresponding checks. 2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better clarity and documentation of its functionality. 3. These updated hooks will be reused in: * #123902. This PR makes them public. **Note:** Moving these hooks to `Tensor/Utils` is not feasible because `MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect` on `MLIRTensorUtils`, leading to a circular dependency.
1 parent 1c207f1 commit 5586541

File tree

2 files changed

+48
-31
lines changed

2 files changed

+48
-31
lines changed

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
116116
/// this method provides a check that it is worth doing the canonicalization.
117117
bool canFoldIntoProducerOp(CastOp castOp);
118118

119+
/// Return true if any of the operands of `op` is a CastOp that can be folded
120+
/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
121+
/// `canFoldIntoProducerOp`.
122+
bool hasFoldableTensorCastOperand(Operation *op);
123+
124+
/// Assuming that `op` contains at least one operand that is a foldable CastOp
125+
/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
126+
/// operands.
127+
SmallVector<Value>
128+
getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
129+
SmallVector<Type> &newResTy);
130+
119131
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
120132
/// that can be folded.
121133
LogicalResult foldTensorCast(Operation *op);

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
354354
castOp.getType());
355355
}
356356

357+
bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
358+
return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
359+
if (llvm::isa<BlockArgument>(opOperand.get()))
360+
return false;
361+
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
362+
return castOp && canFoldIntoConsumerOp(castOp);
363+
});
364+
}
365+
366+
SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
367+
DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
368+
SmallVector<Value> newOperands;
369+
newOperands.reserve(op->getNumOperands());
370+
371+
assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
372+
373+
// Assumes that the result has dpsInits followed by nonDpsInits.
374+
int64_t dpsInitIdx = 0;
375+
for (OpOperand &opOperand : op->getOpOperands()) {
376+
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377+
bool fold = canFoldIntoConsumerOp(tensorCastOp);
378+
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
379+
if (op.isDpsInit(&opOperand) &&
380+
!llvm::isa<MemRefType>(newOperands.back().getType()))
381+
newResTy[dpsInitIdx++] = newOperands.back().getType();
382+
}
383+
return newOperands;
384+
}
385+
357386
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
358387
/// that can be folded.
359388
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
47774806
isa<LoopLikeOpInterface>(op.getOperation()))
47784807
return false;
47794808

4780-
// If no operand comes from a tensor::CastOp and can be folded then fail.
4781-
bool hasTensorCastOperand =
4782-
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4783-
if (llvm::isa<BlockArgument>(opOperand.get()))
4784-
return false;
4785-
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4786-
return castOp && canFoldIntoConsumerOp(castOp);
4787-
});
4788-
4789-
return hasTensorCastOperand;
4790-
}
4791-
4792-
static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4793-
SmallVector<Type> &newResTy) {
4794-
SmallVector<Value> newOperands;
4795-
newOperands.reserve(op->getNumOperands());
4796-
4797-
// Assumes that the result has dpsInits followed by nonDpsInits.
4798-
int64_t dpsInitIdx = 0;
4799-
for (OpOperand &opOperand : op->getOpOperands()) {
4800-
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4801-
bool fold = canFoldIntoConsumerOp(tensorCastOp);
4802-
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4803-
if (op.isDpsInit(&opOperand) &&
4804-
!llvm::isa<MemRefType>(newOperands.back().getType()))
4805-
newResTy[dpsInitIdx++] = newOperands.back().getType();
4806-
}
4807-
return newOperands;
4809+
return hasFoldableTensorCastOperand(op);
48084810
}
48094811

48104812
// Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48684870
return failure();
48694871

48704872
SmallVector<Type> newResultTypes(op->getResultTypes());
4871-
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4873+
SmallVector<Value> newOperands =
4874+
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
48724875

48734876
// Get the updated mixed-tile-sizes attribute.
48744877
SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
49204923
return failure();
49214924

49224925
SmallVector<Type> newResultTypes(op->getResultTypes());
4923-
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4926+
SmallVector<Value> newOperands =
4927+
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
49244928
Value sourceTensor = newOperands[0];
49254929

49264930
// Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
49804984
return failure();
49814985

49824986
SmallVector<Type> newResultTypes(op->getResultTypes());
4983-
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4987+
SmallVector<Value> newOperands =
4988+
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
49844989

49854990
// Clone op
49864991
auto newOp = clone(rewriter, op, newResultTypes, newOperands);

0 commit comments

Comments
 (0)