Skip to content

Commit 928564c

Browse files
authored
[RISCV] Combine a gather to a larger element type (#66694)
If we have a gather load whose indices correspond to valid offsets for a gather with element type twice that our source, we can reduce the number of indices and perform the operation at the larger element type. This is generally profitable since we half VL - and these operations are linear in VL. This may require some additional VL/VTYPE toggles, but this appears to be worthwhile on the whole.
1 parent 915ebb0 commit 928564c

File tree

2 files changed

+92
-16
lines changed

2 files changed

+92
-16
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13589,6 +13589,52 @@ static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
1358913589
return ActiveLanes.all();
1359013590
}
1359113591

13592+
/// Match the index of a gather or scatter operation as an operation
13593+
/// with twice the element width and half the number of elements. This is
13594+
/// generally profitable (if legal) because these operations are linear
13595+
/// in VL, so even if we cause some extract VTYPE/VL toggles, we still
13596+
/// come out ahead.
13597+
static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
13598+
Align BaseAlign, const RISCVSubtarget &ST) {
13599+
if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
13600+
return false;
13601+
if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
13602+
return false;
13603+
13604+
// Attempt a doubling. If we can use a element type 4x or 8x in
13605+
// size, this will happen via multiply iterations of the transform.
13606+
const unsigned NumElems = VT.getVectorNumElements();
13607+
if (NumElems % 2 != 0)
13608+
return false;
13609+
13610+
const unsigned ElementSize = VT.getScalarStoreSize();
13611+
const unsigned WiderElementSize = ElementSize * 2;
13612+
if (WiderElementSize > ST.getELen()/8)
13613+
return false;
13614+
13615+
if (!ST.enableUnalignedVectorMem() && BaseAlign < WiderElementSize)
13616+
return false;
13617+
13618+
for (unsigned i = 0; i < Index->getNumOperands(); i++) {
13619+
// TODO: We've found an active bit of UB, and could be
13620+
// more aggressive here if desired.
13621+
if (Index->getOperand(i)->isUndef())
13622+
return false;
13623+
// TODO: This offset check is too strict if we support fully
13624+
// misaligned memory operations.
13625+
uint64_t C = Index->getConstantOperandVal(i);
13626+
if (C % ElementSize != 0)
13627+
return false;
13628+
if (i % 2 == 0)
13629+
continue;
13630+
uint64_t Last = Index->getConstantOperandVal(i-1);
13631+
if (C != Last + ElementSize)
13632+
return false;
13633+
}
13634+
return true;
13635+
}
13636+
13637+
1359213638
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1359313639
DAGCombinerInfo &DCI) const {
1359413640
SelectionDAG &DAG = DCI.DAG;
@@ -14020,6 +14066,36 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1402014066
DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
1402114067
return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
1402214068
}
14069+
14070+
if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
14071+
matchIndexAsWiderOp(VT, Index, MGN->getMask(),
14072+
MGN->getMemOperand()->getBaseAlign(), Subtarget)) {
14073+
SmallVector<SDValue> NewIndices;
14074+
for (unsigned i = 0; i < Index->getNumOperands(); i += 2)
14075+
NewIndices.push_back(Index.getOperand(i));
14076+
EVT IndexVT = Index.getValueType()
14077+
.getHalfNumVectorElementsVT(*DAG.getContext());
14078+
Index = DAG.getBuildVector(IndexVT, DL, NewIndices);
14079+
14080+
unsigned ElementSize = VT.getScalarStoreSize();
14081+
EVT WideScalarVT = MVT::getIntegerVT(ElementSize * 8 * 2);
14082+
auto EltCnt = VT.getVectorElementCount();
14083+
assert(EltCnt.isKnownEven() && "Splitting vector, but not in half!");
14084+
EVT WideVT = EVT::getVectorVT(*DAG.getContext(), WideScalarVT,
14085+
EltCnt.divideCoefficientBy(2));
14086+
SDValue Passthru = DAG.getBitcast(WideVT, MGN->getPassThru());
14087+
EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
14088+
EltCnt.divideCoefficientBy(2));
14089+
SDValue Mask = DAG.getSplat(MaskVT, DL, DAG.getConstant(1, DL, MVT::i1));
14090+
14091+
SDValue Gather =
14092+
DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), WideVT, DL,
14093+
{MGN->getChain(), Passthru, Mask, MGN->getBasePtr(),
14094+
Index, ScaleOp},
14095+
MGN->getMemOperand(), IndexType, ISD::NON_EXTLOAD);
14096+
SDValue Result = DAG.getBitcast(VT, Gather.getValue(0));
14097+
return DAG.getMergeValues({Result, Gather.getValue(1)}, DL);
14098+
}
1402314099
break;
1402414100
}
1402514101
case ISD::MSCATTER:{

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13024,19 +13024,19 @@ define <4 x i32> @mgather_narrow_edge_case(ptr %base) {
1302413024
define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
1302513025
; RV32-LABEL: mgather_strided_2xSEW:
1302613026
; RV32: # %bb.0:
13027-
; RV32-NEXT: lui a1, %hi(.LCPI107_0)
13028-
; RV32-NEXT: addi a1, a1, %lo(.LCPI107_0)
13029-
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
13030-
; RV32-NEXT: vle8.v v9, (a1)
13027+
; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
13028+
; RV32-NEXT: vid.v v8
13029+
; RV32-NEXT: vsll.vi v9, v8, 3
13030+
; RV32-NEXT: vsetvli zero, zero, e32, m1, ta, ma
1303113031
; RV32-NEXT: vluxei8.v v8, (a0), v9
1303213032
; RV32-NEXT: ret
1303313033
;
1303413034
; RV64V-LABEL: mgather_strided_2xSEW:
1303513035
; RV64V: # %bb.0:
13036-
; RV64V-NEXT: lui a1, %hi(.LCPI107_0)
13037-
; RV64V-NEXT: addi a1, a1, %lo(.LCPI107_0)
13038-
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
13039-
; RV64V-NEXT: vle8.v v9, (a1)
13036+
; RV64V-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
13037+
; RV64V-NEXT: vid.v v8
13038+
; RV64V-NEXT: vsll.vi v9, v8, 3
13039+
; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma
1304013040
; RV64V-NEXT: vluxei8.v v8, (a0), v9
1304113041
; RV64V-NEXT: ret
1304213042
;
@@ -13141,19 +13141,19 @@ define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
1314113141
define <8 x i16> @mgather_gather_2xSEW(ptr %base) {
1314213142
; RV32-LABEL: mgather_gather_2xSEW:
1314313143
; RV32: # %bb.0:
13144-
; RV32-NEXT: lui a1, %hi(.LCPI108_0)
13145-
; RV32-NEXT: addi a1, a1, %lo(.LCPI108_0)
13146-
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
13147-
; RV32-NEXT: vle8.v v9, (a1)
13144+
; RV32-NEXT: lui a1, 82176
13145+
; RV32-NEXT: addi a1, a1, 1024
13146+
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13147+
; RV32-NEXT: vmv.s.x v9, a1
1314813148
; RV32-NEXT: vluxei8.v v8, (a0), v9
1314913149
; RV32-NEXT: ret
1315013150
;
1315113151
; RV64V-LABEL: mgather_gather_2xSEW:
1315213152
; RV64V: # %bb.0:
13153-
; RV64V-NEXT: lui a1, %hi(.LCPI108_0)
13154-
; RV64V-NEXT: addi a1, a1, %lo(.LCPI108_0)
13155-
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
13156-
; RV64V-NEXT: vle8.v v9, (a1)
13153+
; RV64V-NEXT: lui a1, 82176
13154+
; RV64V-NEXT: addiw a1, a1, 1024
13155+
; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13156+
; RV64V-NEXT: vmv.s.x v9, a1
1315713157
; RV64V-NEXT: vluxei8.v v8, (a0), v9
1315813158
; RV64V-NEXT: ret
1315913159
;

0 commit comments

Comments
 (0)