Skip to content

Commit 425d1aa

Browse files
lukel97tstellar
authored andcommitted
[RISCV] Handle scalarized reductions in getArithmeticReductionCost
This fixes a crash reported at llvm#114250 (comment) If the vector type isn't legal at all, e.g. bfloat with +zvfbfmin, then the legalized type will be scalarized. So use getScalarType() instead of getVectorElement() when checking for f16/bf16. (cherry picked from commit 053451c)
1 parent 2d7ad98 commit 425d1aa

File tree

2 files changed

+137
-35
lines changed

2 files changed

+137
-35
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,9 +1658,8 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
16581658
break;
16591659
case ISD::FADD:
16601660
// We can't promote f16/bf16 fadd reductions.
1661-
if ((LT.second.getVectorElementType() == MVT::f16 &&
1662-
!ST->hasVInstructionsF16()) ||
1663-
LT.second.getVectorElementType() == MVT::bf16)
1661+
if ((LT.second.getScalarType() == MVT::f16 && !ST->hasVInstructionsF16()) ||
1662+
LT.second.getScalarType() == MVT::bf16)
16641663
return BaseT::getArithmeticReductionCost(Opcode, Ty, FMF, CostKind);
16651664
if (TTI::requiresOrderedReduction(FMF)) {
16661665
Opcodes.push_back(RISCV::VFMV_S_F);

0 commit comments

Comments
 (0)