Skip to content

Commit 0ea4fb9

Browse files
[AMD][ROCDL] Add packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 in ROCDL dialect (#131850)
- Add packed conversions fp8/bf8->bf16 for gfx950 and fp8/bf8->fp32 for gfx942 in ROCDL dialect - Update amdgpu.ext_packed_fp8 lowering to use ROCDL packed fp8/bf8->f32 conversions for vector target types and ROCDL scalar fp8/bf8->fp32 for scalar target type. --------- Co-authored-by: Jungwook Park <[email protected]>
1 parent f0eeb9f commit 0ea4fb9

File tree

11 files changed

+309
-199
lines changed

11 files changed

+309
-199
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def AMDGPU_ExtPackedFp8Op :
8686
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
8787
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
8888
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
89-
Results<(outs F32:$res)> {
90-
let summary = "Extend one of a vector of packed fp8 values to a float";
89+
Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
90+
let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
91+
9192
let description = [{
92-
Extend the value `source[index]` to a 32-bit float and return it.
93+
Extend one or two 8-bit floats in `source[index]` to a 32-bit float or
94+
two floats and return them.
9395

9496
This rather unusual signature arises from the fact that AMD GPUs cannot
9597
easily work with sub 32-bit quantities, so the compiler intrinsics for
9698
extending 8-bit floats (which are, currently, the only way to work with
9799
this operation) take packed vectors of 4 such floats.
98100

99101
If the passed-in vector has fewer than four elements, or the input is scalar,
100-
the remaining values in the <4 x i8> will be filled with with
102+
the remaining values in the <4 x i8> will be filled with
101103
undefined values as needed.
102104
}];
103105
let assemblyFormat = [{

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -681,158 +681,184 @@ def ROCDL_CvtPkRtz:
681681
}];
682682
}
683683

684-
def ROCDL_CvtScaleF32PkFp8F16 :
684+
def ROCDL_CvtScaleF32PkFp8F16Op :
685685
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
686686
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
687687
let summary = "Scale and convert f16 to packed fp8";
688688
let description = [{
689-
Scale `src` by the exponent in `scale` then convert to packed fp8.
690-
Store the result in low/high word based on $wordSel, preserving the other word.
689+
Scale `src` by the exponent in `scale`, then convert to packed fp8.
690+
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
691691
}];
692692
let assemblyFormat = [{
693693
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
694694
}];
695695
}
696696

697-
def ROCDL_CvtScaleF32PkFp8Bf16 :
697+
def ROCDL_CvtScaleF32PkFp8Bf16Op :
698698
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
699699
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
700700
let summary = "Scale and convert packed bf16 to packed fp8";
701701
let description = [{
702-
Scale `src` by the exponent in `scale` then convert to packed fp8.
703-
Store the result in low/high word based on $wordSel, preserving the other word.
702+
Scale `src` by the exponent in `scale`, then convert to packed fp8.
703+
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
704704
}];
705705
let assemblyFormat = [{
706706
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
707707
}];
708708
}
709709

710710

711-
def ROCDL_CvtScaleF32PkBf8F16 :
711+
def ROCDL_CvtScaleF32PkBf8F16Op :
712712
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
713713
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
714714
let summary = "Scale and convert f16 to packed bf8";
715715
let description = [{
716-
Scale `src` by the exponent in `scale` then convert to packed bf8.
717-
Store the result in low/high word based on $wordSel, preserving the other word.
716+
Scale `src` by the exponent in `scale`, then convert to packed bf8.
717+
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
718718
}];
719719
let assemblyFormat = [{
720720
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
721721
}];
722722
}
723723

724724

725-
def ROCDL_CvtScaleF32PkBf8Bf16 :
725+
def ROCDL_CvtScaleF32PkBf8Bf16Op :
726726
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
727727
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
728728
let summary = "Scale and convert bf16 to packed bf8";
729729
let description = [{
730-
Scale `src` by the exponent in `scale` then convert to packed bf8.
731-
Store the result in low/high word based on $wordSel, preserving the other word.
730+
Scale `src` by the exponent in `scale`, then convert to packed bf8.
731+
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
732732
}];
733733
let assemblyFormat = [{
734734
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
735735
}];
736736
}
737737

