Skip to content

Update the base and index value for masked gather #130920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from

Conversation

rohitaggarwal007
Copy link
Contributor

While lowering of the masked gather in X86, gather instruction was getting split into two gather instructions. The type-legalization phase in SelectionDAG found that index type is v16i64 and ZMM can not hold 16 64 bits value. it decide to split the operand into 8 vector factor i.e. v8i64. It lead to poor codegen and loss of vectorization benefits.

I have recalculated the Index and base of getelementptr which has non-uniform base. The recalculate values are further use by DAG.getMaskedGather to create the instruction.

@RKSimon @phoebewang Please review the PR

@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Mar 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-selectiondag

Author: Rohit Aggarwal (rohitaggarwal007)

Changes

While lowering of the masked gather in X86, gather instruction was getting split into two gather instructions. The type-legalization phase in SelectionDAG found that index type is v16i64 and ZMM can not hold 16 64 bits value. it decide to split the operand into 8 vector factor i.e. v8i64. It lead to poor codegen and loss of vectorization benefits.

I have recalculated the Index and base of getelementptr which has non-uniform base. The recalculate values are further use by DAG.getMaskedGather to create the instruction.

@RKSimon @phoebewang Please review the PR


Patch is 22.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130920.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+8)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+101)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+7)
  • (added) llvm/test/CodeGen/X86/gatherBaseIndexFix.ll (+210)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2089d47e9cbc8..46b28b7f5813d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5127,6 +5127,13 @@ class TargetLowering : public TargetLoweringBase {
                                               SmallVectorImpl<SDValue> &Ops,
                                               SelectionDAG &DAG) const;
 
+  // Target may override this function to decided whether it want to update the
+  // base and index value of a non-uniform gep
+  virtual bool updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+                                  SDValue &Index, const SDLoc &DL,
+                                  const SDValue &Gep, SelectionDAG &DAG,
+                                  const BasicBlock *CurBB) const;
+
   //===--------------------------------------------------------------------===//
   // Div utility functions
   //
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 14bb1d943d2d6..60d50ed735e09 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4905,6 +4905,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
 
+  if (!UniformBase) {
+    TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
+                           I.getParent());
+  }
+
   EVT IdxVT = Index.getValueType();
   EVT EltTy = IdxVT.getVectorElementType();
   if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
@@ -5024,6 +5029,11 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
 
+  if (!UniformBase) {
+    TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
+                           I.getParent());
+  }
+
   EVT IdxVT = Index.getValueType();
   EVT EltTy = IdxVT.getVectorElementType();
   if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index bd72718c49031..eb2ac6044cb6b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5655,6 +5655,14 @@ void TargetLowering::CollectTargetIntrinsicOperands(
     const CallInst &I, SmallVectorImpl<SDValue> &Ops, SelectionDAG &DAG) const {
 }
 
+// By default, this function is disabled. Overriding target can enable it
+bool TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+                                        SDValue &Index, const SDLoc &DL,
+                                        const SDValue &Gep, SelectionDAG &DAG,
+                                        const BasicBlock *CurBB) const {
+  return false;
+}
+
 std::pair<unsigned, const TargetRegisterClass *>
 TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI,
                                              StringRef Constraint,
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 24e5d8bfc404c..528e823abe01b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
              "SHIFT, LEA, etc."),
     cl::Hidden);
 
+static cl::opt<bool>
+    EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
+                          cl::desc("Update the value of base and index"),
+                          cl::Hidden);
+
 X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
                                      const X86Subtarget &STI)
     : TargetLowering(TM), Subtarget(STI) {
@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
     return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
   return TargetLowering::getPrefLoopAlignment();
 }
