-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[VectorCombine] Add type shrinking and zext propagation for fixed-width vector types #104606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Igor Kirillov (igogo-x86) ChangesCheck that Full diff: https://github.com/llvm/llvm-project/pull/104606.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 99bd383ab0dead..c2f4315928a40c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+ bool shrinkType(Instruction &I);
void replaceValue(Value &Old, Value &New) {
Old.replaceAllUsesWith(&New);
@@ -2493,6 +2494,106 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return true;
}
+/// Check if instruction depends on ZExt and this ZExt can be moved after the
+/// instruction. Move ZExt if it is profitable
+bool VectorCombine::shrinkType(llvm::Instruction &I) {
+ Value *ZExted, *OtherOperand;
+ if (match(&I, m_c_BinOp(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) {
+ if (I.getOpcode() != Instruction::And && I.getOpcode() != Instruction::Or &&
+ I.getOpcode() != Instruction::Xor && I.getOpcode() != Instruction::LShr)
+ return false;
+
+ // In case of LShr extraction, ZExtOperand should be applied to the first
+ // operand
+ if (I.getOpcode() == Instruction::LShr && I.getOperand(1) != OtherOperand)
+ return false;
+
+ Instruction *ZExtOperand = cast<Instruction>(
+ I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0));
+
+ auto *BigTy = cast<FixedVectorType>(I.getType());
+ auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
+ auto BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
+
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ auto KB = computeKnownBits(&I, *DL);
+ auto IBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+ if (IBW > BW)
+ return false;
+
+ bool HasUNZExtableUser = false;
+
+ // Calculate costs of leaving current IR as it is and moving ZExt operation
+ // later, along with adding truncates if needed
+ InstructionCost ZExtCost = TTI.getCastInstrCost(
+ Instruction::ZExt, BigTy, SmallTy,
+ TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+ InstructionCost CurrentCost = ZExtCost;
+ InstructionCost ShrinkCost = 0;
+
+ for (User *U : ZExtOperand->users()) {
+ auto *UI = cast<Instruction>(U);
+ if (UI == &I) {
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+ ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+ ShrinkCost += ZExtCost;
+ continue;
+ }
+
+ if (!Instruction::isBinaryOp(UI->getOpcode())) {
+ HasUNZExtableUser = true;
+ continue;
+ }
+
+ // Check if we can propagate ZExt through its other users
+ auto KB = computeKnownBits(UI, *DL);
+ auto UBW = KB.getBitWidth() - KB.Zero.countLeadingOnes();
+ if (UBW <= BW) {
+ CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy);
+ ShrinkCost += TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy);
+ ShrinkCost += ZExtCost;
+ } else {
+ HasUNZExtableUser = true;
+ }
+ }
+
+ // ZExt can't remove, add extra cost
+ if (HasUNZExtableUser)
+ ShrinkCost += ZExtCost;
+
+ // If the other instruction operand is not a constant, we'll need to
+ // generate a truncate instruction. So we have to adjust cost
+ if (!isa<Constant>(OtherOperand))
+ ShrinkCost += TTI.getCastInstrCost(
+ Instruction::Trunc, SmallTy, BigTy,
+ TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput);
+
+ // If the cost of shrinking types and leaving the IR is the same, we'll lean
+ // towards modifying the IR because shrinking opens opportunities for other
+ // shrinking optimisations.
+ if (ShrinkCost > CurrentCost)
+ return false;
+
+ auto *Op0 = ZExted;
+ if (auto *OI = dyn_cast<Instruction>(OtherOperand))
+ Builder.SetInsertPoint(OI->getNextNode());
+ auto *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
+ Builder.SetInsertPoint(&I);
+ // Keep the order of operands the same
+ if (I.getOperand(0) == OtherOperand)
+ std::swap(Op0, Op1);
+ auto *NewBinOp =
+ Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
+ cast<Instruction>(NewBinOp)->copyIRFlags(&I);
+ cast<Instruction>(NewBinOp)->copyMetadata(I);
+ auto *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
+ replaceValue(I, *NewZExtr);
+ return true;
+ }
+ return false;
+}
+
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
@@ -2560,6 +2661,9 @@ bool VectorCombine::run() {
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
break;
+ default:
+ MadeChange |= shrinkType(I);
+ break;
}
} else {
switch (Opcode) {
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
new file mode 100644
index 00000000000000..0166656cf734f5
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -0,0 +1,76 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target triple = "aarch64"
+
+define i32 @test_and(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_and(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
+; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT: ret i32 [[TMP3]]
+;
+entry:
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = and <16 x i32> %0, %a
+ %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+ ret i32 %2
+}
+
+define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
+; CHECK-LABEL: @test_mask_or(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
+; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
+; CHECK-NEXT: ret i32 [[TMP3]]
+;
+entry:
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %a.masked = and <16 x i32> %a, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = or <16 x i32> %0, %a.masked
+ %2 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1)
+ ret i32 %2
+}
+
+define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
+; CHECK-LABEL: @multiuse(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
+; CHECK-NEXT: [[TMP2:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], <i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4, i8 4>
+; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[TMP2]], [[TMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = and <16 x i8> [[WIDE_LOAD]], <i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15, i8 15>
+; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP5]], [[TMP0]]
+; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
+; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP4]], [[TMP7]]
+; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
+; CHECK-NEXT: ret i32 [[TMP9]]
+;
+entry:
+ %u.masked = and <16 x i32> %u, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+ %v.masked = and <16 x i32> %v, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+ %wide.load = load <16 x i8>, ptr %b, align 1
+ %0 = zext <16 x i8> %wide.load to <16 x i32>
+ %1 = lshr <16 x i32> %0, <i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4, i32 4>
+ %2 = or <16 x i32> %1, %v.masked
+ %3 = and <16 x i32> %0, <i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15, i32 15>
+ %4 = or <16 x i32> %3, %u.masked
+ %5 = add nuw nsw <16 x i32> %2, %4
+ %6 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+ ret i32 %6
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
|
0afa95e
to
049d02a
Compare
There was a failing test,
ZExt has two users, and my patch allowed us to calculate the cost of the IR when moving only one ZExt, but for X86 the costs are these: Cost of lshr <8 x i16> is 15 It makes it profitable to shrink types, but I am not sure this is a good optimisation (the number of assembly instructions is increasing after running LLC):
So, I turned off ZExt propagation if it can not propagate through all users, but I wish the cost model were more accurate. |
ping |
return false; | ||
|
||
Instruction *ZExtOperand = | ||
cast<Instruction>(I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me nervous, but I'm not sure how to avoid it.......
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. We don't need to cast into instruction, though, so it now looks a bit less scary :) I grepped through the LLVM and found some examples of the same code, so we are not the first who face this problem
bool VectorCombine::shrinkType(llvm::Instruction &I) { | ||
Value *ZExted, *OtherOperand; | ||
if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)), | ||
m_Value(OtherOperand))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if both are zero-extended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If both ZExts are applied to the same type, the inst-combine will handle it. And if they are different, inst-combine will transform into the form recognised by this patch:
049d02a
to
1261041
Compare
…th vector types Check that binop(zext(value), other) is possible and profitable to transform into: zext(binop(value, trunc(other))). When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.
1261041
to
fefb949
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks OK to me if there are no other comments, LGTM.
/// Check if instruction depends on ZExt and this ZExt can be moved after the | ||
/// instruction. Move ZExt if it is profitable. For example: | ||
/// logic(zext(x),y) -> zext(logic(x,trunc(y))) | ||
/// lshr((zext(x),y) -> zext(lshr(x,trunc(y))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lshr((zext(x),y)
-> lshr(zext(x),y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
Value *Op0 = ZExted; | ||
if (auto *OI = dyn_cast<Instruction>(OtherOperand)) | ||
Builder.SetInsertPoint(OI->getNextNode()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is currently broken.
What if OI->getNextNode() is a PHI?
For my out of tree target the following fails
opt -passes="vector-combine" bbi-99058.ll -o /dev/null
with
PHI nodes not grouped at top of basic block!
%vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
label %vector.body
LLVM ERROR: Broken module found, compilation aborted!
and I think it's because the trunc created on the next line is inserted before the second PHI in the bb.
If you simply comment out the
if (ShrinkCost > CurrentCost)
return false;
code at line 2567 above it happens in tree as well. (I'm sure the testcase can be modified in some way so it happens even with the cost comparison at 2567 for some target but I didn't manage right now.)
bbi-99058.ll in my example is
define i64 @func_1() {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%vec.phi = phi <4 x i32> [ zeroinitializer, %entry ], [ %1, %vector.body ]
%vec.ind = phi <4 x i16> [ zeroinitializer, %entry ], [ zeroinitializer, %vector.body ]
%0 = zext <4 x i16> zeroinitializer to <4 x i32>
%1 = and <4 x i32> %vec.phi, %0
br label %vector.body
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a fix - #108228
…idth (#108705) Consider the following case: ``` define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) { %19 = icmp eq <2 x i64> %vec.ind16, zeroinitializer %20 = zext <2 x i1> %19 to <2 x i32> %21 = lshr <2 x i32> %20, %broadcast.splat20 ret <2 x i32> %21 } ``` After #104606, we shrink the lshr into: ``` define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) { %1 = icmp eq <2 x i64> %vec.ind16, zeroinitializer %2 = trunc <2 x i32> %broadcast.splat20 to <2 x i1> %3 = lshr <2 x i1> %1, %2 %4 = zext <2 x i1> %3 to <2 x i32> ret <2 x i32> %4 } ``` It is incorrect since `lshr i1 X, 1` returns `poison`. This patch adds additional check on the shamt operand. The lshr will get shrunk iff we ensure that the shamt is less than bitwidth of the smaller type. As `computeKnownBits(&I, *DL).countMaxActiveBits() > BW` always evaluates to true for `lshr(zext(X), Y)`, this check will only apply to bitwise logical instructions. Alive2: https://alive2.llvm.org/ce/z/j_RmTa Fixes #108698.
Check that
binop(zext(value)
, other) is possible and profitable to transform into:zext(binop(value, trunc(other)))
.When CPU architecture has illegal scalar type iX, but vector type <N * iX> is legal, scalar expressions before vectorisation may be extended to a legal type iY. This extension could result in underutilization of vector lanes, as more lanes could be used at one instruction with the lower type. Vectorisers may not always recognize opportunities for type shrinking, and this patch aims to address that limitation.