Skip to content

Commit ef33659

Browse files
committed
[AMDGPU] Accept arbitrary sized sources in CalculateByteProvider
This allows working with e.g. v8i8 / v16i8 sources. It is generally useful, but is primarily beneficial when allowing e.g. v8i8s to be passed to branches directly through registers. As such, this is the first in a series of patches to enable that work. However, it effects https://reviews.llvm.org/D155995, so it has been implemented on top of that. Differential Revision: https://reviews.llvm.org/D159036 Change-Id: Idfcb57dacd0c32cab040fe4dd4ac2ec762750664
1 parent 840d0b7 commit ef33659

File tree

6 files changed

+1498
-113
lines changed

6 files changed

+1498
-113
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 124 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -10834,8 +10834,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1083410834
if (Depth >= 6)
1083510835
return std::nullopt;
1083610836

10837-
auto ValueSize = Op.getValueSizeInBits();
10838-
if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
10837+
if (Op.getValueSizeInBits() < 8)
1083910838
return std::nullopt;
1084010839

1084110840
switch (Op->getOpcode()) {
@@ -11126,8 +11125,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1112611125
auto VecIdx = IdxOp->getZExtValue();
1112711126
auto ScalarSize = Op.getScalarValueSizeInBits();
1112811127
if (ScalarSize != 32) {
11129-
if ((VecIdx + 1) * ScalarSize > 32)
11130-
return std::nullopt;
1113111128
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
1113211129
}
1113311130

@@ -11213,9 +11210,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1121311210
int Low16 = PermMask & 0xffff;
1121411211
int Hi16 = (PermMask & 0xffff0000) >> 16;
1121511212

11216-
assert(Op.getValueType().isByteSized());
11217-
assert(OtherOp.getValueType().isByteSized());
11218-
1121911213
auto TempOp = peekThroughBitcasts(Op);
1122011214
auto TempOtherOp = peekThroughBitcasts(OtherOp);
1122111215

@@ -11233,15 +11227,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1123311227
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
1123411228
}
1123511229

11230+
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11231+
unsigned DWordOffset) {
11232+
SDValue Ret;
11233+
if (Src.getValueSizeInBits() <= 32)
11234+
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11235+
11236+
if (Src.getValueSizeInBits() >= 256) {
11237+
assert(!(Src.getValueSizeInBits() % 32));
11238+
Ret = DAG.getBitcast(
11239+
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11240+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11241+
DAG.getConstant(DWordOffset, SL, MVT::i32));
11242+
}
11243+
11244+
Ret = DAG.getBitcastedAnyExtOrTrunc(
11245+
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11246+
if (DWordOffset) {
11247+
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11248+
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11249+
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11250+
}
11251+
11252+
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11253+
}
11254+
1123611255
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1123711256
SelectionDAG &DAG = DCI.DAG;
1123811257
EVT VT = N->getValueType(0);
11239-
11240-
if (VT != MVT::i32)
11241-
return SDValue();
11258+
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
1124211259

1124311260
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11244-
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11261+
assert(VT == MVT::i32);
1124511262
for (int i = 0; i < 4; i++) {
1124611263
// Find the ByteProvider that provides the ith byte of the result of OR
1124711264
std::optional<ByteProvider<SDValue>> P =
@@ -11255,42 +11272,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1125511272
if (PermNodes.size() != 4)
1125611273
return SDValue();
1125711274

11258-
int FirstSrc = 0;
11259-
std::optional<int> SecondSrc;
11275+
std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
11276+
std::optional<std::pair<unsigned, unsigned>> SecondSrc;
1126011277
uint64_t PermMask = 0x00000000;
1126111278
for (size_t i = 0; i < PermNodes.size(); i++) {
1126211279
auto PermOp = PermNodes[i];
1126311280
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
1126411281
// by sizeof(Src2) = 4
1126511282
int SrcByteAdjust = 4;
1126611283

11267-
if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11268-
if (SecondSrc.has_value())
11269-
if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11284+
// If the Src uses a byte from a different DWORD, then it corresponds
11285+
// with a difference source
11286+
if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11287+
((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11288+
if (SecondSrc)
11289+
if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11290+
((PermOp.SrcOffset / 4) != SecondSrc->second))
1127011291
return SDValue();
1127111292

1127211293
// Set the index of the second distinct Src node
11273-
SecondSrc = i;
11274-
assert(!(PermNodes[*SecondSrc].Src->getValueSizeInBits() % 8));
11294+
SecondSrc = {i, PermNodes[i].SrcOffset / 4};
11295+
assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
1127511296
SrcByteAdjust = 0;
1127611297
}
11277-
assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11298+
assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
1127811299
assert(!DAG.getDataLayout().isBigEndian());
11279-
PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
11300+
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
1128011301
}
11281-
11282-
SDValue Op = *PermNodes[FirstSrc].Src;
11283-
SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11284-
: *PermNodes[FirstSrc].Src;
11285-
11286-
// Check that we haven't just recreated the same FSHR node.
11287-
if (N->getOpcode() == ISD::FSHR &&
11288-
(N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11289-
(N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11290-
return SDValue();
11302+
SDLoc DL(N);
11303+
SDValue Op = *PermNodes[FirstSrc.first].Src;
11304+
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
11305+
assert(Op.getValueSizeInBits() == 32);
1129111306

1129211307
// Check that we are not just extracting the bytes in order from an op
11293-
if (Op == OtherOp && Op.getValueSizeInBits() == 32) {
11308+
if (!SecondSrc) {
1129411309
int Low16 = PermMask & 0xffff;
1129511310
int Hi16 = (PermMask & 0xffff0000) >> 16;
1129611311

@@ -11302,8 +11317,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1130211317
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1130311318
}
1130411319

11320+
SDValue OtherOp =
11321+
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11322+
11323+
if (SecondSrc)
11324+
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11325+
11326+
assert(Op.getValueSizeInBits() == 32);
11327+
1130511328
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11306-
SDLoc DL(N);
11329+
1130711330
assert(Op.getValueType().isByteSized() &&
1130811331
OtherOp.getValueType().isByteSized());
1130911332

@@ -11318,7 +11341,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1131811341
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
1131911342
DAG.getConstant(PermMask, DL, MVT::i32));
1132011343
}
11321-
1132211344
return SDValue();
1132311345
}
1132411346

@@ -12794,17 +12816,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
1279412816
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
1279512817
}
1279612818

12819+
struct DotSrc {
12820+
SDValue SrcOp;
12821+
int64_t PermMask;
12822+
int64_t DWordOffset;
12823+
};
12824+
1279712825
static void placeSources(ByteProvider<SDValue> &Src0,
1279812826
ByteProvider<SDValue> &Src1,
12799-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12800-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12801-
int Step) {
12827+
SmallVectorImpl<DotSrc> &Src0s,
12828+
SmallVectorImpl<DotSrc> &Src1s, int Step) {
1280212829

1280312830
assert(Src0.Src.has_value() && Src1.Src.has_value());
1280412831
// Src0s and Src1s are empty, just place arbitrarily.
1280512832
if (Step == 0) {
12806-
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12807-
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12833+
Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
12834+
Src0.SrcOffset / 4});
12835+
Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
12836+
Src1.SrcOffset / 4});
1280812837
return;
1280912838
}
1281012839

