Skip to content

[WebAssembly] Implement all f16x8 binary instructions. #93360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/BuiltinsWebAssembly.def
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ TARGET_BUILTIN(__builtin_wasm_min_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_max_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmin_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmax_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_min_f16x8, "V8hV8hV8h", "nc", "half-precision")
TARGET_BUILTIN(__builtin_wasm_max_f16x8, "V8hV8hV8h", "nc", "half-precision")
TARGET_BUILTIN(__builtin_wasm_pmin_f16x8, "V8hV8hV8h", "nc", "half-precision")
TARGET_BUILTIN(__builtin_wasm_pmax_f16x8, "V8hV8hV8h", "nc", "half-precision")

TARGET_BUILTIN(__builtin_wasm_ceil_f32x4, "V4fV4f", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_floor_f32x4, "V4fV4f", "nc", "simd128")
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20779,6 +20779,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_min_f32:
case WebAssembly::BI__builtin_wasm_min_f64:
case WebAssembly::BI__builtin_wasm_min_f16x8:
case WebAssembly::BI__builtin_wasm_min_f32x4:
case WebAssembly::BI__builtin_wasm_min_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
Expand All @@ -20789,6 +20790,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_max_f32:
case WebAssembly::BI__builtin_wasm_max_f64:
case WebAssembly::BI__builtin_wasm_max_f16x8:
case WebAssembly::BI__builtin_wasm_max_f32x4:
case WebAssembly::BI__builtin_wasm_max_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
Expand All @@ -20797,6 +20799,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::maximum, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
case WebAssembly::BI__builtin_wasm_pmin_f16x8:
case WebAssembly::BI__builtin_wasm_pmin_f32x4:
case WebAssembly::BI__builtin_wasm_pmin_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
Expand All @@ -20805,6 +20808,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::wasm_pmin, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
case WebAssembly::BI__builtin_wasm_pmax_f16x8:
case WebAssembly::BI__builtin_wasm_pmax_f32x4:
case WebAssembly::BI__builtin_wasm_pmax_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
Expand Down
24 changes: 24 additions & 0 deletions clang/test/CodeGen/builtins-wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,30 @@ float extract_lane_f16x8(f16x8 a, int i) {
// WEBASSEMBLY-NEXT: ret float %0
return __builtin_wasm_extract_lane_f16x8(a, i);
}

f16x8 min_f16x8(f16x8 a, f16x8 b) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.minimum.v8f16(<8 x half> %a, <8 x half> %b)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_min_f16x8(a, b);
}

f16x8 max_f16x8(f16x8 a, f16x8 b) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.maximum.v8f16(<8 x half> %a, <8 x half> %b)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_max_f16x8(a, b);
}

f16x8 pmin_f16x8(f16x8 a, f16x8 b) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_pmin_f16x8(a, b);
}

f16x8 pmax_f16x8(f16x8 a, f16x8 b) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_pmax_f16x8(a, b);
}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setTruncStoreAction(T, MVT::f16, Expand);
}

if (Subtarget->hasHalfPrecision()) {
setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal);
setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal);
}

// Expand unavailable integer operations.
for (auto Op :
{ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU,
Expand Down
43 changes: 34 additions & 9 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,34 @@
multiclass ABSTRACT_SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r,
string asmstr_s, bits<32> simdop,
Predicate simd_level> {
list<Predicate> reqs> {
defm "" : I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, asmstr_s,
!if(!ge(simdop, 0x100),
!or(0xfd0000, !and(0xffff, simdop)),
!or(0xfd00, !and(0xff, simdop)))>,
Requires<[simd_level]>;
Requires<reqs>;
}

multiclass SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
string asmstr_s = "", bits<32> simdop = -1,
list<Predicate> reqs = []> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
asmstr_s, simdop, HasSIMD128>;
asmstr_s, simdop, !listconcat([HasSIMD128], reqs)>;
}

multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
asmstr_s, simdop, HasRelaxedSIMD>;
asmstr_s, simdop, [HasRelaxedSIMD]>;
}

multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
asmstr_s, simdop, HasHalfPrecision>;
asmstr_s, simdop, [HasHalfPrecision]>;
}


Expand Down Expand Up @@ -152,6 +153,19 @@ def F64x2 : Vec {
let prefix = "f64x2";
}

def F16x8 : Vec {
let vt = v8f16;
let int_vt = v8i16;
let lane_vt = f32;
let lane_rc = F32;
let lane_bits = 16;
let lane_idx = LaneIdx8;
let lane_load = int_wasm_loadf16_f32;
let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>;
let prefix = "f16x8";
}

// TODO: Include F16x8 here when half precision is better supported.
defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we can't add F16x8 here, it's not "all vectors" anymore... Should we rename it to something? If so, to what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope to include F16x8 here when we better support it and the regular patterns work for it. I've added a comment for now, but can change the name if wanted.

defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];

