Skip to content

Commit b5b61cc

Browse files
authored
[VectorCombine] Preserves the maximal legal FPMathFlags during foldShuffleToIdentity (#94295)
The `VectorCombine::foldShuffleToIdentity` does not preserve fast math flags when folding the shuffle, leading to unexpected vectorized result and missed optimizations with FMA instructions. We can conservatively take the maximal legal set of fast math flags whenever we fold shuffles to identity to enable further optimizations in the backend. --------- Co-authored-by: Henry Jiang <[email protected]>
1 parent 07b9d23 commit b5b61cc

File tree

2 files changed

+83
-15
lines changed

2 files changed

+83
-15
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "llvm/Transforms/Vectorize/VectorCombine.h"
1616
#include "llvm/ADT/DenseMap.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/ScopeExit.h"
1819
#include "llvm/ADT/Statistic.h"
1920
#include "llvm/Analysis/AssumptionCache.h"
@@ -1736,23 +1737,47 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
17361737
Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
17371738
Ty, IdentityLeafs, SplatLeafs, Builder);
17381739
}
1740+
1741+
SmallVector<Value *, 8> ValueList;
1742+
for (const auto &Lane : Item)
1743+
if (Lane.first)
1744+
ValueList.push_back(Lane.first);
1745+
17391746
Builder.SetInsertPoint(I);
17401747
Type *DstTy =
17411748
FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
1742-
if (auto *BI = dyn_cast<BinaryOperator>(I))
1743-
return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
1744-
Ops[1]);
1745-
if (auto *CI = dyn_cast<CmpInst>(I))
1746-
return Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
1747-
if (auto *SI = dyn_cast<SelectInst>(I))
1748-
return Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
1749-
if (auto *CI = dyn_cast<CastInst>(I))
1750-
return Builder.CreateCast((Instruction::CastOps)CI->getOpcode(), Ops[0],
1751-
DstTy);
1752-
if (II)
1753-
return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
1749+
if (auto *BI = dyn_cast<BinaryOperator>(I)) {
1750+
auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
1751+
Ops[0], Ops[1]);
1752+
propagateIRFlags(Value, ValueList);
1753+
return Value;
1754+
}
1755+
if (auto *CI = dyn_cast<CmpInst>(I)) {
1756+
auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
1757+
propagateIRFlags(Value, ValueList);
1758+
return Value;
1759+
}
1760+
if (auto *SI = dyn_cast<SelectInst>(I)) {
1761+
auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
1762+
propagateIRFlags(Value, ValueList);
1763+
return Value;
1764+
}
1765+
if (auto *CI = dyn_cast<CastInst>(I)) {
1766+
auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
1767+
Ops[0], DstTy);
1768+
propagateIRFlags(Value, ValueList);
1769+
return Value;
1770+
}
1771+
if (II) {
1772+
auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
1773+
propagateIRFlags(Value, ValueList);
1774+
return Value;
1775+
}
17541776
assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
1755-
return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
1777+
auto *Value =
1778+
Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
1779+
propagateIRFlags(Value, ValueList);
1780+
return Value;
17561781
}
17571782

17581783
// Starting from a shuffle, look up through operands tracking the shuffled index

llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,10 +828,10 @@ define void @v8f64interleave(i64 %0, ptr %1, ptr %x, double %z) {
828828
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x double> poison, double [[Z:%.*]], i64 0
829829
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x double> [[BROADCAST_SPLATINSERT]], <2 x double> poison, <16 x i32> zeroinitializer
830830
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x double>, ptr [[TMP1:%.*]], align 8
831-
; CHECK-NEXT: [[TMP3:%.*]] = fmul <16 x double> [[WIDE_VEC]], [[TMP2]]
831+
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <16 x double> [[WIDE_VEC]], [[TMP2]]
832832
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds double, ptr [[X:%.*]], i64 [[TMP0:%.*]]
833833
; CHECK-NEXT: [[WIDE_VEC34:%.*]] = load <16 x double>, ptr [[TMP4]], align 8
834-
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = fadd <16 x double> [[WIDE_VEC34]], [[TMP3]]
834+
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = fadd fast <16 x double> [[WIDE_VEC34]], [[TMP3]]
835835
; CHECK-NEXT: [[TMP5:%.*]] = or disjoint i64 [[TMP0]], 7
836836
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds double, ptr [[X]], i64 [[TMP5]]
837837
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[TMP6]], i64 -56
@@ -937,5 +937,48 @@ define <4 x float> @fadd_mismatched_types(<4 x float> %x, <4 x float> %y) {
937937
ret <4 x float> %extshuf
938938
}
939939

