Skip to content

Update the base and index value for masked gather #130920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
72b0f4b
[X86] Update the value of base and index of masked gather for better …
Mar 12, 2025
a231c96
[X86] Update the value of base and index of masked gather for better …
Mar 12, 2025
1ff621f
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
85d2e0e
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
2f0897b
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
cdb181d
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 12, 2025
dd8762a
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
07dd191
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
7d840ed
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
1bd64b8
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
b00f0a9
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
2268967
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
8aeeb31
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
f4e8b0c
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 13, 2025
49f084e
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
a7a52cd
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
6252789
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
8565941
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
5eecf46
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
ad31491
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 17, 2025
ba2f9e7
Merge branch 'gatherBaseIndexfix' of github.com:rohitaggarwal007/llvm…
Mar 19, 2025
2e344a1
Merge branch 'llvm:main' into gatherBaseIndexfix
rohitaggarwal007 Apr 15, 2025
87e2533
Merge branch 'llvm:main' into gatherBaseIndexfix
rohitaggarwal007 Apr 16, 2025
f516be2
Update the masked_gather_scatter.ll
Apr 16, 2025
c2848c2
Remove redundant gatherBaseIndexFix.ll
Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5127,6 +5127,13 @@ class TargetLowering : public TargetLoweringBase {
SmallVectorImpl<SDValue> &Ops,
SelectionDAG &DAG) const;

// Target may override this function to decided whether it want to update the
// base and index value of a non-uniform gep
virtual bool updateBaseAndIndex(const Value *Ptr, SDValue &Base,
SDValue &Index, const SDLoc &DL,
const SDValue &Gep, SelectionDAG &DAG,
const BasicBlock *CurBB) const;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name and purpose unclear. We also have way too many poorly designed, narrow purpose hooks. You should try to avoid introducing this without a very strong justification for why there can't be a reasonable default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not understood the comment.

//===--------------------------------------------------------------------===//
// Div utility functions
//
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4905,6 +4905,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}

if (!UniformBase) {
TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
I.getParent());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DAG builder should avoid interpreting the incoming IR in target specific ways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm
Where should I place the code snippet so that it is applicable for X86 target?

}

EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
Expand Down Expand Up @@ -5024,6 +5029,11 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}

if (!UniformBase) {
TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
I.getParent());
}

EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5655,6 +5655,14 @@ void TargetLowering::CollectTargetIntrinsicOperands(
const CallInst &I, SmallVectorImpl<SDValue> &Ops, SelectionDAG &DAG) const {
}

// By default, this function is disabled. Overriding target can enable it
bool TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
SDValue &Index, const SDLoc &DL,
const SDValue &Gep, SelectionDAG &DAG,
const BasicBlock *CurBB) const {
return false;
}

std::pair<unsigned, const TargetRegisterClass *>
TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI,
StringRef Constraint,
Expand Down
101 changes: 101 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
"SHIFT, LEA, etc."),
cl::Hidden);

static cl::opt<bool>
EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not follow command line option naming convention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is the kind of option that will never be used by anyone. Can you just do this unconditionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this flag for the safety incase any untest usecase occur. So that we can fall back to existing behavior.
Sure, I will remove the flag.

cl::desc("Update the value of base and index"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name and description don't mean anything standalone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update the information

cl::Hidden);

X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
const X86Subtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
return TargetLowering::getPrefLoopAlignment();
}

// Target override this function to decided whether it want to update the base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided -> decide

