Skip to content

Update the base and index value for masked gather #130920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
72b0f4b
[X86] Update the value of base and index of masked gather for better …
Mar 12, 2025
a231c96
[X86] Update the value of base and index of masked gather for better …
Mar 12, 2025
1ff621f
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
85d2e0e
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
2f0897b
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
cdb181d
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
dd8762a
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
07dd191
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
7d840ed
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
1bd64b8
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
b00f0a9
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
2268967
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
8aeeb31
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
f4e8b0c
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
49f084e
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
a7a52cd
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
6252789
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
8565941
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
5eecf46
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
ad31491
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
ba2f9e7
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 19, 2025
2e344a1
Merge branch 'llvm:main' into gatherBaseIndexfix
rohitaggarwal007 Apr 15, 2025
87e2533
Merge branch 'llvm:main' into gatherBaseIndexfix
rohitaggarwal007 Apr 16, 2025
f516be2
Update the masked_gather_scatter.ll
Apr 16, 2025
c2848c2
Remove redundant gatherBaseIndexFix.ll
Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
"SHIFT, LEA, etc."),
cl::Hidden);

static cl::opt<bool>
EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not follow command line option naming convention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is the kind of option that will never be used by anyone. Can you just do this unconditionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this flag for the safety incase any untest usecase occur. So that we can fall back to existing behavior.
Sure, I will remove the flag.

cl::desc("Update the value of base and index"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name and description don't mean anything standalone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update the information

cl::Hidden);

X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
const X86Subtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -56370,6 +56375,112 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Scatter->isTruncatingStore());
}

// Target override this function to decide whether it want to update the base
// and index value of a non-uniform gep
static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, const SDLoc &DL,
const SDValue &Gep, SelectionDAG &DAG) {
if (!EnableBaseIndexUpdate)
return false;

SDValue Nbase;
SDValue Nindex;
bool Changed = false;
// This function check the opcode of Index and update the index
auto checkAndUpdateIndex = [&](SDValue &Idx) {
if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
SDValue Op11 = Idx.getOperand(1); // Build vector of constant

unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
Op11.getOpcode() == ISD::BUILD_VECTOR) {

KnownBits ExtKnown = DAG.computeKnownBits(Op10);
bool ExtIsNonNegative = ExtKnown.isNonNegative();
KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
if (!ExtIsNonNegative || !ExtOpIsNonNegative)
return false;

SDValue NewOp10 =
Op10.getOperand(0); // Get the Operand zero from the ext
EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
// the type of index

auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assuming that the shl is uniform - use getValidMinimumShiftAmount instead (replaces the BUILD_VECTOR check above as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my understanding is this
From

      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **Op11.getOpcode() == ISD::BUILD_VECTOR**) {

to

 unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **DAG.getValidMinimumShiftAmount(Idx)**) {

Please correct me

if (!ConstEltNo)
return false;

SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
DAG.getConstant(ConstEltNo->getZExtValue(),
DL, VT.getScalarType()));
Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10,
DAG.getBuildVector(VT, DL, Ops));
return true;
}
}
return false;
};

// For the gep instruction, we are trying to properly assign the base and
// index value We are go through the lower code and iterate backward.
if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) {
SDValue Op0 = Gep.getOperand(0); // base or add
SDValue Op1 = Gep.getOperand(1); // build vector or SHL
Nbase = Op0;
SDValue Idx = Op1;
auto Flags = Gep->getFlags();

if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
SDValue Op00 = Op0.getOperand(0); // Base
Nbase = Op00;
Idx = Op0.getOperand(1);
} else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
return false;
}
if (!checkAndUpdateIndex(Idx)) {
return false;
}
if (Op0 != Nbase) {
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
if (!ConstEltNo)
return false;

SmallVector<SDValue, 8> Ops(
Nindex.getValueType().getVectorNumElements(),
DAG.getConstant(ConstEltNo->getZExtValue(), DL,
Nindex.getValueType().getScalarType()));
Nindex = DAG.getNode(ISD::ADD, DL, Nindex.getValueType(), Nindex,
DAG.getBuildVector(Nindex.getValueType(), DL, Ops),
Flags);
}
Base = Nbase.getOperand(0);
Index = Nindex;
Changed = true;
} else if (Base.getOpcode() == ISD::CopyFromReg ||
(Base.getOpcode() == ISD::ADD &&
Base.getOperand(0).getOpcode() == ISD::CopyFromReg &&
isConstOrConstSplat(Base.getOperand(1)))) {
if (checkAndUpdateIndex(Index)) {
Index = Nindex;
Changed = true;
}
}
if (Changed) {
LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep "
"information\n";
dbgs() << "updated base "; Base.dump();
dbgs() << "updated Index "; Index.dump(););
return true;
}
return false;
}

