Skip to content

Commit c7995a6

Browse files
authored
[AArch64] Disallow vscale x 1 partial reductions (llvm#125252)
We don't want to allow partial reductions resulting in a vscale x 1 type as we can't lower it in the backend.
1 parent 5df62bd commit c7995a6

File tree

2 files changed

+198
-88
lines changed

2 files changed

+198
-88
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4692,13 +4692,24 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
46924692
EVT InputEVT = EVT::getEVT(InputTypeA);
46934693
EVT AccumEVT = EVT::getEVT(AccumType);
46944694

4695-
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
4696-
return Invalid;
4695+
unsigned VFMinValue = VF.getKnownMinValue();
4696+
4697+
if (VF.isScalable()) {
4698+
if (!ST->isSVEorStreamingSVEAvailable())
4699+
return Invalid;
4700+
4701+
// Don't accept a partial reduction if the scaled accumulator is vscale x 1,
4702+
// since we can't lower that type.
4703+
unsigned Scale =
4704+
AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
4705+
if (VFMinValue == Scale)
4706+
return Invalid;
4707+
}
46974708
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
46984709
return Invalid;
46994710

47004711
if (InputEVT == MVT::i8) {
4701-
switch (VF.getKnownMinValue()) {
4712+
switch (VFMinValue) {
47024713
default:
47034714
return Invalid;
47044715
case 8:
@@ -4717,7 +4728,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
47174728
} else if (InputEVT == MVT::i16) {
47184729
// FIXME: Allow i32 accumulator but increase cost, as we would extend
47194730
// it to i64.
4720-
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
4731+
if (VFMinValue != 8 || AccumEVT != MVT::i64)
47214732
return Invalid;
47224733
} else
47234734
return Invalid;

0 commit comments

Comments
 (0)