-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[SelectionDAG] Make ARITH_FENCE support half and bfloat type #90836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-x86 Author: Phoebe Wang (phoebewang) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90836.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index abe5be76382556..00f94e48a3f9ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
report_fatal_error("Do not know how to soft promote this operator's "
"result!");
+ case ISD::ARITH_FENCE:
+ R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
case ISD::BITCAST: R = SoftPromoteHalfRes_BITCAST(N); break;
case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
SetSoftPromotedHalf(SDValue(N, ResNo), R);
}
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+ return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+ BitConvertToInteger(N->getOperand(0)));
+}
+
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
return BitConvertToInteger(N->getOperand(0));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 49be824deb5134..e9714f6f72b6bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SetSoftPromotedHalf(SDValue Op, SDValue Result);
void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+ SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);
diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..bc80e1a70112d3 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,91 @@ define <8 x float> @f6(<8 x float> %a) {
ret <8 x float> %3
}
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86: # %bb.0:
+; X86-NEXT: subl $12, %esp
+; X86-NEXT: pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: movw %ax, (%esp)
+; X86-NEXT: calll __extendhfsf2
+; X86-NEXT: fstps {{[0-9]+}}(%esp)
+; X86-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfhf2
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: movw %ax, (%esp)
+; X86-NEXT: calll __extendhfsf2
+; X86-NEXT: fstps {{[0-9]+}}(%esp)
+; X86-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfhf2
+; X86-NEXT: addl $12, %esp
+; X86-NEXT: retl
+;
+; X64-LABEL: f7:
+; X64: # %bb.0:
+; X64-NEXT: pushq %rax
+; X64-NEXT: callq __extendhfsf2@PLT
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfhf2@PLT
+; X64-NEXT: callq __extendhfsf2@PLT
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfhf2@PLT
+; X64-NEXT: popq %rax
+; X64-NEXT: retq
+ %1 = fadd fast half %a, %a
+ %t = call half @llvm.arithmetic.fence.f16(half %1)
+ %2 = fadd fast half %a, %a
+ %3 = fadd fast half %1, %2
+ ret half %3
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86: # %bb.0:
+; X86-NEXT: pushl %eax
+; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT: shll $16, %eax
+; X86-NEXT: movd %eax, %xmm0
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfbf2
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: shll $16, %eax
+; X86-NEXT: movd %eax, %xmm0
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfbf2
+; X86-NEXT: popl %eax
+; X86-NEXT: retl
+;
+; X64-LABEL: f8:
+; X64: # %bb.0:
+; X64-NEXT: pushq %rax
+; X64-NEXT: pextrw $0, %xmm0, %eax
+; X64-NEXT: shll $16, %eax
+; X64-NEXT: movd %eax, %xmm0
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfbf2@PLT
+; X64-NEXT: pextrw $0, %xmm0, %eax
+; X64-NEXT: shll $16, %eax
+; X64-NEXT: movd %eax, %xmm0
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfbf2@PLT
+; X64-NEXT: popq %rax
+; X64-NEXT: retq
+ %1 = fadd fast bfloat %a, %a
+ %t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
+ %2 = fadd fast bfloat %a, %a
+ %3 = fadd fast bfloat %1, %2
+ ret bfloat %3
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
declare float @llvm.arithmetic.fence.f32(float)
declare double @llvm.arithmetic.fence.f64(double)
declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)
|
@llvm/pr-subscribers-llvm-selectiondag Author: Phoebe Wang (phoebewang) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90836.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index abe5be76382556..00f94e48a3f9ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
report_fatal_error("Do not know how to soft promote this operator's "
"result!");
+ case ISD::ARITH_FENCE:
+ R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
case ISD::BITCAST: R = SoftPromoteHalfRes_BITCAST(N); break;
case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
SetSoftPromotedHalf(SDValue(N, ResNo), R);
}
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+ return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+ BitConvertToInteger(N->getOperand(0)));
+}
+
SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
return BitConvertToInteger(N->getOperand(0));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 49be824deb5134..e9714f6f72b6bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SetSoftPromotedHalf(SDValue Op, SDValue Result);
void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+ SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);
diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..bc80e1a70112d3 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,91 @@ define <8 x float> @f6(<8 x float> %a) {
ret <8 x float> %3
}
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86: # %bb.0:
+; X86-NEXT: subl $12, %esp
+; X86-NEXT: pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: movw %ax, (%esp)
+; X86-NEXT: calll __extendhfsf2
+; X86-NEXT: fstps {{[0-9]+}}(%esp)
+; X86-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfhf2
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: movw %ax, (%esp)
+; X86-NEXT: calll __extendhfsf2
+; X86-NEXT: fstps {{[0-9]+}}(%esp)
+; X86-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfhf2
+; X86-NEXT: addl $12, %esp
+; X86-NEXT: retl
+;
+; X64-LABEL: f7:
+; X64: # %bb.0:
+; X64-NEXT: pushq %rax
+; X64-NEXT: callq __extendhfsf2@PLT
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfhf2@PLT
+; X64-NEXT: callq __extendhfsf2@PLT
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfhf2@PLT
+; X64-NEXT: popq %rax
+; X64-NEXT: retq
+ %1 = fadd fast half %a, %a
+ %t = call half @llvm.arithmetic.fence.f16(half %1)
+ %2 = fadd fast half %a, %a
+ %3 = fadd fast half %1, %2
+ ret half %3
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86: # %bb.0:
+; X86-NEXT: pushl %eax
+; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT: shll $16, %eax
+; X86-NEXT: movd %eax, %xmm0
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfbf2
+; X86-NEXT: pextrw $0, %xmm0, %eax
+; X86-NEXT: shll $16, %eax
+; X86-NEXT: movd %eax, %xmm0
+; X86-NEXT: addss %xmm0, %xmm0
+; X86-NEXT: movss %xmm0, (%esp)
+; X86-NEXT: calll __truncsfbf2
+; X86-NEXT: popl %eax
+; X86-NEXT: retl
+;
+; X64-LABEL: f8:
+; X64: # %bb.0:
+; X64-NEXT: pushq %rax
+; X64-NEXT: pextrw $0, %xmm0, %eax
+; X64-NEXT: shll $16, %eax
+; X64-NEXT: movd %eax, %xmm0
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfbf2@PLT
+; X64-NEXT: pextrw $0, %xmm0, %eax
+; X64-NEXT: shll $16, %eax
+; X64-NEXT: movd %eax, %xmm0
+; X64-NEXT: addss %xmm0, %xmm0
+; X64-NEXT: callq __truncsfbf2@PLT
+; X64-NEXT: popq %rax
+; X64-NEXT: retq
+ %1 = fadd fast bfloat %a, %a
+ %t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
+ %2 = fadd fast bfloat %a, %a
+ %3 = fadd fast bfloat %1, %2
+ ret bfloat %3
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
declare float @llvm.arithmetic.fence.f32(float)
declare double @llvm.arithmetic.fence.f64(double)
declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)
|
You can test this locally with the following command:git-clang-format --diff a015f015db21e02cbce4ff9d15d0b293e45d0831 1cf434579aa84faeb1eb2e501b2e2603b40fa4fc -- llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h View the diff from clang-format here.diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 00f94e48a3..362e924d3c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2826,7 +2826,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
"result!");
case ISD::ARITH_FENCE:
- R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
+ R = SoftPromoteHalfRes_ARITH_FENCE(N);
+ break;
case ISD::BITCAST: R = SoftPromoteHalfRes_BITCAST(N); break;
case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
case ISD::EXTRACT_VECTOR_ELT:
|
%1 = fadd fast bfloat %a, %a | ||
%t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1) | ||
%2 = fadd fast bfloat %a, %a | ||
%3 = fadd fast bfloat %1, %2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use named values in tests. Also drop the flags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually can drop all the instructions that aren't just the fence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
%3 = fadd fast bfloat %1, %2 | ||
ret bfloat %3 | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test vectors too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
8ec4468
to
7c01fe3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but should also test the 3x case
%b = call <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half> %a) | ||
ret <2 x half> %b | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You skipped 3 x, the case most likely to break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
No description provided.