Skip to content

Commit bf43fff

Browse files
committed
make isBinOpWithConstantInt support left hand side operand
1 parent 0e8d567 commit bf43fff

File tree

6 files changed

+71
-50
lines changed

6 files changed

+71
-50
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -828,20 +828,6 @@ class InterchangeableInstruction {
828828
protected:
829829
Instruction *const MainOp;
830830

831-
/// Return non nullptr if the right operand of I is ConstantInt.
832-
static ConstantInt *isBinOpWithConstantInt(Instruction *I) {
833-
Constant *C;
834-
if (!match(I, m_BinOp(m_Value(), m_Constant(C))))
835-
return nullptr;
836-
if (auto *CI = dyn_cast<ConstantInt>(C))
837-
return CI;
838-
if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
839-
if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
840-
return CI;
841-
}
842-
return nullptr;
843-
}
844-
845831
public:
846832
InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
847833
virtual bool isSame(Instruction *I) {
@@ -867,6 +853,29 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
867853
MaskType Mask = 0b11111111;
868854
MaskType SeenBefore = 0;
869855

856+
/// Return a non-nullptr if either operand of I is a ConstantInt.
857+
static std::pair<ConstantInt *, unsigned>
858+
isBinOpWithConstantInt(Instruction *I) {
859+
unsigned Opcode = I->getOpcode();
860+
unsigned Pos = 1;
861+
Constant *C;
862+
if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
863+
if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
864+
Opcode == Instruction::AShr)
865+
return std::make_pair(nullptr, Pos);
866+
if (!match(I, m_BinOp(m_Constant(C), m_Value())))
867+
return std::make_pair(nullptr, Pos);
868+
Pos = 0;
869+
}
870+
if (auto *CI = dyn_cast<ConstantInt>(C))
871+
return std::make_pair(CI, Pos);
872+
if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
873+
if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
874+
return std::make_pair(CI, Pos);
875+
}
876+
return std::make_pair(nullptr, Pos);
877+
}
878+
870879
static MaskType opcodeToMask(unsigned Opcode) {
871880
switch (Opcode) {
872881
case Instruction::Shl:
@@ -904,26 +913,26 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
904913
if (!binary_search(SupportedOp, Opcode))
905914
return false;
906915
SeenBefore |= opcodeToMask(Opcode);
907-
ConstantInt *CI = isBinOpWithConstantInt(I);
916+
ConstantInt *CI = isBinOpWithConstantInt(I).first;
908917
if (CI) {
909-
const APInt &Op1Int = CI->getValue();
918+
const APInt &CIValue = CI->getValue();
910919
switch (Opcode) {
911920
case Instruction::Shl:
912-
if (Op1Int.isZero())
921+
if (CIValue.isZero())
913922
return true;
914923
return tryAnd(0b101);
915924
case Instruction::Mul:
916-
if (Op1Int.isOne())
925+
if (CIValue.isOne())
917926
return true;
918-
if (Op1Int.isPowerOf2())
927+
if (CIValue.isPowerOf2())
919928
return tryAnd(0b101);
920929
break;
921930
case Instruction::And:
922-
if (Op1Int.isAllOnes())
931+
if (CIValue.isAllOnes())
923932
return true;
924933
break;
925934
default:
926-
if (Op1Int.isZero())
935+
if (CIValue.isZero())
927936
return true;
928937
break;
929938
}
@@ -957,41 +966,48 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
957966
unsigned FromOpcode = MainOp->getOpcode();
958967
if (FromOpcode == ToOpcode)
959968
return SmallVector<Value *>(MainOp->operands());
960-
const APInt &Op1Int = isBinOpWithConstantInt(MainOp)->getValue();
961-
unsigned Op1IntBitWidth = Op1Int.getBitWidth();
962-
APInt RHSV;
969+
auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
970+
const APInt &FromCIValue = CI->getValue();
971+
unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
972+
APInt ToCIValue;
963973
switch (FromOpcode) {
964974
case Instruction::Shl:
965975
if (ToOpcode == Instruction::Mul) {
966-
RHSV = APInt::getOneBitSet(Op1IntBitWidth, Op1Int.getZExtValue());
976+
ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
977+
FromCIValue.getZExtValue());
967978
} else {
968-
assert(Op1Int.isZero() && "Cannot convert the instruction.");
969-
RHSV = ToOpcode == Instruction::And ? APInt::getAllOnes(Op1IntBitWidth)
970-
: APInt::getZero(Op1IntBitWidth);
979+
assert(FromCIValue.isZero() && "Cannot convert the instruction.");
980+
ToCIValue = ToOpcode == Instruction::And
981+
? APInt::getAllOnes(FromCIValueBitWidth)
982+
: APInt::getZero(FromCIValueBitWidth);
971983
}
972984
break;
973985
case Instruction::Mul:
974-
assert(Op1Int.isPowerOf2() && "Cannot convert the instruction.");
986+
assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
975987
if (ToOpcode == Instruction::Shl) {
976-
RHSV = APInt(Op1IntBitWidth, Op1Int.logBase2());
988+
ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
977989
} else {
978-
assert(Op1Int.isOne() && "Cannot convert the instruction.");
979-
RHSV = ToOpcode == Instruction::And ? APInt::getAllOnes(Op1IntBitWidth)
980-
: APInt::getZero(Op1IntBitWidth);
990+
assert(FromCIValue.isOne() && "Cannot convert the instruction.");
991+
ToCIValue = ToOpcode == Instruction::And
992+
? APInt::getAllOnes(FromCIValueBitWidth)
993+
: APInt::getZero(FromCIValueBitWidth);
981994
}
982995
break;
983996
case Instruction::And:
984-
assert(Op1Int.isAllOnes() && "Cannot convert the instruction.");
985-
RHSV = ToOpcode == Instruction::Mul
986-
? APInt::getOneBitSet(Op1IntBitWidth, 0)
987-
: APInt::getZero(Op1IntBitWidth);
997+
assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
998+
ToCIValue = ToOpcode == Instruction::Mul
999+
? APInt::getOneBitSet(FromCIValueBitWidth, 0)
1000+
: APInt::getZero(FromCIValueBitWidth);
9881001
break;
9891002
default:
990-
RHSV = APInt::getZero(Op1IntBitWidth);
1003+
ToCIValue = APInt::getZero(FromCIValueBitWidth);
9911004
break;
9921005
}
993-
return {MainOp->getOperand(0),
994-
ConstantInt::get(MainOp->getOperand(1)->getType(), RHSV)};
1006+
auto LHS = MainOp->getOperand(1 - Pos);
1007+
auto RHS = ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
1008+
if (Pos == 1)
1009+
return SmallVector<Value *>({LHS, RHS});
1010+
return SmallVector<Value *>({RHS, LHS});
9951011
}
9961012
};
9971013

llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,16 @@ define void @h() {
55
; CHECK-LABEL: define void @h() {
66
; CHECK-NEXT: entry:
77
; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr i8, ptr null, i64 16
8-
; CHECK-NEXT: [[ARRAYIDX18:%.*]] = getelementptr i8, ptr null, i64 24
9-
; CHECK-NEXT: store <4 x i16> zeroinitializer, ptr [[ARRAYIDX2]], align 2
10-
; CHECK-NEXT: store <4 x i16> zeroinitializer, ptr [[ARRAYIDX18]], align 2
8+
; CHECK-NEXT: [[TMP0:%.*]] = call <8 x i1> @llvm.vector.insert.v8i1.v2i1(<8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 poison, i1 poison, i1 false, i1 false>, <2 x i1> zeroinitializer, i64 4)
9+
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i1> @llvm.vector.insert.v8i1.v2i1(<8 x i1> <i1 poison, i1 poison, i1 poison, i1 poison, i1 false, i1 false, i1 poison, i1 poison>, <2 x i1> zeroinitializer, i64 0)
10+
; CHECK-NEXT: [[TMP2:%.*]] = call <8 x i1> @llvm.vector.insert.v8i1.v2i1(<8 x i1> [[TMP1]], <2 x i1> zeroinitializer, i64 2)
11+
; CHECK-NEXT: [[TMP3:%.*]] = call <8 x i1> @llvm.vector.insert.v8i1.v2i1(<8 x i1> [[TMP2]], <2 x i1> zeroinitializer, i64 6)
12+
; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i1> [[TMP0]], [[TMP3]]
13+
; CHECK-NEXT: [[TMP5:%.*]] = add <8 x i1> [[TMP0]], [[TMP3]]
14+
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <8 x i1> [[TMP4]], <8 x i1> [[TMP5]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15>
15+
; CHECK-NEXT: [[TMP7:%.*]] = or <8 x i1> [[TMP6]], zeroinitializer
16+
; CHECK-NEXT: [[TMP8:%.*]] = zext <8 x i1> [[TMP7]] to <8 x i16>
17+
; CHECK-NEXT: store <8 x i16> [[TMP8]], ptr [[ARRAYIDX2]], align 2
1118
; CHECK-NEXT: ret void
1219
;
1320
entry:

llvm/test/Transforms/SLPVectorizer/X86/extract-scalar-from-undef.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ define i64 @foo(i32 %tmp7) {
88
; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i32> [[TMP0]], <i32 0, i32 0, i32 poison, i32 0, i32 0, i32 poison, i32 0, i32 poison>
99
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> <i32 0, i32 0, i32 0, i32 0, i32 poison, i32 poison, i32 poison, i32 0>, <8 x i32> [[TMP4]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 14, i32 poison, i32 poison, i32 7>
1010
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <8 x i32> [[TMP2]], i32 0, i32 5
11-
; CHECK-NEXT: [[TMP6:%.*]] = add nsw <8 x i32> [[TMP13]], [[TMP4]]
1211
; CHECK-NEXT: [[TMP5:%.*]] = sub nsw <8 x i32> [[TMP13]], [[TMP4]]
13-
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <8 x i32> [[TMP6]], <8 x i32> [[TMP5]], <8 x i32> <i32 0, i32 9, i32 10, i32 11, i32 4, i32 5, i32 14, i32 15>
12+
; CHECK-NEXT: [[TMP6:%.*]] = add nsw <8 x i32> [[TMP13]], [[TMP4]]
13+
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <8 x i32> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> <i32 8, i32 1, i32 2, i32 3, i32 12, i32 13, i32 6, i32 7>
1414
; CHECK-NEXT: [[TMP8:%.*]] = add <8 x i32> zeroinitializer, [[TMP7]]
1515
; CHECK-NEXT: [[TMP9:%.*]] = xor <8 x i32> [[TMP8]], zeroinitializer
1616
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]])

llvm/test/Transforms/SLPVectorizer/X86/multi-extracts-bv-combined.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ define i32 @foo() {
88
; CHECK-NEXT: [[D:%.*]] = load i32, ptr null, align 4
99
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i32> <i32 0, i32 undef, i32 1, i32 0>, i32 [[D]], i32 1
1010
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[TMP0]], <4 x i32> poison, <8 x i32> <i32 0, i32 1, i32 1, i32 2, i32 3, i32 1, i32 1, i32 1>
11-
; CHECK-NEXT: [[TMP2:%.*]] = or <8 x i32> zeroinitializer, [[TMP1]]
11+
; CHECK-NEXT: [[TMP2:%.*]] = add <8 x i32> zeroinitializer, [[TMP1]]
1212
; CHECK-NEXT: store <8 x i32> [[TMP2]], ptr getelementptr inbounds ([64 x i32], ptr null, i64 0, i64 15), align 4
1313
; CHECK-NEXT: ret i32 0
1414
;

llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ define void @test() {
1414
; CHECK-NEXT: [[TMP9:%.*]] = add <4 x i16> [[TMP7]], [[TMP8]]
1515
; CHECK-NEXT: [[TMP10:%.*]] = sub <4 x i16> [[TMP7]], [[TMP8]]
1616
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
17-
; CHECK-NEXT: [[TMP12:%.*]] = add <4 x i16> zeroinitializer, [[TMP11]]
1817
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> zeroinitializer, [[TMP11]]
19-
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
20-
; CHECK-NEXT: [[TMP15:%.*]] = sext <4 x i16> [[TMP14]] to <4 x i32>
18+
; CHECK-NEXT: [[TMP15:%.*]] = sext <4 x i16> [[TMP13]] to <4 x i32>
2119
; CHECK-NEXT: store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
2220
; CHECK-NEXT: ret void
2321
;

llvm/test/Transforms/SLPVectorizer/shuffle-mask-resized.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ define i32 @test() {
1212
; CHECK-NEXT: br i1 false, label [[BB4:%.*]], label [[BB3]]
1313
; CHECK: bb3:
1414
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> <i32 0, i32 poison>, <2 x i32> <i32 2, i32 1>
15-
; CHECK-NEXT: [[TMP5]] = or <2 x i32> zeroinitializer, [[TMP2]]
15+
; CHECK-NEXT: [[TMP5]] = add <2 x i32> zeroinitializer, [[TMP2]]
1616
; CHECK-NEXT: br label [[BB1]]
1717
; CHECK: bb4:
1818
; CHECK-NEXT: [[TMP6:%.*]] = phi <8 x i32> [ [[TMP1]], [[BB1]] ]

0 commit comments

Comments
 (0)