Skip to content

Commit bc8d8e6

Browse files
committed
[mlir] Fold shape.eq %a, %a to true
Differential Revision: https://reviews.llvm.org/D95430
1 parent 7b3ba8d commit bc8d8e6

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
572572
//===----------------------------------------------------------------------===//
573573

574574
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
575+
if (lhs() == rhs())
576+
return BoolAttr::get(true, getContext());
575577
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
576578
if (lhs == nullptr)
577579
return {};

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ func @shape_eq_fold_0() -> i1 {
787787

788788
// -----
789789

790-
// Do not fold `shape_eq` for non-constant shapes.
790+
// Do not fold `shape_eq` for non-constant different shapes.
791791
// CHECK-LABEL: @shape_eq_do_not_fold
792792
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
793793
func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
@@ -799,6 +799,19 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
799799
return %result : i1
800800
}
801801

802+
803+
// -----
804+
805+
// Fold `shape_eq` for non-constant but same shapes.
806+
// CHECK-LABEL: @shape_eq_do_fold
807+
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
808+
func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
809+
// CHECK: %[[RESULT:.*]] = constant true
810+
// CHECK: return %[[RESULT]] : i1
811+
%result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
812+
return %result : i1
813+
}
814+
802815
// -----
803816

804817
// Fold `mul` for constant sizes.

0 commit comments

Comments
 (0)