@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
13785
13785
return SDValue();
13786
13786
13787
13787
EVT BaseLdVT = BaseLd->getValueType(0);
13788
- SDValue BasePtr = BaseLd->getBasePtr();
13789
13788
13790
13789
// Go through the loads and check that they're strided
13791
- SmallVector<SDValue> Ptrs ;
13792
- Ptrs .push_back(BasePtr );
13790
+ SmallVector<LoadSDNode *> Lds ;
13791
+ Lds .push_back(BaseLd );
13793
13792
Align Align = BaseLd->getAlign();
13794
13793
for (SDValue Op : N->ops().drop_front()) {
13795
13794
auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,60 +13797,38 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
13798
13797
Ld->getValueType(0) != BaseLdVT)
13799
13798
return SDValue();
13800
13799
13801
- Ptrs .push_back(Ld->getBasePtr() );
13800
+ Lds .push_back(Ld);
13802
13801
13803
13802
// The common alignment is the most restrictive (smallest) of all the loads
13804
13803
Align = std::min(Align, Ld->getAlign());
13805
13804
}
13806
13805
13807
- auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
13808
- SDValue Stride;
13809
- for (auto Idx : enumerate(Ptrs)) {
13810
- if (Idx.index() == 0)
13811
- continue;
13812
- SDValue Ptr = Idx.value();
13813
- // Check that each load's pointer is (add LastPtr, Stride)
13814
- if (Ptr.getOpcode() != ISD::ADD ||
13815
- Ptr.getOperand(0) != Ptrs[Idx.index()-1])
13816
- return SDValue();
13817
- SDValue Offset = Ptr.getOperand(1);
13818
- if (!Stride)
13819
- Stride = Offset;
13820
- else if (Offset != Stride)
13821
- return SDValue();
13822
- }
13823
- return Stride;
13824
- };
13825
- auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
13826
- SDValue Stride;
13827
- for (auto Idx : enumerate(Ptrs)) {
13828
- if (Idx.index() == Ptrs.size() - 1)
13829
- continue;
13830
- SDValue Ptr = Idx.value();
13831
- // Check that each load's pointer is (add NextPtr, Stride)
13832
- if (Ptr.getOpcode() != ISD::ADD ||
13833
- Ptr.getOperand(0) != Ptrs[Idx.index()+1])
13834
- return SDValue();
13835
- SDValue Offset = Ptr.getOperand(1);
13836
- if (!Stride)
13837
- Stride = Offset;
13838
- else if (Offset != Stride)
13839
- return SDValue();
13840
- }
13841
- return Stride;
13806
+ using PtrDiff = std::pair<SDValue, bool>;
13807
+ auto GetPtrDiff = [](LoadSDNode *Ld1,
13808
+ LoadSDNode *Ld2) -> std::optional<PtrDiff> {
13809
+ SDValue P1 = Ld1->getBasePtr();
13810
+ SDValue P2 = Ld2->getBasePtr();
13811
+ if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
13812
+ return {{P2.getOperand(1), false}};
13813
+ if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
13814
+ return {{P1.getOperand(1), true}};
13815
+
13816
+ return std::nullopt;
13842
13817
};
13843
13818
13844
- bool Reversed = false;
13845
- SDValue Stride = matchForwardStrided(Ptrs);
13846
- if (!Stride) {
13847
- Stride = matchReverseStrided(Ptrs);
13848
- Reversed = true;
13849
- // TODO: At this point, we've successfully matched a generalized gather
13850
- // load. Maybe we should emit that, and then move the specialized
13851
- // matchers above and below into a DAG combine?
13852
- if (!Stride)
13819
+ // Get the distance between the first and second loads
13820
+ auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
13821
+ if (!BaseDiff)
13822
+ return SDValue();
13823
+
13824
+ // Check all the loads are the same distance apart
13825
+ for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
13826
+ if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
13853
13827
return SDValue();
13854
- }
13828
+
13829
+ // TODO: At this point, we've successfully matched a generalized gather
13830
+ // load. Maybe we should emit that, and then move the specialized
13831
+ // matchers above and below into a DAG combine?
13855
13832
13856
13833
// Get the widened scalar type, e.g. v4i8 -> i64
13857
13834
unsigned WideScalarBitWidth =
@@ -13867,26 +13844,25 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
13867
13844
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
13868
13845
return SDValue();
13869
13846
13847
+ auto [Stride, MustNegateStride] = *BaseDiff;
13848
+ if (MustNegateStride)
13849
+ Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
13850
+
13870
13851
SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
13871
13852
SDValue IntID =
13872
13853
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
13873
13854
Subtarget.getXLenVT());
13874
- if (Reversed)
13875
- Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
13855
+
13876
13856
SDValue AllOneMask =
13877
13857
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
13878
13858
DAG.getConstant(1, DL, MVT::i1));
13879
13859
13880
- SDValue Ops[] = {BaseLd->getChain(),
13881
- IntID,
13882
- DAG.getUNDEF(WideVecVT),
13883
- BasePtr,
13884
- Stride,
13885
- AllOneMask};
13860
+ SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT),
13861
+ BaseLd->getBasePtr(), Stride, AllOneMask};
13886
13862
13887
13863
uint64_t MemSize;
13888
13864
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
13889
- ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
13865
+ ConstStride && ConstStride->getSExtValue() >= 0)
13890
13866
// total size = (elsize * n) + (stride - elsize) * (n-1)
13891
13867
// = elsize + stride * (n-1)
13892
13868
MemSize = WideScalarVT.getSizeInBits() +
0 commit comments