@@ -468,6 +468,12 @@ Value *VPInstruction::generate(VPTransformState &State) {
468
468
Value *A = State.get (getOperand (0 ));
469
469
return Builder.CreateNot (A, Name);
470
470
}
471
+ case Instruction::ExtractElement: {
472
+ assert (State.VF .isVector () && " Only extract elements from vectors" );
473
+ Value *Vec = State.get (getOperand (0 ));
474
+ Value *Idx = State.get (getOperand (1 ), /* IsScalar=*/ true );
475
+ return Builder.CreateExtractElement (Vec, Idx, Name);
476
+ }
471
477
case Instruction::ICmp: {
472
478
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed (this );
473
479
Value *A = State.get (getOperand (0 ), OnlyFirstLaneUsed);
@@ -723,12 +729,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
723
729
Value *A = State.get (getOperand (0 ));
724
730
return Builder.CreateOrReduce (A);
725
731
}
726
- case VPInstruction::ExtractFirstActive: {
727
- Value *Vec = State.get (getOperand (0 ));
728
- Value *Mask = State.get (getOperand (1 ));
729
- Value *Ctz = Builder.CreateCountTrailingZeroElems (
730
- Builder.getInt64Ty (), Mask, true , " first.active.lane" );
731
- return Builder.CreateExtractElement (Vec, Ctz, " early.exit.value" );
732
+ case VPInstruction::FirstActiveLane: {
733
+ Value *Mask = State.get (getOperand (0 ));
734
+ return Builder.CreateCountTrailingZeroElems (Builder.getInt64Ty (), Mask,
735
+ true , Name);
732
736
}
733
737
default :
734
738
llvm_unreachable (" Unsupported opcode for instruction" );
@@ -755,22 +759,24 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
755
759
}
756
760
757
761
switch (getOpcode ()) {
762
+ case Instruction::ExtractElement: {
763
+ // Add on the cost of extracting the element.
764
+ auto *VecTy = toVectorTy (Ctx.Types .inferScalarType (getOperand (0 )), VF);
765
+ return Ctx.TTI .getVectorInstrCost (Instruction::ExtractElement, VecTy,
766
+ Ctx.CostKind );
767
+ }
758
768
case VPInstruction::AnyOf: {
759
769
auto *VecTy = toVectorTy (Ctx.Types .inferScalarType (this ), VF);
760
770
return Ctx.TTI .getArithmeticReductionCost (
761
771
Instruction::Or, cast<VectorType>(VecTy), std::nullopt, Ctx.CostKind );
762
772
}
763
- case VPInstruction::ExtractFirstActive : {
773
+ case VPInstruction::FirstActiveLane : {
764
774
// Calculate the cost of determining the lane index.
765
- auto *PredTy = toVectorTy (Ctx.Types .inferScalarType (getOperand (1 )), VF);
775
+ auto *PredTy = toVectorTy (Ctx.Types .inferScalarType (getOperand (0 )), VF);
766
776
IntrinsicCostAttributes Attrs (Intrinsic::experimental_cttz_elts,
767
777
Type::getInt64Ty (Ctx.LLVMCtx ),
768
778
{PredTy, Type::getInt1Ty (Ctx.LLVMCtx )});
769
- InstructionCost Cost = Ctx.TTI .getIntrinsicInstrCost (Attrs, Ctx.CostKind );
770
- // Add on the cost of extracting the element.
771
- auto *VecTy = toVectorTy (Ctx.Types .inferScalarType (getOperand (0 )), VF);
772
- return Cost + Ctx.TTI .getVectorInstrCost (Instruction::ExtractElement, VecTy,
773
- Ctx.CostKind );
779
+ return Ctx.TTI .getIntrinsicInstrCost (Attrs, Ctx.CostKind );
774
780
}
775
781
case VPInstruction::FirstOrderRecurrenceSplice: {
776
782
assert (VF.isVector () && " Scalar FirstOrderRecurrenceSplice?" );
@@ -793,7 +799,8 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
793
799
794
800
bool VPInstruction::isVectorToScalar () const {
795
801
return getOpcode () == VPInstruction::ExtractFromEnd ||
796
- getOpcode () == VPInstruction::ExtractFirstActive ||
802
+ getOpcode () == Instruction::ExtractElement ||
803
+ getOpcode () == VPInstruction::FirstActiveLane ||
797
804
getOpcode () == VPInstruction::ComputeReductionResult ||
798
805
getOpcode () == VPInstruction::AnyOf;
799
806
}
@@ -853,13 +860,14 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
853
860
if (Instruction::isBinaryOp (getOpcode ()))
854
861
return false ;
855
862
switch (getOpcode ()) {
863
+ case Instruction::ExtractElement:
856
864
case Instruction::ICmp:
857
865
case Instruction::Select:
858
866
case VPInstruction::AnyOf:
859
867
case VPInstruction::CalculateTripCountMinusVF:
860
868
case VPInstruction::CanonicalIVIncrementForPart:
861
869
case VPInstruction::ExtractFromEnd:
862
- case VPInstruction::ExtractFirstActive :
870
+ case VPInstruction::FirstActiveLane :
863
871
case VPInstruction::FirstOrderRecurrenceSplice:
864
872
case VPInstruction::LogicalAnd:
865
873
case VPInstruction::Not:
@@ -878,6 +886,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
878
886
switch (getOpcode ()) {
879
887
default :
880
888
return false ;
889
+ case Instruction::ExtractElement:
890
+ return Op == getOperand (1 );
881
891
case Instruction::PHI:
882
892
return true ;
883
893
case Instruction::ICmp:
@@ -970,7 +980,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
970
980
case VPInstruction::Broadcast:
971
981
O << " broadcast" ;
972
982
break ;
973
-
974
983
case VPInstruction::ExtractFromEnd:
975
984
O << " extract-from-end" ;
976
985
break ;
@@ -986,8 +995,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
986
995
case VPInstruction::AnyOf:
987
996
O << " any-of" ;
988
997
break ;
989
- case VPInstruction::ExtractFirstActive :
990
- O << " extract- first-active" ;
998
+ case VPInstruction::FirstActiveLane :
999
+ O << " first-active-lane " ;
991
1000
break ;
992
1001
default :
993
1002
O << Instruction::getOpcodeName (getOpcode ());
0 commit comments