Skip to content

[AMD][ROCDL] Add packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 in ROCDL dialect #131850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,24 @@ def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
let summary = "Extend a vector of packed fp8 values to two floats";

let description = [{
Extend the value `source[index]` to a 32-bit float and return it.
Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.

This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
extending 8-bit floats (which are, currently, the only way to work with
this operation) take packed vectors of 4 such floats.

If the passed-in vector has fewer than four elements, or the input is scalar,
the remaining values in the <4 x i8> will be filled with with
the remaining values in the <4 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
}];
}

Expand Down
145 changes: 97 additions & 48 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -681,158 +681,184 @@ def ROCDL_CvtPkRtz:
}];
}

def ROCDL_CvtScaleF32PkFp8F16 :
def ROCDL_CvtScaleF32PkFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed fp8";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp8.
Store the result in low/high word based on $wordSel, preserving the other word.
Scale `src` by the exponent in `scale`, then convert to packed fp8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkFp8Bf16 :
def ROCDL_CvtScaleF32PkFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf16 to packed fp8";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp8.
Store the result in low/high word based on $wordSel, preserving the other word.
Scale `src` by the exponent in `scale`, then convert to packed fp8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}


def ROCDL_CvtScaleF32PkBf8F16 :
def ROCDL_CvtScaleF32PkBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed bf8";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed bf8.
Store the result in low/high word based on $wordSel, preserving the other word.
Scale `src` by the exponent in `scale`, then convert to packed bf8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}


def ROCDL_CvtScaleF32PkBf8Bf16 :
def ROCDL_CvtScaleF32PkBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf16 to packed bf8";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed bf8.
Store the result in low/high word based on $wordSel, preserving the other word.
Scale `src` by the exponent in `scale`, then convert to packed bf8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrFp8F16 :
def ROCDL_CvtScaleF32SrFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.

}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrBf8F16 :
def ROCDL_CvtScaleF32SrBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.

}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrFp8Bf16 :
def ROCDL_CvtScaleF32SrFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.

}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrBf8Bf16:
def ROCDL_CvtScaleF32SrBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.

}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkF16Fp8 :
def ROCDL_CvtScaleF32PkF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert fp8 to packed f16";
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
then convert to packed f16.
let summary = "Convert fp8 to packed f16 and scale";
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkF16Bf8 :
def ROCDL_CvtScaleF32PkF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf8 to packed f16";
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
then convert to packed f16.
let summary = "convert bf8 to packed f16 and scale";
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
the packed values by exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}

def ROCDL_CvtScaleF16Fp8 :
def ROCDL_CvtScaleF32PkBf16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Convert fp8 to packed bf16 and scale";
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkBf16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Convert bf8 to packed bf16 and scale";
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}

def ROCDL_CvtScaleF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF16Bf8 :
def ROCDL_CvtScaleF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
Expand All @@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtScale32PkF32Fp8 :
def ROCDL_CvtScaleF32PkF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed fp8 to packed f32";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp32.
Store the result in low/high word based on $wordSel, preserving the other word.
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScale32PkF32Bf8 :
def ROCDL_CvtScaleF32PkF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf8 to packed f32";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp32.
Store the result in low/high word based on $wordSel, preserving the other word.
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
Expand All @@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtScaleF32PkFp8F32:
def ROCDL_CvtScaleF32PkFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed fp8";
Expand All @@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
}];
}

def ROCDL_CvtScaleF32PkBf8F32:
def ROCDL_CvtScaleF32PkBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed bf8";
Expand All @@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
}];
}

def ROCDL_CvtScaleF32SrFp8F32:
def ROCDL_CvtScaleF32SrFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
Expand All @@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
}


def ROCDL_CvtScaleF32SrBf8F32:
def ROCDL_CvtScaleF32SrBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
Expand Down Expand Up @@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
}];
}

def ROCDL_CvtPkF32Fp8Op :
ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, I1:$wordSel)> {
let summary = "Convert packed fp8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `:` type($res)
}];
}

def ROCDL_CvtPkF32Bf8Op :
ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, I1:$wordSel)> {
let summary = "Convert packed bf8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32,
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `:` type($res)
}];
}

def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,13 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
wordSel);
}
return success();
}
Expand Down
Loading