Skip to content

[InstCombine] Try the flipped strictness of predicate in foldICmpShlConstant #92773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented May 20, 2024

This patch extends the transform (icmp pred iM (shl iM %v, N), C) -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) to handle icmps with the flipped strictness of predicate.

See the following case:

icmp ult i64 (shl X, 32), 8589934593 ->
icmp ule i64 (shl X, 32), 8589934592 ->
icmp ule i32 (trunc X, i32), 2 ->
icmp ult i32 (trunc X, i32), 3

Fixes the regression introduced by #86111 (comment).

Alive2 proofs: https://alive2.llvm.org/ce/z/-sp5n3

nuw cannot be propagated as we always use ashr here. I don't see the value of fixing this (see the test test_icmp_shl_nuw).

@dtcxzyw dtcxzyw requested review from asb and goldsteinn May 20, 2024 15:43
@dtcxzyw dtcxzyw requested a review from nikic as a code owner May 20, 2024 15:43
@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch extends the transform (icmp pred iM (shl iM %v, N), C) -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) to handle icmps with the flipped strictness of predicate.

See the following case:

icmp ult i64 (shl X, 32), 8589934593 ->
icmp ule i64 (shl X, 32), 8589934592 ->
icmp ule i32 (trunc X, i32), 2 ->
icmp ult i32 (trunc X, i32), 3

Fixes the regression introduced by #86111 (comment).

Alive2 proofs: https://alive2.llvm.org/ce/z/-sp5n3

nuw cannot be propagated as we always use ashr here. I don't see the value of fixing this (see the test test_icmp_shl_nuw).


Full diff: https://github.com/llvm/llvm-project/pull/92773.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+31-8)
  • (modified) llvm/test/Transforms/InstCombine/icmp.ll (+77)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 542a1c82b127a..f5952362f56bd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2414,14 +2414,37 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
   // free on the target. It has the additional benefit of comparing to a
   // smaller constant that may be more target-friendly.
   unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1);
-  if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt &&
-      DL.isLegalInteger(TypeBits - Amt)) {
-    Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
-    if (auto *ShVTy = dyn_cast<VectorType>(ShType))
-      TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount());
-    Constant *NewC =
-        ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
-    return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
+  if (Shl->hasOneUse() && Amt != 0 && DL.isLegalInteger(TypeBits - Amt)) {
+    auto FoldICmpShlToICmpTrunc = [&](ICmpInst::Predicate Pred,
+                                      const APInt &C) -> Instruction * {
+      if (C.countr_zero() < Amt)
+        return nullptr;
+      Type *TruncTy = ShType->getWithNewBitWidth(TypeBits - Amt);
+      Constant *NewC =
+          ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
+      return new ICmpInst(
+          Pred, Builder.CreateTrunc(X, TruncTy, "", Shl->hasNoSignedWrap()),
+          NewC);
+    };
+
+    if (Instruction *Res = FoldICmpShlToICmpTrunc(Pred, C))
+      return Res;
+
+    // Try the flipped strictness predicate.
+    // e.g.:
+    // icmp ult i64 (shl X, 32), 8589934593 ->
+    // icmp ule i64 (shl X, 32), 8589934592 ->
+    // icmp ule i32 (trunc X, i32), 2 ->
+    // icmp ult i32 (trunc X, i32), 3
+    if (auto FlippedStrictness =
+            InstCombiner::getFlippedStrictnessPredicateAndConstant(
+                Pred, ConstantInt::get(ShType->getContext(), C))) {
+      ICmpInst::Predicate NewPred = FlippedStrictness->first;
+      const APInt &NewC =
+          cast<ConstantInt>(FlippedStrictness->second)->getValue();
+      if (Instruction *Res = FoldICmpShlToICmpTrunc(NewPred, NewC))
+        return Res;
+    }
   }
 
   return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 2d786c8f48833..319e13235ccd6 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -5198,3 +5198,80 @@ define i1 @icmp_freeze_sext(i16 %x, i16 %y) {
   %cmp2 = icmp uge i16 %ext.fr, %y
   ret i1 %cmp2
 }
