Skip to content

Commit cb1fed3

Browse files
committed
[NVPTX] Correctly guard int -> bf16 on PTX version and SM version
1 parent 7fa8585 commit cb1fed3

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,13 +788,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
788788

789789
// sm_80 only has conversions between f32 and bf16. Custom lower all other
790790
// bf16 conversions.
791-
if (STI.hasBF16Math() &&
792-
(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
791+
if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
793792
for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
794793
setOperationAction(
795794
{ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
796795
VT, Custom);
797796
}
797+
setOperationAction(
798+
{ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
799+
MVT::bf16, Custom);
798800
}
799801

800802
setOperationAction(ISD::FROUND, MVT::f16, Promote);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3247,23 +3247,23 @@ def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
32473247

32483248
// sint -> bf16
32493249
def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
3250-
(CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
3250+
(CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32513251
def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
3252-
(CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
3252+
(CVT_bf16_s16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32533253
def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
3254-
(CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
3254+
(CVT_bf16_s32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32553255
def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
3256-
(CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
3256+
(CVT_bf16_s64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32573257

32583258
// uint -> bf16
32593259
def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
3260-
(CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
3260+
(CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32613261
def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
3262-
(CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
3262+
(CVT_bf16_u16 Int16Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32633263
def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
3264-
(CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
3264+
(CVT_bf16_u32 Int32Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32653265
def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
3266-
(CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
3266+
(CVT_bf16_u64 Int64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
32673267

32683268
// sint -> f32
32693269
def : Pat<(f32 (sint_to_fp Int1Regs:$a)),

0 commit comments

Comments
 (0)