@@ -842,8 +842,123 @@ class InstructionsState {
842
842
static InstructionsState invalid() { return {nullptr, nullptr}; }
843
843
};
844
844
845
+ struct InterchangeableInstruction {
846
+ unsigned Opcode;
847
+ SmallVector<Value *> Ops;
848
+ template <class... ArgTypes>
849
+ InterchangeableInstruction(unsigned Opcode, ArgTypes &&...Args)
850
+ : Opcode(Opcode), Ops{std::forward<decltype(Args)>(Args)...} {}
851
+ };
852
+
853
+ bool operator<(const InterchangeableInstruction &LHS,
854
+ const InterchangeableInstruction &RHS) {
855
+ return LHS.Opcode < RHS.Opcode;
856
+ }
857
+
845
858
} // end anonymous namespace
846
859
860
+ /// \returns a sorted list of interchangeable instructions by instruction opcode
861
+ /// that \p I can be converted to.
862
+ /// e.g.,
863
+ /// x << y -> x * (2^y)
864
+ /// x << 1 -> x * 2
865
+ /// x << 0 -> x * 1 -> x - 0 -> x + 0 -> x & 11...1 -> x | 0
866
+ /// x * 0 -> x & 0
867
+ /// x * -1 -> 0 - x
868
+ /// TODO: support more patterns
869
+ static SmallVector<InterchangeableInstruction>
870
+ getInterchangeableInstruction(Instruction *I) {
871
+ // PII = Possible Interchangeable Instruction
872
+ SmallVector<InterchangeableInstruction> PII;
873
+ unsigned Opcode = I->getOpcode();
874
+ PII.emplace_back(Opcode, I->operands());
875
+ if (!is_contained({Instruction::Shl, Instruction::Mul, Instruction::Sub,
876
+ Instruction::Add},
877
+ Opcode))
878
+ return PII;
879
+ Constant *C;
880
+ if (match(I, m_BinOp(m_Value(), m_Constant(C)))) {
881
+ ConstantInt *V = nullptr;
882
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
883
+ V = CI;
884
+ } else if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
885
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
886
+ V = CI;
887
+ }
888
+ if (!V)
889
+ return PII;
890
+ Value *Op0 = I->getOperand(0);
891
+ Type *Op1Ty = I->getOperand(1)->getType();
892
+ const APInt &Op1Int = V->getValue();
893
+ Constant *Zero =
894
+ ConstantInt::get(Op1Ty, APInt::getZero(Op1Int.getBitWidth()));
895
+ Constant *UnsignedMax =
896
+ ConstantInt::get(Op1Ty, APInt::getMaxValue(Op1Int.getBitWidth()));
897
+ switch (Opcode) {
898
+ case Instruction::Shl: {
899
+ PII.emplace_back(Instruction::Mul, Op0,
900
+ ConstantInt::get(Op1Ty, 1 << Op1Int.getZExtValue()));
901
+ if (Op1Int.isZero()) {
902
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
903
+ PII.emplace_back(Instruction::Add, Op0, Zero);
904
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
905
+ PII.emplace_back(Instruction::Or, Op0, Zero);
906
+ }
907
+ break;
908
+ }
909
+ case Instruction::Mul: {
910
+ if (Op1Int.isOne()) {
911
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
912
+ PII.emplace_back(Instruction::Add, Op0, Zero);
913
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
914
+ PII.emplace_back(Instruction::Or, Op0, Zero);
915
+ } else if (Op1Int.isZero()) {
916
+ PII.emplace_back(Instruction::And, Op0, Zero);
917
+ } else if (Op1Int.isAllOnes()) {
918
+ PII.emplace_back(Instruction::Sub, Zero, Op0);
919
+ }
920
+ break;
921
+ }
922
+ case Instruction::Sub:
923
+ if (Op1Int.isZero()) {
924
+ PII.emplace_back(Instruction::Add, Op0, Zero);
925
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
926
+ PII.emplace_back(Instruction::Or, Op0, Zero);
927
+ }
928
+ break;
929
+ case Instruction::Add:
930
+ if (Op1Int.isZero()) {
931
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
932
+ PII.emplace_back(Instruction::Or, Op0, Zero);
933
+ }
934
+ break;
935
+ }
936
+ }
937
+ // std::set_intersection requires a sorted range.
938
+ sort(PII);
939
+ return PII;
940
+ }
941
+
942
+ /// \returns the Op and operands which \p I convert to.
943
+ static std::pair<Value *, SmallVector<Value *>>
944
+ getInterchangeableInstruction(Instruction *I, Instruction *MainOp,
945
+ Instruction *AltOp) {
946
+ SmallVector<InterchangeableInstruction> IIList =
947
+ getInterchangeableInstruction(I);
948
+ const auto *Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
949
+ return II.Opcode == MainOp->getOpcode();
950
+ });
951
+ if (Iter == IIList.end()) {
952
+ Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
953
+ return II.Opcode == AltOp->getOpcode();
954
+ });
955
+ assert(Iter != IIList.end() &&
956
+ "Cannot find an interchangeable instruction.");
957
+ return std::make_pair(AltOp, Iter->Ops);
958
+ }
959
+ return std::make_pair(MainOp, Iter->Ops);
960
+ }
961
+
847
962
/// \returns true if \p Opcode is allowed as part of the main/alternate
848
963
/// instruction for SLP vectorization.
849
964
///
@@ -957,6 +1072,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
957
1072
return InstructionsState::invalid();
958
1073
}
959
1074
bool AnyPoison = InstCnt != VL.size();
1075
+ // Currently, this is only used for binary ops.
1076
+ // TODO: support all instructions
1077
+ SmallVector<InterchangeableInstruction> InterchangeableOpcode =
1078
+ getInterchangeableInstruction(cast<Instruction>(V));
1079
+ SmallVector<InterchangeableInstruction> AlternateInterchangeableOpcode;
1080
+ auto UpdateInterchangeableOpcode =
1081
+ [](SmallVector<InterchangeableInstruction> &LHS,
1082
+ ArrayRef<InterchangeableInstruction> RHS) {
1083
+ SmallVector<InterchangeableInstruction> NewInterchangeableOpcode;
1084
+ std::set_intersection(LHS.begin(), LHS.end(), RHS.begin(), RHS.end(),
1085
+ std::back_inserter(NewInterchangeableOpcode));
1086
+ if (NewInterchangeableOpcode.empty())
1087
+ return false;
1088
+ LHS.swap(NewInterchangeableOpcode);
1089
+ return true;
1090
+ };
960
1091
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
961
1092
auto *I = dyn_cast<Instruction>(VL[Cnt]);
962
1093
if (!I)
@@ -969,14 +1100,32 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
969
1100
return InstructionsState::invalid();
970
1101
unsigned InstOpcode = I->getOpcode();
971
1102
if (IsBinOp && isa<BinaryOperator>(I)) {
972
- if (InstOpcode == Opcode || InstOpcode == AltOpcode)
1103
+ SmallVector<InterchangeableInstruction> ThisInterchangeableOpcode(
1104
+ getInterchangeableInstruction(I));
1105
+ if (UpdateInterchangeableOpcode(InterchangeableOpcode,
1106
+ ThisInterchangeableOpcode))
973
1107
continue;
974
- if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
975
- isValidForAlternation(Opcode)) {
976
- AltOpcode = InstOpcode;
977
- AltIndex = Cnt;
1108
+ if (AlternateInterchangeableOpcode.empty()) {
1109
+ InterchangeableOpcode.erase(
1110
+ remove_if(InterchangeableOpcode,
1111
+ [](const InterchangeableInstruction &I) {
1112
+ return !isValidForAlternation(I.Opcode);
1113
+ }),
1114
+ InterchangeableOpcode.end());
1115
+ ThisInterchangeableOpcode.erase(
1116
+ remove_if(ThisInterchangeableOpcode,
1117
+ [](const InterchangeableInstruction &I) {
1118
+ return !isValidForAlternation(I.Opcode);
1119
+ }),
1120
+ ThisInterchangeableOpcode.end());
1121
+ if (InterchangeableOpcode.empty() || ThisInterchangeableOpcode.empty())
1122
+ return InstructionsState::invalid();
1123
+ AlternateInterchangeableOpcode.swap(ThisInterchangeableOpcode);
978
1124
continue;
979
1125
}
1126
+ if (UpdateInterchangeableOpcode(AlternateInterchangeableOpcode,
1127
+ ThisInterchangeableOpcode))
1128
+ continue;
980
1129
} else if (IsCastOp && isa<CastInst>(I)) {
981
1130
Value *Op0 = IBase->getOperand(0);
982
1131
Type *Ty0 = Op0->getType();
@@ -1077,6 +1226,24 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1077
1226
return InstructionsState::invalid();
1078
1227
}
1079
1228
1229
+ if (IsBinOp) {
1230
+ auto FindOp = [&](ArrayRef<InterchangeableInstruction> CandidateOp) {
1231
+ for (Value *V : VL) {
1232
+ if (isa<PoisonValue>(V))
1233
+ continue;
1234
+ for (const InterchangeableInstruction &I : CandidateOp)
1235
+ if (cast<Instruction>(V)->getOpcode() == I.Opcode)
1236
+ return cast<Instruction>(V);
1237
+ }
1238
+ llvm_unreachable(
1239
+ "Cannot find the candidate instruction for InstructionsState.");
1240
+ };
1241
+ Instruction *MainOp = FindOp(InterchangeableOpcode);
1242
+ Instruction *AltOp = AlternateInterchangeableOpcode.empty()
1243
+ ? MainOp
1244
+ : FindOp(AlternateInterchangeableOpcode);
1245
+ return InstructionsState(MainOp, AltOp);
1246
+ }
1080
1247
return InstructionsState(cast<Instruction>(V),
1081
1248
cast<Instruction>(VL[AltIndex]));
1082
1249
}
@@ -2407,42 +2574,46 @@ class BoUpSLP {
2407
2574
}
2408
2575
2409
2576
/// Go through the instructions in VL and append their operands.
2410
- void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *VL0) {
2577
+ void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *MainOp,
2578
+ Instruction *AltOp) {
2411
2579
assert(!VL.empty() && "Bad VL");
2412
2580
assert((empty() || VL.size() == getNumLanes()) &&
2413
2581
"Expected same number of lanes");
2414
2582
// IntrinsicInst::isCommutative returns true if swapping the first "two"
2415
2583
// arguments to the intrinsic produces the same result.
2416
2584
constexpr unsigned IntrinsicNumOperands = 2;
2417
- unsigned NumOperands = VL0 ->getNumOperands();
2418
- ArgSize = isa<IntrinsicInst>(VL0 ) ? IntrinsicNumOperands : NumOperands;
2585
+ unsigned NumOperands = MainOp ->getNumOperands();
2586
+ ArgSize = isa<IntrinsicInst>(MainOp ) ? IntrinsicNumOperands : NumOperands;
2419
2587
OpsVec.resize(NumOperands);
2420
2588
unsigned NumLanes = VL.size();
2421
- for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) {
2589
+ for (unsigned OpIdx : seq<unsigned>( NumOperands))
2422
2590
OpsVec[OpIdx].resize(NumLanes);
2423
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
2424
- assert((isa<Instruction>(VL[Lane]) || isa<PoisonValue>(VL[Lane])) &&
2425
- "Expected instruction or poison value");
2426
- // Our tree has just 3 nodes: the root and two operands.
2427
- // It is therefore trivial to get the APO. We only need to check the
2428
- // opcode of VL[Lane] and whether the operand at OpIdx is the LHS or
2429
- // RHS operand. The LHS operand of both add and sub is never attached
2430
- // to an inversese operation in the linearized form, therefore its APO
2431
- // is false. The RHS is true only if VL[Lane] is an inverse operation.
2432
-
2433
- // Since operand reordering is performed on groups of commutative
2434
- // operations or alternating sequences (e.g., +, -), we can safely
2435
- // tell the inverse operations by checking commutativity.
2436
- if (isa<PoisonValue>(VL[Lane])) {
2591
+ for (auto [Lane, V] : enumerate(VL)) {
2592
+ assert((isa<Instruction>(V) || isa<PoisonValue>(V)) &&
2593
+ "Expected instruction or poison value");
2594
+ if (isa<PoisonValue>(V)) {
2595
+ for (unsigned OpIdx : seq<unsigned>(NumOperands))
2437
2596
OpsVec[OpIdx][Lane] = {
2438
- PoisonValue::get(VL0 ->getOperand(OpIdx)->getType()), true,
2597
+ PoisonValue::get(MainOp ->getOperand(OpIdx)->getType()), true,
2439
2598
false};
2440
- continue;
2441
- }
2442
- bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane]));
2599
+ continue;
2600
+ }
2601
+ auto [SelectedOp, Ops] =
2602
+ getInterchangeableInstruction(cast<Instruction>(V), MainOp, AltOp);
2603
+ // Our tree has just 3 nodes: the root and two operands.
2604
+ // It is therefore trivial to get the APO. We only need to check the
2605
+ // opcode of V and whether the operand at OpIdx is the LHS or RHS
2606
+ // operand. The LHS operand of both add and sub is never attached to an
2607
+ // inversese operation in the linearized form, therefore its APO is
2608
+ // false. The RHS is true only if V is an inverse operation.
2609
+
2610
+ // Since operand reordering is performed on groups of commutative
2611
+ // operations or alternating sequences (e.g., +, -), we can safely
2612
+ // tell the inverse operations by checking commutativity.
2613
+ bool IsInverseOperation = !isCommutative(cast<Instruction>(SelectedOp));
2614
+ for (unsigned OpIdx : seq<unsigned>(NumOperands)) {
2443
2615
bool APO = (OpIdx == 0) ? false : IsInverseOperation;
2444
- OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx),
2445
- APO, false};
2616
+ OpsVec[OpIdx][Lane] = {Ops[OpIdx], APO, false};
2446
2617
}
2447
2618
}
2448
2619
}
@@ -2549,11 +2720,12 @@ class BoUpSLP {
2549
2720
2550
2721
public:
2551
2722
/// Initialize with all the operands of the instruction vector \p RootVL.
2552
- VLOperands(ArrayRef<Value *> RootVL, Instruction *VL0, const BoUpSLP &R)
2723
+ VLOperands(ArrayRef<Value *> RootVL, Instruction *MainOp,
2724
+ Instruction *AltOp, const BoUpSLP &R)
2553
2725
: TLI(*R.TLI), DL(*R.DL), SE(*R.SE), R(R),
2554
- L(R.LI->getLoopFor((VL0 ->getParent() ))) {
2726
+ L(R.LI->getLoopFor(MainOp ->getParent())) {
2555
2727
// Append all the operands of RootVL.
2556
- appendOperandsOfVL(RootVL, VL0 );
2728
+ appendOperandsOfVL(RootVL, MainOp, AltOp );
2557
2729
}
2558
2730
2559
2731
/// \Returns a value vector with the operands across all lanes for the
@@ -3345,7 +3517,7 @@ class BoUpSLP {
3345
3517
3346
3518
/// Set this bundle's operand from Scalars.
3347
3519
void setOperand(const BoUpSLP &R, bool RequireReorder = false) {
3348
- VLOperands Ops(Scalars, MainOp, R);
3520
+ VLOperands Ops(Scalars, MainOp, AltOp, R);
3349
3521
if (RequireReorder)
3350
3522
Ops.reorder();
3351
3523
for (unsigned I : seq<unsigned>(MainOp->getNumOperands()))
@@ -8561,7 +8733,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8561
8733
LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n");
8562
8734
8563
8735
ValueList Left, Right;
8564
- VLOperands Ops(VL, VL0, *this);
8736
+ VLOperands Ops(VL, VL0, S.getAltOp(), *this);
8565
8737
if (cast<CmpInst>(VL0)->isCommutative()) {
8566
8738
// Commutative predicate - collect + sort operands of the instructions
8567
8739
// so that each side is more likely to have the same opcode.
@@ -15619,7 +15791,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
15619
15791
Value *V = Builder.CreateBinOp(
15620
15792
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
15621
15793
RHS);
15622
- propagateIRFlags(V, E->Scalars, VL0 , It == MinBWs.end());
15794
+ propagateIRFlags(V, E->Scalars, nullptr , It == MinBWs.end());
15623
15795
if (auto *I = dyn_cast<Instruction>(V)) {
15624
15796
V = ::propagateMetadata(I, E->Scalars);
15625
15797
// Drop nuw flags for abs(sub(commutative), true).
0 commit comments