Skip to content

Commit bf69484

Browse files
authored
[VectorCombine] Add type shrinking and zext propagation for fixed-width vector types (llvm#104606)
Check that `binop(zext(value)`, other) is possible and profitable to transform into: `zext(binop(value, trunc(other)))`. When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.
1 parent 0f47e3a commit bf69484

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldShuffleFromReductions(Instruction &I);
120120
bool foldCastFromReductions(Instruction &I);
121121
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
122+
bool shrinkType(Instruction &I);
122123

123124
void replaceValue(Value &Old, Value &New) {
124125
Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,96 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
24932494
return true;
24942495
}
24952496

2497+
/// Check if instruction depends on ZExt and this ZExt can be moved after the
2498+
/// instruction. Move ZExt if it is profitable. For example:
2499+
/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
2500+
/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
2501+
/// Cost model calculations takes into account if zext(x) has other users and
2502+
/// whether it can be propagated through them too.
2503+
bool VectorCombine::shrinkType(llvm::Instruction &I) {
2504+
Value *ZExted, *OtherOperand;
2505+
if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
2506+
m_Value(OtherOperand))) &&
2507+
!match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
2508+
return false;
2509+
2510+
Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
2511+
2512+
auto *BigTy = cast<FixedVectorType>(I.getType());
2513+
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
2514+
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
2515+
2516+
// Check that the expression overall uses at most the same number of bits as
2517+
// ZExted
2518+
KnownBits KB = computeKnownBits(&I, *DL);
2519+
if (KB.countMaxActiveBits() > BW)
2520+
return false;
2521+
2522+
// Calculate costs of leaving current IR as it is and moving ZExt operation
2523+
// later, along with adding truncates if needed
2524+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2525+
InstructionCost ZExtCost = TTI.getCastInstrCost(
2526+
Instruction::ZExt, BigTy, SmallTy,
2527+
TargetTransformInfo::CastContextHint::None, CostKind);
2528+
InstructionCost CurrentCost = ZExtCost;
2529+
InstructionCost ShrinkCost = 0;
2530+
2531+
// Calculate total cost and check that we can propagate through all ZExt users
2532+
for (User *U : ZExtOperand->users()) {
2533+
auto *UI = cast<Instruction>(U);
2534+
if (UI == &I) {
2535+
CurrentCost +=
2536+
TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2537+
ShrinkCost +=
2538+
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2539+
ShrinkCost += ZExtCost;
2540+
continue;
2541+
}
2542+
2543+
if (!Instruction::isBinaryOp(UI->getOpcode()))
2544+
return false;
2545+
2546+
// Check if we can propagate ZExt through its other users
2547+
KB = computeKnownBits(UI, *DL);
2548+
if (KB.countMaxActiveBits() > BW)
2549+
return false;
2550+
2551+
CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2552+
ShrinkCost +=
2553+
TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2554+
ShrinkCost += ZExtCost;
2555+
}
2556+
2557+
// If the other instruction operand is not a constant, we'll need to
2558+
// generate a truncate instruction. So we have to adjust cost
2559+
if (!isa<Constant>(OtherOperand))
2560+
ShrinkCost += TTI.getCastInstrCost(
2561+
Instruction::Trunc, SmallTy, BigTy,
2562+
TargetTransformInfo::CastContextHint::None, CostKind);
2563+
2564+
// If the cost of shrinking types and leaving the IR is the same, we'll lean
2565+
// towards modifying the IR because shrinking opens opportunities for other
2566+
// shrinking optimisations.
2567+
if (ShrinkCost > CurrentCost)
2568+
return false;
2569+
2570+
Value *Op0 = ZExted;
2571+
if (auto *OI = dyn_cast<Instruction>(OtherOperand))
2572+
Builder.SetInsertPoint(OI->getNextNode());
2573+
Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
2574+
Builder.SetInsertPoint(&I);
2575+
// Keep the order of operands the same
2576+
if (I.getOperand(0) == OtherOperand)
2577+
std::swap(Op0, Op1);
2578+
Value *NewBinOp =
2579+
Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
2580+
cast<Instruction>(NewBinOp)->copyIRFlags(&I);
2581+
cast<Instruction>(NewBinOp)->copyMetadata(I);
2582+
Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
2583+
replaceValue(I, *NewZExtr);
2584+
return true;
2585+
}
2586+
24962587
/// This is the entry point for all transforms. Pass manager differences are
24972588
/// handled in the callers of this function.
24982589
bool VectorCombine::run() {
@@ -2560,6 +2651,9 @@ bool VectorCombine::run() {
25602651
case Instruction::BitCast:
25612652
MadeChange |= foldBitcastShuffle(I);
25622653
break;
2654+
default:
2655+
MadeChange |= shrinkType(I);
2656+
break;
25632657
}
25642658
} else {
25652659
switch (Opcode) {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=vector-combine -S %s | FileCheck %s
3+
4+
target triple = "aarch64"
5+
6+
define i32 @test_and(<16 x i32> %a, ptr %b) {
7+
; CHECK-LABEL: @test_and(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
10+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
11+
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
12+
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
13+
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
14+
; CHECK-NEXT: ret i32 [[TMP3]]
15+
;
16+
entry:
17+
%wide.load = load <16 x i8>, ptr %b, align 1
18+
%0 = zext <16 x i8> %wide.load to <16 x i32>
19+
%1 = and <16 x i32> %0, %a
20+
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
21+
ret i32 %2
22+
}
23+
24+
define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
25+
; CHECK-LABEL: @test_mask_or(
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
28+
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
29+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
30+
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
31+
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
32+
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
33+
; CHECK-NEXT: ret i32 [[TMP3]]
34+
;
35+
entry:
36+
%wide.load = load <16 x i8>, ptr %b, align 1
37+
%a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
38+
%0 = zext <16 x i8> %wide.load to <16 x i32>
39+
%1 = or <16 x i32> %0, %a.masked
40+
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
41+
ret i32 %2
42+
}
43+
44+
define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
45+
; CHECK-LABEL: @multiuse(
46+
; CHECK-NEXT: entry:
47+
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
48+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
49+
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
50+
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
51+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
52+
; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
53+
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
54+
; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
55+
; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
56+
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
57+
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
58+
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
59+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
60+
; CHECK-NEXT: ret i32 [[TMP9]]
61+
;
62+
entry:
63+
%u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
64+
%v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
65+
%wide.load = load <16 x i8>, ptr %b, align 1
66+
%0 = zext <16 x i8> %wide.load to <16 x i32>
67+
%1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
68+
%2 = or <16 x i32> %1, %v.masked
69+
%3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
70+
%4 = or <16 x i32> %3, %u.masked
71+
%5 = add nuw nsw <16 x i32> %2, %4
72+
%6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
73+
ret i32 %6
74+
}
75+
76+
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

0 commit comments

Comments
 (0)