Skip to content

Commit 077e0c1

Browse files
authored
AMDGPU: Generalize truncate of shift of cast build_vector combine (llvm#125617)
Previously we only handled cases that looked like the high element extract of a 64-bit shift. Generalize this to handle any multiple indexing. I was hoping this would help avoid some regressions, but it did not. It does however reduce the number of steps the DAG takes to process these cases. NFC-ish, I have yet to find an example where this changes the final output.
1 parent 749372b commit 077e0c1

File tree

2 files changed

+154
-11
lines changed

2 files changed

+154
-11
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4217,18 +4217,21 @@ SDValue AMDGPUTargetLowering::performTruncateCombine(
42174217
// trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y)
42184218
if (Src.getOpcode() == ISD::SRL && !VT.isVector()) {
42194219
if (auto *K = isConstOrConstSplat(Src.getOperand(1))) {
4220-
if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) {
4221-
SDValue BV = stripBitcast(Src.getOperand(0));
4222-
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
4223-
BV.getValueType().getVectorNumElements() == 2) {
4224-
SDValue SrcElt = BV.getOperand(1);
4225-
EVT SrcEltVT = SrcElt.getValueType();
4226-
if (SrcEltVT.isFloatingPoint()) {
4227-
SrcElt = DAG.getNode(ISD::BITCAST, SL,
4228-
SrcEltVT.changeTypeToInteger(), SrcElt);
4220+
SDValue BV = stripBitcast(Src.getOperand(0));
4221+
if (BV.getOpcode() == ISD::BUILD_VECTOR) {
4222+
EVT SrcEltVT = BV.getOperand(0).getValueType();
4223+
unsigned SrcEltSize = SrcEltVT.getSizeInBits();
4224+
unsigned BitIndex = K->getZExtValue();
4225+
unsigned PartIndex = BitIndex / SrcEltSize;
4226+
4227+
if (PartIndex * SrcEltSize == BitIndex &&
4228+
PartIndex < BV.getNumOperands()) {
4229+
if (SrcEltVT.getSizeInBits() == VT.getSizeInBits()) {
4230+
SDValue SrcElt =
4231+
DAG.getNode(ISD::BITCAST, SL, SrcEltVT.changeTypeToInteger(),
4232+
BV.getOperand(PartIndex));
4233+
return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
42294234
}
4230-
4231-
return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
42324235
}
42334236
}
42344237
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
3+
4+
; extract element 0 as shift
5+
define i32 @cast_v4i32_to_i128_trunc_i32(<4 x i32> %arg) {
6+
; CHECK-LABEL: cast_v4i32_to_i128_trunc_i32:
7+
; CHECK: ; %bb.0:
8+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
9+
; CHECK-NEXT: s_setpc_b64 s[30:31]
10+
%bigint = bitcast <4 x i32> %arg to i128
11+
%trunc = trunc i128 %bigint to i32
12+
ret i32 %trunc
13+
}
14+
15+
; extract element 1 as shift
16+
define i32 @cast_v4i32_to_i128_lshr_32_trunc_i32(<4 x i32> %arg) {
17+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i32:
18+
; CHECK: ; %bb.0:
19+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
20+
; CHECK-NEXT: v_mov_b32_e32 v0, v1
21+
; CHECK-NEXT: s_setpc_b64 s[30:31]
22+
%bigint = bitcast <4 x i32> %arg to i128
23+
%srl = lshr i128 %bigint, 32
24+
%trunc = trunc i128 %srl to i32
25+
ret i32 %trunc
26+
}
27+
28+
; extract element 2 as shift
29+
define i32 @cast_v4i32_to_i128_lshr_64_trunc_i32(<4 x i32> %arg) {
30+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i32:
31+
; CHECK: ; %bb.0:
32+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
33+
; CHECK-NEXT: v_mov_b32_e32 v0, v2
34+
; CHECK-NEXT: s_setpc_b64 s[30:31]
35+
%bigint = bitcast <4 x i32> %arg to i128
36+
%srl = lshr i128 %bigint, 64
37+
%trunc = trunc i128 %srl to i32
38+
ret i32 %trunc
39+
}
40+
41+
; extract element 3 as shift
42+
define i32 @cast_v4i32_to_i128_lshr_96_trunc_i32(<4 x i32> %arg) {
43+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_96_trunc_i32:
44+
; CHECK: ; %bb.0:
45+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
46+
; CHECK-NEXT: v_mov_b32_e32 v0, v3
47+
; CHECK-NEXT: s_setpc_b64 s[30:31]
48+
%bigint = bitcast <4 x i32> %arg to i128
49+
%srl = lshr i128 %bigint, 96
50+
%trunc = trunc i128 %srl to i32
51+
ret i32 %trunc
52+
}
53+
54+
; Shift not aligned to element, not a simple extract
55+
define i32 @cast_v4i32_to_i128_lshr_33_trunc_i32(<4 x i32> %arg) {
56+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_33_trunc_i32:
57+
; CHECK: ; %bb.0:
58+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
59+
; CHECK-NEXT: v_alignbit_b32 v0, v2, v1, 1
60+
; CHECK-NEXT: s_setpc_b64 s[30:31]
61+
%bigint = bitcast <4 x i32> %arg to i128
62+
%srl = lshr i128 %bigint, 33
63+
%trunc = trunc i128 %srl to i32
64+
ret i32 %trunc
65+
}
66+
67+
; extract misaligned element
68+
define i32 @cast_v4i32_to_i128_lshr_31_trunc_i32(<4 x i32> %arg) {
69+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_31_trunc_i32:
70+
; CHECK: ; %bb.0:
71+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
72+
; CHECK-NEXT: v_alignbit_b32 v0, v1, v0, 31
73+
; CHECK-NEXT: s_setpc_b64 s[30:31]
74+
%bigint = bitcast <4 x i32> %arg to i128
75+
%srl = lshr i128 %bigint, 31
76+
%trunc = trunc i128 %srl to i32
77+
ret i32 %trunc
78+
}
79+
80+
; extract misaligned element
81+
define i32 @cast_v4i32_to_i128_lshr_48_trunc_i32(<4 x i32> %arg) {
82+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_48_trunc_i32:
83+
; CHECK: ; %bb.0:
84+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
85+
; CHECK-NEXT: s_mov_b32 s4, 0x1000706
86+
; CHECK-NEXT: v_perm_b32 v0, v1, v2, s4
87+
; CHECK-NEXT: s_setpc_b64 s[30:31]
88+
%bigint = bitcast <4 x i32> %arg to i128
89+
%srl = lshr i128 %bigint, 48
90+
%trunc = trunc i128 %srl to i32
91+
ret i32 %trunc
92+
}
93+
94+
; extract elements 1 and 2 with shift
95+
define i64 @cast_v4i32_to_i128_lshr_32_trunc_i64(<4 x i32> %arg) {
96+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i64:
97+
; CHECK: ; %bb.0:
98+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
99+
; CHECK-NEXT: v_mov_b32_e32 v0, v1
100+
; CHECK-NEXT: v_mov_b32_e32 v1, v2
101+
; CHECK-NEXT: s_setpc_b64 s[30:31]
102+
%bigint = bitcast <4 x i32> %arg to i128
103+
%srl = lshr i128 %bigint, 32
104+
%trunc = trunc i128 %srl to i64
105+
ret i64 %trunc
106+
}
107+
108+
; extract elements 2 and 3 with shift
109+
define i64 @cast_v4i32_to_i128_lshr_64_trunc_i64(<4 x i32> %arg) {
110+
; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i64:
111+
; CHECK: ; %bb.0:
112+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
113+
; CHECK-NEXT: v_mov_b32_e32 v1, v3
114+
; CHECK-NEXT: v_mov_b32_e32 v0, v2
115+
; CHECK-NEXT: s_setpc_b64 s[30:31]
116+
%bigint = bitcast <4 x i32> %arg to i128
117+
%srl = lshr i128 %bigint, 64
118+
%trunc = trunc i128 %srl to i64
119+
ret i64 %trunc
120+
}
121+
122+
; FIXME: We don't process this case because we see multiple bitcasts
123+
; before a 32-bit build_vector
124+
define i32 @build_vector_i16_to_shift(i16 %arg0, i16 %arg1, i16 %arg2, i16 %arg3) {
125+
; CHECK-LABEL: build_vector_i16_to_shift:
126+
; CHECK: ; %bb.0:
127+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
128+
; CHECK-NEXT: s_mov_b32 s4, 0x5040100
129+
; CHECK-NEXT: v_perm_b32 v0, v3, v2, s4
130+
; CHECK-NEXT: s_setpc_b64 s[30:31]
131+
%ins.0 = insertelement <4 x i16> poison, i16 %arg0, i32 0
132+
%ins.1 = insertelement <4 x i16> %ins.0, i16 %arg1, i32 1
133+
%ins.2 = insertelement <4 x i16> %ins.1, i16 %arg2, i32 2
134+
%ins.3 = insertelement <4 x i16> %ins.2, i16 %arg3, i32 3
135+
136+
%cast = bitcast <4 x i16> %ins.3 to i64
137+
%srl = lshr i64 %cast, 32
138+
%trunc = trunc i64 %srl to i32
139+
ret i32 %trunc
140+
}

0 commit comments

Comments
 (0)