Skip to content

Commit 8ed000d

Browse files
committed
[LV][NFC] Refactor code for extracting first active element
Refactor the code to extract the first active element of a vector in the early exit block, in preparation for PR #130766. I've replaced the VPInstruction::ExtractFirstActive nodes with a combination of a new VPInstruction::FirstActiveLane node and a Instruction::ExtractElement node.
1 parent 9b83ffb commit 8ed000d

File tree

5 files changed

+46
-30
lines changed

5 files changed

+46
-30
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,8 @@ class VPInstruction : public VPRecipeWithIRFlags,
877877
// Returns a scalar boolean value, which is true if any lane of its (only
878878
// boolean) vector operand is true.
879879
AnyOf,
880-
// Extracts the first active lane of a vector, where the first operand is
881-
// the predicate, and the second operand is the vector to extract.
882-
ExtractFirstActive,
880+
// Calculates the first active lane index of the vector predicate operand.
881+
FirstActiveLane,
883882
};
884883

885884
private:

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
5050
return SetResultTyFromOp();
5151

5252
switch (Opcode) {
53+
case Instruction::ExtractElement:
54+
return inferScalarType(R->getOperand(0));
5355
case Instruction::Select: {
5456
Type *ResTy = inferScalarType(R->getOperand(1));
5557
VPValue *OtherV = R->getOperand(2);
@@ -82,7 +84,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8284
case VPInstruction::CanonicalIVIncrementForPart:
8385
case VPInstruction::AnyOf:
8486
return SetResultTyFromOp();
85-
case VPInstruction::ExtractFirstActive:
87+
case VPInstruction::FirstActiveLane:
88+
return Type::getIntNTy(Ctx, 64);
8689
case VPInstruction::ExtractFromEnd: {
8790
Type *BaseTy = inferScalarType(R->getOperand(0));
8891
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
468468
Value *A = State.get(getOperand(0));
469469
return Builder.CreateNot(A, Name);
470470
}
471+
case Instruction::ExtractElement: {
472+
Value *Vec = State.get(getOperand(0));
473+
Value *Idx = State.get(getOperand(1), true);
474+
return Builder.CreateExtractElement(Vec, Idx, Name);
475+
}
471476
case Instruction::ICmp: {
472477
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
473478
Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -723,12 +728,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
723728
Value *A = State.get(getOperand(0));
724729
return Builder.CreateOrReduce(A);
725730
}
726-
case VPInstruction::ExtractFirstActive: {
727-
Value *Vec = State.get(getOperand(0));
728-
Value *Mask = State.get(getOperand(1));
729-
Value *Ctz = Builder.CreateCountTrailingZeroElems(
730-
Builder.getInt64Ty(), Mask, true, "first.active.lane");
731-
return Builder.CreateExtractElement(Vec, Ctz, "early.exit.value");
731+
case VPInstruction::FirstActiveLane: {
732+
Value *Mask = State.get(getOperand(0));
733+
return Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(), Mask,
734+
true, Name);
732735
}
733736
default:
734737
llvm_unreachable("Unsupported opcode for instruction");
@@ -755,22 +758,24 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
755758
}
756759

757760
switch (getOpcode()) {
761+
case Instruction::ExtractElement: {
762+
// Add on the cost of extracting the element.
763+
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
764+
return Ctx.TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy,
765+
Ctx.CostKind);
766+
}
758767
case VPInstruction::AnyOf: {
759768
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
760769
return Ctx.TTI.getArithmeticReductionCost(
761770
Instruction::Or, cast<VectorType>(VecTy), std::nullopt, Ctx.CostKind);
762771
}
763-
case VPInstruction::ExtractFirstActive: {
772+
case VPInstruction::FirstActiveLane: {
764773
// Calculate the cost of determining the lane index.
765-
auto *PredTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(1)), VF);
774+
auto *PredTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
766775
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
767776
Type::getInt64Ty(Ctx.LLVMCtx),
768777
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
769-
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
770-
// Add on the cost of extracting the element.
771-
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
772-
return Cost + Ctx.TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy,
773-
Ctx.CostKind);
778+
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
774779
}
775780
case VPInstruction::FirstOrderRecurrenceSplice: {
776781
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
@@ -793,7 +798,8 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
793798

794799
bool VPInstruction::isVectorToScalar() const {
795800
return getOpcode() == VPInstruction::ExtractFromEnd ||
796-
getOpcode() == VPInstruction::ExtractFirstActive ||
801+
getOpcode() == Instruction::ExtractElement ||
802+
getOpcode() == VPInstruction::FirstActiveLane ||
797803
getOpcode() == VPInstruction::ComputeReductionResult ||
798804
getOpcode() == VPInstruction::AnyOf;
799805
}
@@ -853,13 +859,14 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
853859
if (Instruction::isBinaryOp(getOpcode()))
854860
return false;
855861
switch (getOpcode()) {
862+
case Instruction::ExtractElement:
856863
case Instruction::ICmp:
857864
case Instruction::Select:
858865
case VPInstruction::AnyOf:
859866
case VPInstruction::CalculateTripCountMinusVF:
860867
case VPInstruction::CanonicalIVIncrementForPart:
861868
case VPInstruction::ExtractFromEnd:
862-
case VPInstruction::ExtractFirstActive:
869+
case VPInstruction::FirstActiveLane:
863870
case VPInstruction::FirstOrderRecurrenceSplice:
864871
case VPInstruction::LogicalAnd:
865872
case VPInstruction::Not:
@@ -970,7 +977,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
970977
case VPInstruction::Broadcast:
971978
O << "broadcast";
972979
break;
973-
974980
case VPInstruction::ExtractFromEnd:
975981
O << "extract-from-end";
976982
break;
@@ -986,8 +992,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
986992
case VPInstruction::AnyOf:
987993
O << "any-of";
988994
break;
989-
case VPInstruction::ExtractFirstActive:
990-
O << "extract-first-active";
995+
case VPInstruction::FirstActiveLane:
996+
O << "first-active-lane";
991997
break;
992998
default:
993999
O << Instruction::getOpcodeName(getOpcode());

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,10 +2158,14 @@ void VPlanTransforms::handleUncountableEarlyExit(
21582158
ExitIRI->extractLastLaneOfOperand(MiddleBuilder);
21592159
}
21602160
// Add the incoming value from the early exit.
2161-
if (!IncomingFromEarlyExit->isLiveIn())
2162-
IncomingFromEarlyExit =
2163-
EarlyExitB.createNaryOp(VPInstruction::ExtractFirstActive,
2164-
{IncomingFromEarlyExit, EarlyExitTakenCond});
2161+
if (!IncomingFromEarlyExit->isLiveIn()) {
2162+
VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
2163+
VPInstruction::FirstActiveLane, {EarlyExitTakenCond}, nullptr,
2164+
"first.active.lane");
2165+
IncomingFromEarlyExit = EarlyExitB.createNaryOp(
2166+
Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
2167+
nullptr, "early.exit.value");
2168+
}
21652169
ExitIRI->addOperand(IncomingFromEarlyExit);
21662170
}
21672171
MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});

