@@ -16,6 +16,12 @@ namespace llvm {
16
16
struct fltSemantics ;
17
17
} // namespace llvm
18
18
19
+ // ===----------------------------------------------------------------------===//
20
+ // Tablegen Interface Declarations
21
+ // ===----------------------------------------------------------------------===//
22
+
23
+ #include " mlir/IR/BuiltinTypeInterfaces.h.inc"
24
+
19
25
namespace mlir {
20
26
class AffineExpr ;
21
27
class AffineMap ;
@@ -56,118 +62,67 @@ class FloatType : public Type {
56
62
};
57
63
58
64
// ===----------------------------------------------------------------------===//
59
- // ShapedType
65
+ // TensorType
60
66
// ===----------------------------------------------------------------------===//
61
67
62
- // / This is a common base class between Vector, UnrankedTensor, RankedTensor,
63
- // / and MemRef types because they share behavior and semantics around shape,
64
- // / rank, and fixed element type. Any type with these semantics should inherit
65
- // / from ShapedType.
66
- class ShapedType : public Type {
68
+ // / Tensor types represent multi-dimensional arrays, and have two variants:
69
+ // / RankedTensorType and UnrankedTensorType.
70
+ // / Note: This class attaches the ShapedType trait to act as a mixin to
71
+ // / provide many useful utility functions. This inheritance has no effect
72
+ // / on derived tensor types.
73
+ class TensorType : public Type , public ShapedType ::Trait<TensorType> {
67
74
public:
68
75
using Type::Type;
69
76
70
- // TODO: merge these two special values in a single one used everywhere.
71
- // Unfortunately, uses of `-1` have crept deep into the codebase now and are
72
- // hard to track.
73
- static constexpr int64_t kDynamicSize = -1 ;
74
- static constexpr int64_t kDynamicStrideOrOffset =
75
- std::numeric_limits<int64_t >::min();
76
-
77
- // / Return clone of this type with new shape and element type.
78
- ShapedType clone (ArrayRef<int64_t > shape, Type elementType);
79
- ShapedType clone (ArrayRef<int64_t > shape);
80
- ShapedType clone (Type elementType);
81
-
82
- // / Return the element type.
77
+ // / Returns the element type of this tensor type.
83
78
Type getElementType () const ;
84
79
85
- // / If an element type is an integer or a float, return its width. Otherwise,
86
- // / abort.
87
- unsigned getElementTypeBitWidth () const ;
88
-
89
- // / If it has static shape, return the number of elements. Otherwise, abort.
90
- int64_t getNumElements () const ;
91
-
92
- // / If this is a ranked type, return the rank. Otherwise, abort.
93
- int64_t getRank () const ;
94
-
95
- // / Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
96
- // / have a rank, while unranked tensors do not.
80
+ // / Returns if this type is ranked, i.e. it has a known number of dimensions.
97
81
bool hasRank () const ;
98
82
99
- // / If this is a ranked type, return the shape. Otherwise, abort .
83
+ // / Returns the shape of this tensor type .
100
84
ArrayRef<int64_t > getShape () const ;
101
85
102
- // / If this is unranked type or any dimension has unknown size (<0), it
103
- // / doesn't have static shape. If all dimensions have known size (>= 0), it
104
- // / has static shape.
105
- bool hasStaticShape () const ;
106
-
107
- // / If this has a static shape and the shape is equal to `shape` return true.
108
- bool hasStaticShape (ArrayRef<int64_t > shape) const ;
109
-
110
- // / If this is a ranked type, return the number of dimensions with dynamic
111
- // / size. Otherwise, abort.
112
- int64_t getNumDynamicDims () const ;
113
-
114
- // / If this is ranked type, return the size of the specified dimension.
115
- // / Otherwise, abort.
116
- int64_t getDimSize (unsigned idx) const ;
117
-
118
- // / Returns true if this dimension has a dynamic size (for ranked types);
119
- // / aborts for unranked types.
120
- bool isDynamicDim (unsigned idx) const ;
121
-
122
- // / Returns the position of the dynamic dimension relative to just the dynamic
123
- // / dimensions, given its `index` within the shape.
124
- unsigned getDynamicDimIndex (unsigned index) const ;
86
+ // / Clone this type with the given shape and element type. If the
87
+ // / provided shape is `None`, the current shape of the type is used.
88
+ TensorType cloneWith (Optional<ArrayRef<int64_t >> shape,
89
+ Type elementType) const ;
125
90
126
- // / Get the total amount of bits occupied by a value of this type. This does
127
- // / not take into account any memory layout or widening constraints, e.g. a
128
- // / vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
129
- // / it will likely be stored as in a 4xi64 vector register. Fail an assertion
130
- // / if the size cannot be computed statically, i.e. if the type has a dynamic
131
- // / shape or if its elemental type does not have a known bit width.
132
- int64_t getSizeInBits () const ;
91
+ // / Return true if the specified element type is ok in a tensor.
92
+ static bool isValidElementType (Type type);
133
93
134
94
// / Methods for support type inquiry through isa, cast, and dyn_cast.
135
95
static bool classof (Type type);
136
96
137
- // / Whether the given dimension size indicates a dynamic dimension.
138
- static constexpr bool isDynamic (int64_t dSize) {
139
- return dSize == kDynamicSize ;
140
- }
141
- static constexpr bool isDynamicStrideOrOffset (int64_t dStrideOrOffset) {
142
- return dStrideOrOffset == kDynamicStrideOrOffset ;
143
- }
97
+ // / Allow implicit conversion to ShapedType.
98
+ operator ShapedType () const { return cast<ShapedType>(); }
144
99
};
145
100
146
101
// ===----------------------------------------------------------------------===//
147
- // TensorType
102
+ // BaseMemRefType
148
103
// ===----------------------------------------------------------------------===//
149
104
150
- // / Tensor types represent multi-dimensional arrays, and have two variants:
151
- // / RankedTensorType and UnrankedTensorType.
152
- class TensorType : public ShapedType {
105
+ // / This class provides a shared interface for ranked and unranked memref types.
106
+ // / Note: This class attaches the ShapedType trait to act as a mixin to
107
+ // / provide many useful utility functions. This inheritance has no effect
108
+ // / on derived memref types.
109
+ class BaseMemRefType : public Type , public ShapedType ::Trait<BaseMemRefType> {
153
110
public:
154
- using ShapedType::ShapedType ;
111
+ using Type::Type ;
155
112
156
- // / Return true if the specified element type is ok in a tensor .
157
- static bool isValidElementType ( Type type) ;
113
+ // / Returns the element type of this memref type .
114
+ Type getElementType () const ;
158
115
159
- // / Methods for support type inquiry through isa, cast, and dyn_cast.
160
- static bool classof (Type type);
161
- };
116
+ // / Returns if this type is ranked, i.e. it has a known number of dimensions.
117
+ bool hasRank () const ;
162
118
163
- // ===----------------------------------------------------------------------===//
164
- // BaseMemRefType
165
- // ===----------------------------------------------------------------------===//
119
+ // / Returns the shape of this memref type.
120
+ ArrayRef<int64_t > getShape () const ;
166
121
167
- // / Base MemRef for Ranked and Unranked variants
168
- class BaseMemRefType : public ShapedType {
169
- public:
170
- using ShapedType::ShapedType ;
122
+ // / Clone this type with the given shape and element type. If the
123
+ // / provided shape is `None`, the current shape of the type is used.
124
+ BaseMemRefType cloneWith (Optional<ArrayRef< int64_t >> shape,
125
+ Type elementType) const ;
171
126
172
127
// / Return true if the specified element type is ok in a memref.
173
128
static bool isValidElementType (Type type);
@@ -181,6 +136,9 @@ class BaseMemRefType : public ShapedType {
181
136
// / [deprecated] Returns the memory space in old raw integer representation.
182
137
// / New `Attribute getMemorySpace()` method should be used instead.
183
138
unsigned getMemorySpaceAsInt () const ;
139
+
140
+ // / Allow implicit conversion to ShapedType.
141
+ operator ShapedType () const { return cast<ShapedType>(); }
184
142
};
185
143
186
144
} // namespace mlir
@@ -192,12 +150,6 @@ class BaseMemRefType : public ShapedType {
192
150
#define GET_TYPEDEF_CLASSES
193
151
#include " mlir/IR/BuiltinTypes.h.inc"
194
152
195
- // ===----------------------------------------------------------------------===//
196
- // Tablegen Interface Declarations
197
- // ===----------------------------------------------------------------------===//
198
-
199
- #include " mlir/IR/BuiltinTypeInterfaces.h.inc"
200
-
201
153
namespace mlir {
202
154
203
155
// ===----------------------------------------------------------------------===//
@@ -439,11 +391,6 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
439
391
return Float128Type::get (ctx);
440
392
}
441
393
442
- inline bool ShapedType::classof (Type type) {
443
- return type.isa <RankedTensorType, VectorType, UnrankedTensorType,
444
- UnrankedMemRefType, MemRefType>();
445
- }
446
-
447
394
inline bool TensorType::classof (Type type) {
448
395
return type.isa <RankedTensorType, UnrankedTensorType>();
449
396
}
0 commit comments