@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
354
354
castOp.getType ());
355
355
}
356
356
357
+ bool mlir::tensor::hasFoldableTensorCastOperand (Operation *op) {
358
+ return llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
359
+ if (llvm::isa<BlockArgument>(opOperand.get ()))
360
+ return false ;
361
+ auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
362
+ return castOp && canFoldIntoConsumerOp (castOp);
363
+ });
364
+ }
365
+
366
+ SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding (
367
+ DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
368
+ SmallVector<Value> newOperands;
369
+ newOperands.reserve (op->getNumOperands ());
370
+
371
+ assert (hasFoldableTensorCastOperand (op) && " No foldable CastOp operands!" );
372
+
373
+ // Assumes that the result has dpsInits followed by nonDpsInits.
374
+ int64_t dpsInitIdx = 0 ;
375
+ for (OpOperand &opOperand : op->getOpOperands ()) {
376
+ auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
377
+ bool fold = canFoldIntoConsumerOp (tensorCastOp);
378
+ newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
379
+ if (op.isDpsInit (&opOperand) &&
380
+ !llvm::isa<MemRefType>(newOperands.back ().getType ()))
381
+ newResTy[dpsInitIdx++] = newOperands.back ().getType ();
382
+ }
383
+ return newOperands;
384
+ }
385
+
357
386
// / Performs folding of any operand of `op` if it comes from a tensor::CastOp
358
387
// / that can be folded.
359
388
LogicalResult mlir::tensor::foldTensorCast (Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4777
4806
isa<LoopLikeOpInterface>(op.getOperation ()))
4778
4807
return false ;
4779
4808
4780
- // If no operand comes from a tensor::CastOp and can be folded then fail.
4781
- bool hasTensorCastOperand =
4782
- llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4783
- if (llvm::isa<BlockArgument>(opOperand.get ()))
4784
- return false ;
4785
- auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4786
- return castOp && canFoldIntoConsumerOp (castOp);
4787
- });
4788
-
4789
- return hasTensorCastOperand;
4790
- }
4791
-
4792
- static SmallVector<Value> getNewOperands (DestinationStyleOpInterface op,
4793
- SmallVector<Type> &newResTy) {
4794
- SmallVector<Value> newOperands;
4795
- newOperands.reserve (op->getNumOperands ());
4796
-
4797
- // Assumes that the result has dpsInits followed by nonDpsInits.
4798
- int64_t dpsInitIdx = 0 ;
4799
- for (OpOperand &opOperand : op->getOpOperands ()) {
4800
- auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4801
- bool fold = canFoldIntoConsumerOp (tensorCastOp);
4802
- newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4803
- if (op.isDpsInit (&opOperand) &&
4804
- !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4805
- newResTy[dpsInitIdx++] = newOperands.back ().getType ();
4806
- }
4807
- return newOperands;
4809
+ return hasFoldableTensorCastOperand (op);
4808
4810
}
4809
4811
4810
4812
// Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4868
4870
return failure ();
4869
4871
4870
4872
SmallVector<Type> newResultTypes (op->getResultTypes ());
4871
- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4873
+ SmallVector<Value> newOperands =
4874
+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
4872
4875
4873
4876
// Get the updated mixed-tile-sizes attribute.
4874
4877
SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
4920
4923
return failure ();
4921
4924
4922
4925
SmallVector<Type> newResultTypes (op->getResultTypes ());
4923
- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4926
+ SmallVector<Value> newOperands =
4927
+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
4924
4928
Value sourceTensor = newOperands[0 ];
4925
4929
4926
4930
// Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
4980
4984
return failure ();
4981
4985
4982
4986
SmallVector<Type> newResultTypes (op->getResultTypes ());
4983
- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4987
+ SmallVector<Value> newOperands =
4988
+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
4984
4989
4985
4990
// Clone op
4986
4991
auto newOp = clone (rewriter, op, newResultTypes, newOperands);
0 commit comments