@@ -59,6 +59,27 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
59
59
dispatchIndexOpFoldResult (ofr, dynamicVec, staticVec, sentinel);
60
60
}
61
61
62
+ // / Return true if ofr1 and ofr2 are the same integer constant attribute values
63
+ // / or the same SSA value.
64
+ // / Ignore integer bitwitdh and type mismatch that come from the fact there is
65
+ // / no IndexAttr and that IndexType have no bitwidth.
66
+ bool mlir::isEqualConstantIntOrValue (OpFoldResult op1, OpFoldResult op2) {
67
+ auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t > {
68
+ Attribute attr = ofr.dyn_cast <Attribute>();
69
+ // Note: isa+cast-like pattern allows writing the condition below as 1 line.
70
+ if (!attr && ofr.get <Value>().getDefiningOp <ConstantOp>())
71
+ attr = ofr.get <Value>().getDefiningOp <ConstantOp>().getValue ();
72
+ if (auto intAttr = attr.dyn_cast_or_null <IntegerAttr>())
73
+ return intAttr.getValue ().getSExtValue ();
74
+ return llvm::None;
75
+ };
76
+ auto cst1 = getConstantIntValue (op1), cst2 = getConstantIntValue (op2);
77
+ if (cst1 && cst2 && *cst1 == *cst2)
78
+ return true ;
79
+ auto v1 = op1.dyn_cast <Value>(), v2 = op2.dyn_cast <Value>();
80
+ return v1 && v2 && v1 == v2;
81
+ }
82
+
62
83
// ===----------------------------------------------------------------------===//
63
84
// StandardOpsDialect Interfaces
64
85
// ===----------------------------------------------------------------------===//
@@ -3557,6 +3578,34 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3557
3578
context);
3558
3579
}
3559
3580
3581
+ //
3582
+ static LogicalResult
3583
+ foldIdentityOffsetSizeAndStrideOpInterface (OffsetSizeAndStrideOpInterface op,
3584
+ ShapedType shapedType) {
3585
+ OpBuilder b (op.getContext ());
3586
+ for (OpFoldResult ofr : op.getMixedOffsets ())
3587
+ if (!isEqualConstantIntOrValue (ofr, b.getIndexAttr (0 )))
3588
+ return failure ();
3589
+ // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
3590
+ // is appropriate.
3591
+ auto shape = shapedType.getShape ();
3592
+ for (auto it : llvm::zip (op.getMixedSizes (), shape))
3593
+ if (!isEqualConstantIntOrValue (std::get<0 >(it),
3594
+ b.getIndexAttr (std::get<1 >(it))))
3595
+ return failure ();
3596
+ for (OpFoldResult ofr : op.getMixedStrides ())
3597
+ if (!isEqualConstantIntOrValue (ofr, b.getIndexAttr (1 )))
3598
+ return failure ();
3599
+ return success ();
3600
+ }
3601
+
3602
+ OpFoldResult SubTensorOp::fold (ArrayRef<Attribute>) {
3603
+ if (getSourceType () == getType () &&
3604
+ succeeded (foldIdentityOffsetSizeAndStrideOpInterface (*this , getType ())))
3605
+ return this ->source ();
3606
+ return OpFoldResult ();
3607
+ }
3608
+
3560
3609
// ===----------------------------------------------------------------------===//
3561
3610
// SubTensorInsertOp
3562
3611
// ===----------------------------------------------------------------------===//
@@ -3597,6 +3646,13 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
3597
3646
build (b, result, source, dest, offsetValues, sizeValues, strideValues);
3598
3647
}
3599
3648
3649
+ OpFoldResult SubTensorInsertOp::fold (ArrayRef<Attribute>) {
3650
+ if (getSourceType () == getType () &&
3651
+ succeeded (foldIdentityOffsetSizeAndStrideOpInterface (*this , getType ())))
3652
+ return this ->source ();
3653
+ return OpFoldResult ();
3654
+ }
3655
+
3600
3656
// ===----------------------------------------------------------------------===//
3601
3657
// TensorLoadOp
3602
3658
// ===----------------------------------------------------------------------===//
0 commit comments