Skip to content

[InstCombine] Fold select of clamped shifts #114797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,75 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
return nullptr;
}

static Instruction *foldSelectWithClampedShift(SelectInst &SI,
InstCombinerImpl &IC,
IRBuilderBase &Builder) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
Type *SelType = SI.getType();
uint64_t BW = SelType->getScalarSizeInBits();

auto MatchClampedShift = [&](Value *V, Value *Amt) -> BinaryOperator * {
Value *X, *Limit;
Instruction *I;

// Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
// --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
// Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
// --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
// iff C >= BW-1
if (match(V, m_OneUse(m_Shift(m_Value(X),
m_UMin(m_Specific(Amt), m_Value(Limit)))))) {
KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
if (KnownLimit.getMinValue().uge(BW - 1))
return cast<BinaryOperator>(V);
}

// Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
// --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
// Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
// --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
// iff Pow2 element width we just demand the amt mask bits.
if (isPowerOf2_64(BW) &&
match(V, m_OneUse(m_Shift(m_Value(X), m_Instruction(I))))) {
KnownBits Known(BW);
APInt DemandedBits = APInt::getLowBitsSet(BW, Log2_64(BW));
if (Value *NewAmt = IC.SimplifyMultipleUseDemandedBits(
I, DemandedBits, Known, /*Depth=*/0,
IC.getSimplifyQuery().getWithInstruction(I)))
return Amt == NewAmt ? cast<BinaryOperator>(V) : nullptr;
}

return nullptr;
};

Value *Amt;
if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_UGT, m_Value(Amt),
m_SpecificInt(BW - 1)))) {
if (BinaryOperator *ShiftI = MatchClampedShift(FalseVal, Amt)) {
Amt = Builder.CreateFreeze(Amt);
return SelectInst::Create(
Builder.CreateICmpUGT(Amt, cast<Instruction>(CondVal)->getOperand(1)),
TrueVal,
Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt));
}
}

if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(Amt),
m_SpecificInt(BW)))) {
if (BinaryOperator *ShiftI = MatchClampedShift(TrueVal, Amt)) {
Amt = Builder.CreateFreeze(Amt);
return SelectInst::Create(
Builder.CreateICmpULT(Amt, cast<Instruction>(CondVal)->getOperand(1)),
Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt),
FalseVal);
}
}

return nullptr;
}

static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
if (!FI)
Expand Down Expand Up @@ -3871,6 +3940,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldSelectExtConst(SI))
return I;

if (Instruction *I = foldSelectWithClampedShift(SI, *this, Builder))
return I;

if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
return I;

Expand Down
236 changes: 236 additions & 0 deletions llvm/test/Transforms/InstCombine/select-shift-clamp.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -passes=instcombine < %s | FileCheck %s

declare void @use_i17(i17)
declare void @use_i32(i32)

; Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
; --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
; Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
; --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
; iff Pow2 element width and C masks all amt bits.

define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ult_shl_clamp_and_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ult i32 %a1, 32
%m = and i32 %a1, 31
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %s, i32 %a2
ret i32 %r
}

define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ule_ashr_clamp_and_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ule i32 %a1, 31
%m = and i32 %a1, 127
%s = ashr i32 %a0, %m
%r = select i1 %c, i32 %s, i32 %a2
ret i32 %r
}

define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ugt_lshr_clamp_and_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ugt i32 %a1, 31
%m = and i32 %a1, 31
%s = lshr i32 %a0, %m
%r = select i1 %c, i32 %a2, i32 %s
ret i32 %r
}

define i32 @select_uge_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_uge_shl_clamp_and_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp uge i32 %a1, 32
%m = and i32 %a1, 63
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %a2, i32 %s
ret i32 %r
}

; negative test - multiuse
define i32 @select_ule_ashr_clamp_and_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ule_ashr_clamp_and_i32_multiuse(
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 127
; CHECK-NEXT: [[S:%.*]] = ashr i32 [[A0:%.*]], [[M]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
; CHECK-NEXT: call void @use_i32(i32 [[S]])
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ule i32 %a1, 31
%m = and i32 %a1, 127
%s = ashr i32 %a0, %m
%r = select i1 %c, i32 %s, i32 %a2
call void @use_i32(i32 %s)
ret i32 %r
}

; negative test - mask doesn't cover all legal amount bit
define i32 @select_ult_shl_clamp_and_i32_badmask(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ult_shl_clamp_and_i32_badmask(
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 28
; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ult i32 %a1, 32
%m = and i32 %a1, 28
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %s, i32 %a2
ret i32 %r
}

; negative test - non-pow2
define i17 @select_uge_lshr_clamp_and_i17_nonpow2(i17 %a0, i17 %a1, i17 %a2) {
; CHECK-LABEL: @select_uge_lshr_clamp_and_i17_nonpow2(
; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 16
; CHECK-NEXT: [[M:%.*]] = and i17 [[A1]], 255
; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
; CHECK-NEXT: ret i17 [[R]]
;
%c = icmp uge i17 %a1, 17
%m = and i17 %a1, 255
%s = lshr i17 %a0, %m
%r = select i1 %c, i17 %a2, i17 %s
ret i17 %r
}

; Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
; --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
; Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
; --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
; iff C >= BW-1

define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ult_shl_clamp_umin_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1]], 32
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ult i32 %a1, 32
%m = call i32 @llvm.umin.i32(i32 %a1, i32 31)
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %s, i32 %a2
ret i32 %r
}