738-
def ROCDL_CvtScaleF32SrFp8F16 :
738+
def ROCDL_CvtScaleF32SrFp8F16Op :
739739
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
740740
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
741741
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
742742
let description = [{
743-
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
744-
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
743+
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
744+
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
745745

746746
}];
747747
let assemblyFormat = [{
748748
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
749749
}];
750750
}
751751

752-
def ROCDL_CvtScaleF32SrBf8F16 :
752+
def ROCDL_CvtScaleF32SrBf8F16Op :
753753
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
754754
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
755755
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
756756
let description = [{
757-
Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
758-
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
757+
Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
758+
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
759759

760760
}];
761761
let assemblyFormat = [{
762762
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
763763
}];
764764
}
765765

766-
def ROCDL_CvtScaleF32SrFp8Bf16 :
766+
def ROCDL_CvtScaleF32SrFp8Bf16Op :
767767
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
768768
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
769769
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
770770
let description = [{
771-
Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
772-
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
771+
Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
772+
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
773773

774774
}];
775775
let assemblyFormat = [{
776776
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
777777
}];
778778
}
779779

780-
def ROCDL_CvtScaleF32SrBf8Bf16:
780+
def ROCDL_CvtScaleF32SrBf8Bf16Op :
781781
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
782782
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
783783
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
784784
let description = [{
785-
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
786-
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
785+
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
786+
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
787787

788788
}];
789789
let assemblyFormat = [{
790790
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
791791
}];
792792
}
793793

794-
def ROCDL_CvtScaleF32PkF16Fp8 :
794+
def ROCDL_CvtScaleF32PkF16Fp8Op :
795795
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
796796
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
797-
let summary = "Scale and convert fp8 to packed f16";
798-
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
799-
then convert to packed f16.
797+
let summary = "Convert fp8 to packed f16 and scale";
798+
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
799+
the packed values by the exponent in `scale`.
800800
}];
801801
let assemblyFormat = [{
802802
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
803803
}];
804804
}
805805

806-
def ROCDL_CvtScaleF32PkF16Bf8 :
806+
def ROCDL_CvtScaleF32PkF16Bf8Op :
807807
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
808808
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
809-
let summary = "Scale and convert bf8 to packed f16";
810-
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
811-
then convert to packed f16.
809+
let summary = "convert bf8 to packed f16 and scale";
810+
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
811+
the packed values by exponent in `scale`.
812812
}];
813813
let assemblyFormat = [{
814814
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
815815
}];
816816
}
817817

818-
def ROCDL_CvtScaleF16Fp8 :
818+
def ROCDL_CvtScaleF32PkBf16Fp8Op :
819+
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
820+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
821+
let summary = "Convert fp8 to packed bf16 and scale";
822+
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
823+
the packed values by the exponent in `scale`.
824+
}];
825+
let assemblyFormat = [{
826+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
827+
}];
828+
}
829+
830+
def ROCDL_CvtScaleF32PkBf16Bf8Op :
831+
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
832+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
833+
let summary = "Convert bf8 to packed bf16 and scale";
834+
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
835+
the packed values by the exponent in `scale`.
836+
}];
837+
let assemblyFormat = [{
838+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
839+
}];
840+
}
841+
842+
def ROCDL_CvtScaleF16Fp8Op :
819843
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
820844
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
821845
let summary = "Scale and convert fp8 to f16";
822-
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
823-
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
846+
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
847+
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
848+
preserving the others.
824849
}];
825850
let assemblyFormat = [{
826851
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
827852
}];
828853
}
829854