static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
Expand All @@ -56380,6 +56491,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (DCI.isBeforeLegalize()) {
if (updateBaseAndIndex(Base, Index, DL, Index, DAG))
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
unsigned IndexWidth = Index.getScalarValueSizeInBits();

// Shrink constant indices if they are larger than 32-bits.
Expand Down Expand Up @@ -56475,7 +56588,6 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
return SDValue(N, 0);
}
}

return SDValue();
}

Expand Down
76 changes: 76 additions & 0 deletions llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
; RUN: llc -update-baseIndex -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
; RUN: llc -update-baseIndex=false -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s -check-prefix=OLD

%struct.pt = type { float, float, float, i32 }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test needs to be reduced down a lot - remove all the unnecessary attributes etc. and just contain the minimal IR to show the problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Contributor Author

@rohitaggarwal007 rohitaggarwal007 Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon, updated the test case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.


define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
; CHECK-LABEL: test_gather_16f32_1:
; CHECK: # %bb.0:
; CHECK-NEXT: vpslld $4, (%rsi), %zmm2
; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
; CHECK-NEXT: vpmovb2m %xmm0, %k1
; CHECK-NEXT: vgatherdps (%rdi,%zmm2), %zmm1 {%k1}
; CHECK-NEXT: vmovaps %zmm1, %zmm0
; CHECK-NEXT: retq
;
; OLD-LABEL: test_gather_16f32_1:
; OLD: # %bb.0:
; OLD-NEXT: vpsllw $7, %xmm0, %xmm0
; OLD-NEXT: vmovdqu64 (%rsi), %zmm4
; OLD-NEXT: vextractf64x4 $1, %zmm1, %ymm3
; OLD-NEXT: vpmovb2m %xmm0, %k1
; OLD-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm4, %zmm0
; OLD-NEXT: kshiftrw $8, %k1, %k2
; OLD-NEXT: vextracti64x4 $1, %zmm0, %ymm2
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm2 = ymm2[0],zero,ymm2[1],zero,ymm2[2],zero,ymm2[3],zero,ymm2[4],zero,ymm2[5],zero,ymm2[6],zero,ymm2[7],zero
; OLD-NEXT: vpsllq $4, %zmm0, %zmm0
; OLD-NEXT: vpsllq $4, %zmm2, %zmm2
; OLD-NEXT: vgatherqps (%rdi,%zmm0), %ymm1 {%k1}
; OLD-NEXT: vgatherqps (%rdi,%zmm2), %ymm3 {%k2}
; OLD-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm0
; OLD-NEXT: retq
%wide.load = load <16 x i32>, ptr %arr, align 4
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
%zext = zext <16 x i32> %and to <16 x i64>
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
ret <16 x float> %res
}

define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
; CHECK-LABEL: test_gather_16f32_2:
; CHECK: # %bb.0:
; CHECK-NEXT: vpslld $4, (%rsi), %zmm2
; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
; CHECK-NEXT: vpmovb2m %xmm0, %k1
; CHECK-NEXT: vgatherdps 4(%rdi,%zmm2), %zmm1 {%k1}
; CHECK-NEXT: vmovaps %zmm1, %zmm0
; CHECK-NEXT: retq
;
; OLD-LABEL: test_gather_16f32_2:
; OLD: # %bb.0:
; OLD-NEXT: vpsllw $7, %xmm0, %xmm0
; OLD-NEXT: vmovdqu64 (%rsi), %zmm4
; OLD-NEXT: vextractf64x4 $1, %zmm1, %ymm3
; OLD-NEXT: vpmovb2m %xmm0, %k1
; OLD-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm4, %zmm0
; OLD-NEXT: kshiftrw $8, %k1, %k2
; OLD-NEXT: vextracti64x4 $1, %zmm0, %ymm2
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm2 = ymm2[0],zero,ymm2[1],zero,ymm2[2],zero,ymm2[3],zero,ymm2[4],zero,ymm2[5],zero,ymm2[6],zero,ymm2[7],zero
; OLD-NEXT: vpsllq $4, %zmm0, %zmm0
; OLD-NEXT: vpsllq $4, %zmm2, %zmm2
; OLD-NEXT: vgatherqps 4(%rdi,%zmm0), %ymm1 {%k1}
; OLD-NEXT: vgatherqps 4(%rdi,%zmm2), %ymm3 {%k2}
; OLD-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm0
; OLD-NEXT: retq
%wide.load = load <16 x i32>, ptr %arr, align 4
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
%zext = zext <16 x i32> %and to <16 x i64>
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
ret <16 x float> %res
}