-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[WebAssembly] Implement all f16x8 binary instructions. #93360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented. add/sub/mul/div - use regular LL instructions min/max - use the minimum/maximum intrinsic, and also have builtins pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins Specified at: https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md
@llvm/pr-subscribers-mc @llvm/pr-subscribers-clang Author: Brendan Dahl (brendandahl) ChangesThis reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented. add/sub/mul/div - use regular LL instructions Full diff: https://github.com/llvm/llvm-project/pull/93360.diff 7 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index fd8c1b480d6da..4e48ff48b60f5 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -135,6 +135,10 @@ TARGET_BUILTIN(__builtin_wasm_min_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_max_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmin_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmax_f64x2, "V2dV2dV2d", "nc", "simd128")
+TARGET_BUILTIN(__builtin_wasm_min_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_max_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmin_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmax_f16x8, "V8hV8hV8h", "nc", "half-precision")
TARGET_BUILTIN(__builtin_wasm_ceil_f32x4, "V4fV4f", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_floor_f32x4, "V4fV4f", "nc", "simd128")
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0549afa12e430..f8be7182b5267 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -20779,6 +20779,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_min_f32:
case WebAssembly::BI__builtin_wasm_min_f64:
+ case WebAssembly::BI__builtin_wasm_min_f16x8:
case WebAssembly::BI__builtin_wasm_min_f32x4:
case WebAssembly::BI__builtin_wasm_min_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20789,6 +20790,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_max_f32:
case WebAssembly::BI__builtin_wasm_max_f64:
+ case WebAssembly::BI__builtin_wasm_max_f16x8:
case WebAssembly::BI__builtin_wasm_max_f32x4:
case WebAssembly::BI__builtin_wasm_max_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20797,6 +20799,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::maximum, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
+ case WebAssembly::BI__builtin_wasm_pmin_f16x8:
case WebAssembly::BI__builtin_wasm_pmin_f32x4:
case WebAssembly::BI__builtin_wasm_pmin_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20805,6 +20808,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::wasm_pmin, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
+ case WebAssembly::BI__builtin_wasm_pmax_f16x8:
case WebAssembly::BI__builtin_wasm_pmax_f32x4:
case WebAssembly::BI__builtin_wasm_pmax_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c
index 93a6ab06081c9..d6ee4f68700dc 100644
--- a/clang/test/CodeGen/builtins-wasm.c
+++ b/clang/test/CodeGen/builtins-wasm.c
@@ -825,6 +825,30 @@ float extract_lane_f16x8(f16x8 a, int i) {
// WEBASSEMBLY-NEXT: ret float %0
return __builtin_wasm_extract_lane_f16x8(a, i);
}
+
+f16x8 min_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.minimum.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_min_f16x8(a, b);
+}
+
+f16x8 max_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.maximum.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_max_f16x8(a, b);
+}
+
+f16x8 pmin_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_pmin_f16x8(a, b);
+}
+
+f16x8 pmax_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_pmax_f16x8(a, b);
+}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 518b6932a0c87..7cbae1bef8ef4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -142,6 +142,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setTruncStoreAction(T, MVT::f16, Expand);
}
+ if (Subtarget->hasHalfPrecision()) {
+ setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal);
+ setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal);
+ }
+
// Expand unavailable integer operations.
for (auto Op :
{ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 558e3d859dcd8..83260fbaa700b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -16,33 +16,34 @@
multiclass ABSTRACT_SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r,
string asmstr_s, bits<32> simdop,
- Predicate simd_level> {
+ list<Predicate> reqs> {
defm "" : I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, asmstr_s,
!if(!ge(simdop, 0x100),
!or(0xfd0000, !and(0xffff, simdop)),
!or(0xfd00, !and(0xff, simdop)))>,
- Requires<[simd_level]>;
+ Requires<reqs>;
}
multiclass SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
- string asmstr_s = "", bits<32> simdop = -1> {
+ string asmstr_s = "", bits<32> simdop = -1,
+ list<Predicate> reqs = []> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasSIMD128>;
+ asmstr_s, simdop, !listconcat([HasSIMD128], reqs)>;
}
multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasRelaxedSIMD>;
+ asmstr_s, simdop, [HasRelaxedSIMD]>;
}
multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasHalfPrecision>;
+ asmstr_s, simdop, [HasHalfPrecision]>;
}
@@ -152,6 +153,18 @@ def F64x2 : Vec {
let prefix = "f64x2";
}
+def F16x8 : Vec {
+ let vt = v8f16;
+ let int_vt = v8i16;
+ let lane_vt = f32;
+ let lane_rc = F32;
+ let lane_bits = 16;
+ let lane_idx = LaneIdx8;
+ let lane_load = int_wasm_loadf16_f32;
+ let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>;
+ let prefix = "f16x8";
+}
+
defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];
@@ -781,13 +794,14 @@ def : Pat<(v2i64 (nodes[0] (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
// Bitwise operations
//===----------------------------------------------------------------------===//
-multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
+multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
+ bits<32> simdop, list<Predicate> reqs = []> {
defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs),
(outs), (ins),
[(set (vec.vt V128:$dst),
(node (vec.vt V128:$lhs), (vec.vt V128:$rhs)))],
vec.prefix#"."#name#"\t$dst, $lhs, $rhs",
- vec.prefix#"."#name, simdop>;
+ vec.prefix#"."#name, simdop, reqs>;
}
multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
@@ -1199,6 +1213,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
defm "" : SIMDBinary<F32x4, node, name, baseInst>;
defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
+ defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>;
}
// Addition: add
@@ -1242,7 +1257,7 @@ defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>;
// Also match the pmin/pmax cases where the operands are int vectors (but the
// comparison is still a floating point comparison). This can happen when using
// the wasm_simd128.h intrinsics because v128_t is an integer vector.
-foreach vec = [F32x4, F64x2] in {
+foreach vec = [F32x4, F64x2, F16x8] in {
defvar pmin = !cast<NI>("PMIN_"#vec);
defvar pmax = !cast<NI>("PMAX_"#vec);
def : Pat<(vec.int_vt (vselect
@@ -1266,6 +1281,10 @@ def : Pat<(v2f64 (int_wasm_pmin (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMIN_F64x2 V128:$lhs, V128:$rhs)>;
def : Pat<(v2f64 (int_wasm_pmax (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMAX_F64x2 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmin (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+ (PMIN_F16x8 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmax (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+ (PMAX_F16x8 V128:$lhs, V128:$rhs)>;
//===----------------------------------------------------------------------===//
// Conversions
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index d9d3f6be800fd..73ccea8d652db 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -35,3 +35,71 @@ define float @extract_lane_v8f16(<8 x half> %v) {
%r = call float @llvm.wasm.extract.lane.f16x8(<8 x half> %v, i32 1)
ret float %r
}
+
+; CHECK-LABEL: add_v8f16:
+; CHECK: f16x8.add $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @add_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fadd <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: sub_v8f16:
+; CHECK: f16x8.sub $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @sub_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fsub <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: mul_v8f16:
+; CHECK: f16x8.mul $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @mul_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fmul <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: div_v8f16:
+; CHECK: f16x8.div $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @div_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fdiv <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: min_intrinsic_v8f16:
+; CHECK: f16x8.min $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.minimum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @min_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+ %a = call <8 x half> @llvm.minimum.v8f16(<8 x half> %x, <8 x half> %y)
+ ret <8 x half> %a
+}
+
+; CHECK-LABEL: max_intrinsic_v8f16:
+; CHECK: f16x8.max $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.maximum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @max_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+ %a = call <8 x half> @llvm.maximum.v8f16(<8 x half> %x, <8 x half> %y)
+ ret <8 x half> %a
+}
+
+; CHECK-LABEL: pmin_intrinsic_v8f16:
+; CHECK: f16x8.pmin $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.wasm.pmin.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmin_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+ %v = call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+ ret <8 x half> %v
+}
+
+; CHECK-LABEL: pmax_intrinsic_v8f16:
+; CHECK: f16x8.pmax $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.wasm.pmax.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmax_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+ %v = call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+ ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index d397188a9882e..113a23da776fa 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -851,4 +851,28 @@ main:
# CHECK: f16x8.extract_lane 1 # encoding: [0xfd,0xa1,0x02,0x01]
f16x8.extract_lane 1
+ # CHECK: f16x8.add # encoding: [0xfd,0xb4,0x02]
+ f16x8.add
+
+ # CHECK: f16x8.sub # encoding: [0xfd,0xb5,0x02]
+ f16x8.sub
+
+ # CHECK: f16x8.mul # encoding: [0xfd,0xb6,0x02]
+ f16x8.mul
+
+ # CHECK: f16x8.div # encoding: [0xfd,0xb7,0x02]
+ f16x8.div
+
+ # CHECK: f16x8.min # encoding: [0xfd,0xb8,0x02]
+ f16x8.min
+
+ # CHECK: f16x8.max # encoding: [0xfd,0xb9,0x02]
+ f16x8.max
+
+ # CHECK: f16x8.pmin # encoding: [0xfd,0xba,0x02]
+ f16x8.pmin
+
+ # CHECK: f16x8.pmax # encoding: [0xfd,0xbb,0x02]
+ f16x8.pmax
+
end_function
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you have commit access!
@@ -1199,6 +1213,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>; | |||
multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> { | |||
defm "" : SIMDBinary<F32x4, node, name, baseInst>; | |||
defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>; | |||
defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand why it's added, and I wish we can multi-inherit from SIMDBinary
and HALF_PRECISION_I
, but I'm not sure if we can do it... (Can we?)
I'm not strongly opinionated about it and it's basically just a matter of preference, but how about adding a multiclass
like HalfPrecisionBinary
or something that inherits from SIMDBinary
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up adding HalfPrecisionBinary
. I was hoping there was some way I could pass a multiclass id as a parameter so i could then pass in SIMD_I
or HALF_PRECISION_I
as an argument, but I couldn't figure out a way to make that work.
let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>; | ||
let prefix = "f16x8"; | ||
} | ||
|
||
defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we can't add F16x8
here, it's not "all vectors" anymore... Should we rename it to something? If so, to what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope to include F16x8 here when we better support it and the regular patterns work for it. I've added a comment for now, but can change the name if wanted.
This reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented. add/sub/mul/div - use regular LL instructions min/max - use the minimum/maximum intrinsic, and also have builtins pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins Specified at: https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md
This reuses most of the code that was created for f32x4 and f64x2 binary instructions and tries to follow how they were implemented.
add/sub/mul/div - use regular LL instructions
min/max - use the minimum/maximum intrinsic, and also have builtins
pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins
Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md