Skip to content

[WebAssembly] Support promoting lower lanes of f16x8 to f32x4. #129786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions clang/lib/Headers/wasm_simd128.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ typedef int __i32x2 __attribute__((__vector_size__(8), __aligned__(8)));
typedef unsigned int __u32x2
__attribute__((__vector_size__(8), __aligned__(8)));
typedef float __f32x2 __attribute__((__vector_size__(8), __aligned__(8)));
typedef __fp16 __f16x4 __attribute__((__vector_size__(8), __aligned__(8)));

#define __DEFAULT_FN_ATTRS \
__attribute__((__always_inline__, __nodebug__, __target__("simd128"), \
Expand Down Expand Up @@ -2010,6 +2011,14 @@ static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_convert_u16x8(v128_t __a) {
return (v128_t) __builtin_convertvector((__u16x8)__a, __f16x8);
}

static __inline__ v128_t __FP16_FN_ATTRS
wasm_f32x4_promote_low_f16x8(v128_t __a) {
return (v128_t) __builtin_convertvector(
(__f16x4){((__f16x8)__a)[0], ((__f16x8)__a)[1], ((__f16x8)__a)[2],
((__f16x8)__a)[3]},
__f32x4);
}

static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_relaxed_madd(v128_t __a,
v128_t __b,
v128_t __c) {
Expand Down
6 changes: 6 additions & 0 deletions cross-project-tests/intrinsic-header-tests/wasm_simd128.c
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,12 @@ v128_t test_f64x2_promote_low_f32x4(v128_t a) {
return wasm_f64x2_promote_low_f32x4(a);
}

// CHECK-LABEL: test_f32x4_promote_low_f16x8:
// CHECK: f32x4.promote_low_f16x8{{$}}
v128_t test_f32x4_promote_low_f16x8(v128_t a) {
return wasm_f32x4_promote_low_f16x8(a);
}

// CHECK-LABEL: test_i8x16_shuffle:
// CHECK: i8x16.shuffle 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
// 0{{$}}
Expand Down
55 changes: 40 additions & 15 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2341,7 +2341,7 @@ WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,

static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
SDLoc DL(Op);
if (Op.getValueType() != MVT::v2f64)
if (Op.getValueType() != MVT::v2f64 && Op.getValueType() != MVT::v4f32)
return SDValue();

auto GetConvertedLane = [](SDValue Op, unsigned &Opcode, SDValue &SrcVec,
Expand All @@ -2354,6 +2354,7 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
Opcode = WebAssemblyISD::CONVERT_LOW_U;
break;
case ISD::FP_EXTEND:
case ISD::FP16_TO_FP:
Opcode = WebAssemblyISD::PROMOTE_LOW;
break;
default:
Expand All @@ -2372,36 +2373,60 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
return true;
};

unsigned LHSOpcode, RHSOpcode, LHSIndex, RHSIndex;
SDValue LHSSrcVec, RHSSrcVec;
if (!GetConvertedLane(Op.getOperand(0), LHSOpcode, LHSSrcVec, LHSIndex) ||
!GetConvertedLane(Op.getOperand(1), RHSOpcode, RHSSrcVec, RHSIndex))
unsigned NumLanes = Op.getValueType() == MVT::v2f64 ? 2 : 4;
unsigned FirstOpcode = 0, SecondOpcode = 0, ThirdOpcode = 0, FourthOpcode = 0;
unsigned FirstIndex = 0, SecondIndex = 0, ThirdIndex = 0, FourthIndex = 0;
SDValue FirstSrcVec, SecondSrcVec, ThirdSrcVec, FourthSrcVec;

if (!GetConvertedLane(Op.getOperand(0), FirstOpcode, FirstSrcVec,
FirstIndex) ||
!GetConvertedLane(Op.getOperand(1), SecondOpcode, SecondSrcVec,
SecondIndex))
return SDValue();

// If we're converting to v4f32, check the third and fourth lanes, too.
if (NumLanes == 4 && (!GetConvertedLane(Op.getOperand(2), ThirdOpcode,
ThirdSrcVec, ThirdIndex) ||
!GetConvertedLane(Op.getOperand(3), FourthOpcode,
FourthSrcVec, FourthIndex)))
return SDValue();

if (FirstOpcode != SecondOpcode)
return SDValue();

if (LHSOpcode != RHSOpcode)
// TODO Add an optimization similar to the v2f64 below for shuffling the
// vectors when the lanes are in the wrong order or come from different src
// vectors.
if (NumLanes == 4 &&
(FirstOpcode != ThirdOpcode || FirstOpcode != FourthOpcode ||
FirstSrcVec != SecondSrcVec || FirstSrcVec != ThirdSrcVec ||
FirstSrcVec != FourthSrcVec || FirstIndex != 0 || SecondIndex != 1 ||
ThirdIndex != 2 || FourthIndex != 3))
return SDValue();

MVT ExpectedSrcVT;
switch (LHSOpcode) {
switch (FirstOpcode) {
case WebAssemblyISD::CONVERT_LOW_S:
case WebAssemblyISD::CONVERT_LOW_U:
ExpectedSrcVT = MVT::v4i32;
break;
case WebAssemblyISD::PROMOTE_LOW:
ExpectedSrcVT = MVT::v4f32;
ExpectedSrcVT = NumLanes == 2 ? MVT::v4f32 : MVT::v8i16;
break;
}
if (LHSSrcVec.getValueType() != ExpectedSrcVT)
if (FirstSrcVec.getValueType() != ExpectedSrcVT)
return SDValue();

auto Src = LHSSrcVec;
if (LHSIndex != 0 || RHSIndex != 1 || LHSSrcVec != RHSSrcVec) {
auto Src = FirstSrcVec;
if (NumLanes == 2 &&
(FirstIndex != 0 || SecondIndex != 1 || FirstSrcVec != SecondSrcVec)) {
// Shuffle the source vector so that the converted lanes are the low lanes.
Src = DAG.getVectorShuffle(
ExpectedSrcVT, DL, LHSSrcVec, RHSSrcVec,
{static_cast<int>(LHSIndex), static_cast<int>(RHSIndex) + 4, -1, -1});
Src = DAG.getVectorShuffle(ExpectedSrcVT, DL, FirstSrcVec, SecondSrcVec,
{static_cast<int>(FirstIndex),
static_cast<int>(SecondIndex) + 4, -1, -1});
}
return DAG.getNode(LHSOpcode, DL, MVT::v2f64, Src);
return DAG.getNode(FirstOpcode, DL, NumLanes == 2 ? MVT::v2f64 : MVT::v4f32,
Src);
}

SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,8 @@ defm "" : SIMDConvert<F32x4, F64x2, demote_zero,
def promote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
def promote_low : SDNode<"WebAssemblyISD::PROMOTE_LOW", promote_t>;
defm "" : SIMDConvert<F64x2, F32x4, promote_low, "promote_low_f32x4", 0x5f>;
defm "" : HalfPrecisionConvert<F32x4, I16x8, promote_low, "promote_low_f16x8",
0x14b>;

// Lower extending loads to load64_zero + promote_low
def extloadv2f32 : PatFrag<(ops node:$ptr), (extload node:$ptr)> {
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/WebAssembly/half-precision.ll
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,23 @@ define <8 x half> @shuffle_poison_v8f16(<8 x half> %x, <8 x half> %y) {
i32 poison, i32 poison, i32 poison, i32 poison>
ret <8 x half> %res
}

define <4 x float> @promote_low_v4f32(<8 x half> %x) {
; CHECK-LABEL: promote_low_v4f32:
; CHECK: .functype promote_low_v4f32 (v128) -> (v128){{$}}
; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0
; CHECK-NEXT: return $pop[[R]]
%v = shufflevector <8 x half> %x, <8 x half> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
%a = fpext <4 x half> %v to <4 x float>
ret <4 x float> %a
}

define <4 x float> @promote_low_v4f32_2(<8 x half> %x) {
; CHECK-LABEL: promote_low_v4f32_2:
; CHECK: .functype promote_low_v4f32_2 (v128) -> (v128)
; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0
; CHECK-NEXT: return $pop[[R]]
%v = fpext <8 x half> %x to <8 x float>
%a = shufflevector <8 x float> %v, <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x float> %a
}
3 changes: 3 additions & 0 deletions llvm/test/MC/WebAssembly/simd-encodings.s
Original file line number Diff line number Diff line change
Expand Up @@ -935,4 +935,7 @@ main:
# CHECK: f16x8.convert_i16x8_u # encoding: [0xfd,0xc8,0x02]
f16x8.convert_i16x8_u

# CHECK: f32x4.promote_low_f16x8 # encoding: [0xfd,0xcb,0x02]
f32x4.promote_low_f16x8

end_function