Skip to content

Commit 937cfdc

Browse files
authored
[X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address (llvm#135201)
We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
1 parent 1d3d3f4 commit 937cfdc

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56517,6 +56517,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
5651756517
SDValue Base = GorS->getBasePtr();
5651856518
SDValue Scale = GorS->getScale();
5651956519
EVT IndexVT = Index.getValueType();
56520+
EVT IndexSVT = IndexVT.getVectorElementType();
5652056521
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5652156522

5652256523
if (DCI.isBeforeLegalize()) {
@@ -56553,41 +56554,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
5655356554
}
5655456555

5655556556
EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
56556-
// Try to move splat constant adders from the index operand to the base
56557+
56558+
// Try to move splat adders from the index operand to the base
5655756559
// pointer operand. Taking care to multiply by the scale. We can only do
5655856560
// this when index element type is the same as the pointer type.
5655956561
// Otherwise we need to be sure the math doesn't wrap before the scale.
56560-
if (Index.getOpcode() == ISD::ADD &&
56561-
IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
56562+
if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
56563+
isa<ConstantSDNode>(Scale)) {
5656256564
uint64_t ScaleAmt = Scale->getAsZExtVal();
56563-
if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
56564-
BitVector UndefElts;
56565-
if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
56566-
// FIXME: Allow non-constant?
56567-
if (UndefElts.none()) {
56568-
// Apply the scale.
56569-
APInt Adder = C->getAPIntValue() * ScaleAmt;
56570-
// Add it to the existing base.
56571-
Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
56572-
DAG.getConstant(Adder, DL, PtrVT));
56573-
Index = Index.getOperand(0);
56574-
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
56575-
}
56576-
}
5657756565

56578-
// It's also possible base is just a constant. In that case, just
56579-
// replace it with 0 and move the displacement into the index.
56580-
if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
56581-
isOneConstant(Scale)) {
56582-
SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
56583-
// Combine the constant build_vector and the constant base.
56584-
Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
56585-
// Add to the LHS of the original Index add.
56586-
Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
56587-
Base = DAG.getConstant(0, DL, Base.getValueType());
56588-
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
56566+
for (unsigned I = 0; I != 2; ++I)
56567+
if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
56568+
BitVector UndefElts;
56569+
if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
56570+
if (UndefElts.none()) {
56571+
// If the splat value is constant we can add the scaled splat value
56572+
// to the existing base.
56573+
if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
56574+
APInt Adder = C->getAPIntValue() * ScaleAmt;
56575+
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
56576+
DAG.getConstant(Adder, DL, PtrVT));
56577+
SDValue NewIndex = Index.getOperand(1 - I);
56578+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56579+
}
56580+
// For non-constant cases, limit this to non-scaled cases.
56581+
if (ScaleAmt == 1) {
56582+
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
56583+
SDValue NewIndex = Index.getOperand(1 - I);
56584+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56585+
}
56586+
}
56587+
}
56588+
// It's also possible base is just a constant. In that case, just
56589+
// replace it with 0 and move the displacement into the index.
56590+
if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
56591+
SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
56592+
// Combine the constant build_vector and the constant base.
56593+
Splat =
56594+
DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
56595+
// Add to the other half of the original Index add.
56596+
SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
56597+
Index.getOperand(1 - I), Splat);
56598+
SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
56599+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56600+
}
5658956601
}
56590-
}
5659156602
}
5659256603

5659356604
if (DCI.isBeforeLegalizeOps()) {

llvm/test/CodeGen/X86/masked_gather_scatter.ll

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
50285028
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax
50295029
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx
50305030
; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm2
5031-
; X86-KNL-NEXT: vpbroadcastd %eax, %zmm0
5032-
; X86-KNL-NEXT: vpaddd %zmm2, %zmm0, %zmm3
50335031
; X86-KNL-NEXT: kmovw %k1, %k2
50345032
; X86-KNL-NEXT: vmovaps %zmm1, %zmm0
50355033
; X86-KNL-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
5036-
; X86-KNL-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
5034+
; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
50375035
; X86-KNL-NEXT: retl
50385036
;
50395037
; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
50975095
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
50985096
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %ecx
50995097
; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2
5100-
; X86-SKX-NEXT: vpbroadcastd %eax, %zmm0
5101-
; X86-SKX-NEXT: vpaddd %zmm2, %zmm0, %zmm3
51025098
; X86-SKX-NEXT: kmovw %k1, %k2
51035099
; X86-SKX-NEXT: vmovaps %zmm1, %zmm0
51045100
; X86-SKX-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
5105-
; X86-SKX-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
5101+
; X86-SKX-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
51065102
; X86-SKX-NEXT: retl
51075103
%wide.load = load <16 x i32>, ptr %arr, align 4
51085104
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>

0 commit comments

Comments
 (0)