Skip to content

Commit 676bfb2

Browse files
committed
[mlir] Refactor ShapedType into an interface
ShapedType was created in a time before interfaces, and is one of the earliest type base classes in the ecosystem. This commit refactors ShapedType into an interface, which is what it would have been if interfaces had existed at that time. The API of ShapedType and it's derived classes are essentially untouched by this refactor, with the exception being the API surrounding kDynamicIndex (which requires a sole home). For now, the API of ShapedType and its name have been kept as consistent to the current state of the world as possible (to help with potential migration churn, among other reasons). Moving forward though, we should look into potentially restructuring its API and possible its name as well (it should really have "Interface" at the end like other interfaces at the very least). One other potentially interesting note is that I've attached the ShapedType::Trait to TensorType/BaseMemRefType to act as mixins for the ShapedType API. This is kind of weird, but allows for sharing the same API (i.e. preventing API loss from the transition from base class -> Interface). This inheritance doesn't affect any of the derived classes, it is just for API mixin. Differential Revision: https://reviews.llvm.org/D116962
1 parent a60e83f commit 676bfb2

File tree

14 files changed

+396
-307
lines changed

14 files changed

+396
-307
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,151 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
4141
}];
4242
}
4343

44+
//===----------------------------------------------------------------------===//
45+
// ShapedType
46+
//===----------------------------------------------------------------------===//
47+
48+
def ShapedTypeInterface : TypeInterface<"ShapedType"> {
49+
let cppNamespace = "::mlir";
50+
let description = [{
51+
This interface provides a common API for interacting with multi-dimensional
52+
container types. These types contain a shape and an element type.
53+
54+
A shape is a list of sizes corresponding to the dimensions of the container.
55+
If the number of dimensions in the shape is unknown, the shape is "unranked".
56+
If the number of dimensions is known, the shape "ranked". The sizes of the
57+
dimensions of the shape must be positive, or kDynamicSize (in which case the
58+
size of the dimension is dynamic, or not statically known).
59+
}];
60+
let methods = [
61+
InterfaceMethod<[{
62+
Returns a clone of this type with the given shape and element
63+
type. If a shape is not provided, the current shape of the type is used.
64+
}],
65+
"::mlir::ShapedType", "cloneWith", (ins
66+
"::llvm::Optional<::llvm::ArrayRef<int64_t>>":$shape,
67+
"::mlir::Type":$elementType
68+
)>,
69+
70+
InterfaceMethod<[{
71+
Returns the element type of this shaped type.
72+
}],
73+
"::mlir::Type", "getElementType">,
74+
75+
InterfaceMethod<[{
76+
Returns if this type is ranked, i.e. it has a known number of dimensions.
77+
}],
78+
"bool", "hasRank">,
79+
80+
InterfaceMethod<[{
81+
Returns the shape of this type if it is ranked, otherwise asserts.
82+
}],
83+
"::llvm::ArrayRef<int64_t>", "getShape">,
84+
];
85+
86+
let extraClassDeclaration = [{
87+
// TODO: merge these two special values in a single one used everywhere.
88+
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
89+
// hard to track.
90+
static constexpr int64_t kDynamicSize = -1;
91+
static constexpr int64_t kDynamicStrideOrOffset =
92+
std::numeric_limits<int64_t>::min();
93+
94+
/// Whether the given dimension size indicates a dynamic dimension.
95+
static constexpr bool isDynamic(int64_t dSize) {
96+
return dSize == kDynamicSize;
97+
}
98+
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
99+
return dStrideOrOffset == kDynamicStrideOrOffset;
100+
}
101+
102+
/// Return the number of elements present in the given shape.
103+
static int64_t getNumElements(ArrayRef<int64_t> shape);
104+
105+
/// Returns the total amount of bits occupied by a value of this type. This
106+
/// does not take into account any memory layout or widening constraints,
107+
/// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
108+
/// practice it will likely be stored as in a 4xi64 vector register. Fails
109+
/// with an assertion if the size cannot be computed statically, e.g. if the
110+
/// type has a dynamic shape or if its elemental type does not have a known
111+
/// bit width.
112+
int64_t getSizeInBits() const;
113+
}];
114+
115+
let extraSharedClassDeclaration = [{
116+
/// Return a clone of this type with the given new shape and element type.
117+
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
118+
return $_type.cloneWith(shape, elementType);
119+
}
120+
/// Return a clone of this type with the given new shape.
121+
auto clone(::llvm::ArrayRef<int64_t> shape) {
122+
return $_type.cloneWith(shape, $_type.getElementType());
123+
}
124+
/// Return a clone of this type with the given new element type.
125+
auto clone(::mlir::Type elementType) {
126+
return $_type.cloneWith(/*shape=*/llvm::None, elementType);
127+
}
128+
129+
/// If an element type is an integer or a float, return its width. Otherwise,
130+
/// abort.
131+
unsigned getElementTypeBitWidth() const {
132+
return $_type.getElementType().getIntOrFloatBitWidth();
133+
}
134+
135+
/// If this is a ranked type, return the rank. Otherwise, abort.
136+
int64_t getRank() const {
137+
assert($_type.hasRank() && "cannot query rank of unranked shaped type");
138+
return $_type.getShape().size();
139+
}
140+
141+
/// If it has static shape, return the number of elements. Otherwise, abort.
142+
int64_t getNumElements() const {
143+
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
144+
return ::mlir::ShapedType::getNumElements($_type.getShape());
145+
}
146+
147+
/// Returns true if this dimension has a dynamic size (for ranked types);
148+
/// aborts for unranked types.
149+
bool isDynamicDim(unsigned idx) const {
150+
assert(idx < getRank() && "invalid index for shaped type");
151+
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
152+
}
153+
154+
/// Returns if this type has a static shape, i.e. if the type is ranked and
155+
/// all dimensions have known size (>= 0).
156+
bool hasStaticShape() const {
157+
return $_type.hasRank() &&
158+
llvm::none_of($_type.getShape(), ::mlir::ShapedType::isDynamic);
159+
}
160+
161+
/// Returns if this type has a static shape and the shape is equal to
162+
/// `shape` return true.
163+
bool hasStaticShape(::llvm::ArrayRef<int64_t> shape) const {
164+
return hasStaticShape() && $_type.getShape() == shape;
165+
}
166+
167+
/// If this is a ranked type, return the number of dimensions with dynamic
168+
/// size. Otherwise, abort.
169+
int64_t getNumDynamicDims() const {
170+
return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic);
171+
}
172+
173+
/// If this is ranked type, return the size of the specified dimension.
174+
/// Otherwise, abort.
175+
int64_t getDimSize(unsigned idx) const {
176+
assert(idx < getRank() && "invalid index for shaped type");
177+
return $_type.getShape()[idx];
178+
}
179+
180+
/// Returns the position of the dynamic dimension relative to just the dynamic
181+
/// dimensions, given its `index` within the shape.
182+
unsigned getDynamicDimIndex(unsigned index) const {
183+
assert(index < getRank() && "invalid index");
184+
assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index");
185+
return llvm::count_if($_type.getShape().take_front(index),
186+
::mlir::ShapedType::isDynamic);
187+
}
188+
}];
189+
}
190+
44191
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 44 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ namespace llvm {
1616
struct fltSemantics;
1717
} // namespace llvm
1818

