Skip to content

[InstCombine] Enable more fabs fold when the user ignores sign bit of zero/NaN #139861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,47 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
return nullptr;
}

/// Return true if the sign bit of result can be ignored when the result is
/// zero.
static bool ignoreSignBitOfZero(Instruction &I) {
if (I.hasNoSignedZeros())
return true;

// Check if the sign bit is ignored by the only user.
if (!I.hasOneUse())
return false;
Instruction *User = I.user_back();

// fcmp treats both positive and negative zero as equal.
if (User->getOpcode() == Instruction::FCmp)
return true;

if (auto *FPOp = dyn_cast<FPMathOperator>(User))
return FPOp->hasNoSignedZeros();

return false;
}

/// Return true if the sign bit of result can be ignored when the result is NaN.
static bool ignoreSignBitOfNaN(Instruction &I) {
if (I.hasNoNaNs())
return true;

// Check if the sign bit is ignored by the only user.
if (!I.hasOneUse())
return false;
Instruction *User = I.user_back();

// fcmp ignores the sign bit of NaN.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All proper floating point instructions kind of ignore the sign bit of a nan, this is just one particular instance. Eventually we should have a utility function to identify all potentially canonicalizing instructions which we can ignore the nan sign bit from

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have generalized these two helpers to handle more FP ops/intrinsics. We may move them into ValueTracking in the future.

if (User->getOpcode() == Instruction::FCmp)
return true;

if (auto *FPOp = dyn_cast<FPMathOperator>(User))
return FPOp->hasNoNaNs();

return false;
}

// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
// fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work.
static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
Expand All @@ -2797,7 +2838,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
// of NAN, but IEEE-754 specifies the signbit of NAN values with
// fneg/fabs operations.
if (match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X))) &&
(cast<FPMathOperator>(CondVal)->hasNoNaNs() || SI.hasNoNaNs() ||
(cast<FPMathOperator>(CondVal)->hasNoNaNs() || ignoreSignBitOfNaN(SI) ||
isKnownNeverNaN(X, /*Depth=*/0,
IC.getSimplifyQuery().getWithInstruction(
cast<Instruction>(CondVal))))) {
Expand Down Expand Up @@ -2844,7 +2885,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
// Note: We require "nnan" for this fold because fcmp ignores the signbit
// of NAN, but IEEE-754 specifies the signbit of NAN values with
// fneg/fabs operations.
if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs())
if (!ignoreSignBitOfZero(SI) || !ignoreSignBitOfNaN(SI))
return nullptr;

if (Swap)
Expand Down
101 changes: 101 additions & 0 deletions llvm/test/Transforms/InstCombine/fabs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1276,3 +1276,104 @@ define <2 x float> @test_select_neg_negx_x_wrong_type(<2 x float> %value) {
%value.addr.0.i = select i1 %a1, <2 x float> %fneg.i, <2 x float> %value
ret <2 x float> %value.addr.0.i
}

define i1 @test_fabs_used_by_fcmp(float %x, float %y) {
; CHECK-LABEL: @test_fabs_used_by_fcmp(
; CHECK-NEXT: [[SEL:%.*]] = call float @llvm.fabs.f32(float [[X:%.*]])
; CHECK-NEXT: [[CMP2:%.*]] = fcmp olt float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP2]]
;
%cmp = fcmp oge float %x, 0.000000e+00
%neg = fneg float %x
%sel = select i1 %cmp, float %x, float %neg
%cmp2 = fcmp olt float %sel, %y
ret i1 %cmp2
}

define float @test_fabs_used_by_fpop_nnan_nsz(float %x, float %y) {
; CHECK-LABEL: @test_fabs_used_by_fpop_nnan_nsz(
; CHECK-NEXT: [[SEL:%.*]] = call float @llvm.fabs.f32(float [[X:%.*]])
; CHECK-NEXT: [[ADD:%.*]] = fadd nnan nsz float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret float [[ADD]]
;
%cmp = fcmp oge float %x, 0.000000e+00
%neg = fneg float %x
%sel = select i1 %cmp, float %x, float %neg
%add = fadd nnan nsz float %sel, %y
ret float %add
}

define i1 @test_fabs_fsub_used_by_fcmp(float %x, float %y) {
; CHECK-LABEL: @test_fabs_fsub_used_by_fcmp(
; CHECK-NEXT: [[SEL:%.*]] = call float @llvm.fabs.f32(float [[X:%.*]])
; CHECK-NEXT: [[CMP2:%.*]] = fcmp olt float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP2]]
;
%cmp = fcmp ogt float %x, 0.000000e+00
%neg = fsub float 0.000000e+00, %x
%sel = select i1 %cmp, float %x, float %neg
%cmp2 = fcmp olt float %sel, %y
ret i1 %cmp2
}