// and index value of a non-uniform gep
bool X86TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
SDValue &Index, const SDLoc &DL,
const SDValue &Gep,
SelectionDAG &DAG,
const BasicBlock *CurBB) const {
if (!EnableBaseIndexUpdate)
return false;

const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
Copy link
Collaborator

@topperc topperc Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check that the Ptr is a GEP? It doesn't look like we use any information from it. The rest of the code is on SelectionDAG nodes unless I missed something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check whether ptr is derived from GEP. If not, we return false and transformation does not happen.
If the check does not fit here, i can move it to callsite

if (GEP && GEP->getParent() != CurBB)
return false;

SDValue nbase;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable names should be capitalized

/* For the gep instruction, we are trying to properly assign the base and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use // comments

index value We are go through the lower code and iterate backward.
*/
if (Gep.getOpcode() == ISD::ADD) {
SDValue Op0 = Gep.getOperand(0); // base or add
SDValue Op1 = Gep.getOperand(1); // build vector or SHL
nbase = Op0;
SDValue Idx = Op1;
auto Flags = Gep->getFlags();

if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
SDValue Op00 = Op0.getOperand(0); // Base
nbase = Op00;
Idx = Op0.getOperand(1);
} else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
return false;
}
SDValue nIndex;
if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
SDValue Op11 = Idx.getOperand(1); // Build vector of constant

unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
Op11.getOpcode() == ISD::BUILD_VECTOR) {

KnownBits ExtKnown = DAG.computeKnownBits(Op10);
bool ExtIsNonNegative = ExtKnown.isNonNegative();
KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
if (!(ExtIsNonNegative && ExtOpIsNonNegative))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deMorgan

return false;

SDValue newOp10 =
Op10.getOperand(0); // Get the Operand zero from the ext
EVT VT = newOp10.getValueType(); // Use the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Use the" is an incomplete comment


auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
if (!ConstEltNo) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

return false;
}
SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
DAG.getConstant(ConstEltNo->getZExtValue(),
DL, VT.getScalarType()));
nIndex = DAG.getNode(ISD::SHL, DL, VT, newOp10,
DAG.getBuildVector(VT, DL, Ops));
} else {
return false;
}
} else {
return false;
}
if (Op0 != nbase) {
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
if (!ConstEltNo) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

return false;
}
SmallVector<SDValue, 8> Ops(
nIndex.getValueType().getVectorNumElements(),
DAG.getConstant(ConstEltNo->getZExtValue(), DL,
nIndex.getValueType().getScalarType()));
nIndex = DAG.getNode(ISD::ADD, DL, nIndex.getValueType(), nIndex,
DAG.getBuildVector(nIndex.getValueType(), DL, Ops),
Flags);
}
Base = nbase.getOperand(0);
Index = nIndex;
LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successful*

"information\n";
dbgs() << "updated base "; Base.dump();
dbgs() << "updated Index "; Index.dump(););
return true;
}
return false;
}
7 changes: 7 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,13 @@ namespace llvm {
return TargetLoweringBase::getTypeToTransformTo(Context, VT);
}

// Target override this function to decided whether it want to update the
// base and index value of a non-uniform gep
bool updateBaseAndIndex(const Value *Ptr, SDValue &Base, SDValue &Index,
const SDLoc &DL, const SDValue &Gep,
SelectionDAG &DAG,
const BasicBlock *CurBB) const override;

protected:
std::pair<const TargetRegisterClass *, uint8_t>
findRepresentativeClass(const TargetRegisterInfo *TRI,
Expand Down
37 changes: 37 additions & 0 deletions llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
; RUN: llc -update-baseIndex -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
; RUN: llc -update-baseIndex=false -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s -check-prefix=OLD

%struct.pt = type { float, float, float, i32 }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test needs to be reduced down a lot - remove all the unnecessary attributes etc. and just contain the minimal IR to show the problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Contributor Author

@rohitaggarwal007 rohitaggarwal007 Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon, updated the test case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.


; CHECK-LABEL: test_gather_16f32_1:
; CHECK: vgatherdps

; OLD-LABEL: test_gather_16f32_1:
; OLD: vgatherqps
; OLD: vgatherqps

define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
%wide.load = load <16 x i32>, ptr %arr, align 4
%4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) avoid numbered variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

%5 = zext <16 x i32> %4 to <16 x i64>
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
ret <16 x float> %res
}

; CHECK-LABEL: test_gather_16f32_2:
; CHECK: vgatherdps

; OLD-LABEL: test_gather_16f32_2:
; OLD: vgatherqps
; OLD: vgatherqps

define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
%wide.load = load <16 x i32>, ptr %arr, align 4
%4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
%5 = zext <16 x i32> %4 to <16 x i64>
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 1
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
ret <16 x float> %res
}