+
+define i1 @test_icmp_shl(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[X:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[TMP1]], 3
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl i64 %x, 32
+  %cmp = icmp ult i64 %shl, 8589934593
+  ret i1 %cmp
+}
+
+define i1 @test_icmp_shl_multiuse(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl_multiuse(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[X:%.*]], 32
+; CHECK-NEXT:    call void @use_i64(i64 [[SHL]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[SHL]], 8589934593
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl i64 %x, 32
+  call void @use_i64(i64 %shl)
+  %cmp = icmp ult i64 %shl, 8589934593
+  ret i1 %cmp
+}
+
+define i1 @test_icmp_shl_illegal_length(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl_illegal_length(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[X:%.*]], 31
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[SHL]], 8589934593
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl i64 %x, 31
+  %cmp = icmp ult i64 %shl, 8589934593
+  ret i1 %cmp
+}
+
+define i1 @test_icmp_shl_invalid_rhsc(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl_invalid_rhsc(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[X:%.*]], 32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[SHL]], 8589934595
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl i64 %x, 32
+  %cmp = icmp ult i64 %shl, 8589934595
+  ret i1 %cmp
+}
+
+define i1 @test_icmp_shl_nuw(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl_nuw(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[X:%.*]], 3
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nuw i64 %x, 32
+  %cmp = icmp ult i64 %shl, 8589934593
+  ret i1 %cmp
+}
+
+define i1 @test_icmp_shl_nsw(i64 %x) {
+; CHECK-LABEL: @test_icmp_shl_nsw(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[X:%.*]], 3
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nsw i64 %x, 32
+  %cmp = icmp ult i64 %shl, 8589934593
+  ret i1 %cmp
+}
+
+define <2 x i1> @test_icmp_shl_vec(<2 x i64> %x) {
+; CHECK-LABEL: @test_icmp_shl_vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i64> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 3, i32 3>
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %shl = shl <2 x i64> %x, splat(i64 32)
+  %cmp = icmp ult <2 x i64> %shl, splat(i64 8589934593)
+  ret <2 x i1> %cmp
+}