940+
define void @maximal_legal_fpmath(ptr %addr1, ptr %addr2, ptr %result, float %val) {
941+
; CHECK-LABEL: define void @maximal_legal_fpmath(
942+
; CHECK-SAME: ptr [[ADDR1:%.*]], ptr [[ADDR2:%.*]], ptr [[RESULT:%.*]], float [[VAL:%.*]]) {
943+
; CHECK-NEXT: [[SPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[VAL]], i64 0
944+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[SPLATINSERT]], <4 x float> poison, <16 x i32> zeroinitializer
945+
; CHECK-NEXT: [[VEC1:%.*]] = load <16 x float>, ptr [[ADDR1]], align 4
946+
; CHECK-NEXT: [[VEC2:%.*]] = load <16 x float>, ptr [[ADDR2]], align 4
947+
; CHECK-NEXT: [[TMP2:%.*]] = fmul contract <16 x float> [[TMP1]], [[VEC2]]
948+
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = fadd reassoc contract <16 x float> [[VEC1]], [[TMP2]]
949+
; CHECK-NEXT: store <16 x float> [[INTERLEAVED_VEC]], ptr [[RESULT]], align 4
950+
; CHECK-NEXT: ret void
951+
;
952+
%splatinsert = insertelement <4 x float> poison, float %val, i64 0
953+
%incoming.vec = shufflevector <4 x float> %splatinsert, <4 x float> poison, <4 x i32> zeroinitializer
954+
955+
%vec1 = load <16 x float>, ptr %addr1, align 4
956+
%strided.vec1 = shufflevector <16 x float> %vec1, <16 x float> poison, <4 x i32> <i32 0, i32 4, i32 8, i32 12>
957+
%strided.vec2 = shufflevector <16 x float> %vec1, <16 x float> poison, <4 x i32> <i32 1, i32 5, i32 9, i32 13>
958+
%strided.vec3 = shufflevector <16 x float> %vec1, <16 x float> poison, <4 x i32> <i32 2, i32 6, i32 10, i32 14>
959+
%strided.vec4 = shufflevector <16 x float> %vec1, <16 x float> poison, <4 x i32> <i32 3, i32 7, i32 11, i32 15>
960+
961+
%vec2 = load <16 x float>, ptr %addr2, align 4
962+
%strided.vec6 = shufflevector <16 x float> %vec2, <16 x float> poison, <4 x i32> <i32 0, i32 4, i32 8, i32 12>
963+
%strided.vec7 = shufflevector <16 x float> %vec2, <16 x float> poison, <4 x i32> <i32 1, i32 5, i32 9, i32 13>
964+
%strided.vec8 = shufflevector <16 x float> %vec2, <16 x float> poison, <4 x i32> <i32 2, i32 6, i32 10, i32 14>
965+
%strided.vec9 = shufflevector <16 x float> %vec2, <16 x float> poison, <4 x i32> <i32 3, i32 7, i32 11, i32 15>
966+
967+
%1 = fmul fast <4 x float> %incoming.vec, %strided.vec6
968+
%2 = fadd fast <4 x float> %strided.vec1, %1
969+
%3 = fmul contract <4 x float> %incoming.vec, %strided.vec7
970+
%4 = fadd fast <4 x float> %strided.vec2, %3
971+
%5 = fmul contract reassoc <4 x float> %incoming.vec, %strided.vec8
972+
%6 = fadd fast <4 x float> %strided.vec3, %5
973+
%7 = fmul contract reassoc <4 x float> %incoming.vec, %strided.vec9
974+
%8 = fadd contract reassoc <4 x float> %strided.vec4, %7
975+
976+
%9 = shufflevector <4 x float> %2, <4 x float> %4, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
977+
%10 = shufflevector <4 x float> %6, <4 x float> %8, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
978+
%interleaved.vec = shufflevector <8 x float> %9, <8 x float> %10, <16 x i32> <i32 0, i32 4, i32 8, i32 12, i32 1, i32 5, i32 9, i32 13, i32 2, i32 6, i32 10, i32 14, i32 3, i32 7, i32 11, i32 15>
979+
store <16 x float> %interleaved.vec, ptr %result, align 4
980+
981+
ret void
982+
}
940983

941984
declare void @use(<4 x i8>)

0 commit comments

Comments
 (0)