@@ -850,8 +850,123 @@ class InstructionsState {
850
850
static InstructionsState invalid() { return {nullptr, nullptr}; }
851
851
};
852
852
853
+ struct InterchangeableInstruction {
854
+ unsigned Opcode;
855
+ SmallVector<Value *> Ops;
856
+ template <class... ArgTypes>
857
+ InterchangeableInstruction(unsigned Opcode, ArgTypes &&...Args)
858
+ : Opcode(Opcode), Ops{std::forward<decltype(Args)>(Args)...} {}
859
+ };
860
+
861
+ bool operator<(const InterchangeableInstruction &LHS,
862
+ const InterchangeableInstruction &RHS) {
863
+ return LHS.Opcode < RHS.Opcode;
864
+ }
865
+
853
866
} // end anonymous namespace
854
867
868
+ /// \returns a sorted list of interchangeable instructions by instruction opcode
869
+ /// that \p I can be converted to.
870
+ /// e.g.,
871
+ /// x << y -> x * (2^y)
872
+ /// x << 1 -> x * 2
873
+ /// x << 0 -> x * 1 -> x - 0 -> x + 0 -> x & 11...1 -> x | 0
874
+ /// x * 0 -> x & 0
875
+ /// x * -1 -> 0 - x
876
+ /// TODO: support more patterns
877
+ static SmallVector<InterchangeableInstruction>
878
+ getInterchangeableInstruction(Instruction *I) {
879
+ // PII = Possible Interchangeable Instruction
880
+ SmallVector<InterchangeableInstruction> PII;
881
+ unsigned Opcode = I->getOpcode();
882
+ PII.emplace_back(Opcode, I->operands());
883
+ if (!is_contained({Instruction::Shl, Instruction::Mul, Instruction::Sub,
884
+ Instruction::Add},
885
+ Opcode))
886
+ return PII;
887
+ Constant *C;
888
+ if (match(I, m_BinOp(m_Value(), m_Constant(C)))) {
889
+ ConstantInt *V = nullptr;
890
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
891
+ V = CI;
892
+ } else if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
893
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
894
+ V = CI;
895
+ }
896
+ if (!V)
897
+ return PII;
898
+ Value *Op0 = I->getOperand(0);
899
+ Type *Op1Ty = I->getOperand(1)->getType();
900
+ const APInt &Op1Int = V->getValue();
901
+ Constant *Zero =
902
+ ConstantInt::get(Op1Ty, APInt::getZero(Op1Int.getBitWidth()));
903
+ Constant *UnsignedMax =
904
+ ConstantInt::get(Op1Ty, APInt::getMaxValue(Op1Int.getBitWidth()));
905
+ switch (Opcode) {
906
+ case Instruction::Shl: {
907
+ PII.emplace_back(Instruction::Mul, Op0,
908
+ ConstantInt::get(Op1Ty, 1 << Op1Int.getZExtValue()));
909
+ if (Op1Int.isZero()) {
910
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
911
+ PII.emplace_back(Instruction::Add, Op0, Zero);
912
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
913
+ PII.emplace_back(Instruction::Or, Op0, Zero);
914
+ }
915
+ break;
916
+ }
917
+ case Instruction::Mul: {
918
+ if (Op1Int.isOne()) {
919
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
920
+ PII.emplace_back(Instruction::Add, Op0, Zero);
921
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
922
+ PII.emplace_back(Instruction::Or, Op0, Zero);
923
+ } else if (Op1Int.isZero()) {
924
+ PII.emplace_back(Instruction::And, Op0, Zero);
925
+ } else if (Op1Int.isAllOnes()) {
926
+ PII.emplace_back(Instruction::Sub, Zero, Op0);
927
+ }
928
+ break;
929
+ }
930
+ case Instruction::Sub:
931
+ if (Op1Int.isZero()) {
932
+ PII.emplace_back(Instruction::Add, Op0, Zero);
933
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
934
+ PII.emplace_back(Instruction::Or, Op0, Zero);
935
+ }
936
+ break;
937
+ case Instruction::Add:
938
+ if (Op1Int.isZero()) {
939
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
940
+ PII.emplace_back(Instruction::Or, Op0, Zero);
941
+ }
942
+ break;
943
+ }
944
+ }
945
+ // std::set_intersection requires a sorted range.
946
+ sort(PII);
947
+ return PII;
948
+ }
949
+
950
+ /// \returns the Op and operands which \p I convert to.
951
+ static std::pair<Value *, SmallVector<Value *>>
952
+ getInterchangeableInstruction(Instruction *I, Instruction *MainOp,
953
+ Instruction *AltOp) {
954
+ SmallVector<InterchangeableInstruction> IIList =
955
+ getInterchangeableInstruction(I);
956
+ const auto *Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
957
+ return II.Opcode == MainOp->getOpcode();
958
+ });
959
+ if (Iter == IIList.end()) {
960
+ Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
961
+ return II.Opcode == AltOp->getOpcode();
962
+ });
963
+ assert(Iter != IIList.end() &&
964
+ "Cannot find an interchangeable instruction.");
965
+ return std::make_pair(AltOp, Iter->Ops);
966
+ }
967
+ return std::make_pair(MainOp, Iter->Ops);
968
+ }
969
+
855
970
/// \returns true if \p Opcode is allowed as part of the main/alternate
856
971
/// instruction for SLP vectorization.
857
972
///
@@ -965,6 +1080,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
965
1080
return InstructionsState::invalid();
966
1081
}
967
1082
bool AnyPoison = InstCnt != VL.size();
1083
+ // Currently, this is only used for binary ops.
1084
+ // TODO: support all instructions
1085
+ SmallVector<InterchangeableInstruction> InterchangeableOpcode =
1086
+ getInterchangeableInstruction(cast<Instruction>(V));
1087
+ SmallVector<InterchangeableInstruction> AlternateInterchangeableOpcode;
1088
+ auto UpdateInterchangeableOpcode =
1089
+ [](SmallVector<InterchangeableInstruction> &LHS,
1090
+ ArrayRef<InterchangeableInstruction> RHS) {
1091
+ SmallVector<InterchangeableInstruction> NewInterchangeableOpcode;
1092
+ std::set_intersection(LHS.begin(), LHS.end(), RHS.begin(), RHS.end(),
1093
+ std::back_inserter(NewInterchangeableOpcode));
1094
+ if (NewInterchangeableOpcode.empty())
1095
+ return false;
1096
+ LHS.swap(NewInterchangeableOpcode);
1097
+ return true;
1098
+ };
968
1099
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
969
1100
auto *I = dyn_cast<Instruction>(VL[Cnt]);
970
1101
if (!I)
@@ -977,14 +1108,32 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
977
1108
return InstructionsState::invalid();
978
1109
unsigned InstOpcode = I->getOpcode();
979
1110
if (IsBinOp && isa<BinaryOperator>(I)) {
980
- if (InstOpcode == Opcode || InstOpcode == AltOpcode)
1111
+ SmallVector<InterchangeableInstruction> ThisInterchangeableOpcode(
1112
+ getInterchangeableInstruction(I));
1113
+ if (UpdateInterchangeableOpcode(InterchangeableOpcode,
1114
+ ThisInterchangeableOpcode))
981
1115
continue;
982
- if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
983
- isValidForAlternation(Opcode)) {
984
- AltOpcode = InstOpcode;
985
- AltIndex = Cnt;
1116
+ if (AlternateInterchangeableOpcode.empty()) {
1117
+ InterchangeableOpcode.erase(
1118
+ remove_if(InterchangeableOpcode,
1119
+ [](const InterchangeableInstruction &I) {
1120
+ return !isValidForAlternation(I.Opcode);
1121
+ }),
1122
+ InterchangeableOpcode.end());
1123
+ ThisInterchangeableOpcode.erase(
1124
+ remove_if(ThisInterchangeableOpcode,
1125
+ [](const InterchangeableInstruction &I) {
1126
+ return !isValidForAlternation(I.Opcode);
1127
+ }),
1128
+ ThisInterchangeableOpcode.end());
1129
+ if (InterchangeableOpcode.empty() || ThisInterchangeableOpcode.empty())
1130
+ return InstructionsState::invalid();
1131
+ AlternateInterchangeableOpcode.swap(ThisInterchangeableOpcode);
986
1132
continue;
987
1133
}
1134
+ if (UpdateInterchangeableOpcode(AlternateInterchangeableOpcode,
1135
+ ThisInterchangeableOpcode))
1136
+ continue;
988
1137
} else if (IsCastOp && isa<CastInst>(I)) {
989
1138
Value *Op0 = IBase->getOperand(0);
990
1139
Type *Ty0 = Op0->getType();
@@ -1085,6 +1234,24 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1085
1234
return InstructionsState::invalid();
1086
1235
}
1087
1236
1237
+ if (IsBinOp) {
1238
+ auto FindOp = [&](ArrayRef<InterchangeableInstruction> CandidateOp) {
1239
+ for (Value *V : VL) {
1240
+ if (isa<PoisonValue>(V))
1241
+ continue;
1242
+ for (const InterchangeableInstruction &I : CandidateOp)
1243
+ if (cast<Instruction>(V)->getOpcode() == I.Opcode)
1244
+ return cast<Instruction>(V);
1245
+ }
1246
+ llvm_unreachable(
1247
+ "Cannot find the candidate instruction for InstructionsState.");
1248
+ };
1249
+ Instruction *MainOp = FindOp(InterchangeableOpcode);
1250
+ Instruction *AltOp = AlternateInterchangeableOpcode.empty()
1251
+ ? MainOp
1252
+ : FindOp(AlternateInterchangeableOpcode);
1253
+ return InstructionsState(MainOp, AltOp);
1254
+ }
1088
1255
return InstructionsState(cast<Instruction>(V),
1089
1256
cast<Instruction>(VL[AltIndex]));
1090
1257
}
@@ -2416,42 +2583,46 @@ class BoUpSLP {
2416
2583
}
2417
2584
2418
2585
/// Go through the instructions in VL and append their operands.
2419
- void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *VL0) {
2586
+ void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *MainOp,
2587
+ Instruction *AltOp) {
2420
2588
assert(!VL.empty() && "Bad VL");
2421
2589
assert((empty() || VL.size() == getNumLanes()) &&
2422
2590
"Expected same number of lanes");
2423
2591
// IntrinsicInst::isCommutative returns true if swapping the first "two"
2424
2592
// arguments to the intrinsic produces the same result.
2425
2593
constexpr unsigned IntrinsicNumOperands = 2;
2426
- unsigned NumOperands = VL0 ->getNumOperands();
2427
- ArgSize = isa<IntrinsicInst>(VL0 ) ? IntrinsicNumOperands : NumOperands;
2594
+ unsigned NumOperands = MainOp ->getNumOperands();
2595
+ ArgSize = isa<IntrinsicInst>(MainOp ) ? IntrinsicNumOperands : NumOperands;
2428
2596
OpsVec.resize(NumOperands);
2429
2597
unsigned NumLanes = VL.size();
2430
- for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) {
2598
+ for (unsigned OpIdx : seq<unsigned>( NumOperands))
2431
2599
OpsVec[OpIdx].resize(NumLanes);
2432
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
2433
- assert((isa<Instruction>(VL[Lane]) || isa<PoisonValue>(VL[Lane])) &&
2434
- "Expected instruction or poison value");
2435
- // Our tree has just 3 nodes: the root and two operands.
2436
- // It is therefore trivial to get the APO. We only need to check the
2437
- // opcode of VL[Lane] and whether the operand at OpIdx is the LHS or
2438
- // RHS operand. The LHS operand of both add and sub is never attached
2439
- // to an inversese operation in the linearized form, therefore its APO
2440
- // is false. The RHS is true only if VL[Lane] is an inverse operation.
2441
-
2442
- // Since operand reordering is performed on groups of commutative
2443
- // operations or alternating sequences (e.g., +, -), we can safely
2444
- // tell the inverse operations by checking commutativity.
2445
- if (isa<PoisonValue>(VL[Lane])) {
2600
+ for (auto [Lane, V] : enumerate(VL)) {
2601
+ assert((isa<Instruction>(V) || isa<PoisonValue>(V)) &&
2602
+ "Expected instruction or poison value");
2603
+ if (isa<PoisonValue>(V)) {
2604
+ for (unsigned OpIdx : seq<unsigned>(NumOperands))
2446
2605
OpsVec[OpIdx][Lane] = {
2447
- PoisonValue::get(VL0 ->getOperand(OpIdx)->getType()), true,
2606
+ PoisonValue::get(MainOp ->getOperand(OpIdx)->getType()), true,
2448
2607
false};
2449
- continue;
2450
- }
2451
- bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane]));
2608
+ continue;
2609
+ }
2610
+ auto [SelectedOp, Ops] =
2611
+ getInterchangeableInstruction(cast<Instruction>(V), MainOp, AltOp);
2612
+ // Our tree has just 3 nodes: the root and two operands.
2613
+ // It is therefore trivial to get the APO. We only need to check the
2614
+ // opcode of V and whether the operand at OpIdx is the LHS or RHS
2615
+ // operand. The LHS operand of both add and sub is never attached to an
2616
+ // inversese operation in the linearized form, therefore its APO is
2617
+ // false. The RHS is true only if V is an inverse operation.
2618
+
2619
+ // Since operand reordering is performed on groups of commutative
2620
+ // operations or alternating sequences (e.g., +, -), we can safely
2621
+ // tell the inverse operations by checking commutativity.
2622
+ bool IsInverseOperation = !isCommutative(cast<Instruction>(SelectedOp));
2623
+ for (unsigned OpIdx : seq<unsigned>(NumOperands)) {
2452
2624
bool APO = (OpIdx == 0) ? false : IsInverseOperation;
2453
- OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx),
2454
- APO, false};
2625
+ OpsVec[OpIdx][Lane] = {Ops[OpIdx], APO, false};
2455
2626
}
2456
2627
}
2457
2628
}
@@ -2557,11 +2728,12 @@ class BoUpSLP {
2557
2728
2558
2729
public:
2559
2730
/// Initialize with all the operands of the instruction vector \p RootVL.
2560
- VLOperands(ArrayRef<Value *> RootVL, Instruction *VL0, const BoUpSLP &R)
2731
+ VLOperands(ArrayRef<Value *> RootVL, Instruction *MainOp,
2732
+ Instruction *AltOp, const BoUpSLP &R)
2561
2733
: TLI(*R.TLI), DL(*R.DL), SE(*R.SE), R(R),
2562
- L(R.LI->getLoopFor((VL0 ->getParent() ))) {
2734
+ L(R.LI->getLoopFor(MainOp ->getParent())) {
2563
2735
// Append all the operands of RootVL.
2564
- appendOperandsOfVL(RootVL, VL0 );
2736
+ appendOperandsOfVL(RootVL, MainOp, AltOp );
2565
2737
}
2566
2738
2567
2739
/// \Returns a value vector with the operands across all lanes for the
@@ -3351,7 +3523,7 @@ class BoUpSLP {
3351
3523
3352
3524
/// Set this bundle's operand from Scalars.
3353
3525
void setOperand(const BoUpSLP &R, bool RequireReorder = false) {
3354
- VLOperands Ops(Scalars, MainOp, R);
3526
+ VLOperands Ops(Scalars, MainOp, AltOp, R);
3355
3527
if (RequireReorder)
3356
3528
Ops.reorder();
3357
3529
for (unsigned I : seq<unsigned>(MainOp->getNumOperands()))
@@ -8592,7 +8764,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8592
8764
TE->dump());
8593
8765
8594
8766
ValueList Left, Right;
8595
- VLOperands Ops(VL, VL0, *this);
8767
+ VLOperands Ops(VL, VL0, S.getAltOp(), *this);
8596
8768
if (cast<CmpInst>(VL0)->isCommutative()) {
8597
8769
// Commutative predicate - collect + sort operands of the instructions
8598
8770
// so that each side is more likely to have the same opcode.
@@ -15797,7 +15969,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
15797
15969
Value *V = Builder.CreateBinOp(
15798
15970
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
15799
15971
RHS);
15800
- propagateIRFlags(V, E->Scalars, VL0 , It == MinBWs.end());
15972
+ propagateIRFlags(V, E->Scalars, nullptr , It == MinBWs.end());
15801
15973
if (auto *I = dyn_cast<Instruction>(V)) {
15802
15974
V = ::propagateMetadata(I, E->Scalars);
15803
15975
// Drop nuw flags for abs(sub(commutative), true).
0 commit comments