+
+// Target override this function to decided whether it want to update the base
+// and index value of a non-uniform gep
+bool X86TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+                                           SDValue &Index, const SDLoc &DL,
+                                           const SDValue &Gep,
+                                           SelectionDAG &DAG,
+                                           const BasicBlock *CurBB) const {
+  if (!EnableBaseIndexUpdate)
+    return false;
+
+  const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (GEP && GEP->getParent() != CurBB)
+    return false;
+
+  SDValue nbase;
+  /* For the gep instruction, we are trying to properly assign the base and
+  index value We are go through the lower code and iterate backward.
+  */
+  if (Gep.getOpcode() == ISD::ADD) {
+    SDValue Op0 = Gep.getOperand(0); // base or  add
+    SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+    nbase = Op0;
+    SDValue Idx = Op1;
+    auto Flags = Gep->getFlags();
+
+    if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+      SDValue Op00 = Op0.getOperand(0); // Base
+      nbase = Op00;
+      Idx = Op0.getOperand(1);
+    } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+                 Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+      return false;
+    }
+    SDValue nIndex;
+    if (Idx.getOpcode() == ISD::SHL) {  // shl zext, BV
+      SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+      SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+          IndexWidth > 32 &&
+          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+          Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+        KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+        bool ExtIsNonNegative = ExtKnown.isNonNegative();
+        KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+        bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+        if (!(ExtIsNonNegative && ExtOpIsNonNegative))
+          return false;
+
+        SDValue newOp10 =
+            Op10.getOperand(0);          // Get the Operand zero from the ext
+        EVT VT = newOp10.getValueType(); // Use the
+
+        auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+        if (!ConstEltNo) {
+          return false;
+        }
+        SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+                                    DAG.getConstant(ConstEltNo->getZExtValue(),
+                                                    DL, VT.getScalarType()));
+        nIndex = DAG.getNode(ISD::SHL, DL, VT, newOp10,
+                             DAG.getBuildVector(VT, DL, Ops));
+      } else {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    if (Op0 != nbase) {
+      auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+      if (!ConstEltNo) {
+        return false;
+      }
+      SmallVector<SDValue, 8> Ops(
+          nIndex.getValueType().getVectorNumElements(),
+          DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+                          nIndex.getValueType().getScalarType()));
+      nIndex = DAG.getNode(ISD::ADD, DL, nIndex.getValueType(), nIndex,
+                           DAG.getBuildVector(nIndex.getValueType(), DL, Ops),
+                           Flags);
+    }
+    Base = nbase.getOperand(0);
+    Index = nIndex;
+    LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
+                         "information\n";
+               dbgs() << "updated base "; Base.dump();
+               dbgs() << "updated Index "; Index.dump(););
+    return true;
+  }
+  return false;
+}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 4a2b35e9efe7c..c092055329c58 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1671,6 +1671,13 @@ namespace llvm {
       return TargetLoweringBase::getTypeToTransformTo(Context, VT);
     }
 
