@@ -10834,8 +10834,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10834
10834
if (Depth >= 6)
10835
10835
return std::nullopt;
10836
10836
10837
- auto ValueSize = Op.getValueSizeInBits();
10838
- if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
10837
+ if (Op.getValueSizeInBits() < 8)
10839
10838
return std::nullopt;
10840
10839
10841
10840
switch (Op->getOpcode()) {
@@ -11126,8 +11125,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11126
11125
auto VecIdx = IdxOp->getZExtValue();
11127
11126
auto ScalarSize = Op.getScalarValueSizeInBits();
11128
11127
if (ScalarSize != 32) {
11129
- if ((VecIdx + 1) * ScalarSize > 32)
11130
- return std::nullopt;
11131
11128
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
11132
11129
}
11133
11130
@@ -11213,9 +11210,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11213
11210
int Low16 = PermMask & 0xffff;
11214
11211
int Hi16 = (PermMask & 0xffff0000) >> 16;
11215
11212
11216
- assert(Op.getValueType().isByteSized());
11217
- assert(OtherOp.getValueType().isByteSized());
11218
-
11219
11213
auto TempOp = peekThroughBitcasts(Op);
11220
11214
auto TempOtherOp = peekThroughBitcasts(OtherOp);
11221
11215
@@ -11233,15 +11227,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11233
11227
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
11234
11228
}
11235
11229
11230
+ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11231
+ unsigned DWordOffset) {
11232
+ SDValue Ret;
11233
+ if (Src.getValueSizeInBits() <= 32)
11234
+ return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11235
+
11236
+ if (Src.getValueSizeInBits() >= 256) {
11237
+ assert(!(Src.getValueSizeInBits() % 32));
11238
+ Ret = DAG.getBitcast(
11239
+ MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11240
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11241
+ DAG.getConstant(DWordOffset, SL, MVT::i32));
11242
+ }
11243
+
11244
+ Ret = DAG.getBitcastedAnyExtOrTrunc(
11245
+ Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11246
+ if (DWordOffset) {
11247
+ auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11248
+ DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11249
+ return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11250
+ }
11251
+
11252
+ return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11253
+ }
11254
+
11236
11255
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11237
11256
SelectionDAG &DAG = DCI.DAG;
11238
11257
EVT VT = N->getValueType(0);
11239
-
11240
- if (VT != MVT::i32)
11241
- return SDValue();
11258
+ SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11242
11259
11243
11260
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11244
- SmallVector<ByteProvider<SDValue>, 8> PermNodes ;
11261
+ assert(VT == MVT::i32) ;
11245
11262
for (int i = 0; i < 4; i++) {
11246
11263
// Find the ByteProvider that provides the ith byte of the result of OR
11247
11264
std::optional<ByteProvider<SDValue>> P =
@@ -11255,42 +11272,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11255
11272
if (PermNodes.size() != 4)
11256
11273
return SDValue();
11257
11274
11258
- int FirstSrc = 0 ;
11259
- std::optional<int > SecondSrc;
11275
+ std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4) ;
11276
+ std::optional<std::pair<unsigned, unsigned> > SecondSrc;
11260
11277
uint64_t PermMask = 0x00000000;
11261
11278
for (size_t i = 0; i < PermNodes.size(); i++) {
11262
11279
auto PermOp = PermNodes[i];
11263
11280
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
11264
11281
// by sizeof(Src2) = 4
11265
11282
int SrcByteAdjust = 4;
11266
11283
11267
- if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11268
- if (SecondSrc.has_value())
11269
- if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11284
+ // If the Src uses a byte from a different DWORD, then it corresponds
11285
+ // with a difference source
11286
+ if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11287
+ ((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11288
+ if (SecondSrc)
11289
+ if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11290
+ ((PermOp.SrcOffset / 4) != SecondSrc->second))
11270
11291
return SDValue();
11271
11292
11272
11293
// Set the index of the second distinct Src node
11273
- SecondSrc = i ;
11274
- assert(!(PermNodes[* SecondSrc].Src->getValueSizeInBits() % 8));
11294
+ SecondSrc = {i, PermNodes[i].SrcOffset / 4} ;
11295
+ assert(!(PermNodes[SecondSrc->first ].Src->getValueSizeInBits() % 8));
11275
11296
SrcByteAdjust = 0;
11276
11297
}
11277
- assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11298
+ assert(( PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
11278
11299
assert(!DAG.getDataLayout().isBigEndian());
11279
- PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
11300
+ PermMask |= (( PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
11280
11301
}
11281
-
11282
- SDValue Op = *PermNodes[FirstSrc].Src;
11283
- SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11284
- : *PermNodes[FirstSrc].Src;
11285
-
11286
- // Check that we haven't just recreated the same FSHR node.
11287
- if (N->getOpcode() == ISD::FSHR &&
11288
- (N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11289
- (N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11290
- return SDValue();
11302
+ SDLoc DL(N);
11303
+ SDValue Op = *PermNodes[FirstSrc.first].Src;
11304
+ Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
11305
+ assert(Op.getValueSizeInBits() == 32);
11291
11306
11292
11307
// Check that we are not just extracting the bytes in order from an op
11293
- if (Op == OtherOp && Op.getValueSizeInBits() == 32 ) {
11308
+ if (!SecondSrc ) {
11294
11309
int Low16 = PermMask & 0xffff;
11295
11310
int Hi16 = (PermMask & 0xffff0000) >> 16;
11296
11311
@@ -11302,8 +11317,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11302
11317
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
11303
11318
}
11304
11319
11320
+ SDValue OtherOp =
11321
+ SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11322
+
11323
+ if (SecondSrc)
11324
+ OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11325
+
11326
+ assert(Op.getValueSizeInBits() == 32);
11327
+
11305
11328
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11306
- SDLoc DL(N);
11329
+
11307
11330
assert(Op.getValueType().isByteSized() &&
11308
11331
OtherOp.getValueType().isByteSized());
11309
11332
@@ -11318,7 +11341,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11318
11341
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
11319
11342
DAG.getConstant(PermMask, DL, MVT::i32));
11320
11343
}
11321
-
11322
11344
return SDValue();
11323
11345
}
11324
11346
@@ -12794,17 +12816,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
12794
12816
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
12795
12817
}
12796
12818
12819
+ struct DotSrc {
12820
+ SDValue SrcOp;
12821
+ int64_t PermMask;
12822
+ int64_t DWordOffset;
12823
+ };
12824
+
12797
12825
static void placeSources(ByteProvider<SDValue> &Src0,
12798
12826
ByteProvider<SDValue> &Src1,
12799
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12800
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12801
- int Step) {
12827
+ SmallVectorImpl<DotSrc> &Src0s,
12828
+ SmallVectorImpl<DotSrc> &Src1s, int Step) {
12802
12829
12803
12830
assert(Src0.Src.has_value() && Src1.Src.has_value());
12804
12831
// Src0s and Src1s are empty, just place arbitrarily.
12805
12832
if (Step == 0) {
12806
- Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12807
- Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12833
+ Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
12834
+ Src0.SrcOffset / 4});
12835
+ Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
12836
+ Src1.SrcOffset / 4});
12808
12837
return;
12809
12838
}
12810
12839
@@ -12817,38 +12846,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
12817
12846
unsigned FMask = 0xFF << (8 * (3 - Step));
12818
12847
12819
12848
unsigned FirstMask =
12820
- BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12849
+ ( BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12821
12850
unsigned SecondMask =
12822
- BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12851
+ ( BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12823
12852
// Attempt to find Src vector which contains our SDValue, if so, add our
12824
12853
// perm mask to the existing one. If we are unable to find a match for the
12825
12854
// first SDValue, attempt to find match for the second.
12826
12855
int FirstGroup = -1;
12827
12856
for (int I = 0; I < 2; I++) {
12828
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12829
- I == 0 ? Src0s : Src1s;
12830
- auto MatchesFirst = [& BPP](std::pair<SDValue, unsigned> IterElt) {
12831
- return IterElt.first == * BPP.first.Src ;
12857
+ SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
12858
+ auto MatchesFirst = [&BPP](DotSrc &IterElt) {
12859
+ return IterElt.SrcOp == * BPP.first.Src &&
12860
+ ( IterElt.DWordOffset == ( BPP.first.SrcOffset / 4)) ;
12832
12861
};
12833
12862
12834
12863
auto Match = llvm::find_if(Srcs, MatchesFirst);
12835
12864
if (Match != Srcs.end()) {
12836
- Match->second = addPermMasks(FirstMask, Match->second );
12865
+ Match->PermMask = addPermMasks(FirstMask, Match->PermMask );
12837
12866
FirstGroup = I;
12838
12867
break;
12839
12868
}
12840
12869
}
12841
12870
if (FirstGroup != -1) {
12842
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12843
- FirstGroup == 1 ? Src0s : Src1s;
12844
- auto MatchesSecond = [& BPP](std::pair<SDValue, unsigned> IterElt) {
12845
- return IterElt.first == * BPP.second.Src ;
12871
+ SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
12872
+ auto MatchesSecond = [&BPP](DotSrc &IterElt) {
12873
+ return IterElt.SrcOp == * BPP.second.Src &&
12874
+ ( IterElt.DWordOffset == ( BPP.second.SrcOffset / 4)) ;
12846
12875
};
12847
12876
auto Match = llvm::find_if(Srcs, MatchesSecond);
12848
12877
if (Match != Srcs.end()) {
12849
- Match->second = addPermMasks(SecondMask, Match->second );
12878
+ Match->PermMask = addPermMasks(SecondMask, Match->PermMask );
12850
12879
} else
12851
- Srcs.push_back({*BPP.second.Src, SecondMask});
12880
+ Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4 });
12852
12881
return;
12853
12882
}
12854
12883
}
@@ -12860,29 +12889,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
12860
12889
unsigned FMask = 0xFF << (8 * (3 - Step));
12861
12890
12862
12891
Src0s.push_back(
12863
- {*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12892
+ {*Src0.Src,
12893
+ ((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12894
+ Src1.SrcOffset / 4});
12864
12895
Src1s.push_back(
12865
- {*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12896
+ {*Src1.Src,
12897
+ ((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12898
+ Src1.SrcOffset / 4});
12866
12899
12867
12900
return;
12868
12901
}
12869
12902
12870
- static SDValue
12871
- resolveSources(SelectionDAG &DAG, SDLoc SL,
12872
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12873
- bool IsSigned, bool IsAny) {
12903
+ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
12904
+ SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
12905
+ bool IsAny) {
12874
12906
12875
12907
// If we just have one source, just permute it accordingly.
12876
12908
if (Srcs.size() == 1) {
12877
12909
auto Elt = Srcs.begin();
12878
- auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first , SL, MVT::i32 );
12910
+ auto EltOp = getDWordFromOffset(DAG , SL, Elt->SrcOp, Elt->DWordOffset );
12879
12911
12880
- // v_perm will produce the original value.
12881
- if (Elt->second == 0x3020100)
12882
- return EltVal ;
12912
+ // v_perm will produce the original value
12913
+ if (Elt->PermMask == 0x3020100)
12914
+ return EltOp ;
12883
12915
12884
- return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
12885
- DAG.getConstant(Elt->second , SL, MVT::i32));
12916
+ return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
12917
+ DAG.getConstant(Elt->PermMask , SL, MVT::i32));
12886
12918
}
12887
12919
12888
12920
auto FirstElt = Srcs.begin();
@@ -12893,8 +12925,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
12893
12925
// If we have multiple sources in the chain, combine them via perms (using
12894
12926
// calculated perm mask) and Ors.
12895
12927
while (true) {
12896
- auto FirstMask = FirstElt->second ;
12897
- auto SecondMask = SecondElt->second ;
12928
+ auto FirstMask = FirstElt->PermMask ;
12929
+ auto SecondMask = SecondElt->PermMask ;
12898
12930
12899
12931
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
12900
12932
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -12904,9 +12936,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
12904
12936
12905
12937
auto PermMask = addPermMasks(FirstMask, SecondMask);
12906
12938
auto FirstVal =
12907
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
12939
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
12908
12940
auto SecondVal =
12909
- DAG.getBitcastedAnyExtOrTrunc(SecondElt->first , SL, MVT::i32 );
12941
+ getDWordFromOffset(DAG , SL, SecondElt->SrcOp, SecondElt->DWordOffset );
12910
12942
12911
12943
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
12912
12944
SecondVal,
@@ -12920,12 +12952,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
12920
12952
// If we only have a FirstElt, then just combine that into the cumulative
12921
12953
// source node.
12922
12954
if (SecondElt == Srcs.end()) {
12923
- auto EltVal =
12924
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
12955
+ auto EltOp =
12956
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
12925
12957
12926
12958
Perms.push_back(
12927
- DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
12928
- DAG.getConstant(FirstElt->second , SL, MVT::i32)));
12959
+ DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
12960
+ DAG.getConstant(FirstElt->PermMask , SL, MVT::i32)));
12929
12961
break;
12930
12962
}
12931
12963
}
@@ -12936,9 +12968,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
12936
12968
: Perms[0];
12937
12969
}
12938
12970
12939
- static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12940
- unsigned ChainLength) {
12941
- for (auto &[EntryVal, EntryMask] : Srcs) {
12971
+ static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
12972
+ for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
12942
12973
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
12943
12974
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
12944
12975
EntryMask += ZeroMask;
@@ -13003,8 +13034,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13003
13034
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
13004
13035
SDValue TempNode(N, 0);
13005
13036
std::optional<bool> IsSigned;
13006
- SmallVector<std::pair<SDValue, unsigned> , 4> Src0s;
13007
- SmallVector<std::pair<SDValue, unsigned> , 4> Src1s;
13037
+ SmallVector<DotSrc , 4> Src0s;
13038
+ SmallVector<DotSrc , 4> Src1s;
13008
13039
SmallVector<SDValue, 4> Src2s;
13009
13040
13010
13041
// Match the v_dot4 tree, while collecting src nodes.
@@ -13082,11 +13113,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13082
13113
// (commutation).
13083
13114
bool UseOriginalSrc = false;
13084
13115
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13085
- Src0s.begin()->second == Src1s.begin()->second &&
13086
- Src0s.begin()->first .getValueSizeInBits() = = 32 &&
13087
- Src1s.begin()->first .getValueSizeInBits() = = 32) {
13116
+ Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13117
+ Src0s.begin()->SrcOp .getValueSizeInBits() > = 32 &&
13118
+ Src1s.begin()->SrcOp .getValueSizeInBits() > = 32) {
13088
13119
SmallVector<unsigned, 4> SrcBytes;
13089
- auto Src0Mask = Src0s.begin()->second ;
13120
+ auto Src0Mask = Src0s.begin()->PermMask ;
13090
13121
SrcBytes.push_back(Src0Mask & 0xFF000000);
13091
13122
bool UniqueEntries = true;
13092
13123
for (auto I = 1; I < 4; I++) {
@@ -13101,11 +13132,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13101
13132
13102
13133
if (UniqueEntries) {
13103
13134
UseOriginalSrc = true;
13104
- // Must be 32 bits to enter above conditional.
13105
- assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13106
- assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13107
- Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13108
- Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13135
+
13136
+ auto FirstElt = Src0s.begin();
13137
+ auto FirstEltOp =
13138
+ getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13139
+
13140
+ auto SecondElt = Src1s.begin();
13141
+ auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13142
+ SecondElt->DWordOffset);
13143
+
13144
+ Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13145
+ MVT::getIntegerVT(32));
13146
+ Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13147
+ MVT::getIntegerVT(32));
13109
13148
}
13110
13149
}
13111
13150
0 commit comments