Skip to content

[SelectionDAG] Make ARITH_FENCE support half and bfloat type #90836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 5, 2024

Conversation

phoebewang
Copy link
Contributor

No description provided.

@phoebewang phoebewang requested review from RKSimon and topperc May 2, 2024 09:16
@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels May 2, 2024
@llvmbot
Copy link
Member

llvmbot commented May 2, 2024

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90836.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/test/CodeGen/X86/arithmetic_fence2.ll (+85)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index abe5be76382556..00f94e48a3f9ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to soft promote this operator's "
                        "result!");
 
+  case ISD::ARITH_FENCE:
+    R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
   case ISD::BITCAST:    R = SoftPromoteHalfRes_BITCAST(N); break;
   case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     SetSoftPromotedHalf(SDValue(N, ResNo), R);
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+  return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+                     BitConvertToInteger(N->getOperand(0)));
+}
+
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
   return BitConvertToInteger(N->getOperand(0));
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 49be824deb5134..e9714f6f72b6bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SetSoftPromotedHalf(SDValue Op, SDValue Result);
 
   void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+  SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
   SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
   SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
   SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);
diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..bc80e1a70112d3 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,91 @@ define <8 x float> @f6(<8 x float> %a) {
   ret <8 x float> %3
 }
 
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86:       # %bb.0:
+; X86-NEXT:    subl $12, %esp
+; X86-NEXT:    pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    addl $12, %esp
+; X86-NEXT:    retl
+;
+; X64-LABEL: f7:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    callq __extendhfsf2@PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2@PLT
+; X64-NEXT:    callq __extendhfsf2@PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2@PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast half %a, %a
+  %t = call half @llvm.arithmetic.fence.f16(half %1)
+  %2 = fadd fast half %a, %a
+  %3 = fadd fast half %1, %2
+  ret half %3
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    popl %eax
+; X86-NEXT:    retl
+;
+; X64-LABEL: f8:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2@PLT
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2@PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast bfloat %a, %a
+  %t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
+  %2 = fadd fast bfloat %a, %a
+  %3 = fadd fast bfloat %1, %2
+  ret bfloat %3
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
 declare float @llvm.arithmetic.fence.f32(float)
 declare double @llvm.arithmetic.fence.f64(double)
 declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

@llvmbot
Copy link
Member

llvmbot commented May 2, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Phoebe Wang (phoebewang)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90836.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/test/CodeGen/X86/arithmetic_fence2.ll (+85)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index abe5be76382556..00f94e48a3f9ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to soft promote this operator's "
                        "result!");
 
+  case ISD::ARITH_FENCE:
+    R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
   case ISD::BITCAST:    R = SoftPromoteHalfRes_BITCAST(N); break;
   case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     SetSoftPromotedHalf(SDValue(N, ResNo), R);
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+  return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+                     BitConvertToInteger(N->getOperand(0)));
+}
+
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
   return BitConvertToInteger(N->getOperand(0));
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 49be824deb5134..e9714f6f72b6bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SetSoftPromotedHalf(SDValue Op, SDValue Result);
 
   void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+  SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
   SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
   SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
   SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);
diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..bc80e1a70112d3 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,91 @@ define <8 x float> @f6(<8 x float> %a) {
   ret <8 x float> %3
 }
 
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86:       # %bb.0:
+; X86-NEXT:    subl $12, %esp
+; X86-NEXT:    pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    addl $12, %esp
+; X86-NEXT:    retl
+;
+; X64-LABEL: f7:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    callq __extendhfsf2@PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2@PLT
+; X64-NEXT:    callq __extendhfsf2@PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2@PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast half %a, %a
+  %t = call half @llvm.arithmetic.fence.f16(half %1)
+  %2 = fadd fast half %a, %a
+  %3 = fadd fast half %1, %2
+  ret half %3
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    popl %eax
+; X86-NEXT:    retl
+;
+; X64-LABEL: f8:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2@PLT
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2@PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast bfloat %a, %a
+  %t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
+  %2 = fadd fast bfloat %a, %a
+  %3 = fadd fast bfloat %1, %2
+  ret bfloat %3
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
 declare float @llvm.arithmetic.fence.f32(float)
 declare double @llvm.arithmetic.fence.f64(double)
 declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

Copy link

github-actions bot commented May 2, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff a015f015db21e02cbce4ff9d15d0b293e45d0831 1cf434579aa84faeb1eb2e501b2e2603b40fa4fc -- llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 00f94e48a3..362e924d3c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2826,7 +2826,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
                        "result!");
 
   case ISD::ARITH_FENCE:
-    R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
+    R = SoftPromoteHalfRes_ARITH_FENCE(N);
+    break;
   case ISD::BITCAST:    R = SoftPromoteHalfRes_BITCAST(N); break;
   case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
   case ISD::EXTRACT_VECTOR_ELT:

%1 = fadd fast bfloat %a, %a
%t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
%2 = fadd fast bfloat %a, %a
%3 = fadd fast bfloat %1, %2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use named values in tests. Also drop the flags

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can drop all the instructions that aren't just the fence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

%3 = fadd fast bfloat %1, %2
ret bfloat %3
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test vectors too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but should also test the 3x case

%b = call <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half> %a)
ret <2 x half> %b
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You skipped 3 x, the case most likely to break

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@phoebewang phoebewang merged commit 02dfbbf into llvm:main May 5, 2024
3 of 4 checks passed
@phoebewang phoebewang deleted the arithmetic_fence branch May 5, 2024 05:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants