Skip to content

Commit b74a46c

Browse files
committed
[X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address
We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
1 parent 4ea57b3 commit b74a46c

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56521,6 +56521,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
5652156521
SDValue Base = GorS->getBasePtr();
5652256522
SDValue Scale = GorS->getScale();
5652356523
EVT IndexVT = Index.getValueType();
56524+
EVT IndexSVT = IndexVT.getVectorElementType();
5652456525
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5652556526

5652656527
if (DCI.isBeforeLegalize()) {
@@ -56557,41 +56558,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
5655756558
}
5655856559

5655956560
EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
56560-
// Try to move splat constant adders from the index operand to the base
56561+
56562+
// Try to move splat adders from the index operand to the base
5656156563
// pointer operand. Taking care to multiply by the scale. We can only do
5656256564
// this when index element type is the same as the pointer type.
5656356565
// Otherwise we need to be sure the math doesn't wrap before the scale.
56564-
if (Index.getOpcode() == ISD::ADD &&
56565-
IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
56566+
if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
56567+
isa<ConstantSDNode>(Scale)) {
5656656568
uint64_t ScaleAmt = Scale->getAsZExtVal();
56567-
if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
56568-
BitVector UndefElts;
56569-
if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
56570-
// FIXME: Allow non-constant?
56571-
if (UndefElts.none()) {
56572-
// Apply the scale.
56573-
APInt Adder = C->getAPIntValue() * ScaleAmt;
56574-
// Add it to the existing base.
56575-
Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
56576-
DAG.getConstant(Adder, DL, PtrVT));
56577-
Index = Index.getOperand(0);
56578-
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
56579-
}
56580-
}
5658156569

56582-
// It's also possible base is just a constant. In that case, just
56583-
// replace it with 0 and move the displacement into the index.
56584-
if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
56585-
isOneConstant(Scale)) {
56586-
SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
56587-
// Combine the constant build_vector and the constant base.
56588-
Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
56589-
// Add to the LHS of the original Index add.
56590-
Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
56591-
Base = DAG.getConstant(0, DL, Base.getValueType());
56592-
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
56570+
for (unsigned I = 0; I != 2; ++I)
56571+
if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
56572+
BitVector UndefElts;
56573+
if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
56574+
if (UndefElts.none()) {
56575+
// If the splat value is constant we can add the scaled splat value
56576+
// to the existing base.
56577+
if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
56578+
APInt Adder = C->getAPIntValue() * ScaleAmt;
56579+
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
56580+
DAG.getConstant(Adder, DL, PtrVT));
56581+
SDValue NewIndex = Index.getOperand(1 - I);
56582+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56583+
}
56584+
// For non-constant cases, limit this to non-scaled cases.
56585+
if (ScaleAmt == 1) {
56586+
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
56587+
SDValue NewIndex = Index.getOperand(1 - I);
56588+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56589+
}
56590+
}
56591+
}
56592+
// It's also possible base is just a constant. In that case, just
56593+
// replace it with 0 and move the displacement into the index.
56594+
if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
56595+
SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
56596+
// Combine the constant build_vector and the constant base.
56597+
Splat =
56598+
DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
56599+
// Add to the other half of the original Index add.
56600+
SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
56601+
Index.getOperand(1 - I), Splat);
56602+
SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
56603+
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
56604+
}
5659356605
}
56594-
}
5659556606
}
5659656607

5659756608
if (DCI.isBeforeLegalizeOps()) {

llvm/test/CodeGen/X86/masked_gather_scatter.ll

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
50285028
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax
50295029
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx
50305030
; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm2
5031-
; X86-KNL-NEXT: vpbroadcastd %eax, %zmm0
5032-
; X86-KNL-NEXT: vpaddd %zmm2, %zmm0, %zmm3
50335031
; X86-KNL-NEXT: kmovw %k1, %k2
50345032
; X86-KNL-NEXT: vmovaps %zmm1, %zmm0
50355033
; X86-KNL-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
5036-
; X86-KNL-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
5034+
; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
50375035
; X86-KNL-NEXT: retl
50385036
;
50395037
; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
50975095
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
50985096
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %ecx
50995097
; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2
5100-
; X86-SKX-NEXT: vpbroadcastd %eax, %zmm0
5101-
; X86-SKX-NEXT: vpaddd %zmm2, %zmm0, %zmm3
51025098
; X86-SKX-NEXT: kmovw %k1, %k2
51035099
; X86-SKX-NEXT: vmovaps %zmm1, %zmm0
51045100
; X86-SKX-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
5105-
; X86-SKX-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
5101+
; X86-SKX-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
51065102
; X86-SKX-NEXT: retl
51075103
%wide.load = load <16 x i32>, ptr %arr, align 4
51085104
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>

0 commit comments

Comments
 (0)