@dtcxzyw dtcxzyw linked an issue May 20, 2024 that may be closed by this pull request
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request May 20, 2024
Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
if (Shl->hasOneUse() && Amt != 0 && DL.isLegalInteger(TypeBits - Amt)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using shouldChangeType(ShType, TruncTy) to account for vecs? (Or the case that both are illegal)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using shouldChangeType(ShType, TruncTy) to account for vecs? (Or the case that both are illegal)

It breaks some vec tests. So I decide to use shouldChangeType(OldWidth, NewWidth).

Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(
Pred, Builder.CreateTrunc(X, TruncTy, "", Shl->hasNoSignedWrap()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing propagation of nuw?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nuw cannot be propagated as we always use ashr here. I don't see the value of fixing this (see the test test_icmp_shl_nuw).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have nuw propagation in your proofs no?


define i1 @src_ult_nuw(i32 %x, i32 %c) {
  %cttz = call i32 @llvm.cttz.i32(i32 %c, i1 true)
  %cond = icmp uge i32 %cttz, 16
  call void @llvm.assume(i1 %cond)

  %shl = shl nuw i32 %x, 16
  %icmp = icmp ult i32 %shl, %c
  ret i1 %icmp
}

define i1 @tgt_ult_nuw(i32 %x, i32 %c) {
  %trunc = trunc i32 %x to i16
  %ashr_c = ashr i32 %c, 16
  %trunc_c = trunc nuw i32 %ashr_c to i16
  %icmp = icmp ult i16 %trunc, %trunc_c
  ret i1 %icmp

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

----------------------------------------
define i1 @src_eq_nuw(i32 %x, i32 %c) {
#0:
  %cttz = cttz i32 %c, 1
  %cond = icmp uge i32 %cttz, 16
  assume i1 %cond
  %shl = shl nuw i32 %x, 16
  %icmp = icmp eq i32 %shl, %c
  ret i1 %icmp
}
=>
define i1 @tgt_eq_nuw(i32 %x, i32 %c) {
#0:
  %trunc = trunc i32 %x to i16
  %ashr_c = ashr i32 %c, 16
  %trunc_c = trunc nuw i32 %ashr_c to i16
  %icmp = icmp eq i16 %trunc, %trunc_c
  ret i1 %icmp
}
Transformation doesn't verify!

ERROR: Target is more poisonous than source

Example:
i32 %x = #x00000000 (0)
i32 %c = #xbfff0000 (3221159936, -1073807360)

Source:
i32 %cttz = #x00000010 (16)
i1 %cond = #x1 (1)
i32 %shl = #x00000000 (0)
i1 %icmp = #x0 (0)

Target:
i16 %trunc = #x0000 (0)
i32 %ashr_c = #xffffbfff (4294950911, -16385)
i16 %trunc_c = poison
i1 %icmp = poison
Source value: #x0 (0)
Target value: poison


----------------------------------------
define i1 @src_ult_nuw(i32 %x, i32 %c) {
#0:
  %cttz = cttz i32 %c, 1
  %cond = icmp uge i32 %cttz, 16
  assume i1 %cond
  %shl = shl nuw i32 %x, 16
  %icmp = icmp ult i32 %shl, %c
  ret i1 %icmp
}
=>
define i1 @tgt_ult_nuw(i32 %x, i32 %c) {
#0:
  %trunc = trunc i32 %x to i16
  %ashr_c = ashr i32 %c, 16
  %trunc_c = trunc nuw i32 %ashr_c to i16
  %icmp = icmp ult i16 %trunc, %trunc_c
  ret i1 %icmp
}
Transformation doesn't verify!

ERROR: Target is more poisonous than source

Example:
i32 %x = #x00000000 (0)
i32 %c = #xbfff0000 (3221159936, -1073807360)

Source:
i32 %cttz = #x00000010 (16)
i1 %cond = #x1 (1)
i32 %shl = #x00000000 (0)
i1 %icmp = #x1 (1)

Target:
i16 %trunc = #x0000 (0)
i32 %ashr_c = #xffffbfff (4294950911, -16385)
i16 %trunc_c = poison
i1 %icmp = poison
Source value: #x1 (1)
Target value: poison


----------------------------------------
define i1 @src_slt_nuw(i32 %x, i32 %c) {
#0:
  %cttz = cttz i32 %c, 1
  %cond = icmp uge i32 %cttz, 16
  assume i1 %cond
  %shl = shl nuw i32 %x, 16
  %icmp = icmp slt i32 %shl, %c
  ret i1 %icmp
}
=>
define i1 @tgt_slt_nuw(i32 %x, i32 %c) {
#0:
  %trunc = trunc i32 %x to i16
  %ashr_c = ashr i32 %c, 16
  %trunc_c = trunc nuw i32 %ashr_c to i16
  %icmp = icmp slt i16 %trunc, %trunc_c
  ret i1 %icmp
}
Transformation doesn't verify!

ERROR: Target is more poisonous than source

Example:
i32 %x = #x00000000 (0)
i32 %c = #xbfff0000 (3221159936, -1073807360)

Source:
i32 %cttz = #x00000010 (16)
i1 %cond = #x1 (1)
i32 %shl = #x00000000 (0)
i1 %icmp = #x0 (0)

Target:
i16 %trunc = #x0000 (0)
i32 %ashr_c = #xffffbfff (4294950911, -16385)
i16 %trunc_c = poison
i1 %icmp = poison
Source value: #x0 (0)
Target value: poison

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't have nsw and have nuw, you can just use lshr and propagate nuw? If you have both you have to drop either nuw or nsw.
https://alive2.llvm.org/ce/z/9tjTDb

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't have nsw and have nuw, you can just use lshr and propagate nuw? If you have both you have to drop either nuw or nsw. https://alive2.llvm.org/ce/z/9tjTDb

Yeah, I don't know which one is better when we have both flags :(
BTW, as the motivation cases (see dtcxzyw/llvm-opt-benchmark#616) have no flags on shl, I think it is ok to just keep it simple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't have nsw and have nuw, you can just use lshr and propagate nuw? If you have both you have to drop either nuw or nsw. https://alive2.llvm.org/ce/z/9tjTDb

Yeah, I don't know which one is better when we have both flags :(

I guess knowledge about 1 extra bit vs more knowledge about 1 less bit. For my money id stick w/ nuw.

BTW, as the motivation cases (see dtcxzyw/llvm-opt-benchmark#616) have no flags on shl, I think it is ok to just keep it simple.

IMO its not a big increase in complexity, and since we regularly have missed-optimizations related to throwing away info, think its better to just handle right from the start.

; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 3, i32 3>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%shl = shl <2 x i64> %x, splat(i64 32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't know about the splat syntax, thats nice :)

ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
if (Shl->hasOneUse() && Amt != 0 &&
shouldChangeType(ShType->getScalarSizeInBits(), TypeBits - Amt)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To account for Vec types should construct ShType->getWithNewBitWidth(TypeBits - Amt) and use that + ShType for this check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#92773 (comment)

It breaks some vec tests.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 5c7c1f6 into llvm:main May 28, 2024
4 checks passed
@dtcxzyw dtcxzyw deleted the perf/fold-icmp-shl-to-icmp-trunc branch May 28, 2024 04:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update diff May 10th 2024, 1:32:19 pm
4 participants