-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Update the base and index value for masked gather #130920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…codegen on-behalf-of: @amd [email protected]
…codegen on-behalf-of: @amd [email protected]
…-project into gatherBaseIndexfix
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-llvm-selectiondag Author: Rohit Aggarwal (rohitaggarwal007) ChangesWhile lowering of the masked gather in X86, gather instruction was getting split into two gather instructions. The type-legalization phase in SelectionDAG found that index type is v16i64 and ZMM can not hold 16 64 bits value. it decide to split the operand into 8 vector factor i.e. v8i64. It lead to poor codegen and loss of vectorization benefits. I have recalculated the Index and base of getelementptr which has non-uniform base. The recalculate values are further use by DAG.getMaskedGather to create the instruction. @RKSimon @phoebewang Please review the PR Patch is 22.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130920.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2089d47e9cbc8..46b28b7f5813d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5127,6 +5127,13 @@ class TargetLowering : public TargetLoweringBase {
SmallVectorImpl<SDValue> &Ops,
SelectionDAG &DAG) const;
+ // Target may override this function to decided whether it want to update the
+ // base and index value of a non-uniform gep
+ virtual bool updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+ SDValue &Index, const SDLoc &DL,
+ const SDValue &Gep, SelectionDAG &DAG,
+ const BasicBlock *CurBB) const;
+
//===--------------------------------------------------------------------===//
// Div utility functions
//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 14bb1d943d2d6..60d50ed735e09 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4905,6 +4905,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
+ if (!UniformBase) {
+ TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
+ I.getParent());
+ }
+
EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
@@ -5024,6 +5029,11 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
+ if (!UniformBase) {
+ TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
+ I.getParent());
+ }
+
EVT IdxVT = Index.getValueType();
EVT EltTy = IdxVT.getVectorElementType();
if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index bd72718c49031..eb2ac6044cb6b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5655,6 +5655,14 @@ void TargetLowering::CollectTargetIntrinsicOperands(
const CallInst &I, SmallVectorImpl<SDValue> &Ops, SelectionDAG &DAG) const {
}
+// By default, this function is disabled. Overriding target can enable it
+bool TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+ SDValue &Index, const SDLoc &DL,
+ const SDValue &Gep, SelectionDAG &DAG,
+ const BasicBlock *CurBB) const {
+ return false;
+}
+
std::pair<unsigned, const TargetRegisterClass *>
TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI,
StringRef Constraint,
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 24e5d8bfc404c..528e823abe01b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
"SHIFT, LEA, etc."),
cl::Hidden);
+static cl::opt<bool>
+ EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
+ cl::desc("Update the value of base and index"),
+ cl::Hidden);
+
X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
const X86Subtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
return TargetLowering::getPrefLoopAlignment();
}
+
+// Target override this function to decided whether it want to update the base
+// and index value of a non-uniform gep
+bool X86TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+ SDValue &Index, const SDLoc &DL,
+ const SDValue &Gep,
+ SelectionDAG &DAG,
+ const BasicBlock *CurBB) const {
+ if (!EnableBaseIndexUpdate)
+ return false;
+
+ const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (GEP && GEP->getParent() != CurBB)
+ return false;
+
+ SDValue nbase;
+ /* For the gep instruction, we are trying to properly assign the base and
+ index value We are go through the lower code and iterate backward.
+ */
+ if (Gep.getOpcode() == ISD::ADD) {
+ SDValue Op0 = Gep.getOperand(0); // base or add
+ SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+ nbase = Op0;
+ SDValue Idx = Op1;
+ auto Flags = Gep->getFlags();
+
+ if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+ SDValue Op00 = Op0.getOperand(0); // Base
+ nbase = Op00;
+ Idx = Op0.getOperand(1);
+ } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+ Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+ return false;
+ }
+ SDValue nIndex;
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+ Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!(ExtIsNonNegative && ExtOpIsNonNegative))
+ return false;
+
+ SDValue newOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = newOp10.getValueType(); // Use the
+
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+ if (!ConstEltNo) {
+ return false;
+ }
+ SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+ DAG.getConstant(ConstEltNo->getZExtValue(),
+ DL, VT.getScalarType()));
+ nIndex = DAG.getNode(ISD::SHL, DL, VT, newOp10,
+ DAG.getBuildVector(VT, DL, Ops));
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ if (Op0 != nbase) {
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+ if (!ConstEltNo) {
+ return false;
+ }
+ SmallVector<SDValue, 8> Ops(
+ nIndex.getValueType().getVectorNumElements(),
+ DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+ nIndex.getValueType().getScalarType()));
+ nIndex = DAG.getNode(ISD::ADD, DL, nIndex.getValueType(), nIndex,
+ DAG.getBuildVector(nIndex.getValueType(), DL, Ops),
+ Flags);
+ }
+ Base = nbase.getOperand(0);
+ Index = nIndex;
+ LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
+ "information\n";
+ dbgs() << "updated base "; Base.dump();
+ dbgs() << "updated Index "; Index.dump(););
+ return true;
+ }
+ return false;
+}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 4a2b35e9efe7c..c092055329c58 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1671,6 +1671,13 @@ namespace llvm {
return TargetLoweringBase::getTypeToTransformTo(Context, VT);
}
+ // Target override this function to decided whether it want to update the
+ // base and index value of a non-uniform gep
+ bool updateBaseAndIndex(const Value *Ptr, SDValue &Base, SDValue &Index,
+ const SDLoc &DL, const SDValue &Gep,
+ SelectionDAG &DAG,
+ const BasicBlock *CurBB) const override;
+
protected:
std::pair<const TargetRegisterClass *, uint8_t>
findRepresentativeClass(const TargetRegisterInfo *TRI,
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..ae24188ece58d
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,210 @@
+; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+; RUN: llc -update-baseIndex -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+; RUN: llc -update-baseIndex=false -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s -check-prefix=OLD
+
+; ModuleID = 'qwdemo.c'
+source_filename = "qwdemo.c"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.pt = type { float, float, float, i32 }
+
+; Function Attrs: nofree nosync nounwind memory(argmem: readwrite) uwtable
+define dso_local i32 @foo(float noundef %cut_coulsq, ptr noalias nocapture noundef readonly %jlist, i32 noundef %jnum, ptr noalias nocapture noundef readonly %x, ptr noalias nocapture noundef writeonly %trsq, ptr noalias nocapture noundef writeonly %tdelx, ptr noalias nocapture noundef writeonly %tdely, ptr noalias nocapture noundef writeonly %tdelz, ptr noalias nocapture noundef writeonly %tjtype, ptr noalias nocapture noundef writeonly %tj, ptr noalias nocapture noundef readnone %tx, ptr noalias nocapture noundef readnone %ty, ptr noalias nocapture noundef readnone %tz) local_unnamed_addr #0 {
+entry:
+ %0 = load float, ptr %x, align 4, !tbaa !5
+ %y = getelementptr inbounds %struct.pt, ptr %x, i64 0, i32 1
+ %1 = load float, ptr %y, align 4, !tbaa !11
+ %z = getelementptr inbounds %struct.pt, ptr %x, i64 0, i32 2
+ %2 = load float, ptr %z, align 4, !tbaa !12
+ %cmp62 = icmp sgt i32 %jnum, 0
+ br i1 %cmp62, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader: ; preds = %entry
+ %wide.trip.count = zext i32 %jnum to i64
+ %min.iters.check = icmp ult i32 %jnum, 16
+ br i1 %min.iters.check, label %for.body.preheader75, label %vector.ph
+
+vector.ph: ; preds = %for.body.preheader
+ %n.vec = and i64 %wide.trip.count, 4294967280
+ %broadcast.splatinsert = insertelement <16 x float> poison, float %0, i64 0
+ %broadcast.splat = shufflevector <16 x float> %broadcast.splatinsert, <16 x float> poison, <16 x i32> zeroinitializer
+ %broadcast.splatinsert67 = insertelement <16 x float> poison, float %1, i64 0
+ %broadcast.splat68 = shufflevector <16 x float> %broadcast.splatinsert67, <16 x float> poison, <16 x i32> zeroinitializer
+ %broadcast.splatinsert70 = insertelement <16 x float> poison, float %2, i64 0
+ %broadcast.splat71 = shufflevector <16 x float> %broadcast.splatinsert70, <16 x float> poison, <16 x i32> zeroinitializer
+ %broadcast.splatinsert72 = insertelement <16 x float> poison, float %cut_coulsq, i64 0
+ %broadcast.splat73 = shufflevector <16 x float> %broadcast.splatinsert72, <16 x float> poison, <16 x i32> zeroinitializer
+ br label %vector.body
+
+; CHECK-LABEL: .LBB0_6:
+; CHECK: vgatherdps (%rdx,%zmm12), %zmm13 {%k1}
+; CHECK: vgatherdps (%rdx,%zmm14), %zmm15 {%k1}
+; CHECK: vgatherdps (%rdx,%zmm17), %zmm16 {%k1}
+
+; OLD-LABEL: .LBB0_6:
+; OLD: vgatherqps (%rdx,%zmm12), %ymm15 {%k1}
+; OLD: vgatherqps (%rdx,%zmm11), %ymm12 {%k1}
+; OLD: vgatherqps 4(,%zmm14), %ymm12 {%k1}
+; OLD: vgatherqps 4(,%zmm13), %ymm15 {%k1}
+; OLD: vgatherqps 8(,%zmm14), %ymm15 {%k1}
+; OLD: vgatherqps 8(,%zmm13), %ymm16 {%k1}
+
+vector.body: ; preds = %vector.body, %vector.ph
+ %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+ %pred.index = phi i32 [ 0, %vector.ph ], [ %predphi, %vector.body ]
+ %3 = getelementptr inbounds i32, ptr %jlist, i64 %index
+ %wide.load = load <16 x i32>, ptr %3, align 4, !tbaa !13
+ %4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %5 = zext <16 x i32> %4 to <16 x i64>
+ %6 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5
+ %wide.masked.gather = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %6, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !5
+ %7 = fsub <16 x float> %broadcast.splat, %wide.masked.gather
+ %8 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 1
+ %wide.masked.gather66 = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %8, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !11
+ %9 = fsub <16 x float> %broadcast.splat68, %wide.masked.gather66
+ %10 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 2
+ %wide.masked.gather69 = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %10, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !12
+ %11 = fsub <16 x float> %broadcast.splat71, %wide.masked.gather69
+ %12 = fmul <16 x float> %9, %9
+ %13 = tail call <16 x float> @llvm.fmuladd.v16f32(<16 x float> %7, <16 x float> %7, <16 x float> %12)
+ %14 = tail call <16 x float> @llvm.fmuladd.v16f32(<16 x float> %11, <16 x float> %11, <16 x float> %13)
+ %15 = fcmp olt <16 x float> %14, %broadcast.splat73
+ %16 = sext i32 %pred.index to i64
+ %17 = getelementptr float, ptr %trsq, i64 %16
+ tail call void @llvm.masked.compressstore.v16f32(<16 x float> %14, ptr %17, <16 x i1> %15), !tbaa !14
+ %18 = getelementptr float, ptr %tdelx, i64 %16
+ tail call void @llvm.masked.compressstore.v16f32(<16 x float> %7, ptr %18, <16 x i1> %15), !tbaa !14
+ %19 = getelementptr float, ptr %tdely, i64 %16
+ tail call void @llvm.masked.compressstore.v16f32(<16 x float> %9, ptr %19, <16 x i1> %15), !tbaa !14
+ %20 = getelementptr float, ptr %tdelz, i64 %16
+ tail call void @llvm.masked.compressstore.v16f32(<16 x float> %11, ptr %20, <16 x i1> %15), !tbaa !14
+ %21 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 3
+ %wide.masked.gather74 = tail call <16 x i32> @llvm.masked.gather.v16i32.v16p0(<16 x ptr> %21, i32 4, <16 x i1> %15, <16 x i32> poison), !tbaa !15
+ %22 = getelementptr i32, ptr %tjtype, i64 %16
+ tail call void @llvm.masked.compressstore.v16i32(<16 x i32> %wide.masked.gather74, ptr %22, <16 x i1> %15), !tbaa !13
+ %23 = getelementptr i32, ptr %tj, i64 %16
+ tail call void @llvm.masked.compressstore.v16i32(<16 x i32> %wide.load, ptr %23, <16 x i1> %15), !tbaa !13
+ %24 = bitcast <16 x i1> %15 to i16
+ %mask.popcnt = tail call i16 @llvm.ctpop.i16(i16 %24), !range !16
+ %popcnt.cmp.not = icmp eq i16 %24, 0
+ %narrow = select i1 %popcnt.cmp.not, i16 0, i16 %mask.popcnt
+ %popcnt.inc = zext i16 %narrow to i32
+ %predphi = add i32 %pred.index, %popcnt.inc
+ %index.next = add nuw i64 %index, 16
+ %25 = icmp eq i64 %index.next, %n.vec
+ br i1 %25, label %middle.block, label %vector.body, !llvm.loop !17
+
+middle.block: ; preds = %vector.body
+ %cmp.n = icmp eq i64 %n.vec, %wide.trip.count
+ br i1 %cmp.n, label %for.cond.cleanup, label %for.body.preheader75
+
+for.body.preheader75: ; preds = %for.body.preheader, %middle.block
+ %indvars.iv.ph = phi i64 [ 0, %for.body.preheader ], [ %n.vec, %middle.block ]
+ %ej.064.ph = phi i32 [ 0, %for.body.preheader ], [ %predphi, %middle.block ]
+ br label %for.body
+
+for.cond.cleanup: ; preds = %if.end, %middle.block, %entry
+ %ej.0.lcssa = phi i32 [ 0, %entry ], [ %predphi, %middle.block ], [ %ej.1, %if.end ]
+ ret i32 %ej.0.lcssa
+
+for.body: ; preds = %for.body.preheader75, %if.end
+ %indvars.iv = phi i64 [ %indvars.iv.next, %if.end ], [ %indvars.iv.ph, %for.body.preheader75 ]
+ %ej.064 = phi i32 [ %ej.1, %if.end ], [ %ej.064.ph, %for.body.preheader75 ]
+ %arrayidx4 = getelementptr inbounds i32, ptr %jlist, i64 %indvars.iv
+ %26 = load i32, ptr %arrayidx4, align 4, !tbaa !13
+ %and = and i32 %26, 536870911
+ %idxprom5 = zext i32 %and to i64
+ %arrayidx6 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5
+ %27 = load float, ptr %arrayidx6, align 4, !tbaa !5
+ %sub = fsub float %0, %27
+ %y10 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 1
+ %28 = load float, ptr %y10, align 4, !tbaa !11
+ %sub11 = fsub float %1, %28
+ %z14 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 2
+ %29 = load float, ptr %z14, align 4, !tbaa !12
+ %sub15 = fsub float %2, %29
+ %mul16 = fmul float %sub11, %sub11
+ %30 = tail call float @llvm.fmuladd.f32(float %sub, float %sub, float %mul16)
+ %31 = tail call float @llvm.fmuladd.f32(float %sub15, float %sub15, float %30)
+ %cmp17 = fcmp olt float %31, %cut_coulsq
+ br i1 %cmp17, label %if.then, label %if.end
+
+if.then: ; preds = %for.body
+ %idxprom18 = sext i32 %ej.064 to i64
+ %arrayidx19 = getelementptr inbounds float, ptr %trsq, i64 %idxprom18
+ store float %31, ptr %arrayidx19, align 4, !tbaa !14
+ %arrayidx21 = getelementptr inbounds float, ptr %tdelx, i64 %idxprom18
+ store float %sub, ptr %arrayidx21, align 4, !tbaa !14
+ %arrayidx23 = getelementptr inbounds float, ptr %tdely, i64 %idxprom18
+ store float %sub11, ptr %arrayidx23, align 4, !tbaa !14
+ %arrayidx25 = getelementptr inbounds float, ptr %tdelz, i64 %idxprom18
+ store float %sub15, ptr %arrayidx25, align 4, !tbaa !14
+ %w = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 3
+ %32 = load i32, ptr %w, align 4, !tbaa !15
+ %arrayidx29 = getelementptr inbounds i32, ptr %tjtype, i64 %idxprom18
+ store i32 %32, ptr %arrayidx29, align 4, !tbaa !13
+ %arrayidx33 = getelementptr inbounds i32, ptr %tj, i64 %idxprom18
+ store i32 %26, ptr %arrayidx33, align 4, !tbaa !13
+ %inc = add nsw i32 %ej.064, 1
+ br label %if.end
+
+if.end: ; preds = %if.then, %for.body
+ %ej.1 = phi i32 [ %inc, %if.then ], [ %ej.064, %for.body ]
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !21
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare float @llvm.fmuladd.f32(float, float, float) #1
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(read)
+declare <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr>, i32 immarg, <16 x i1>, <16 x float>) #2
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare <16 x float> @llvm.fmuladd.v16f32(<16 x float>, <16 x float>, <16 x float>) #3
+
+; Function Attrs: nocallback ...
[truncated]
|
// Target may override this function to decided whether it want to update the | ||
// base and index value of a non-uniform gep | ||
virtual bool updateBaseAndIndex(const Value *Ptr, SDValue &Base, | ||
SDValue &Index, const SDLoc &DL, | ||
const SDValue &Gep, SelectionDAG &DAG, | ||
const BasicBlock *CurBB) const; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name and purpose unclear. We also have way too many poorly designed, narrow purpose hooks. You should try to avoid introducing this without a very strong justification for why there can't be a reasonable default behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not understood the comment.
@@ -4905,6 +4905,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { | |||
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); | |||
} | |||
|
|||
if (!UniformBase) { | |||
TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG, | |||
I.getParent()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DAG builder should avoid interpreting the incoming IR in target specific ways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @arsenm
Where should I place the code snippet so that it is applicable for X86 target?
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" | ||
target triple = "x86_64-unknown-linux-gnu" | ||
|
||
%struct.pt = type { float, float, float, i32 } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test needs to be reduced down a lot - remove all the unnecessary attributes etc. and just contain the minimal IR to show the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RKSimon, updated the test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" | ||
target triple = "x86_64-unknown-linux-gnu" | ||
|
||
%struct.pt = type { float, float, float, i32 } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.
|
||
define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) { | ||
%wide.load = load <16 x i32>, ptr %arr, align 4 | ||
%4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) avoid numbered variable names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const { | |||
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment); | |||
return TargetLowering::getPrefLoopAlignment(); | |||
} | |||
|
|||
// Target override this function to decided whether it want to update the base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decided -> decide
return false; | ||
|
||
SDValue nbase; | ||
/* For the gep instruction, we are trying to properly assign the base and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use //
comments
if (GEP && GEP->getParent() != CurBB) | ||
return false; | ||
|
||
SDValue nbase; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable names should be capitalized
bool ExtIsNonNegative = ExtKnown.isNonNegative(); | ||
KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0)); | ||
bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative(); | ||
if (!(ExtIsNonNegative && ExtOpIsNonNegative)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deMorgan
EVT VT = newOp10.getValueType(); // Use the | ||
|
||
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0)); | ||
if (!ConstEltNo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop curly braces
} | ||
if (Op0 != nbase) { | ||
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0)); | ||
if (!ConstEltNo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop curly braces
} | ||
Base = nbase.getOperand(0); | ||
Index = nIndex; | ||
LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successful*
|
||
SDValue newOp10 = | ||
Op10.getOperand(0); // Get the Operand zero from the ext | ||
EVT VT = newOp10.getValueType(); // Use the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Use the" is an incomplete comment
if (!EnableBaseIndexUpdate) | ||
return false; | ||
|
||
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to check that the Ptr is a GEP? It doesn't look like we use any information from it. The rest of the code is on SelectionDAG nodes unless I missed something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check whether ptr is derived from GEP. If not, we return false and transformation does not happen.
If the check does not fit here, i can move it to callsite
Can this be done in |
Yes, it is possible there also. |
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
✅ With the latest revision this PR passed the C/C++ code formatter. |
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd much prefer to see this fixed in x86 in combineGatherScatter, where there's already a lot of similar code - the TLI callback seems too x86 specific tbh.
Sure, I will more the code to combineGatherScatter |
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking this is more complicated than it needs to be - we should be able to transfer some/all of the SHL into the gather/scatter scale operand (max scale = 8) - assuming we don't lose the sign bit. By doing that we should(?) be able to then recognize when its safe to remove the zext to v16i64 (illegal type so should always be able to truncate to v16i32) - does that make sense?
EVT VT = NewOp10.getValueType(); // Use the operand's type to determine | ||
// the type of index | ||
|
||
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is assuming that the shl is uniform - use getValidMinimumShiftAmount instead (replaces the BUILD_VECTOR check above as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, my understanding is this
From
unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
**Op11.getOpcode() == ISD::BUILD_VECTOR**) {
to
unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
**DAG.getValidMinimumShiftAmount(Idx)**) {
Please correct me
@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization( | |||
"SHIFT, LEA, etc."), | |||
cl::Hidden); | |||
|
|||
static cl::opt<bool> | |||
EnableBaseIndexUpdate("update-baseIndex", cl::init(true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not follow command line option naming convention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this is the kind of option that will never be used by anyone. Can you just do this unconditionally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added this flag for the safety incase any untest usecase occur. So that we can fall back to existing behavior.
Sure, I will remove the flag.
@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization( | |||
"SHIFT, LEA, etc."), | |||
cl::Hidden); | |||
|
|||
static cl::opt<bool> | |||
EnableBaseIndexUpdate("update-baseIndex", cl::init(true), | |||
cl::desc("Update the value of base and index"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name and description don't mean anything standalone
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will update the information
I didn't understand it completely. How does the scale operand change will impact the Index operand? I might be missing some knowledge here. Thanks |
…-project into gatherBaseIndexfix
…-project into gatherBaseIndexfix
x86 natively supports gather/scatter nodes with scale values of 1/2/4/8 (similar to regular address math). If the index node is ISD::SHL by a constant shift amount, we might be able to transfer some/all of the shift amount into the scale:
this could then result in newindex being able to be safely truncated to a vXi32 type |
Thanks for the reply. The pattern of index is not always of type ISD::SHL. It could be ISD::ADD also. As we know ptr there is non-uniform as the base value is 0 and index have many pattern as follows.
I will try this and update the patch. Just one question, on changing the scale in the micro coded instruction, there will be extra multiplication operation to calculate the absolute address. This will cause extra operation, right? Or this will be handled in another pipe while decoding(parallel fashion)? |
@RKSimon, I tried the suggestion but the value of newscale is coming out to be 16 and it is failing if newscale <= 8 check. The getValidMinimumShiftAmount() is returning 4 and on using 1 << 4, we are getting 16 and hence bailing out. Thanks |
Here is my experiment: main...RKSimon:llvm-project:x86-gather-scatter-addr-scale It still needs work: it stupidly moves 1 bit of shift left per combine instead of doing something smarter, plus we should probably only bother trying to do this if we think it will eventually allow truncation. Anyway, take a look and see if there's anything you can use that will help. |
Hi @RKSimon I commented my changes and applied the patch. It worked for the reduced test case. In the test case attached in the commit, it is working.
In the DAGcombiner when refineUniformBase is called to combine the gather again. For %res, when DAG is trying to simplify base and separate the Operand from index and, update the base. It is bailing out.
In our case index has multiple uses and bailing out.
Additionally, When i debug the logs, I found out that instruction which are already combined and process, are still referring to these uses. Looks a wired behavior. |
Gentle Reminder! |
Hi @RKSimon |
Please can you start with splitting off the new test coverage to a new PR with trunk's current codegen? It'll make it easier to compare against different approaches. (sorry for slow reply). |
…#130920 Show current codegen for base reference
…llvm#130920 Show current codegen for base reference
@RKSimon, I have update the testcase in the masked_gather_scatter.ll file |
…llvm#130920 Show current codegen for base reference
While lowering of the masked gather in X86, gather instruction was getting split into two gather instructions. The type-legalization phase in SelectionDAG found that index type is v16i64 and ZMM can not hold 16 64 bits value. it decide to split the operand into 8 vector factor i.e. v8i64. It lead to poor codegen and loss of vectorization benefits.
I have recalculated the Index and base of getelementptr which has non-uniform base. The recalculate values are further use by DAG.getMaskedGather to create the instruction.
@RKSimon @phoebewang Please review the PR