Expand Down Expand Up @@ -781,13 +795,19 @@ def : Pat<(v2i64 (nodes[0] (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
// Bitwise operations
//===----------------------------------------------------------------------===//

multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
bits<32> simdop, list<Predicate> reqs = []> {
defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs),
(outs), (ins),
[(set (vec.vt V128:$dst),
(node (vec.vt V128:$lhs), (vec.vt V128:$rhs)))],
vec.prefix#"."#name#"\t$dst, $lhs, $rhs",
vec.prefix#"."#name, simdop>;
vec.prefix#"."#name, simdop, reqs>;
}

multiclass HalfPrecisionBinary<Vec vec, SDPatternOperator node, string name,
bits<32> simdop> {
defm "" : SIMDBinary<vec, node, name, simdop, [HasHalfPrecision]>;
}

multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
Expand Down Expand Up @@ -1199,6 +1219,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
defm "" : SIMDBinary<F32x4, node, name, baseInst>;
defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
defm "" : HalfPrecisionBinary<F16x8, node, name, !add(baseInst, 80)>;
}

// Addition: add
Expand Down Expand Up @@ -1242,7 +1263,7 @@ defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>;
// Also match the pmin/pmax cases where the operands are int vectors (but the
// comparison is still a floating point comparison). This can happen when using
// the wasm_simd128.h intrinsics because v128_t is an integer vector.
foreach vec = [F32x4, F64x2] in {
foreach vec = [F32x4, F64x2, F16x8] in {
defvar pmin = !cast<NI>("PMIN_"#vec);
defvar pmax = !cast<NI>("PMAX_"#vec);
def : Pat<(vec.int_vt (vselect
Expand All @@ -1266,6 +1287,10 @@ def : Pat<(v2f64 (int_wasm_pmin (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMIN_F64x2 V128:$lhs, V128:$rhs)>;
def : Pat<(v2f64 (int_wasm_pmax (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMAX_F64x2 V128:$lhs, V128:$rhs)>;
def : Pat<(v8f16 (int_wasm_pmin (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
(PMIN_F16x8 V128:$lhs, V128:$rhs)>;
def : Pat<(v8f16 (int_wasm_pmax (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
(PMAX_F16x8 V128:$lhs, V128:$rhs)>;

//===----------------------------------------------------------------------===//
// Conversions
Expand Down
68 changes: 68 additions & 0 deletions llvm/test/CodeGen/WebAssembly/half-precision.ll
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,71 @@ define float @extract_lane_v8f16(<8 x half> %v) {
%r = call float @llvm.wasm.extract.lane.f16x8(<8 x half> %v, i32 1)
ret float %r
}

; CHECK-LABEL: add_v8f16:
; CHECK: f16x8.add $push0=, $0, $1
; CHECK-NEXT: return $pop0
define <8 x half> @add_v8f16(<8 x half> %a, <8 x half> %b) {
%r = fadd <8 x half> %a, %b
ret <8 x half> %r
}

; CHECK-LABEL: sub_v8f16:
; CHECK: f16x8.sub $push0=, $0, $1
; CHECK-NEXT: return $pop0
define <8 x half> @sub_v8f16(<8 x half> %a, <8 x half> %b) {
%r = fsub <8 x half> %a, %b
ret <8 x half> %r
}

; CHECK-LABEL: mul_v8f16:
; CHECK: f16x8.mul $push0=, $0, $1
; CHECK-NEXT: return $pop0
define <8 x half> @mul_v8f16(<8 x half> %a, <8 x half> %b) {
%r = fmul <8 x half> %a, %b
ret <8 x half> %r
}

; CHECK-LABEL: div_v8f16:
; CHECK: f16x8.div $push0=, $0, $1
; CHECK-NEXT: return $pop0
define <8 x half> @div_v8f16(<8 x half> %a, <8 x half> %b) {
%r = fdiv <8 x half> %a, %b
ret <8 x half> %r
}

; CHECK-LABEL: min_intrinsic_v8f16:
; CHECK: f16x8.min $push0=, $0, $1
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.minimum.v8f16(<8 x half>, <8 x half>)
define <8 x half> @min_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
%a = call <8 x half> @llvm.minimum.v8f16(<8 x half> %x, <8 x half> %y)
ret <8 x half> %a
}

; CHECK-LABEL: max_intrinsic_v8f16:
; CHECK: f16x8.max $push0=, $0, $1
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.maximum.v8f16(<8 x half>, <8 x half>)
define <8 x half> @max_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
%a = call <8 x half> @llvm.maximum.v8f16(<8 x half> %x, <8 x half> %y)
ret <8 x half> %a
}

; CHECK-LABEL: pmin_intrinsic_v8f16:
; CHECK: f16x8.pmin $push0=, $0, $1
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.wasm.pmin.v8f16(<8 x half>, <8 x half>)
define <8 x half> @pmin_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
%v = call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
ret <8 x half> %v
}

; CHECK-LABEL: pmax_intrinsic_v8f16:
; CHECK: f16x8.pmax $push0=, $0, $1
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.wasm.pmax.v8f16(<8 x half>, <8 x half>)
define <8 x half> @pmax_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
%v = call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
ret <8 x half> %v
}
24 changes: 24 additions & 0 deletions llvm/test/MC/WebAssembly/simd-encodings.s
Original file line number Diff line number Diff line change
Expand Up @@ -851,4 +851,28 @@ main:
# CHECK: f16x8.extract_lane 1 # encoding: [0xfd,0xa1,0x02,0x01]
f16x8.extract_lane 1

# CHECK: f16x8.add # encoding: [0xfd,0xb4,0x02]
f16x8.add

# CHECK: f16x8.sub # encoding: [0xfd,0xb5,0x02]
f16x8.sub

# CHECK: f16x8.mul # encoding: [0xfd,0xb6,0x02]
f16x8.mul

# CHECK: f16x8.div # encoding: [0xfd,0xb7,0x02]
f16x8.div

# CHECK: f16x8.min # encoding: [0xfd,0xb8,0x02]
f16x8.min

# CHECK: f16x8.max # encoding: [0xfd,0xb9,0x02]
f16x8.max

# CHECK: f16x8.pmin # encoding: [0xfd,0xba,0x02]
f16x8.pmin

# CHECK: f16x8.pmax # encoding: [0xfd,0xbb,0x02]
f16x8.pmax

end_function
Loading