Skip to content

Commit 02dfbbf

Browse files
authored
[SelectionDAG] Make ARITH_FENCE support half and bfloat type (#90836)
1 parent 7a484d3 commit 02dfbbf

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
28632863
report_fatal_error("Do not know how to soft promote this operator's "
28642864
"result!");
28652865

2866+
case ISD::ARITH_FENCE:
2867+
R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
28662868
case ISD::BITCAST: R = SoftPromoteHalfRes_BITCAST(N); break;
28672869
case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
28682870
case ISD::EXTRACT_VECTOR_ELT:
@@ -2942,6 +2944,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
29422944
SetSoftPromotedHalf(SDValue(N, ResNo), R);
29432945
}
29442946

2947+
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
2948+
return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
2949+
BitConvertToInteger(N->getOperand(0)));
2950+
}
2951+
29452952
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
29462953
return BitConvertToInteger(N->getOperand(0));
29472954
}

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
728728
void SetSoftPromotedHalf(SDValue Op, SDValue Result);
729729

730730
void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
731+
SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
731732
SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
732733
SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
733734
SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);

llvm/test/CodeGen/X86/arithmetic_fence2.ll

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,160 @@ define <8 x float> @f6(<8 x float> %a) {
157157
ret <8 x float> %3
158158
}
159159