19+
//===----------------------------------------------------------------------===//
20+
// Tablegen Interface Declarations
21+
//===----------------------------------------------------------------------===//
22+
23+
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
24+
1925
namespace mlir {
2026
class AffineExpr;
2127
class AffineMap;
@@ -56,118 +62,67 @@ class FloatType : public Type {
5662
};
5763

5864
//===----------------------------------------------------------------------===//
59-
// ShapedType
65+
// TensorType
6066
//===----------------------------------------------------------------------===//
6167

62-
/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
63-
/// and MemRef types because they share behavior and semantics around shape,
64-
/// rank, and fixed element type. Any type with these semantics should inherit
65-
/// from ShapedType.
66-
class ShapedType : public Type {
68+
/// Tensor types represent multi-dimensional arrays, and have two variants:
69+
/// RankedTensorType and UnrankedTensorType.
70+
/// Note: This class attaches the ShapedType trait to act as a mixin to
71+
/// provide many useful utility functions. This inheritance has no effect
72+
/// on derived tensor types.
73+
class TensorType : public Type, public ShapedType::Trait<TensorType> {
6774
public:
6875
using Type::Type;
6976

70-
// TODO: merge these two special values in a single one used everywhere.
71-
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
72-
// hard to track.
73-
static constexpr int64_t kDynamicSize = -1;
74-
static constexpr int64_t kDynamicStrideOrOffset =
75-
std::numeric_limits<int64_t>::min();
76-
77-
/// Return clone of this type with new shape and element type.
78-
ShapedType clone(ArrayRef<int64_t> shape, Type elementType);
79-
ShapedType clone(ArrayRef<int64_t> shape);
80-
ShapedType clone(Type elementType);
81-
82-
/// Return the element type.
77+
/// Returns the element type of this tensor type.
8378
Type getElementType() const;
8479

85-
/// If an element type is an integer or a float, return its width. Otherwise,
86-
/// abort.
87-
unsigned getElementTypeBitWidth() const;
88-
89-
/// If it has static shape, return the number of elements. Otherwise, abort.
90-
int64_t getNumElements() const;
91-
92-
/// If this is a ranked type, return the rank. Otherwise, abort.
93-
int64_t getRank() const;
94-
95-
/// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
96-
/// have a rank, while unranked tensors do not.
80+
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
9781
bool hasRank() const;
9882

99-
/// If this is a ranked type, return the shape. Otherwise, abort.
83+
/// Returns the shape of this tensor type.
10084
ArrayRef<int64_t> getShape() const;
10185

102-
/// If this is unranked type or any dimension has unknown size (<0), it
103-
/// doesn't have static shape. If all dimensions have known size (>= 0), it
104-
/// has static shape.
105-
bool hasStaticShape() const;
106-
107-
/// If this has a static shape and the shape is equal to `shape` return true.
108-
bool hasStaticShape(ArrayRef<int64_t> shape) const;
109-
110-
/// If this is a ranked type, return the number of dimensions with dynamic
111-
/// size. Otherwise, abort.
112-
int64_t getNumDynamicDims() const;
113-
114-
/// If this is ranked type, return the size of the specified dimension.
115-
/// Otherwise, abort.
116-
int64_t getDimSize(unsigned idx) const;
117-
118-
/// Returns true if this dimension has a dynamic size (for ranked types);
119-
/// aborts for unranked types.
120-
bool isDynamicDim(unsigned idx) const;
121-
122-
/// Returns the position of the dynamic dimension relative to just the dynamic
123-
/// dimensions, given its `index` within the shape.
124-
unsigned getDynamicDimIndex(unsigned index) const;
86+
/// Clone this type with the given shape and element type. If the
87+
/// provided shape is `None`, the current shape of the type is used.
88+
TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
89+
Type elementType) const;
12590

126-
/// Get the total amount of bits occupied by a value of this type. This does
127-
/// not take into account any memory layout or widening constraints, e.g. a
128-
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
129-
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
130-
/// if the size cannot be computed statically, i.e. if the type has a dynamic
131-
/// shape or if its elemental type does not have a known bit width.
132-
int64_t getSizeInBits() const;
91+
/// Return true if the specified element type is ok in a tensor.
92+
static bool isValidElementType(Type type);
13393

13494
/// Methods for support type inquiry through isa, cast, and dyn_cast.
13595
static bool classof(Type type);
13696

137-
/// Whether the given dimension size indicates a dynamic dimension.
138-
static constexpr bool isDynamic(int64_t dSize) {
139-
return dSize == kDynamicSize;
140-
}
141-
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
142-
return dStrideOrOffset == kDynamicStrideOrOffset;
143-
}
97+
/// Allow implicit conversion to ShapedType.
98+
operator ShapedType() const { return cast<ShapedType>(); }
14499
};
145100