define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) {
; CHECK-LABEL: @select_ule_ashr_clamp_umin_i17(
; CHECK-NEXT: [[A1:%.*]] = freeze i17 [[A3:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i17 [[A1]], 17
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[TMP1]], i17 [[A2:%.*]]
; CHECK-NEXT: ret i17 [[R]]
;
%c = icmp ule i17 %a1, 16
%m = call i17 @llvm.umin.i17(i17 %a1, i17 17)
%s = ashr i17 %a0, %m
%r = select i1 %c, i17 %s, i17 %a2
ret i17 %r
}

define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32(
; CHECK-NEXT: [[A1:%.*]] = freeze i32 [[A3:%.*]]
; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1]], 31
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ugt i32 %a1, 31
%m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %a2, i32 %s
ret i32 %r
}

define <2 x i32> @select_uge_lshr_clamp_umin_v2i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32> %a2) {
; CHECK-LABEL: @select_uge_lshr_clamp_umin_v2i32(
; CHECK-NEXT: [[A1:%.*]] = freeze <2 x i32> [[A3:%.*]]
; CHECK-NEXT: [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[A1]]
; CHECK-NEXT: [[C:%.*]] = icmp ugt <2 x i32> [[A1]], splat (i32 31)
; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> [[A2:%.*]], <2 x i32> [[S]]
; CHECK-NEXT: ret <2 x i32> [[R]]
;
%c = icmp uge <2 x i32> %a1, <i32 32, i32 32>
%m = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %a1, <2 x i32> <i32 63, i32 31>)
%s = lshr <2 x i32> %a0, %m
%r = select <2 x i1> %c, <2 x i32> %a2, <2 x i32> %s
ret <2 x i32> %r
}

; negative test - multiuse
define i32 @select_ugt_shl_clamp_umin_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32_multiuse(
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 32
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 128)
; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
; CHECK-NEXT: call void @use_i32(i32 [[S]])
; CHECK-NEXT: ret i32 [[R]]
;
%c = icmp ugt i32 %a1, 32
%m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
%s = shl i32 %a0, %m
%r = select i1 %c, i32 %a2, i32 %s
call void @use_i32(i32 %s)
ret i32 %r
}

; negative test - umin limit doesn't cover all legal amounts
define i17 @select_uge_lshr_clamp_umin_i17_badlimit(i17 %a0, i17 %a1, i17 %a2) {
; CHECK-LABEL: @select_uge_lshr_clamp_umin_i17_badlimit(
; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 15
; CHECK-NEXT: [[M:%.*]] = call i17 @llvm.umin.i17(i17 [[A1]], i17 12)
; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
; CHECK-NEXT: ret i17 [[R]]
;
%c = icmp uge i17 %a1, 16
%m = call i17 @llvm.umin.i17(i17 %a1, i17 12)
%s = lshr i17 %a0, %m
%r = select i1 %c, i17 %a2, i17 %s
ret i17 %r
}

define range(i64 0, -9223372036854775807) <4 x i64> @PR109888(<4 x i64> %0) {
; CHECK-LABEL: @PR109888(
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i64> [[TMP1:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw <4 x i64> splat (i64 1), [[TMP0]]
; CHECK-NEXT: [[C:%.*]] = icmp ult <4 x i64> [[TMP0]], splat (i64 64)
; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i64> [[TMP2]], <4 x i64> zeroinitializer
; CHECK-NEXT: ret <4 x i64> [[R]]
;
%c = icmp ult <4 x i64> %0, <i64 64, i64 64, i64 64, i64 64>
%m = and <4 x i64> %0, <i64 63, i64 63, i64 63, i64 63>
%s = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, %m
%r = select <4 x i1> %c, <4 x i64> %s, <4 x i64> zeroinitializer
ret <4 x i64> %r
}