+    // Target override this function to decided whether it want to update the
+    // base and index value of a non-uniform gep
+    bool updateBaseAndIndex(const Value *Ptr, SDValue &Base, SDValue &Index,
+                            const SDLoc &DL, const SDValue &Gep,
+                            SelectionDAG &DAG,
+                            const BasicBlock *CurBB) const override;
+
   protected:
     std::pair<const TargetRegisterClass *, uint8_t>
     findRepresentativeClass(const TargetRegisterInfo *TRI,
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..ae24188ece58d
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,210 @@
+; RUN: llc  -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+; RUN: llc -update-baseIndex -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+; RUN: llc -update-baseIndex=false -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s -check-prefix=OLD
+
+; ModuleID = 'qwdemo.c'
+source_filename = "qwdemo.c"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.pt = type { float, float, float, i32 }
+
+; Function Attrs: nofree nosync nounwind memory(argmem: readwrite) uwtable
+define dso_local i32 @foo(float noundef %cut_coulsq, ptr noalias nocapture noundef readonly %jlist, i32 noundef %jnum, ptr noalias nocapture noundef readonly %x, ptr noalias nocapture noundef writeonly %trsq, ptr noalias nocapture noundef writeonly %tdelx, ptr noalias nocapture noundef writeonly %tdely, ptr noalias nocapture noundef writeonly %tdelz, ptr noalias nocapture noundef writeonly %tjtype, ptr noalias nocapture noundef writeonly %tj, ptr noalias nocapture noundef readnone %tx, ptr noalias nocapture noundef readnone %ty, ptr noalias nocapture noundef readnone %tz) local_unnamed_addr #0 {
+entry:
+  %0 = load float, ptr %x, align 4, !tbaa !5
+  %y = getelementptr inbounds %struct.pt, ptr %x, i64 0, i32 1
+  %1 = load float, ptr %y, align 4, !tbaa !11
+  %z = getelementptr inbounds %struct.pt, ptr %x, i64 0, i32 2
+  %2 = load float, ptr %z, align 4, !tbaa !12
+  %cmp62 = icmp sgt i32 %jnum, 0
+  br i1 %cmp62, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:                               ; preds = %entry
+  %wide.trip.count = zext i32 %jnum to i64
+  %min.iters.check = icmp ult i32 %jnum, 16
+  br i1 %min.iters.check, label %for.body.preheader75, label %vector.ph
+
+vector.ph:                                        ; preds = %for.body.preheader
+  %n.vec = and i64 %wide.trip.count, 4294967280
+  %broadcast.splatinsert = insertelement <16 x float> poison, float %0, i64 0
+  %broadcast.splat = shufflevector <16 x float> %broadcast.splatinsert, <16 x float> poison, <16 x i32> zeroinitializer
+  %broadcast.splatinsert67 = insertelement <16 x float> poison, float %1, i64 0
+  %broadcast.splat68 = shufflevector <16 x float> %broadcast.splatinsert67, <16 x float> poison, <16 x i32> zeroinitializer
+  %broadcast.splatinsert70 = insertelement <16 x float> poison, float %2, i64 0
+  %broadcast.splat71 = shufflevector <16 x float> %broadcast.splatinsert70, <16 x float> poison, <16 x i32> zeroinitializer
+  %broadcast.splatinsert72 = insertelement <16 x float> poison, float %cut_coulsq, i64 0
+  %broadcast.splat73 = shufflevector <16 x float> %broadcast.splatinsert72, <16 x float> poison, <16 x i32> zeroinitializer
+  br label %vector.body
+
+; CHECK-LABEL: .LBB0_6:
+; CHECK:   vgatherdps      (%rdx,%zmm12), %zmm13 {%k1}
+; CHECK:   vgatherdps      (%rdx,%zmm14), %zmm15 {%k1}
+; CHECK:   vgatherdps      (%rdx,%zmm17), %zmm16 {%k1}
+
+; OLD-LABEL: .LBB0_6:
+; OLD:  vgatherqps      (%rdx,%zmm12), %ymm15 {%k1}
+; OLD:  vgatherqps      (%rdx,%zmm11), %ymm12 {%k1}
+; OLD:  vgatherqps      4(,%zmm14), %ymm12 {%k1}
+; OLD:  vgatherqps      4(,%zmm13), %ymm15 {%k1}
+; OLD:  vgatherqps      8(,%zmm14), %ymm15 {%k1}
+; OLD:  vgatherqps      8(,%zmm13), %ymm16 {%k1}
+
+vector.body:                                      ; preds = %vector.body, %vector.ph
+  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+  %pred.index = phi i32 [ 0, %vector.ph ], [ %predphi, %vector.body ]
+  %3 = getelementptr inbounds i32, ptr %jlist, i64 %index
+  %wide.load = load <16 x i32>, ptr %3, align 4, !tbaa !13
+  %4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+  %5 = zext <16 x i32> %4 to <16 x i64>
+  %6 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5
+  %wide.masked.gather = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %6, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !5
+  %7 = fsub <16 x float> %broadcast.splat, %wide.masked.gather
+  %8 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 1
+  %wide.masked.gather66 = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %8, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !11
+  %9 = fsub <16 x float> %broadcast.splat68, %wide.masked.gather66
+  %10 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 2
+  %wide.masked.gather69 = tail call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %10, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> poison), !tbaa !12
+  %11 = fsub <16 x float> %broadcast.splat71, %wide.masked.gather69
+  %12 = fmul <16 x float> %9, %9
+  %13 = tail call <16 x float> @llvm.fmuladd.v16f32(<16 x float> %7, <16 x float> %7, <16 x float> %12)
+  %14 = tail call <16 x float> @llvm.fmuladd.v16f32(<16 x float> %11, <16 x float> %11, <16 x float> %13)
+  %15 = fcmp olt <16 x float> %14, %broadcast.splat73
+  %16 = sext i32 %pred.index to i64
+  %17 = getelementptr float, ptr %trsq, i64 %16
+  tail call void @llvm.masked.compressstore.v16f32(<16 x float> %14, ptr %17, <16 x i1> %15), !tbaa !14
+  %18 = getelementptr float, ptr %tdelx, i64 %16
+  tail call void @llvm.masked.compressstore.v16f32(<16 x float> %7, ptr %18, <16 x i1> %15), !tbaa !14
+  %19 = getelementptr float, ptr %tdely, i64 %16
+  tail call void @llvm.masked.compressstore.v16f32(<16 x float> %9, ptr %19, <16 x i1> %15), !tbaa !14
+  %20 = getelementptr float, ptr %tdelz, i64 %16
+  tail call void @llvm.masked.compressstore.v16f32(<16 x float> %11, ptr %20, <16 x i1> %15), !tbaa !14
+  %21 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %5, i32 3
+  %wide.masked.gather74 = tail call <16 x i32> @llvm.masked.gather.v16i32.v16p0(<16 x ptr> %21, i32 4, <16 x i1> %15, <16 x i32> poison), !tbaa !15
+  %22 = getelementptr i32, ptr %tjtype, i64 %16
+  tail call void @llvm.masked.compressstore.v16i32(<16 x i32> %wide.masked.gather74, ptr %22, <16 x i1> %15), !tbaa !13
+  %23 = getelementptr i32, ptr %tj, i64 %16
+  tail call void @llvm.masked.compressstore.v16i32(<16 x i32> %wide.load, ptr %23, <16 x i1> %15), !tbaa !13
+  %24 = bitcast <16 x i1> %15 to i16
+  %mask.popcnt = tail call i16 @llvm.ctpop.i16(i16 %24), !range !16
+  %popcnt.cmp.not = icmp eq i16 %24, 0
+  %narrow = select i1 %popcnt.cmp.not, i16 0, i16 %mask.popcnt
+  %popcnt.inc = zext i16 %narrow to i32
+  %predphi = add i32 %pred.index, %popcnt.inc
+  %index.next = add nuw i64 %index, 16
+  %25 = icmp eq i64 %index.next, %n.vec
+  br i1 %25, label %middle.block, label %vector.body, !llvm.loop !17
+
+middle.block:                                     ; preds = %vector.body
+  %cmp.n = icmp eq i64 %n.vec, %wide.trip.count
+  br i1 %cmp.n, label %for.cond.cleanup, label %for.body.preheader75
+
+for.body.preheader75:                             ; preds = %for.body.preheader, %middle.block
+  %indvars.iv.ph = phi i64 [ 0, %for.body.preheader ], [ %n.vec, %middle.block ]
+  %ej.064.ph = phi i32 [ 0, %for.body.preheader ], [ %predphi, %middle.block ]
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %if.end, %middle.block, %entry
+  %ej.0.lcssa = phi i32 [ 0, %entry ], [ %predphi, %middle.block ], [ %ej.1, %if.end ]
+  ret i32 %ej.0.lcssa
+
+for.body:                                         ; preds = %for.body.preheader75, %if.end
+  %indvars.iv = phi i64 [ %indvars.iv.next, %if.end ], [ %indvars.iv.ph, %for.body.preheader75 ]
+  %ej.064 = phi i32 [ %ej.1, %if.end ], [ %ej.064.ph, %for.body.preheader75 ]
+  %arrayidx4 = getelementptr inbounds i32, ptr %jlist, i64 %indvars.iv
+  %26 = load i32, ptr %arrayidx4, align 4, !tbaa !13
+  %and = and i32 %26, 536870911
+  %idxprom5 = zext i32 %and to i64
+  %arrayidx6 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5
+  %27 = load float, ptr %arrayidx6, align 4, !tbaa !5
+  %sub = fsub float %0, %27
+  %y10 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 1
+  %28 = load float, ptr %y10, align 4, !tbaa !11
+  %sub11 = fsub float %1, %28
+  %z14 = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 2
+  %29 = load float, ptr %z14, align 4, !tbaa !12
+  %sub15 = fsub float %2, %29
+  %mul16 = fmul float %sub11, %sub11
+  %30 = tail call float @llvm.fmuladd.f32(float %sub, float %sub, float %mul16)
+  %31 = tail call float @llvm.fmuladd.f32(float %sub15, float %sub15, float %30)
+  %cmp17 = fcmp olt float %31, %cut_coulsq
+  br i1 %cmp17, label %if.then, label %if.end
+
+if.then:                                          ; preds = %for.body
+  %idxprom18 = sext i32 %ej.064 to i64
+  %arrayidx19 = getelementptr inbounds float, ptr %trsq, i64 %idxprom18
+  store float %31, ptr %arrayidx19, align 4, !tbaa !14
+  %arrayidx21 = getelementptr inbounds float, ptr %tdelx, i64 %idxprom18
+  store float %sub, ptr %arrayidx21, align 4, !tbaa !14
+  %arrayidx23 = getelementptr inbounds float, ptr %tdely, i64 %idxprom18
+  store float %sub11, ptr %arrayidx23, align 4, !tbaa !14
+  %arrayidx25 = getelementptr inbounds float, ptr %tdelz, i64 %idxprom18
+  store float %sub15, ptr %arrayidx25, align 4, !tbaa !14
+  %w = getelementptr inbounds %struct.pt, ptr %x, i64 %idxprom5, i32 3
+  %32 = load i32, ptr %w, align 4, !tbaa !15
+  %arrayidx29 = getelementptr inbounds i32, ptr %tjtype, i64 %idxprom18
+  store i32 %32, ptr %arrayidx29, align 4, !tbaa !13
+  %arrayidx33 = getelementptr inbounds i32, ptr %tj, i64 %idxprom18
+  store i32 %26, ptr %arrayidx33, align 4, !tbaa !13
+  %inc = add nsw i32 %ej.064, 1
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %for.body
+  %ej.1 = phi i32 [ %inc, %if.then ], [ %ej.064, %for.body ]
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !21
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare float @llvm.fmuladd.f32(float, float, float) #1
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(read)
+declare <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr>, i32 immarg, <16 x i1>, <16 x float>) #2
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare <16 x float> @llvm.fmuladd.v16f32(<16 x float>, <16 x float>, <16 x float>) #3
+
+; Function Attrs: nocallback ...
[truncated]