146101
//===----------------------------------------------------------------------===//
147-
// TensorType
102+
// BaseMemRefType
148103
//===----------------------------------------------------------------------===//
149104

150-
/// Tensor types represent multi-dimensional arrays, and have two variants:
151-
/// RankedTensorType and UnrankedTensorType.
152-
class TensorType : public ShapedType {
105+
/// This class provides a shared interface for ranked and unranked memref types.
106+
/// Note: This class attaches the ShapedType trait to act as a mixin to
107+
/// provide many useful utility functions. This inheritance has no effect
108+
/// on derived memref types.
109+
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
153110
public:
154-
using ShapedType::ShapedType;
111+
using Type::Type;
155112

156-
/// Return true if the specified element type is ok in a tensor.
157-
static bool isValidElementType(Type type);
113+
/// Returns the element type of this memref type.
114+
Type getElementType() const;
158115

159-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
160-
static bool classof(Type type);
161-
};
116+
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
117+
bool hasRank() const;
162118

163-
//===----------------------------------------------------------------------===//
164-
// BaseMemRefType
165-
//===----------------------------------------------------------------------===//
119+
/// Returns the shape of this memref type.
120+
ArrayRef<int64_t> getShape() const;
166121

167-
/// Base MemRef for Ranked and Unranked variants
168-
class BaseMemRefType : public ShapedType {
169-
public:
170-
using ShapedType::ShapedType;
122+
/// Clone this type with the given shape and element type. If the
123+
/// provided shape is `None`, the current shape of the type is used.
124+
BaseMemRefType cloneWith(Optional<ArrayRef<int64_t>> shape,
125+
Type elementType) const;
171126

172127
/// Return true if the specified element type is ok in a memref.
173128
static bool isValidElementType(Type type);
@@ -181,6 +136,9 @@ class BaseMemRefType : public ShapedType {
181136
/// [deprecated] Returns the memory space in old raw integer representation.
182137
/// New `Attribute getMemorySpace()` method should be used instead.
183138
unsigned getMemorySpaceAsInt() const;
139+
140+
/// Allow implicit conversion to ShapedType.
141+
operator ShapedType() const { return cast<ShapedType>(); }
184142
};
185143

186144
} // namespace mlir
@@ -192,12 +150,6 @@ class BaseMemRefType : public ShapedType {
192150
#define GET_TYPEDEF_CLASSES
193151
#include "mlir/IR/BuiltinTypes.h.inc"
194152

195-
//===----------------------------------------------------------------------===//
196-
// Tablegen Interface Declarations
197-
//===----------------------------------------------------------------------===//
198-
199-
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
200-
201153
namespace mlir {
202154

203155
//===----------------------------------------------------------------------===//
@@ -439,11 +391,6 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
439391
return Float128Type::get(ctx);
440392
}
441393

442-
inline bool ShapedType::classof(Type type) {
443-
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
444-
UnrankedMemRefType, MemRefType>();
445-
}
446-
447394
inline bool TensorType::classof(Type type) {
448395
return type.isa<RankedTensorType, UnrankedTensorType>();
449396
}

0 commit comments

Comments
 (0)