Skip to content

Commit 96bb281

Browse files
authored
[AArch64] Prevent unnecessary truncation in bool vector reduce code generation (#120096)
Prevent unnecessarily truncating results of 128 bit wide vector comparisons to 64 bit wide vector values in boolean vector reduce operations.
1 parent 3ed2a81 commit 96bb281

File tree

4 files changed

+742
-71
lines changed

4 files changed

+742
-71
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15928,17 +15928,32 @@ static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
1592815928
return getVectorBitwiseReduce(Opcode, HalfVec, VT, DL, DAG);
1592915929
}
1593015930

15931-
// Vectors that are less than 64 bits get widened to neatly fit a 64 bit
15932-
// register, so e.g. <4 x i1> gets lowered to <4 x i16>. Sign extending to
15933-
// this element size leads to the best codegen, since e.g. setcc results
15934-
// might need to be truncated otherwise.
15935-
EVT ExtendedVT = MVT::getIntegerVT(std::max(64u / NumElems, 8u));
15931+
// Results of setcc operations get widened to 128 bits if their input
15932+
// operands are 128 bits wide, otherwise vectors that are less than 64 bits
15933+
// get widened to neatly fit a 64 bit register, so e.g. <4 x i1> gets
15934+
// lowered to either <4 x i16> or <4 x i32>. Sign extending to this element
15935+
// size leads to the best codegen, since e.g. setcc results might need to be
15936+
// truncated otherwise.
15937+
unsigned ExtendedWidth = 64;
15938+
if (Vec.getOpcode() == ISD::SETCC &&
15939+
Vec.getOperand(0).getValueSizeInBits() >= 128) {
15940+
ExtendedWidth = 128;
15941+
}
15942+
EVT ExtendedVT = MVT::getIntegerVT(std::max(ExtendedWidth / NumElems, 8u));
1593615943

1593715944
// any_ext doesn't work with umin/umax, so only use it for uadd.
1593815945
unsigned ExtendOp =
1593915946
ScalarOpcode == ISD::XOR ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
1594015947
SDValue Extended = DAG.getNode(
1594115948
ExtendOp, DL, VecVT.changeVectorElementType(ExtendedVT), Vec);
15949+
// The uminp/uminv and umaxp/umaxv instructions don't have .2d variants, so
15950+
// in that case we bitcast the sign extended values from v2i64 to v4i32
15951+
// before reduction for optimal code generation.
15952+
if ((ScalarOpcode == ISD::AND || ScalarOpcode == ISD::OR) &&
15953+
NumElems == 2 && ExtendedWidth == 128) {
15954+
Extended = DAG.getBitcast(MVT::v4i32, Extended);
15955+
ExtendedVT = MVT::i32;
15956+
}
1594215957
switch (ScalarOpcode) {
1594315958
case ISD::AND:
1594415959
Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);

llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ define i1 @unordered_floating_point_compare_on_v8f32(<8 x float> %a_vec) {
1212
; CHECK-NEXT: mov w8, #1 // =0x1
1313
; CHECK-NEXT: uzp1 v0.8h, v0.8h, v1.8h
1414
; CHECK-NEXT: mvn v0.16b, v0.16b
15-
; CHECK-NEXT: xtn v0.8b, v0.8h
16-
; CHECK-NEXT: umaxv b0, v0.8b
15+
; CHECK-NEXT: umaxv h0, v0.8h
1716
; CHECK-NEXT: fmov w9, s0
1817
; CHECK-NEXT: bic w0, w8, w9
1918
; CHECK-NEXT: ret

0 commit comments

Comments
 (0)