Skip to content

[X86][SelectionDAG] Fix the Gather's base and index by modifying the Scale value #137813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
741acb0
Fix the Gather's base and index for one use or multiple uses of Index…
Mar 19, 2025
9d395b1
Fix the Gather's base and index for one use or multiple uses of Index…
Mar 19, 2025
0625ae4
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 9, 2025
84d9a5f
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 9, 2025
76d9167
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 9, 2025
b73051c
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 9, 2025
a43aa1d
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 9, 2025
03b3daf
Changes
Apr 9, 2025
0e8856f
Occurrence
Apr 9, 2025
7eb7663
Update gatherBaseIndexFix.ll
rohitaggarwal007 Apr 9, 2025
4ed4a4d
squash! Changes
Apr 9, 2025
0c4eb0f
s Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal…
Apr 9, 2025
dc1b6d6
Merge branch 'llvm:main' into gatherMultipleOccurrence
rohitaggarwal007 Apr 10, 2025
34df281
Merge branch 'main' into gatherMultipleOccurrence
rohitaggarwal007 Apr 11, 2025
5aaa2ab
Merge branch 'llvm:main' into gatherMultipleOccurrence
rohitaggarwal007 Apr 14, 2025
8ce9360
Merge branch 'llvm:main' into gatherMultipleOccurrence
rohitaggarwal007 Apr 16, 2025
bfbbe9f
Update the masked_gather_scatter.ll
Apr 16, 2025
4bb9f5b
Remove redundant gatherBaseIndexFix.ll
Apr 16, 2025
fb131f8
Remove updateBaseIndex function and update the Checks for testcase
Apr 17, 2025
47bf70b
Revert back the hasOne check
Apr 17, 2025
716943e
Fold opertation
Apr 18, 2025
b9e571a
Restrict Scale so that it can happen fully or none. Added a logic for…
Apr 22, 2025
f0a175a
Add test case and merge usecases fo SHL
Apr 23, 2025
4b57e4a
Update undef to poison in testcase
Apr 23, 2025
8995df9
Merge branch 'gatherMultipleOccurrence' of github.com:rohitaggarwal00…
Apr 29, 2025
5a12c7d
Update the testcase
Apr 29, 2025
d2f1352
Remove the unwanted tests
Apr 29, 2025
ed72aa3
Log2 handling and updating the comments
May 5, 2025
1644def
Fix formatting
May 5, 2025
7ca7d48
Merge branch 'llvm:main' into gatherMultipleOccurrence
rohitaggarwal007 May 6, 2025
94f16e6
Code movement to if(IndexWidth > 32) condition
May 6, 2025
cfcf542
Merge branch 'llvm:main' into gatherMultipleOccurrence
rohitaggarwal007 May 12, 2025
7ab9924
Hoist the variable to reuse it
May 12, 2025
62fa5ad
Split the PR into two PRs.
May 13, 2025
ad2e419
Merge branch 'main' into gatherMultipleOccurrence
RKSimon May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56652,31 +56652,81 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (DCI.isBeforeLegalize()) {
// Attempt to move shifted index into the address scale, allows further
// index truncation below.
if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
unsigned BitWidth = Index.getScalarValueSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hoist the equivalent IndexWidth above this and remove BitWidth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

unsigned ScaleAmt = Scale->getAsZExtVal();
assert(isPowerOf2_32(ScaleAmt) && "Scale must be a power of 2");
unsigned Log2ScaleAmt = Log2_32(ScaleAmt);
unsigned MaskBits = BitWidth - Log2ScaleAmt;
APInt DemandedBits = APInt::getLowBitsSet(BitWidth, MaskBits);
if (TLI.SimplifyDemandedBits(Index, DemandedBits, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
return SDValue(N, 0);
}
if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
if (*MinShAmt >= 1 && (*MinShAmt + Log2ScaleAmt) < 4 &&
DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
SDValue ShAmt = Index.getOperand(1);
SDValue NewShAmt =
DAG.getNode(ISD::SUB, DL, ShAmt.getValueType(), ShAmt,
DAG.getConstant(1, DL, ShAmt.getValueType()));
SDValue NewIndex = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
Index.getOperand(0), NewShAmt);
SDValue NewScale =
DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
}
}
}
unsigned IndexWidth = Index.getScalarValueSizeInBits();

// Shrink indices if they are larger than 32-bits.
// Only do this before legalize types since v2i64 could become v2i32.
// FIXME: We could check that the type is legal if we're after legalize
// types, but then we would need to construct test cases where that happens.
if (IndexWidth > 32 && DAG.ComputeNumSignBits(Index) > (IndexWidth - 32)) {
EVT NewVT = IndexVT.changeVectorElementType(MVT::i32);

// FIXME: We could support more than just constant fold, but we need to
// careful with costing. A truncate that can be optimized out would be
// fine. Otherwise we might only want to create a truncate if it avoids a
// split.
if (SDValue TruncIndex =
DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);

// Shrink any sign/zero extends from 32 or smaller to larger than 32 if
// there are sufficient sign bits. Only do this before legalize types to
// avoid creating illegal types in truncate.
if ((Index.getOpcode() == ISD::SIGN_EXTEND ||
Index.getOpcode() == ISD::ZERO_EXTEND) &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
// \ComputeNumSignBits value is recomputed for the shift Index
if (IndexWidth > 32) {
// If the index is a left shift, \ComputeNumSignBits we are recomputing
// the number of sign bits from the shifted value. We are trying to enable
// the optimization in which we can shrink indices if they are larger than
// 32-bits. Using the existing fold techniques implemented below.
unsigned ComputeNumSignBits = DAG.ComputeNumSignBits(Index);
if (Index.getOpcode() == ISD::SHL) {
if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
if (DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
ComputeNumSignBits += *MinShAmt;
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon, I do not understand the comment fully... Does it mean to move it outside IndexWith > 32 or Does it mean to remove this code snippet?
if (Index.getOpcode() == ISD::SHL) { if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) { if (DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) { ComputeNumSignBits += *MinShAmt; } } }

We can not remove this code as it is required and our test case is failing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case, please can you pull it out of this patch - I'm happy with the rest of it, but I'm concerned that this will only work under very specific conditions - better to pull it out and you can work on it in a follow up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incase of SHL have the value of shift 4 or greater... We are recalculating the number of signed bit so that we can truncate.

Sure, i will create a separate PR for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, removed the code and created a separate PR.

if (ComputeNumSignBits > (IndexWidth - 32)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better if you can move everything inside if (IndexWidth > 32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon, Done. Moved the code inside the if (IndexWidth > 32)

EVT NewVT = IndexVT.changeVectorElementType(MVT::i32);

// FIXME: We could support more than just constant fold, but we need to
// careful with costing. A truncate that can be optimized out would be
// fine. Otherwise we might only want to create a truncate if it avoids
// a split.
if (SDValue TruncIndex =
DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);

// Shrink any sign/zero extends from 32 or smaller to larger than 32 if
// there are sufficient sign bits. Only do this before legalize types to
// avoid creating illegal types in truncate.
if ((Index.getOpcode() == ISD::SIGN_EXTEND ||
Index.getOpcode() == ISD::ZERO_EXTEND) &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}

// Shrink if we remove an illegal type.
if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
}
}
}
Expand Down
Loading