160+
define half @f7(half %a) nounwind {
161+
; X86-LABEL: f7:
162+
; X86: # %bb.0:
163+
; X86-NEXT: pinsrw $0, {{[0-9]+}}(%esp), %xmm0
164+
; X86-NEXT: #ARITH_FENCE
165+
; X86-NEXT: retl
166+
;
167+
; X64-LABEL: f7:
168+
; X64: # %bb.0:
169+
; X64-NEXT: #ARITH_FENCE
170+
; X64-NEXT: retq
171+
%b = call half @llvm.arithmetic.fence.f16(half %a)
172+
ret half %b
173+
}
174+
175+
define bfloat @f8(bfloat %a) nounwind {
176+
; X86-LABEL: f8:
177+
; X86: # %bb.0:
178+
; X86-NEXT: movzwl {{[0-9]+}}(%esp), %eax
179+
; X86-NEXT: #ARITH_FENCE
180+
; X86-NEXT: pinsrw $0, %eax, %xmm0
181+
; X86-NEXT: retl
182+
;
183+
; X64-LABEL: f8:
184+
; X64: # %bb.0:
185+
; X64-NEXT: pextrw $0, %xmm0, %eax
186+
; X64-NEXT: #ARITH_FENCE
187+
; X64-NEXT: pinsrw $0, %eax, %xmm0
188+
; X64-NEXT: retq
189+
%b = call bfloat @llvm.arithmetic.fence.bf16(bfloat %a)
190+
ret bfloat %b
191+
}
192+
193+
define <2 x half> @f9(<2 x half> %a) nounwind {
194+
; X86-LABEL: f9:
195+
; X86: # %bb.0:
196+
; X86-NEXT: movdqa %xmm0, %xmm1
197+
; X86-NEXT: psrld $16, %xmm1
198+
; X86-NEXT: #ARITH_FENCE
199+
; X86-NEXT: #ARITH_FENCE
200+
; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
201+
; X86-NEXT: retl
202+
;
203+
; X64-LABEL: f9:
204+
; X64: # %bb.0:
205+
; X64-NEXT: movdqa %xmm0, %xmm1
206+
; X64-NEXT: psrld $16, %xmm1
207+
; X64-NEXT: #ARITH_FENCE
208+
; X64-NEXT: #ARITH_FENCE
209+
; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
210+
; X64-NEXT: retq
211+
%b = call <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half> %a)
212+
ret <2 x half> %b
213+
}
214+
215+
define <3 x bfloat> @f10(<3 x bfloat> %a) nounwind {
216+
; X86-LABEL: f10:
217+
; X86: # %bb.0:
218+
; X86-NEXT: pextrw $0, %xmm0, %eax
219+
; X86-NEXT: movdqa %xmm0, %xmm1
220+
; X86-NEXT: psrld $16, %xmm1
221+
; X86-NEXT: pextrw $0, %xmm1, %ecx
222+
; X86-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
223+
; X86-NEXT: pextrw $0, %xmm0, %edx
224+
; X86-NEXT: #ARITH_FENCE
225+
; X86-NEXT: #ARITH_FENCE
226+
; X86-NEXT: #ARITH_FENCE
227+
; X86-NEXT: pinsrw $0, %eax, %xmm0
228+
; X86-NEXT: pinsrw $0, %ecx, %xmm1
229+
; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
230+
; X86-NEXT: pinsrw $0, %edx, %xmm1
231+
; X86-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
232+
; X86-NEXT: retl
233+
;
234+
; X64-LABEL: f10:
235+
; X64: # %bb.0:
236+
; X64-NEXT: pextrw $0, %xmm0, %eax
237+
; X64-NEXT: movdqa %xmm0, %xmm1
238+
; X64-NEXT: psrld $16, %xmm1
239+
; X64-NEXT: pextrw $0, %xmm1, %ecx
240+
; X64-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
241+
; X64-NEXT: pextrw $0, %xmm0, %edx
242+
; X64-NEXT: #ARITH_FENCE
243+
; X64-NEXT: #ARITH_FENCE
244+
; X64-NEXT: #ARITH_FENCE
245+
; X64-NEXT: pinsrw $0, %eax, %xmm0
246+
; X64-NEXT: pinsrw $0, %ecx, %xmm1
247+
; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
248+
; X64-NEXT: pinsrw $0, %edx, %xmm1
249+
; X64-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
250+
; X64-NEXT: retq
251+
%b = call <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat> %a)
252+
ret <3 x bfloat> %b
253+
}
254+
255+
define <4 x bfloat> @f11(<4 x bfloat> %a) nounwind {
256+
; X86-LABEL: f11:
257+
; X86: # %bb.0:
258+
; X86-NEXT: pushl %esi
259+
; X86-NEXT: movdqa %xmm0, %xmm1
260+
; X86-NEXT: psrlq $48, %xmm1
261+
; X86-NEXT: pextrw $0, %xmm1, %eax
262+
; X86-NEXT: movdqa %xmm0, %xmm1
263+
; X86-NEXT: shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1]
264+
; X86-NEXT: pextrw $0, %xmm1, %edx
265+
; X86-NEXT: pextrw $0, %xmm0, %ecx
266+
; X86-NEXT: psrld $16, %xmm0
267+
; X86-NEXT: pextrw $0, %xmm0, %esi
268+
; X86-NEXT: #ARITH_FENCE
269+
; X86-NEXT: #ARITH_FENCE
270+
; X86-NEXT: #ARITH_FENCE
271+
; X86-NEXT: #ARITH_FENCE
272+
; X86-NEXT: pinsrw $0, %eax, %xmm0
273+
; X86-NEXT: pinsrw $0, %edx, %xmm1
274+
; X86-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
275+
; X86-NEXT: pinsrw $0, %ecx, %xmm0
276+
; X86-NEXT: pinsrw $0, %esi, %xmm2
277+
; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
278+
; X86-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
279+
; X86-NEXT: popl %esi
280+
; X86-NEXT: retl
281+
;
282+
; X64-LABEL: f11:
283+
; X64: # %bb.0:
284+
; X64-NEXT: movdqa %xmm0, %xmm1
285+
; X64-NEXT: psrlq $48, %xmm1
286+
; X64-NEXT: pextrw $0, %xmm1, %eax
287+
; X64-NEXT: movdqa %xmm0, %xmm1
288+
; X64-NEXT: shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1]
289+
; X64-NEXT: pextrw $0, %xmm1, %ecx
290+
; X64-NEXT: pextrw $0, %xmm0, %edx
291+
; X64-NEXT: psrld $16, %xmm0
292+
; X64-NEXT: pextrw $0, %xmm0, %esi
293+
; X64-NEXT: #ARITH_FENCE
294+
; X64-NEXT: #ARITH_FENCE
295+
; X64-NEXT: #ARITH_FENCE
296+
; X64-NEXT: #ARITH_FENCE
297+
; X64-NEXT: pinsrw $0, %eax, %xmm0
298+
; X64-NEXT: pinsrw $0, %ecx, %xmm1
299+
; X64-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
300+
; X64-NEXT: pinsrw $0, %edx, %xmm0
301+
; X64-NEXT: pinsrw $0, %esi, %xmm2
302+
; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
303+
; X64-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
304+
; X64-NEXT: retq
305+
%b = call <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat> %a)
306+
ret <4 x bfloat> %b
307+
}
308+
309+
declare half @llvm.arithmetic.fence.f16(half)
310+
declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
311+
declare <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half>)
312+
declare <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat>)
313+
declare <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat>)
160314
declare float @llvm.arithmetic.fence.f32(float)
161315
declare double @llvm.arithmetic.fence.f64(double)
162316
declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

0 commit comments

Comments
 (0)