Comment on lines 5130 to 5136
// Target may override this function to decided whether it want to update the
// base and index value of a non-uniform gep
virtual bool updateBaseAndIndex(const Value *Ptr, SDValue &Base,
SDValue &Index, const SDLoc &DL,
const SDValue &Gep, SelectionDAG &DAG,
const BasicBlock *CurBB) const;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name and purpose unclear. We also have way too many poorly designed, narrow purpose hooks. You should try to avoid introducing this without a very strong justification for why there can't be a reasonable default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not understood the comment.

@@ -4905,6 +4905,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}

if (!UniformBase) {
TLI.updateBaseAndIndex(Ptr, Base, Index, getCurSDLoc(), getValue(Ptr), DAG,
I.getParent());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DAG builder should avoid interpreting the incoming IR in target specific ways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm
Where should I place the code snippet so that it is applicable for X86 target?

@dtcxzyw dtcxzyw requested review from RKSimon and phoebewang March 12, 2025 08:10
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.pt = type { float, float, float, i32 }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test needs to be reduced down a lot - remove all the unnecessary attributes etc. and just contain the minimal IR to show the problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Contributor Author

@rohitaggarwal007 rohitaggarwal007 Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon, updated the test case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.pt = type { float, float, float, i32 }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - now you should be able to use the utils/update_llc_test_checks.py script to generate more thorough CHECK lines without too much noise.


define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
%wide.load = load <16 x i32>, ptr %arr, align 4
%4 = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) avoid numbered variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
return TargetLowering::getPrefLoopAlignment();
}