@@ -12817,38 +12846,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1281712846
unsigned FMask = 0xFF << (8 * (3 - Step));
1281812847

1281912848
unsigned FirstMask =
12820-
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12849+
(BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1282112850
unsigned SecondMask =
12822-
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12851+
(BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1282312852
// Attempt to find Src vector which contains our SDValue, if so, add our
1282412853
// perm mask to the existing one. If we are unable to find a match for the
1282512854
// first SDValue, attempt to find match for the second.
1282612855
int FirstGroup = -1;
1282712856
for (int I = 0; I < 2; I++) {
12828-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12829-
I == 0 ? Src0s : Src1s;
12830-
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12831-
return IterElt.first == *BPP.first.Src;
12857+
SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
12858+
auto MatchesFirst = [&BPP](DotSrc &IterElt) {
12859+
return IterElt.SrcOp == *BPP.first.Src &&
12860+
(IterElt.DWordOffset == (BPP.first.SrcOffset / 4));
1283212861
};
1283312862

1283412863
auto Match = llvm::find_if(Srcs, MatchesFirst);
1283512864
if (Match != Srcs.end()) {
12836-
Match->second = addPermMasks(FirstMask, Match->second);
12865+
Match->PermMask = addPermMasks(FirstMask, Match->PermMask);
1283712866
FirstGroup = I;
1283812867
break;
1283912868
}
1284012869
}
1284112870
if (FirstGroup != -1) {
12842-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12843-
FirstGroup == 1 ? Src0s : Src1s;
12844-
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12845-
return IterElt.first == *BPP.second.Src;
12871+
SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
12872+
auto MatchesSecond = [&BPP](DotSrc &IterElt) {
12873+
return IterElt.SrcOp == *BPP.second.Src &&
12874+
(IterElt.DWordOffset == (BPP.second.SrcOffset / 4));
1284612875
};
1284712876
auto Match = llvm::find_if(Srcs, MatchesSecond);
1284812877
if (Match != Srcs.end()) {
12849-
Match->second = addPermMasks(SecondMask, Match->second);
12878+
Match->PermMask = addPermMasks(SecondMask, Match->PermMask);
1285012879
} else
12851-
Srcs.push_back({*BPP.second.Src, SecondMask});
12880+
Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4});
1285212881
return;
1285312882
}
1285412883
}
@@ -12860,29 +12889,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1286012889
unsigned FMask = 0xFF << (8 * (3 - Step));
1286112890