define float @test_fabs_fsub_used_by_fpop_nnan(float %x, float %y) {
; CHECK-LABEL: @test_fabs_fsub_used_by_fpop_nnan(
; CHECK-NEXT: [[SEL:%.*]] = call float @llvm.fabs.f32(float [[X:%.*]])
; CHECK-NEXT: [[ADD:%.*]] = fadd nnan float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret float [[ADD]]
;
%cmp = fcmp ogt float %x, 0.000000e+00
%neg = fsub float 0.000000e+00, %x
%sel = select i1 %cmp, float %x, float %neg
%add = fadd nnan float %sel, %y
ret float %add
}

; Negative tests

define float @test_fabs_used_by_fpop_nnan(float %x, float %y) {
; CHECK-LABEL: @test_fabs_used_by_fpop_nnan(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge float [[X:%.*]], 0.000000e+00
; CHECK-NEXT: [[NEG:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], float [[X]], float [[NEG]]
; CHECK-NEXT: [[ADD:%.*]] = fadd nnan float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret float [[ADD]]
;
%cmp = fcmp oge float %x, 0.000000e+00
%neg = fneg float %x
%sel = select i1 %cmp, float %x, float %neg
%add = fadd nnan float %sel, %y
ret float %add
}

define float @test_fabs_used_by_fpop_nsz(float %x, float %y) {
; CHECK-LABEL: @test_fabs_used_by_fpop_nsz(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge float [[X:%.*]], 0.000000e+00
; CHECK-NEXT: [[NEG:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], float [[X]], float [[NEG]]
; CHECK-NEXT: [[ADD:%.*]] = fadd nsz float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: ret float [[ADD]]
;
%cmp = fcmp oge float %x, 0.000000e+00
%neg = fneg float %x
%sel = select i1 %cmp, float %x, float %neg
%add = fadd nsz float %sel, %y
ret float %add
}

define i1 @test_fabs_used_by_fcmp_multiuse(float %x, float %y) {
; CHECK-LABEL: @test_fabs_used_by_fcmp_multiuse(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge float [[X:%.*]], 0.000000e+00
; CHECK-NEXT: [[NEG:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], float [[X]], float [[NEG]]
; CHECK-NEXT: [[CMP2:%.*]] = fcmp olt float [[SEL]], [[Y:%.*]]
; CHECK-NEXT: call void @use(float [[SEL]])
; CHECK-NEXT: ret i1 [[CMP2]]
;
%cmp = fcmp oge float %x, 0.000000e+00
%neg = fneg float %x
%sel = select i1 %cmp, float %x, float %neg
%cmp2 = fcmp olt float %sel, %y
call void @use(float %sel)
ret i1 %cmp2
}
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/fneg.ll
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ define float @select_common_op_fneg_false(float %x, i1 %b) {

define float @fabs(float %a) {
; CHECK-LABEL: @fabs(
; CHECK-NEXT: [[FNEG1:%.*]] = call nnan ninf nsz float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: [[FNEG1:%.*]] = call float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: ret float [[FNEG1]]
;
%fneg = fneg float %a
Expand All @@ -721,7 +721,7 @@ define float @fabs(float %a) {

define float @fnabs(float %a) {
; CHECK-LABEL: @fnabs(
; CHECK-NEXT: [[TMP1:%.*]] = call fast float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: [[TMP1:%.*]] = call float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: [[FNEG1:%.*]] = fneg fast float [[TMP1]]
; CHECK-NEXT: ret float [[FNEG1]]
;
Expand All @@ -734,7 +734,7 @@ define float @fnabs(float %a) {

define float @fnabs_1(float %a) {
; CHECK-LABEL: @fnabs_1(
; CHECK-NEXT: [[TMP1:%.*]] = call fast float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: [[TMP1:%.*]] = call float @llvm.fabs.f32(float [[A:%.*]])
; CHECK-NEXT: [[FNEG1:%.*]] = fneg fast float [[TMP1]]
; CHECK-NEXT: ret float [[FNEG1]]
;
Expand Down