// Target override this function to decided whether it want to update the base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided -> decide

return false;

SDValue nbase;
/* For the gep instruction, we are trying to properly assign the base and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use // comments

if (GEP && GEP->getParent() != CurBB)
return false;

SDValue nbase;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable names should be capitalized

bool ExtIsNonNegative = ExtKnown.isNonNegative();
KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
if (!(ExtIsNonNegative && ExtOpIsNonNegative))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deMorgan

EVT VT = newOp10.getValueType(); // Use the

auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
if (!ConstEltNo) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

}
if (Op0 != nbase) {
auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
if (!ConstEltNo) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

}
Base = nbase.getOperand(0);
Index = nIndex;
LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successful*


SDValue newOp10 =
Op10.getOperand(0); // Get the Operand zero from the ext
EVT VT = newOp10.getValueType(); // Use the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Use the" is an incomplete comment

if (!EnableBaseIndexUpdate)
return false;

const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
Copy link
Collaborator

@topperc topperc Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check that the Ptr is a GEP? It doesn't look like we use any information from it. The rest of the code is on SelectionDAG nodes unless I missed something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check whether ptr is derived from GEP. If not, we return false and transformation does not happen.
If the check does not fit here, i can move it to callsite

@topperc
Copy link
Collaborator

topperc commented Mar 13, 2025

Can this be done in combineGatherScatter in X86ISelLowering.cpp before type legalization?

@rohitaggarwal007
Copy link
Contributor Author

Can this be done in combineGatherScatter in X86ISelLowering.cpp before type legalization?

Yes, it is possible there also.
I thought it will be good to fix it while forming the initial DAG and let the combiner do the further optimization. So I followed this approach.

Copy link

github-actions bot commented Mar 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd much prefer to see this fixed in x86 in combineGatherScatter, where there's already a lot of similar code - the TLI callback seems too x86 specific tbh.

@rohitaggarwal007
Copy link
Contributor Author

I'd much prefer to see this fixed in x86 in combineGatherScatter, where there's already a lot of similar code - the TLI callback seems too x86 specific tbh.

Sure, I will more the code to combineGatherScatter

@rohitaggarwal007
Copy link
Contributor Author

@topperc @arsenm @RKSimon I have the implemented the suggestions.
Please check.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking this is more complicated than it needs to be - we should be able to transfer some/all of the SHL into the gather/scatter scale operand (max scale = 8) - assuming we don't lose the sign bit. By doing that we should(?) be able to then recognize when its safe to remove the zext to v16i64 (illegal type so should always be able to truncate to v16i32) - does that make sense?

EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
// the type of index

auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assuming that the shl is uniform - use getValidMinimumShiftAmount instead (replaces the BUILD_VECTOR check above as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my understanding is this
From

      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **Op11.getOpcode() == ISD::BUILD_VECTOR**) {

to

 unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **DAG.getValidMinimumShiftAmount(Idx)**) {

Please correct me

@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
"SHIFT, LEA, etc."),
cl::Hidden);

static cl::opt<bool>
EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not follow command line option naming convention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is the kind of option that will never be used by anyone. Can you just do this unconditionally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this flag for the safety incase any untest usecase occur. So that we can fall back to existing behavior.
Sure, I will remove the flag.

@@ -126,6 +126,11 @@ static cl::opt<bool> MulConstantOptimization(
"SHIFT, LEA, etc."),
cl::Hidden);

static cl::opt<bool>
EnableBaseIndexUpdate("update-baseIndex", cl::init(true),
cl::desc("Update the value of base and index"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name and description don't mean anything standalone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update the information

@rohitaggarwal007
Copy link
Contributor Author

I'm thinking this is more complicated than it needs to be - we should be able to transfer some/all of the SHL into the gather/scatter scale operand (max scale = 8) - assuming we don't lose the sign bit. By doing that we should(?) be able to then recognize when its safe to remove the zext to v16i64 (illegal type so should always be able to truncate to v16i32) - does that make sense?

I didn't understand it completely. How does the scale operand change will impact the Index operand? I might be missing some knowledge here. Thanks

@RKSimon
Copy link
Collaborator

RKSimon commented Mar 21, 2025

x86 natively supports gather/scatter nodes with scale values of 1/2/4/8 (similar to regular address math). If the index node is ISD::SHL by a constant shift amount, we might be able to transfer some/all of the shift amount into the scale:

   //  index = index_src << { 1, 2, 3, 4 } ; min_shift_amount = 1
   if (numsignbits(index_src) > (1 + min_shift_amount) {
     newscale = oldscale * (1 << min_shift_amount) ; iff newscale <= 8
     newindex  = index_src << { 0, 1, 2, 3 } ;
  }

this could then result in newindex being able to be safely truncated to a vXi32 type

@rohitaggarwal007
Copy link
Contributor Author

x86 natively supports gather/scatter nodes with scale values of 1/2/4/8 (similar to regular address math). If the index node is ISD::SHL by a constant shift amount, we might be able to transfer some/all of the shift amount into the scale:

   //  index = index_src << { 1, 2, 3, 4 } ; min_shift_amount = 1
   if (numsignbits(index_src) > (1 + min_shift_amount) {
     newscale = oldscale * (1 << min_shift_amount) ; iff newscale <= 8
     newindex  = index_src << { 0, 1, 2, 3 } ;
  }

this could then result in newindex being able to be safely truncated to a vXi32 type

Thanks for the reply.

The pattern of index is not always of type ISD::SHL. It could be ISD::ADD also. As we know ptr there is non-uniform as the base value is 0 and index have many pattern as follows.

  1. index -> add ( BASE, ISD::SHL)
    We are reassigning base = BASE and index to ISD::SHL.
    In this pattern, your suggestion could be applied as per my understanding.
  2. index -> add (add (BASE, ISD::SHL), Constant)
    We are reassigning base = BASE and index to add(ISD::SHL , Constant).
    in this case, your suggestion for scale does not apply here.
    To make it applicable, I should do some changes as - base = BASE + Constant and index to ISD::SHL.

I will try this and update the patch.

Just one question, on changing the scale in the micro coded instruction, there will be extra multiplication operation to calculate the absolute address. This will cause extra operation, right? Or this will be handled in another pipe while decoding(parallel fashion)?

@rohitaggarwal007
Copy link
Contributor Author

@RKSimon, I tried the suggestion but the value of newscale is coming out to be 16 and it is failing if newscale <= 8 check. The getValidMinimumShiftAmount() is returning 4 and on using 1 << 4, we are getting 16 and hence bailing out.
What does this means and how to use the partial value if possible?

Thanks
Rohit Aggarwal

@RKSimon
Copy link
Collaborator

RKSimon commented Mar 24, 2025

Here is my experiment: main...RKSimon:llvm-project:x86-gather-scatter-addr-scale

It still needs work: it stupidly moves 1 bit of shift left per combine instead of doing something smarter, plus we should probably only bother trying to do this if we think it will eventually allow truncation.

Anyway, take a look and see if there's anything you can use that will help.

@rohitaggarwal007
Copy link
Contributor Author

rohitaggarwal007 commented Mar 25, 2025

Here is my experiment: main...RKSimon:llvm-project:x86-gather-scatter-addr-scale

It still needs work: it stupidly moves 1 bit of shift left per combine instead of doing something smarter, plus we should probably only bother trying to do this if we think it will eventually allow truncation.

Anyway, take a look and see if there's anything you can use that will help.

Hi @RKSimon

I commented my changes and applied the patch. It worked for the reduced test case. In the test case attached in the commit, it is working.
But in the case where there are two GEPs in the test case, it is failing

define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0)  {
; CHECK-LABEL: test_gather_16f32_2:
  %wide.load = load <16 x i32>, ptr %arr, align 4
  %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
  %zext = zext <16 x i32> %and to <16 x i64>
  %ptrs1 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
  %res1 = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs1, i32 4, <16 x i1> %mask, <16 x float> %src0)
  %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
  %pair1 = insertvalue {<16 x float>, <16 x float>} undef, <16 x float> %res1, 0
  %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
  ret {<16 x float>, <16 x float>} %pair2
  }

In the DAGcombiner when refineUniformBase is called to combine the gather again. For %res, when DAG is trying to simplify base and separate the Operand from index and, update the base. It is bailing out.

  if (!isNullConstant(BasePtr) && !Index.hasOneUse())
    return false;

In our case index has multiple uses and bailing out.

Is it good to relax this condition? Combiner create new node on combing!

Additionally, When i debug the logs, I found out that instruction which are already combined and process, are still referring to these uses. Looks a wired behavior.

@rohitaggarwal007
Copy link
Contributor Author

Gentle Reminder!
Waiting for the response.

@rohitaggarwal007
Copy link
Contributor Author

Hi @RKSimon
Should I try to use some portion of the code from your suggestion and continue with my logic?

@RKSimon
Copy link
Collaborator

RKSimon commented Apr 8, 2025

Please can you start with splitting off the new test coverage to a new PR with trunk's current codegen? It'll make it easier to compare against different approaches.

(sorry for slow reply).

@rohitaggarwal007
Copy link
Contributor Author

@RKSimon Created a new PR with test cases. Used the logic provided by you and done some changes in refineUniformBase for test case#3. #134979

RKSimon added a commit that referenced this pull request Apr 9, 2025
AllinLeeYL pushed a commit to AllinLeeYL/llvm-project that referenced this pull request Apr 10, 2025
@rohitaggarwal007
Copy link
Contributor Author

@RKSimon, I have update the testcase in the masked_gather_scatter.ll file

var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
@rohitaggarwal007
Copy link
Contributor Author

#137813 and #139703 are alternative solution for this problem.

@rohitaggarwal007 rohitaggarwal007 deleted the gatherBaseIndexfix branch May 20, 2025 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants