Skip to content

Commit dae827c

Browse files
[mlir][IR] Turn FloatType into a type interface
This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
1 parent e7412a5 commit dae827c

File tree

7 files changed

+107
-124
lines changed

7 files changed

+107
-124
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
#include "mlir/IR/Types.h"
1313

14+
namespace llvm {
15+
struct fltSemantics;
16+
} // namespace llvm
17+
18+
namespace mlir {
19+
class FloatType;
20+
class MLIRContext;
21+
} // namespace mlir
22+
1423
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
1524

1625
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,46 @@
1616

1717
include "mlir/IR/OpBase.td"
1818

19+
def FloatTypeInterface : TypeInterface<"FloatType"> {
20+
let cppNamespace = "::mlir";
21+
let methods = [
22+
InterfaceMethod<"TODO", "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
23+
];
24+
25+
let extraClassDeclaration = [{
26+
// Convenience factories.
27+
static FloatType getBF16(MLIRContext *ctx);
28+
static FloatType getF16(MLIRContext *ctx);
29+
static FloatType getF32(MLIRContext *ctx);
30+
static FloatType getTF32(MLIRContext *ctx);
31+
static FloatType getF64(MLIRContext *ctx);
32+
static FloatType getF80(MLIRContext *ctx);
33+
static FloatType getF128(MLIRContext *ctx);
34+
static FloatType getFloat8E5M2(MLIRContext *ctx);
35+
static FloatType getFloat8E4M3(MLIRContext *ctx);
36+
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
37+
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
38+
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
39+
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
40+
static FloatType getFloat8E3M4(MLIRContext *ctx);
41+
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
42+
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
43+
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
44+
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
45+
46+
/// Return the bitwidth of this float type.
47+
unsigned getWidth();
48+
49+
/// Return the width of the mantissa of this type.
50+
/// The width includes the integer bit.
51+
unsigned getFPMantissaWidth();
52+
53+
/// Get or create a new FloatType with bitwidth scaled by `scale`.
54+
/// Return null if the scaled element type cannot be represented.
55+
FloatType scaleElementBitwidth(unsigned scale);
56+
}];
57+
}
58+
1959
//===----------------------------------------------------------------------===//
2060
// MemRefElementTypeInterface
2161
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ struct fltSemantics;
2525
namespace mlir {
2626
class AffineExpr;
2727
class AffineMap;
28-
class FloatType;
2928
class IndexType;
3029
class IntegerType;
3130
class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
4443
class ValueSemantics
4544
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4645

47-
//===----------------------------------------------------------------------===//
48-
// FloatType
49-
//===----------------------------------------------------------------------===//
50-
51-
class FloatType : public Type {
52-
public:
53-
using Type::Type;
54-
55-
// Convenience factories.
56-
static FloatType getBF16(MLIRContext *ctx);
57-
static FloatType getF16(MLIRContext *ctx);
58-
static FloatType getF32(MLIRContext *ctx);
59-
static FloatType getTF32(MLIRContext *ctx);
60-
static FloatType getF64(MLIRContext *ctx);
61-
static FloatType getF80(MLIRContext *ctx);
62-
static FloatType getF128(MLIRContext *ctx);
63-
static FloatType getFloat8E5M2(MLIRContext *ctx);
64-
static FloatType getFloat8E4M3(MLIRContext *ctx);
65-
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
66-
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
67-
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
68-
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
69-
static FloatType getFloat8E3M4(MLIRContext *ctx);
70-
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
71-
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
72-
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
73-
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
74-
75-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
76-
static bool classof(Type type);
77-
78-
/// Return the bitwidth of this float type.
79-
unsigned getWidth();
80-
81-
/// Return the width of the mantissa of this type.
82-
/// The width includes the integer bit.
83-
unsigned getFPMantissaWidth();
84-
85-
/// Get or create a new FloatType with bitwidth scaled by `scale`.
86-
/// Return null if the scaled element type cannot be represented.
87-
FloatType scaleElementBitwidth(unsigned scale);
88-
89-
/// Return the floating semantics of this float type.
90-
const llvm::fltSemantics &getFloatSemantics();
91-
};
92-
9346
//===----------------------------------------------------------------------===//
9447
// TensorType
9548
//===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
448401
llvm::isa<MemRefElementTypeInterface>(type);
449402
}
450403

451-
inline bool FloatType::classof(Type type) {
452-
return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
453-
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
454-
Float8E5M2FNUZType, Float8E4M3FNUZType,
455-
Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
456-
BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
457-
Float64Type, Float80Type, Float128Type>(type);
458-
}
459-
460404
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
461405
return Float4E2M1FNType::get(ctx);
462406
}

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
8080

8181
// Base class for Builtin dialect float types.
8282
class Builtin_FloatType<string name, string mnemonic>
83-
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
83+
: Builtin_Type<name, mnemonic, /*traits=*/[
84+
DeclareTypeInterfaceMethods<FloatTypeInterface,
85+
["getFloatSemantics"]>]> {
8486
let extraClassDeclaration = [{
8587
static }] # name # [{Type get(MLIRContext *context);
8688
}];

mlir/lib/IR/BuiltinTypeInterfaces.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/IR/BuiltinTypes.h"
1010
#include "mlir/IR/Diagnostics.h"
11+
#include "llvm/ADT/APFloat.h"
1112
#include "llvm/ADT/Sequence.h"
1213

1314
using namespace mlir;
@@ -19,6 +20,34 @@ using namespace mlir::detail;
1920

2021
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
2122

23+
//===----------------------------------------------------------------------===//
24+
// FloatType
25+
//===----------------------------------------------------------------------===//
26+
27+
unsigned FloatType::getWidth() {
28+
return APFloat::semanticsSizeInBits(getFloatSemantics());
29+
}
30+
31+
FloatType FloatType::scaleElementBitwidth(unsigned scale) {
32+
if (!scale)
33+
return FloatType();
34+
MLIRContext *ctx = getContext();
35+
if (isF16() || isBF16()) {
36+
if (scale == 2)
37+
return FloatType::getF32(ctx);
38+
if (scale == 4)
39+
return FloatType::getF64(ctx);
40+
}
41+
if (isF32())
42+
if (scale == 2)
43+
return FloatType::getF64(ctx);
44+
return FloatType();
45+
}
46+
47+
unsigned FloatType::getFPMantissaWidth() {
48+
return APFloat::semanticsPrecision(getFloatSemantics());
49+
}
50+
2251
//===----------------------------------------------------------------------===//
2352
// ShapedType
2453
//===----------------------------------------------------------------------===//

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 25 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -87,73 +87,32 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
8787
}
8888

8989
//===----------------------------------------------------------------------===//
90-
// Float Type
91-
//===----------------------------------------------------------------------===//
92-
93-
unsigned FloatType::getWidth() {
94-
return APFloat::semanticsSizeInBits(getFloatSemantics());
95-
}
96-
97-
/// Returns the floating semantics for the given type.
98-
const llvm::fltSemantics &FloatType::getFloatSemantics() {
99-
if (llvm::isa<Float4E2M1FNType>(*this))
100-
return APFloat::Float4E2M1FN();
101-
if (llvm::isa<Float6E2M3FNType>(*this))
102-
return APFloat::Float6E2M3FN();
103-
if (llvm::isa<Float6E3M2FNType>(*this))
104-
return APFloat::Float6E3M2FN();
105-
if (llvm::isa<Float8E5M2Type>(*this))
106-
return APFloat::Float8E5M2();
107-
if (llvm::isa<Float8E4M3Type>(*this))
108-
return APFloat::Float8E4M3();
109-
if (llvm::isa<Float8E4M3FNType>(*this))
110-
return APFloat::Float8E4M3FN();
111-
if (llvm::isa<Float8E5M2FNUZType>(*this))
112-
return APFloat::Float8E5M2FNUZ();
113-
if (llvm::isa<Float8E4M3FNUZType>(*this))
114-
return APFloat::Float8E4M3FNUZ();
115-
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
116-
return APFloat::Float8E4M3B11FNUZ();
117-
if (llvm::isa<Float8E3M4Type>(*this))
118-
return APFloat::Float8E3M4();
119-
if (llvm::isa<Float8E8M0FNUType>(*this))
120-
return APFloat::Float8E8M0FNU();
121-
if (llvm::isa<BFloat16Type>(*this))
122-
return APFloat::BFloat();
123-
if (llvm::isa<Float16Type>(*this))
124-
return APFloat::IEEEhalf();
125-
if (llvm::isa<FloatTF32Type>(*this))
126-
return APFloat::FloatTF32();
127-
if (llvm::isa<Float32Type>(*this))
128-
return APFloat::IEEEsingle();
129-
if (llvm::isa<Float64Type>(*this))
130-
return APFloat::IEEEdouble();
131-
if (llvm::isa<Float80Type>(*this))
132-
return APFloat::x87DoubleExtended();
133-
if (llvm::isa<Float128Type>(*this))
134-
return APFloat::IEEEquad();
135-
llvm_unreachable("non-floating point type used");
136-
}
137-
138-
FloatType FloatType::scaleElementBitwidth(unsigned scale) {
139-
if (!scale)
140-
return FloatType();
141-
MLIRContext *ctx = getContext();
142-
if (isF16() || isBF16()) {
143-
if (scale == 2)
144-
return FloatType::getF32(ctx);
145-
if (scale == 4)
146-
return FloatType::getF64(ctx);
147-
}
148-
if (isF32())
149-
if (scale == 2)
150-
return FloatType::getF64(ctx);
151-
return FloatType();
152-
}
90+
// Float Types
91+
//===----------------------------------------------------------------------===//
15392

154-
unsigned FloatType::getFPMantissaWidth() {
155-
return APFloat::semanticsPrecision(getFloatSemantics());
156-
}
93+
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
94+
const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
95+
return APFloat::SEM(); \
96+
}
97+
FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
98+
FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
99+
FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
100+
FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
101+
FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
102+
FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
103+
FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
104+
FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
105+
FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
106+
FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
107+
FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
108+
FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
109+
FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
110+
FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
111+
FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
112+
FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
113+
FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
114+
FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
115+
#undef FLOAT_TYPE_SEMANTICS
157116

158117
//===----------------------------------------------------------------------===//
159118
// FunctionType

mlir/unittests/IR/InterfaceAttachmentTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct Model
4343
/// overrides default methods.
4444
struct OverridingModel
4545
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
46-
FloatType> {
46+
Float32Type> {
4747
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
4848
return type.getIntOrFloatBitWidth() + arg;
4949
}

0 commit comments

Comments
 (0)