Skip to content

Commit 182a65a

Browse files
authored
[RISCV] Refactor performCONCAT_VECTORSCombine. NFC (#69068)
Instead of doing a forward pass for positive strides and a reverse pass for negative strides, we can just do one pass by negating the offset if the pointers do happen to be in reverse order. We can extend getPtrDiff later in #68726 to handle more constant offset sequences.
1 parent 1e8ab99 commit 182a65a

File tree

1 file changed

+34
-58
lines changed

1 file changed

+34
-58
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1378513785
return SDValue();
1378613786

1378713787
EVT BaseLdVT = BaseLd->getValueType(0);
13788-
SDValue BasePtr = BaseLd->getBasePtr();
1378913788

1379013789
// Go through the loads and check that they're strided
13791-
SmallVector<SDValue> Ptrs;
13792-
Ptrs.push_back(BasePtr);
13790+
SmallVector<LoadSDNode *> Lds;
13791+
Lds.push_back(BaseLd);
1379313792
Align Align = BaseLd->getAlign();
1379413793
for (SDValue Op : N->ops().drop_front()) {
1379513794
auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,60 +13797,38 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1379813797
Ld->getValueType(0) != BaseLdVT)
1379913798
return SDValue();
1380013799

13801-
Ptrs.push_back(Ld->getBasePtr());
13800+
Lds.push_back(Ld);
1380213801

1380313802
// The common alignment is the most restrictive (smallest) of all the loads
1380413803
Align = std::min(Align, Ld->getAlign());
1380513804
}
1380613805

13807-
auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
13808-
SDValue Stride;
13809-
for (auto Idx : enumerate(Ptrs)) {
13810-
if (Idx.index() == 0)
13811-
continue;
13812-
SDValue Ptr = Idx.value();
13813-
// Check that each load's pointer is (add LastPtr, Stride)
13814-
if (Ptr.getOpcode() != ISD::ADD ||
13815-
Ptr.getOperand(0) != Ptrs[Idx.index()-1])
13816-
return SDValue();
13817-
SDValue Offset = Ptr.getOperand(1);
13818-
if (!Stride)
13819-
Stride = Offset;
13820-
else if (Offset != Stride)
13821-
return SDValue();
13822-
}
13823-
return Stride;
13824-
};
13825-
auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
13826-
SDValue Stride;
13827-
for (auto Idx : enumerate(Ptrs)) {
13828-
if (Idx.index() == Ptrs.size() - 1)
13829-
continue;
13830-
SDValue Ptr = Idx.value();
13831-
// Check that each load's pointer is (add NextPtr, Stride)
13832-
if (Ptr.getOpcode() != ISD::ADD ||
13833-
Ptr.getOperand(0) != Ptrs[Idx.index()+1])
13834-
return SDValue();
13835-
SDValue Offset = Ptr.getOperand(1);
13836-
if (!Stride)
13837-
Stride = Offset;
13838-
else if (Offset != Stride)
13839-
return SDValue();
13840-
}
13841-
return Stride;
13806+
using PtrDiff = std::pair<SDValue, bool>;
13807+
auto GetPtrDiff = [](LoadSDNode *Ld1,
13808+
LoadSDNode *Ld2) -> std::optional<PtrDiff> {
13809+
SDValue P1 = Ld1->getBasePtr();
13810+
SDValue P2 = Ld2->getBasePtr();
13811+
if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
13812+
return {{P2.getOperand(1), false}};
13813+
if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
13814+
return {{P1.getOperand(1), true}};
13815+
13816+
return std::nullopt;
1384213817
};
1384313818

13844-
bool Reversed = false;
13845-
SDValue Stride = matchForwardStrided(Ptrs);
13846-
if (!Stride) {
13847-
Stride = matchReverseStrided(Ptrs);
13848-
Reversed = true;
13849-
// TODO: At this point, we've successfully matched a generalized gather
13850-
// load. Maybe we should emit that, and then move the specialized
13851-
// matchers above and below into a DAG combine?
13852-
if (!Stride)
13819+
// Get the distance between the first and second loads
13820+
auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
13821+
if (!BaseDiff)
13822+
return SDValue();
13823+
13824+
// Check all the loads are the same distance apart
13825+
for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
13826+
if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
1385313827
return SDValue();
13854-
}
13828+
13829+
// TODO: At this point, we've successfully matched a generalized gather
13830+
// load. Maybe we should emit that, and then move the specialized
13831+
// matchers above and below into a DAG combine?
1385513832

1385613833
// Get the widened scalar type, e.g. v4i8 -> i64
1385713834
unsigned WideScalarBitWidth =
@@ -13867,26 +13844,25 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1386713844
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
1386813845
return SDValue();
1386913846

13847+
auto [Stride, MustNegateStride] = *BaseDiff;
13848+
if (MustNegateStride)
13849+
Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
13850+
1387013851
SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
1387113852
SDValue IntID =
1387213853
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
1387313854
Subtarget.getXLenVT());
13874-
if (Reversed)
13875-
Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
13855+
1387613856
SDValue AllOneMask =
1387713857
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
1387813858
DAG.getConstant(1, DL, MVT::i1));
1387913859

13880-
SDValue Ops[] = {BaseLd->getChain(),
13881-
IntID,
13882-
DAG.getUNDEF(WideVecVT),
13883-
BasePtr,
13884-
Stride,
13885-
AllOneMask};
13860+
SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT),
13861+
BaseLd->getBasePtr(), Stride, AllOneMask};
1388613862

1388713863
uint64_t MemSize;
1388813864
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
13889-
ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
13865+
ConstStride && ConstStride->getSExtValue() >= 0)
1389013866
// total size = (elsize * n) + (stride - elsize) * (n-1)
1389113867
// = elsize + stride * (n-1)
1389213868
MemSize = WideScalarVT.getSizeInBits() +

0 commit comments

Comments
 (0)