llvm/test/Transforms/LoopVectorize/AArch64/early_exit_costs.ll

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ define i64 @same_exit_block_pre_inc_use1_sve() #1 {
1111
; CHECK-LABEL: LV: Checking a loop in 'same_exit_block_pre_inc_use1_sve'
1212
; CHECK: LV: Selecting VF: vscale x 16
1313
; CHECK: Calculating cost of work in exit block vector.early.exit
14-
; CHECK-NEXT: Cost of 6 for VF vscale x 16: EMIT vp<{{.*}}> = extract-first-active
15-
; CHECK-NEXT: Cost of 6 for VF vscale x 16: EMIT vp<{{.*}}> = extract-first-active
14+
; CHECK-NEXT: Cost of 4 for VF vscale x 16: EMIT vp<{{.*}}> = first-active-lane vp<{{.*}}>
15+
; CHECK-NEXT: Cost of 2 for VF vscale x 16: EMIT vp<{{.*}}> = extractelement ir<{{.*}}>, vp<{{.*}}>
16+
; CHECK-NEXT: Cost of 4 for VF vscale x 16: EMIT vp<{{.*}}>.1 = first-active-lane vp<{{.*}}>
17+
; CHECK-NEXT: Cost of 2 for VF vscale x 16: EMIT vp<{{.*}}>.1 = extractelement ir<{{.*}}>, vp<%first.active.lane>.1
1618
; CHECK: LV: Minimum required TC for runtime checks to be profitable:32
1719
entry:
1820
%p1 = alloca [1024 x i8]
@@ -48,8 +50,10 @@ define i64 @same_exit_block_pre_inc_use1_nosve() {
4850
; CHECK-LABEL: LV: Checking a loop in 'same_exit_block_pre_inc_use1_nosve'
4951
; CHECK: LV: Selecting VF: 16
5052
; CHECK: Calculating cost of work in exit block vector.early.exit
51-
; CHECK-NEXT: Cost of 50 for VF 16: EMIT vp<{{.*}}> = extract-first-active
52-
; CHECK-NEXT: Cost of 50 for VF 16: EMIT vp<{{.*}}> = extract-first-active
53+
; CHECK-NEXT: Cost of 48 for VF 16: EMIT vp<{{.*}}> = first-active-lane vp<{{.*}}>
54+
; CHECK-NEXT: Cost of 2 for VF 16: EMIT vp<{{.*}}> = extractelement ir<{{.*}}>, vp<{{.*}}>
55+
; CHECK-NEXT: Cost of 48 for VF 16: EMIT vp<{{.*}}>.1 = first-active-lane vp<{{.*}}>
56+
; CHECK-NEXT: Cost of 2 for VF 16: EMIT vp<{{.*}}>.1 = extractelement ir<{{.*}}>, vp<%first.active.lane>.1
5357
; CHECK: LV: Minimum required TC for runtime checks to be profitable:176
5458
; CHECK-NEXT: LV: Vectorization is not beneficial: expected trip count < minimum profitable VF (64 < 176)
5559
; CHECK-NEXT: LV: Too many memory checks needed.

0 commit comments

Comments
 (0)