Skip to content

Commit 0afa95e

Browse files
committed
[VectorCombine] Add type shrinking and zext propagation for fixed-width vector types
Check that binop(zext(value), other) is possible and profitable to transform into: zext(binop(value, trunc(other))). When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.
1 parent c4bf949 commit 0afa95e

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldShuffleFromReductions(Instruction &I);
120120
bool foldCastFromReductions(Instruction &I);
121121
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
122+
bool shrinkType(Instruction &I);
122123

123124
void replaceValue(Value &Old, Value &New) {
124125
Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,106 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
24932494
return true;
24942495
}
24952496

2497+
/// Check if instruction depends on ZExt and this ZExt can be moved after the
2498+
/// instruction. Move ZExt if it is profitable
2499+
bool VectorCombine::shrinkType(llvm::Instruction &I) {
2500+
Value *ZExted, *OtherOperand;
2501+
if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
2502+
if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
2503+
I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
2504+
return false;
2505+
2506+
// In case of LShr extraction, ZExtOperand should be applied to the first
2507+
// operand
2508+
if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
2509+
return false;
2510+
2511+
Instruction *ZExtOperand = cast<Instruction>(
2512+
I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
2513+
2514+
auto *BigTy = cast<FixedVectorType>(I.getType());
2515+
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
2516+
auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
2517+
2518+
// Check that the expression overall uses at most the same number of bits as
2519+
// ZExted
2520+
auto KB = computeKnownBits(&I, *DL);
2521+
auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
2522+
if (IBW > BW)
2523+
return false;
2524+
2525+
bool HasUNZExtableUser = false;
2526+
2527+
// Calculate costs of leaving current IR as it is and moving ZExt operation
2528+
// later, along with adding truncates if needed
2529+
InstructionCost ZExtCost = TTI.getCastInstrCost(
2530+
Instruction::ZExt, BigTy, SmallTy,
2531+
TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
2532+
InstructionCost CurrentCost = ZExtCost;
2533+
InstructionCost ShrinkCost = 0;
2534+
2535+
for (User *U : ZExtOperand->users()) {
2536+
auto *UI = cast<Instruction>(U);
2537+
if (UI == &I) {
2538+
CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
2539+
ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
2540+
ShrinkCost += ZExtCost;
2541+
continue;
2542+
}
2543+
2544+
if (!Instruction::isBinaryOp(UI->getOpcode())) {
2545+
HasUNZExtableUser = true;
2546+
continue;
2547+
}
2548+
2549+
// Check if we can propagate ZExt through its other users
2550+
auto KB = computeKnownBits(UI, *DL);
2551+
auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
2552+
if (UBW <= BW) {
2553+
CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
2554+
ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
2555+
ShrinkCost += ZExtCost;
2556+
} else {
2557+
HasUNZExtableUser = true;
2558+
}
2559+
}
2560+
2561+
// ZExt can't remove, add extra cost
2562+
if (HasUNZExtableUser)
2563+
ShrinkCost += ZExtCost;
2564+
2565+
// If the other instruction operand is not a constant, we'll need to
2566+
// generate a truncate instruction. So we have to adjust cost
2567+
if (!isa<Constant>(OtherOperand))
2568+
ShrinkCost += TTI.getCastInstrCost(
2569+
Instruction::Trunc, SmallTy, BigTy,
2570+
TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
2571+
2572+
// If the cost of shrinking types and leaving the IR is the same, we'll lean
2573+
// towards modifying the IR because shrinking opens opportunities for other
2574+
// shrinking optimisations.
2575+
if (ShrinkCost > CurrentCost)
2576+
return false;
2577+
2578+
auto *Op0 = ZExted;
2579+
if (auto *OI = dyn_cast<Instruction>(OtherOperand))
2580+
Builder.SetInsertPoint(OI->getNextNode());
2581+
auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
2582+
Builder.SetInsertPoint(&I);
2583+
// Keep the order of operands the same
2584+
if (I.getOperand(0) == OtherOperand)
2585+
std::swap(Op0, Op1);
2586+
auto *NewBinOp =
2587+
Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
2588+
cast<Instruction>(NewBinOp)->copyIRFlags(&I);
2589+
cast<Instruction>(NewBinOp)->copyMetadata(I);
2590+
auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
2591+
replaceValue(I, *NewZExtr);
2592+
return true;
2593+
}
2594+
return false;
2595+
}
2596+
24962597
/// This is the entry point for all transforms. Pass manager differences are
24972598
/// handled in the callers of this function.
24982599
bool VectorCombine::run() {
@@ -2560,6 +2661,9 @@ bool VectorCombine::run() {
25602661
case Instruction::BitCast:
25612662
MadeChange |= foldBitcastShuffle(I);
25622663
break;
2664+
default:
2665+
MadeChange |= shrinkType(I);
2666+
break;
25632667
}
25642668
} else {
25652669
switch (Opcode) {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=vector-combine -S %s | FileCheck %s
3+
4+
target triple = "aarch64"
5+
6+
define i32 @test_and(<16 x i32> %a, ptr %b) {
7+
; CHECK-LABEL: @test_and(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
10+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
11+
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
12+
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
13+
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
14+
; CHECK-NEXT: ret i32 [[TMP3]]
15+
;
16+
entry:
17+
%wide.load = load <16 x i8>, ptr %b, align 1
18+
%0 = zext <16 x i8> %wide.load to <16 x i32>
19+
%1 = and <16 x i32> %0, %a
20+
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
21+
ret i32 %2
22+
}
23+
24+
define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
25+
; CHECK-LABEL: @test_mask_or(
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
28+
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
29+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
30+
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
31+
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
32+
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
33+
; CHECK-NEXT: ret i32 [[TMP3]]
34+
;
35+
entry:
36+
%wide.load = load <16 x i8>, ptr %b, align 1
37+
%a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
38+
%0 = zext <16 x i8> %wide.load to <16 x i32>
39+
%1 = or <16 x i32> %0, %a.masked
40+
%2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
41+
ret i32 %2
42+
}
43+
44+
define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
45+
; CHECK-LABEL: @multiuse(
46+
; CHECK-NEXT: entry:
47+
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
48+
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
49+
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
50+
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
51+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
52+
; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
53+
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
54+
; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
55+
; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
56+
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
57+
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
58+
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
59+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
60+
; CHECK-NEXT: ret i32 [[TMP9]]
61+
;
62+
entry:
63+
%u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
64+
%v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
65+
%wide.load = load <16 x i8>, ptr %b, align 1
66+
%0 = zext <16 x i8> %wide.load to <16 x i32>
67+
%1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
68+
%2 = or <16 x i32> %1, %v.masked
69+
%3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
70+
%4 = or <16 x i32> %3, %u.masked
71+
%5 = add nuw nsw <16 x i32> %2, %4
72+
%6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
73+
ret i32 %6
74+
}
75+
76+
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

0 commit comments

Comments
 (0)