830-
def ROCDL_CvtScaleF16Bf8 :
855+
def ROCDL_CvtScaleF16Bf8Op :
831856
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
832857
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
833858
let summary = "Scale and convert fp8 to f16";
834-
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
835-
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
859+
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
860+
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
861+
preserving the others.
836862
}];
837863
let assemblyFormat = [{
838864
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
@@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
842868
//===---------------------------------------------------------------------===//
843869
// 32-bit float intrinsics
844870
//===---------------------------------------------------------------------===//
845-
def ROCDL_CvtScale32PkF32Fp8 :
871+
def ROCDL_CvtScaleF32PkF32Fp8Op :
846872
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
847873
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
848874
let summary = "Scale and convert packed fp8 to packed f32";
849875
let description = [{
850-
Scale `src` by the exponent in `scale` then convert to packed fp32.
851-
Store the result in low/high word based on $wordSel, preserving the other word.
876+
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
877+
the exponent in `scale`. Store the result in a vector.
852878
}];
853879
let assemblyFormat = [{
854880
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
855881
}];
856882
}
857-
def ROCDL_CvtScale32PkF32Bf8 :
883+
def ROCDL_CvtScaleF32PkF32Bf8Op :
858884
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
859885
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
860886
let summary = "Scale and convert packed bf8 to packed f32";
861887
let description = [{
862-
Scale `src` by the exponent in `scale` then convert to packed fp32.
863-
Store the result in low/high word based on $wordSel, preserving the other word.
888+
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
889+
the exponent in `scale`. Store the result in a vector.
864890
}];
865891
let assemblyFormat = [{
866892
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
@@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
869895
//===---------------------------------------------------------------------===//
870896
// 8-bit float scale intrinsics
871897
//===---------------------------------------------------------------------===//
872-
def ROCDL_CvtScaleF32PkFp8F32:
898+
def ROCDL_CvtScaleF32PkFp8F32Op :
873899
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
874900
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
875901
let summary = "Scale and convert two f32's to packed fp8";
@@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
882908
}];
883909
}
884910

885-
def ROCDL_CvtScaleF32PkBf8F32:
911+
def ROCDL_CvtScaleF32PkBf8F32Op :
886912
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
887913
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
888914
let summary = "Scale and convert two f32's to packed bf8";
@@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
895921
}];
896922
}
897923

898-
def ROCDL_CvtScaleF32SrFp8F32:
924+
def ROCDL_CvtScaleF32SrFp8F32Op :
899925
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
900926
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
901927
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
@@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
909935
}
910936

911937

912-
def ROCDL_CvtScaleF32SrBf8F32:
938+
def ROCDL_CvtScaleF32SrBf8F32Op :
913939
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
914940
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
915941
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
@@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
9781004
}];
9791005
}
9801006

1007+
def ROCDL_CvtPkF32Fp8Op :
1008+
ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
1009+
Arguments<(ins I32:$src, I1:$wordSel)> {
1010+
let summary = "Convert packed fp8 to packed f32";
1011+
let description = [{
1012+
Convert `src` based on $wordSel to packed fp32.
1013+
}];
1014+
let assemblyFormat = [{
1015+
attr-dict $src `[` $wordSel `]` `:` type($res)
1016+
}];
1017+
}
1018+
1019+
def ROCDL_CvtPkF32Bf8Op :
1020+
ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
1021+
Arguments<(ins I32:$src, I1:$wordSel)> {
1022+
let summary = "Convert packed bf8 to packed f32";
1023+
let description = [{
1024+
Convert `src` based on $wordSel to packed fp32,
1025+
}];
1026+
let assemblyFormat = [{
1027+
attr-dict $src `[` $wordSel `]` `:` type($res)
1028+
}];
1029+
}
9811030

9821031
def ROCDL_CvtPkBf8F32Op :
9831032
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
959959

960960
Value source = adaptor.getSource();
961961
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
962+
auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
962963
Type sourceElemType = getElementTypeOrSelf(op.getSource());
963964
// Extend to a v4i8
964965
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
@@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
977978
source = longVec;
978979
}
979980
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
980-
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
981-
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
982-
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
983-
wordSel);
984-
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
985-
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
986-
wordSel);
981+
if (resultVecType) {
982+
Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
983+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
984+
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
985+
wordSel);
986+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
987+
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
988+
wordSel);
989+
}
990+
} else {
991+
Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
992+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
993+
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
994+
byteSel);
995+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
996+
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
997+
byteSel);
998+
}
987999
}
9881000
return success();
9891001
}

0 commit comments

Comments
 (0)