Skip to content

Commit d5ab38f

Browse files
authored
[RISCV] Support select/merge like ops for bf16 vectors when have Zvfbfmin (#91936)
1 parent 4b70294 commit d5ab38f

11 files changed

+1150
-38
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11021102
ISD::EXTRACT_SUBVECTOR},
11031103
VT, Custom);
11041104
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1105+
if (Subtarget.hasStdExtZfbfmin()) {
1106+
if (Subtarget.hasVInstructionsF16())
1107+
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1108+
else if (Subtarget.hasVInstructionsF16Minimal())
1109+
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1110+
}
1111+
setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
1112+
Custom);
1113+
setOperationAction(ISD::SELECT_CC, VT, Expand);
11051114
// TODO: Promote to fp32.
11061115
}
11071116
}
@@ -1331,6 +1340,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13311340
ISD::EXTRACT_SUBVECTOR},
13321341
VT, Custom);
13331342
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1343+
if (Subtarget.hasStdExtZfbfmin()) {
1344+
if (Subtarget.hasVInstructionsF16())
1345+
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1346+
else if (Subtarget.hasVInstructionsF16Minimal())
1347+
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1348+
}
1349+
setOperationAction(
1350+
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
1351+
Custom);
13341352
// TODO: Promote to fp32.
13351353
continue;
13361354
}
@@ -6704,10 +6722,16 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
67046722
case ISD::BUILD_VECTOR:
67056723
return lowerBUILD_VECTOR(Op, DAG, Subtarget);
67066724
case ISD::SPLAT_VECTOR:
6707-
if (Op.getValueType().getScalarType() == MVT::f16 &&
6708-
(Subtarget.hasVInstructionsF16Minimal() &&
6709-
!Subtarget.hasVInstructionsF16())) {
6710-
if (Op.getValueType() == MVT::nxv32f16)
6725+
if ((Op.getValueType().getScalarType() == MVT::f16 &&
6726+
(Subtarget.hasVInstructionsF16Minimal() &&
6727+
Subtarget.hasStdExtZfhminOrZhinxmin() &&
6728+
!Subtarget.hasVInstructionsF16())) ||
6729+
(Op.getValueType().getScalarType() == MVT::bf16 &&
6730+
(Subtarget.hasVInstructionsBF16() && Subtarget.hasStdExtZfbfmin() &&
6731+
Subtarget.hasVInstructionsF16Minimal() &&
6732+
!Subtarget.hasVInstructionsF16()))) {
6733+
if (Op.getValueType() == MVT::nxv32f16 ||
6734+
Op.getValueType() == MVT::nxv32bf16)
67116735
return SplitVectorOp(Op, DAG);
67126736
SDLoc DL(Op);
67136737
SDValue NewScalar =

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,20 @@ class GetIntVTypeInfo<VTypeInfo vti> {
382382
// Equivalent integer vector type. Eg.
383383
// VI8M1 → VI8M1 (identity)
384384
// VF64M4 → VI64M4
385-
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VF", "VI", !cast<string>(vti)));
385+
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VI",
386+
!subst("VF", "VI",
387+
!cast<string>(vti))));
388+
}
389+
390+
// This functor is used to obtain the fp vector type that has the same SEW and
391+
// multiplier as the input parameter type.
392+
class GetFpVTypeInfo<VTypeInfo vti> {
393+
// Equivalent integer vector type. Eg.
394+
// VF16M1 → VF16M1 (identity)
395+
// VBF16M1 → VF16M1
396+
VTypeInfo Vti = !cast<VTypeInfo>(!subst("VBF", "VF",
397+
!subst("VI", "VF",
398+
!cast<string>(vti))));
386399
}
387400

