-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Update the base and index value for masked gather #130920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
72b0f4b
a231c96
1ff621f
85d2e0e
2f0897b
cdb181d
dd8762a
07dd191
7d840ed
1bd64b8
b00f0a9
2268967
8aeeb31
f4e8b0c
49f084e
a7a52cd
6252789
8565941
5eecf46
ad31491
ba2f9e7
2e344a1
87e2533
f516be2
c2848c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization( | |
"SHIFT, LEA, etc."), | ||
cl::Hidden); | ||
|
||
static cl::opt<bool> | ||
EnableBaseIndexUpdate("update-baseIndex", cl::init(true), | ||
cl::desc("Update the value of base and index"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Name and description don't mean anything standalone There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will update the information |
||
cl::Hidden); | ||
|
||
X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, | ||
const X86Subtarget &STI) | ||
: TargetLowering(TM), Subtarget(STI) { | ||
|
@@ -56370,6 +56375,112 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS, | |
Scatter->isTruncatingStore()); | ||
} | ||
|
||
// Target override this function to decide whether it want to update the base | ||
// and index value of a non-uniform gep | ||
static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, const SDLoc &DL, | ||
const SDValue &Gep, SelectionDAG &DAG) { | ||
if (!EnableBaseIndexUpdate) | ||
return false; | ||
|
||
SDValue Nbase; | ||
SDValue Nindex; | ||
bool Changed = false; | ||
// This function check the opcode of Index and update the index | ||
auto checkAndUpdateIndex = [&](SDValue &Idx) { | ||
if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV | ||
SDValue Op10 = Idx.getOperand(0); // Zext or Sext value | ||
SDValue Op11 = Idx.getOperand(1); // Build vector of constant | ||
|
||
unsigned IndexWidth = Op10.getScalarValueSizeInBits(); | ||
if ((Op10.getOpcode() == ISD::SIGN_EXTEND || | ||
Op10.getOpcode() == ISD::ZERO_EXTEND) && | ||
IndexWidth > 32 && | ||
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 && | ||
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) && | ||
Op11.getOpcode() == ISD::BUILD_VECTOR) { | ||
|
||
KnownBits ExtKnown = DAG.computeKnownBits(Op10); | ||
bool ExtIsNonNegative = ExtKnown.isNonNegative(); | ||
KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0)); | ||
bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative(); | ||
if (!ExtIsNonNegative || !ExtOpIsNonNegative) | ||
return false; | ||
|
||
SDValue NewOp10 = | ||
Op10.getOperand(0); // Get the Operand zero from the ext | ||
EVT VT = NewOp10.getValueType(); // Use the operand's type to determine | ||
// the type of index | ||
|
||
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is assuming that the shl is uniform - use getValidMinimumShiftAmount instead (replaces the BUILD_VECTOR check above as well) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, my understanding is this
to
Please correct me |
||
if (!ConstEltNo) | ||
return false; | ||
|
||
SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), | ||
DAG.getConstant(ConstEltNo->getZExtValue(), | ||
DL, VT.getScalarType())); | ||
Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10, | ||
DAG.getBuildVector(VT, DL, Ops)); | ||
return true; | ||
} | ||
} | ||
return false; | ||
}; | ||
|
||
// For the gep instruction, we are trying to properly assign the base and | ||
// index value We are go through the lower code and iterate backward. | ||
if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) { | ||
SDValue Op0 = Gep.getOperand(0); // base or add | ||
SDValue Op1 = Gep.getOperand(1); // build vector or SHL | ||
Nbase = Op0; | ||
SDValue Idx = Op1; | ||
auto Flags = Gep->getFlags(); | ||
|
||
if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx) | ||
SDValue Op00 = Op0.getOperand(0); // Base | ||
Nbase = Op00; | ||
Idx = Op0.getOperand(1); | ||
} else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR && | ||
Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) { | ||
return false; | ||
} | ||
if (!checkAndUpdateIndex(Idx)) { | ||
return false; | ||
} | ||
if (Op0 != Nbase) { | ||
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0)); | ||
if (!ConstEltNo) | ||
return false; | ||
|
||
SmallVector<SDValue, 8> Ops( | ||
Nindex.getValueType().getVectorNumElements(), | ||
DAG.getConstant(ConstEltNo->getZExtValue(), DL, | ||
Nindex.getValueType().getScalarType())); | ||
Nindex = DAG.getNode(ISD::ADD, DL, Nindex.getValueType(), Nindex, | ||
DAG.getBuildVector(Nindex.getValueType(), DL, Ops), | ||
Flags); | ||
} | ||
Base = Nbase.getOperand(0); | ||
Index = Nindex; | ||
Changed = true; | ||
} else if (Base.getOpcode() == ISD::CopyFromReg || | ||
(Base.getOpcode() == ISD::ADD && | ||
Base.getOperand(0).getOpcode() == ISD::CopyFromReg && | ||
isConstOrConstSplat(Base.getOperand(1)))) { | ||
if (checkAndUpdateIndex(Index)) { | ||
Index = Nindex; | ||
Changed = true; | ||
} | ||
} | ||
if (Changed) { | ||
LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep " | ||
"information\n"; | ||
dbgs() << "updated base "; Base.dump(); | ||
dbgs() << "updated Index "; Index.dump();); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, | ||
TargetLowering::DAGCombinerInfo &DCI) { | ||
SDLoc DL(N); | ||
|
@@ -56380,6 +56491,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, | |
const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | ||
|
||
if (DCI.isBeforeLegalize()) { | ||
if (updateBaseAndIndex(Base, Index, DL, Index, DAG)) | ||
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); | ||
unsigned IndexWidth = Index.getScalarValueSizeInBits(); | ||
|
||
// Shrink constant indices if they are larger than 32-bits. | ||
|
@@ -56475,7 +56588,6 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, | |
return SDValue(N, 0); | ||
} | ||
} | ||
|
||
return SDValue(); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s | ||
; RUN: llc -update-baseIndex -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s | ||
; RUN: llc -update-baseIndex=false -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s -check-prefix=OLD | ||
|
||
%struct.pt = type { float, float, float, i32 } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test needs to be reduced down a lot - remove all the unnecessary attributes etc. and just contain the minimal IR to show the problem There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RKSimon, updated the test case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise. |
||
|
||
define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) { | ||
; CHECK-LABEL: test_gather_16f32_1: | ||
; CHECK: # %bb.0: | ||
; CHECK-NEXT: vpslld $4, (%rsi), %zmm2 | ||
; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0 | ||
; CHECK-NEXT: vpmovb2m %xmm0, %k1 | ||
; CHECK-NEXT: vgatherdps (%rdi,%zmm2), %zmm1 {%k1} | ||
; CHECK-NEXT: vmovaps %zmm1, %zmm0 | ||
; CHECK-NEXT: retq | ||
; | ||
; OLD-LABEL: test_gather_16f32_1: | ||
; OLD: # %bb.0: | ||
; OLD-NEXT: vpsllw $7, %xmm0, %xmm0 | ||
; OLD-NEXT: vmovdqu64 (%rsi), %zmm4 | ||
; OLD-NEXT: vextractf64x4 $1, %zmm1, %ymm3 | ||
; OLD-NEXT: vpmovb2m %xmm0, %k1 | ||
; OLD-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm4, %zmm0 | ||
; OLD-NEXT: kshiftrw $8, %k1, %k2 | ||
; OLD-NEXT: vextracti64x4 $1, %zmm0, %ymm2 | ||
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero | ||
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm2 = ymm2[0],zero,ymm2[1],zero,ymm2[2],zero,ymm2[3],zero,ymm2[4],zero,ymm2[5],zero,ymm2[6],zero,ymm2[7],zero | ||
; OLD-NEXT: vpsllq $4, %zmm0, %zmm0 | ||
; OLD-NEXT: vpsllq $4, %zmm2, %zmm2 | ||
; OLD-NEXT: vgatherqps (%rdi,%zmm0), %ymm1 {%k1} | ||
; OLD-NEXT: vgatherqps (%rdi,%zmm2), %ymm3 {%k2} | ||
; OLD-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm0 | ||
; OLD-NEXT: retq | ||
%wide.load = load <16 x i32>, ptr %arr, align 4 | ||
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911> | ||
%zext = zext <16 x i32> %and to <16 x i64> | ||
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext | ||
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0) | ||
ret <16 x float> %res | ||
} | ||
|
||
define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) { | ||
; CHECK-LABEL: test_gather_16f32_2: | ||
; CHECK: # %bb.0: | ||
; CHECK-NEXT: vpslld $4, (%rsi), %zmm2 | ||
; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0 | ||
; CHECK-NEXT: vpmovb2m %xmm0, %k1 | ||
; CHECK-NEXT: vgatherdps 4(%rdi,%zmm2), %zmm1 {%k1} | ||
; CHECK-NEXT: vmovaps %zmm1, %zmm0 | ||
; CHECK-NEXT: retq | ||
; | ||
; OLD-LABEL: test_gather_16f32_2: | ||
; OLD: # %bb.0: | ||
; OLD-NEXT: vpsllw $7, %xmm0, %xmm0 | ||
; OLD-NEXT: vmovdqu64 (%rsi), %zmm4 | ||
; OLD-NEXT: vextractf64x4 $1, %zmm1, %ymm3 | ||
; OLD-NEXT: vpmovb2m %xmm0, %k1 | ||
; OLD-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm4, %zmm0 | ||
; OLD-NEXT: kshiftrw $8, %k1, %k2 | ||
; OLD-NEXT: vextracti64x4 $1, %zmm0, %ymm2 | ||
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero | ||
; OLD-NEXT: vpmovzxdq {{.*#+}} zmm2 = ymm2[0],zero,ymm2[1],zero,ymm2[2],zero,ymm2[3],zero,ymm2[4],zero,ymm2[5],zero,ymm2[6],zero,ymm2[7],zero | ||
; OLD-NEXT: vpsllq $4, %zmm0, %zmm0 | ||
; OLD-NEXT: vpsllq $4, %zmm2, %zmm2 | ||
; OLD-NEXT: vgatherqps 4(%rdi,%zmm0), %ymm1 {%k1} | ||
; OLD-NEXT: vgatherqps 4(%rdi,%zmm2), %ymm3 {%k2} | ||
; OLD-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm0 | ||
; OLD-NEXT: retq | ||
%wide.load = load <16 x i32>, ptr %arr, align 4 | ||
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911> | ||
%zext = zext <16 x i32> %and to <16 x i64> | ||
%ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1 | ||
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0) | ||
ret <16 x float> %res | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not follow command line option naming convention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this is the kind of option that will never be used by anyone. Can you just do this unconditionally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added this flag for the safety incase any untest usecase occur. So that we can fall back to existing behavior.
Sure, I will remove the flag.