1286212891
Src0s.push_back(
12863-
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12892+
{*Src0.Src,
12893+
((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12894+
Src1.SrcOffset / 4});
1286412895
Src1s.push_back(
12865-
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12896+
{*Src1.Src,
12897+
((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12898+
Src1.SrcOffset / 4});
1286612899

1286712900
return;
1286812901
}
1286912902

12870-
static SDValue
12871-
resolveSources(SelectionDAG &DAG, SDLoc SL,
12872-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12873-
bool IsSigned, bool IsAny) {
12903+
static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
12904+
SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
12905+
bool IsAny) {
1287412906

1287512907
// If we just have one source, just permute it accordingly.
1287612908
if (Srcs.size() == 1) {
1287712909
auto Elt = Srcs.begin();
12878-
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
12910+
auto EltOp = getDWordFromOffset(DAG, SL, Elt->SrcOp, Elt->DWordOffset);
1287912911

12880-
// v_perm will produce the original value.
12881-
if (Elt->second == 0x3020100)
12882-
return EltVal;
12912+
// v_perm will produce the original value
12913+
if (Elt->PermMask == 0x3020100)
12914+
return EltOp;
1288312915

12884-
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12885-
DAG.getConstant(Elt->second, SL, MVT::i32));
12916+
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
12917+
DAG.getConstant(Elt->PermMask, SL, MVT::i32));
1288612918
}
1288712919

1288812920
auto FirstElt = Srcs.begin();
@@ -12893,8 +12925,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1289312925
// If we have multiple sources in the chain, combine them via perms (using
1289412926
// calculated perm mask) and Ors.
1289512927
while (true) {
12896-
auto FirstMask = FirstElt->second;
12897-
auto SecondMask = SecondElt->second;
12928+
auto FirstMask = FirstElt->PermMask;
12929+
auto SecondMask = SecondElt->PermMask;
1289812930

1289912931
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
1290012932
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -12904,9 +12936,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1290412936

1290512937
auto PermMask = addPermMasks(FirstMask, SecondMask);
1290612938
auto FirstVal =
12907-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12939+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1290812940
auto SecondVal =
12909-
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
12941+
getDWordFromOffset(DAG, SL, SecondElt->SrcOp, SecondElt->DWordOffset);
1291012942

1291112943
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
1291212944
SecondVal,
@@ -12920,12 +12952,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1292012952
// If we only have a FirstElt, then just combine that into the cumulative
1292112953
// source node.
1292212954
if (SecondElt == Srcs.end()) {
12923-
auto EltVal =
12924-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12955+
auto EltOp =
12956+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1292512957

1292612958
Perms.push_back(
12927-
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12928-
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
12959+
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
12960+
DAG.getConstant(FirstElt->PermMask, SL, MVT::i32)));
1292912961
break;
1293012962
}
1293112963
}
@@ -12936,9 +12968,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1293612968
: Perms[0];
1293712969
}
1293812970

12939-
static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12940-
unsigned ChainLength) {
12941-
for (auto &[EntryVal, EntryMask] : Srcs) {
12971+
static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
12972+
for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
1294212973
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
1294312974
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
1294412975
EntryMask += ZeroMask;
@@ -13003,8 +13034,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1300313034
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
1300413035
SDValue TempNode(N, 0);
1300513036
std::optional<bool> IsSigned;
13006-
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
13007-
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
13037+
SmallVector<DotSrc, 4> Src0s;
13038+
SmallVector<DotSrc, 4> Src1s;
1300813039
SmallVector<SDValue, 4> Src2s;
1300913040

1301013041
// Match the v_dot4 tree, while collecting src nodes.
@@ -13082,11 +13113,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1308213113
// (commutation).
1308313114
bool UseOriginalSrc = false;
1308413115
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13085-
Src0s.begin()->second == Src1s.begin()->second &&
13086-
Src0s.begin()->first.getValueSizeInBits() == 32 &&
13087-
Src1s.begin()->first.getValueSizeInBits() == 32) {
13116+
Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13117+
Src0s.begin()->SrcOp.getValueSizeInBits() >= 32 &&
13118+
Src1s.begin()->SrcOp.getValueSizeInBits() >= 32) {
1308813119
SmallVector<unsigned, 4> SrcBytes;
13089-
auto Src0Mask = Src0s.begin()->second;
13120+
auto Src0Mask = Src0s.begin()->PermMask;
1309013121
SrcBytes.push_back(Src0Mask & 0xFF000000);
1309113122
bool UniqueEntries = true;
1309213123
for (auto I = 1; I < 4; I++) {
@@ -13101,11 +13132,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1310113132

1310213133
if (UniqueEntries) {
1310313134
UseOriginalSrc = true;
13104-
// Must be 32 bits to enter above conditional.
13105-
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13106-
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13107-
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13108-
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13135+
13136+
auto FirstElt = Src0s.begin();
13137+
auto FirstEltOp =
13138+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13139+
13140+
auto SecondElt = Src1s.begin();
13141+
auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13142+
SecondElt->DWordOffset);
13143+
13144+
Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13145+
MVT::getIntegerVT(32));
13146+
Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13147+
MVT::getIntegerVT(32));
1310913148
}
1311013149
}
1311113150

0 commit comments

Comments
 (0)