388401
class MTypeInfo<ValueType Mas, LMULInfo M, string Bx> {

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ defm : VPatFPSetCCSDNode_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">;
13941394
// Floating-point vselects:
13951395
// 11.15. Vector Integer Merge Instructions
13961396
// 13.15. Vector Floating-Point Merge Instruction
1397-
foreach fvti = AllFloatVectors in {
1397+
foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
13981398
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
13991399
let Predicates = GetVTypePredicates<ivti>.Predicates in {
14001400
def : Pat<(fvti.Vector (vselect (fvti.Mask V0), fvti.RegClass:$rs1,
@@ -1412,7 +1412,9 @@ foreach fvti = AllFloatVectors in {
14121412
fvti.RegClass:$rs2, 0, (fvti.Mask V0), fvti.AVL, fvti.Log2SEW)>;
14131413

14141414
}
1415-
let Predicates = GetVTypePredicates<fvti>.Predicates in
1415+
1416+
let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
1417+
GetVTypeScalarPredicates<fvti>.Predicates) in
14161418
def : Pat<(fvti.Vector (vselect (fvti.Mask V0),
14171419
(SplatFPOp fvti.ScalarRegClass:$rs1),
14181420
fvti.RegClass:$rs2)),
@@ -1475,7 +1477,7 @@ foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
14751477
//===----------------------------------------------------------------------===//
14761478

14771479
foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
1478-
let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
1480+
let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
14791481
GetVTypeScalarPredicates<fvti>.Predicates) in
14801482
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl undef, fvti.ScalarRegClass:$rs1, srcvalue)),
14811483
(!cast<Instruction>("PseudoVFMV_V_"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,7 +2604,7 @@ foreach vti = AllFloatVectors in {
26042604
}
26052605
}
26062606

2607-
foreach fvti = AllFloatVectors in {
2607+
foreach fvti = !listconcat(AllFloatVectors, AllBFloatVectors) in {
26082608
// Floating-point vselects:
26092609
// 11.15. Vector Integer Merge Instructions
26102610
// 13.15. Vector Floating-Point Merge Instruction
@@ -2639,7 +2639,8 @@ foreach fvti = AllFloatVectors in {
26392639
GPR:$vl, fvti.Log2SEW)>;
26402640
}
26412641

2642-
let Predicates = GetVTypePredicates<fvti>.Predicates in {
2642+
let Predicates = !listconcat(GetVTypePredicates<GetFpVTypeInfo<fvti>.Vti>.Predicates,
2643+
GetVTypeScalarPredicates<fvti>.Predicates) in {
26432644
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
26442645
(SplatFPOp fvti.ScalarRegClass:$rs1),
26452646
fvti.RegClass:$rs2,

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-select-fp.ll

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -target-abi=ilp32d \
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d \
33
; RUN: -verify-machineinstrs < %s | FileCheck %s
4-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -target-abi=lp64d \
4+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d \
55
; RUN: -verify-machineinstrs < %s | FileCheck %s
6-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
6+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
77
; RUN: -verify-machineinstrs < %s | FileCheck %s
8-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \
8+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+m,+experimental-zfbfmin,+experimental-zvfbfmin -target-abi=lp64d -riscv-v-vector-bits-min=128 \
99
; RUN: -verify-machineinstrs < %s | FileCheck %s
1010

1111
define <2 x half> @select_v2f16(i1 zeroext %c, <2 x half> %a, <2 x half> %b) {
@@ -343,3 +343,123 @@ define <16 x double> @selectcc_v16f64(double %a, double %b, <16 x double> %c, <1
343343
%v = select i1 %cmp, <16 x double> %c, <16 x double> %d
344344
ret <16 x double> %v
345345
}
346+
347+
define <2 x bfloat> @select_v2bf16(i1 zeroext %c, <2 x bfloat> %a, <2 x bfloat> %b) {
348+
; CHECK-LABEL: select_v2bf16:
349+
; CHECK: # %bb.0:
350+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
351+
; CHECK-NEXT: vmv.v.x v10, a0
352+
; CHECK-NEXT: vmsne.vi v0, v10, 0
353+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
354+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
355+
; CHECK-NEXT: ret
356+
%v = select i1 %c, <2 x bfloat> %a, <2 x bfloat> %b
357+
ret <2 x bfloat> %v
358+
}
359+
360+
define <2 x bfloat> @selectcc_v2bf16(bfloat %a, bfloat %b, <2 x bfloat> %c, <2 x bfloat> %d) {
361+
; CHECK-LABEL: selectcc_v2bf16:
362+
; CHECK: # %bb.0:
363+
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
364+
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
365+
; CHECK-NEXT: feq.s a0, fa4, fa5
366+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
367+
; CHECK-NEXT: vmv.v.x v10, a0
368+
; CHECK-NEXT: vmsne.vi v0, v10, 0
369+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
370+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
371+
; CHECK-NEXT: ret
372+
%cmp = fcmp oeq bfloat %a, %b
373+
%v = select i1 %cmp, <2 x bfloat> %c, <2 x bfloat> %d
374+
ret <2 x bfloat> %v
375+
}
376+
377+
define <4 x bfloat> @select_v4bf16(i1 zeroext %c, <4 x bfloat> %a, <4 x bfloat> %b) {
378+
; CHECK-LABEL: select_v4bf16:
379+
; CHECK: # %bb.0:
380+
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
381+
; CHECK-NEXT: vmv.v.x v10, a0
382+
; CHECK-NEXT: vmsne.vi v0, v10, 0
383+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
384+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
385+
; CHECK-NEXT: ret
386+
%v = select i1 %c, <4 x bfloat> %a, <4 x bfloat> %b
387+
ret <4 x bfloat> %v
388+
}
389+
390+
define <4 x bfloat> @selectcc_v4bf16(bfloat %a, bfloat %b, <4 x bfloat> %c, <4 x bfloat> %d) {
391+
; CHECK-LABEL: selectcc_v4bf16:
392+
; CHECK: # %bb.0:
393+
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
394+
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
395+
; CHECK-NEXT: feq.s a0, fa4, fa5
396+
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
397+
; CHECK-NEXT: vmv.v.x v10, a0
398+
; CHECK-NEXT: vmsne.vi v0, v10, 0
399+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
400+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
401+
; CHECK-NEXT: ret
402+
%cmp = fcmp oeq bfloat %a, %b
403+
%v = select i1 %cmp, <4 x bfloat> %c, <4 x bfloat> %d
404+
ret <4 x bfloat> %v
405+
}
406+
407+
define <8 x bfloat> @select_v8bf16(i1 zeroext %c, <8 x bfloat> %a, <8 x bfloat> %b) {
408+
; CHECK-LABEL: select_v8bf16:
409+
; CHECK: # %bb.0:
410+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
411+
; CHECK-NEXT: vmv.v.x v10, a0
412+
; CHECK-NEXT: vmsne.vi v0, v10, 0
413+
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
414+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
415+
; CHECK-NEXT: ret
416+
%v = select i1 %c, <8 x bfloat> %a, <8 x bfloat> %b
417+
ret <8 x bfloat> %v
418+
}
419+
420+
define <8 x bfloat> @selectcc_v8bf16(bfloat %a, bfloat %b, <8 x bfloat> %c, <8 x bfloat> %d) {
421+
; CHECK-LABEL: selectcc_v8bf16:
422+
; CHECK: # %bb.0:
423+
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
424+
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
425+
; CHECK-NEXT: feq.s a0, fa4, fa5
426+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
427+
; CHECK-NEXT: vmv.v.x v10, a0
428+
; CHECK-NEXT: vmsne.vi v0, v10, 0
429+
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
430+
; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0
431+
; CHECK-NEXT: ret
432+
%cmp = fcmp oeq bfloat %a, %b
433+
%v = select i1 %cmp, <8 x bfloat> %c, <8 x bfloat> %d
434+
ret <8 x bfloat> %v
435+
}
436+
437+
define <16 x bfloat> @select_v16bf16(i1 zeroext %c, <16 x bfloat> %a, <16 x bfloat> %b) {
438+
; CHECK-LABEL: select_v16bf16:
439+
; CHECK: # %bb.0:
440+
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
441+
; CHECK-NEXT: vmv.v.x v12, a0
442+
; CHECK-NEXT: vmsne.vi v0, v12, 0
443+
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
444+
; CHECK-NEXT: vmerge.vvm v8, v10, v8, v0
445+
; CHECK-NEXT: ret
446+
%v = select i1 %c, <16 x bfloat> %a, <16 x bfloat> %b
447+
ret <16 x bfloat> %v
448+
}
449+
450+
define <16 x bfloat> @selectcc_v16bf16(bfloat %a, bfloat %b, <16 x bfloat> %c, <16 x bfloat> %d) {
451+
; CHECK-LABEL: selectcc_v16bf16:
452+
; CHECK: # %bb.0:
453+
; CHECK-NEXT: fcvt.s.bf16 fa5, fa1
454+
; CHECK-NEXT: fcvt.s.bf16 fa4, fa0
455+
; CHECK-NEXT: feq.s a0, fa4, fa5
456+
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
457+
; CHECK-NEXT: vmv.v.x v12, a0
458+
; CHECK-NEXT: vmsne.vi v0, v12, 0
459+
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
460+
; CHECK-NEXT: vmerge.vvm v8, v10, v8, v0
461+
; CHECK-NEXT: ret
462+
%cmp = fcmp oeq bfloat %a, %b
463+
%v = select i1 %cmp, <16 x bfloat> %c, <16 x bfloat> %d
464+
ret <16 x bfloat> %v
465+
}

0 commit comments

Comments
 (0)