-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] Convert TensorType and BaseMemRefType to interfaces #133053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Convert TensorType and BaseMemRefType to interfaces #133053
Conversation
Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure. Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-tensor Author: Andrei Golubev (andrey-golubev) ChangesExisting design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure. Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces. Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
- [ShapedTypeInterface], "::mlir::TensorType"> {
+ [TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+ TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+ bool hasRank() const { return true; }
+
TensorDescType clone(::mlir::Type elementType) {
- return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
- int getArrayLength() {
+ int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}
- bool getBoundaryCheck() {
+ bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ }];
+ let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return cloneWith(shape, elementType);
+ return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return cloneWith(shape, getElementType());
+ return $_type.cloneWith(shape, $_type.getElementType());
}
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified tensor types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::TensorType::Trait>();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified memref types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space in which data referred to by this memref resides.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ [deprecated] Returns the memory space in old raw integer representation.
+ New `Attribute getMemorySpace()` method should be used instead.
+ }],
+ "unsigned", "getMemorySpaceAsInt">,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this tensor type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this tensor type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<TensorType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this memref type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this memref type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<BaseMemRefType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a memref.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Returns the memory space in which data referred to by this memref resides.
- Attribute getMemorySpace() const;
-
- /// [deprecated] Returns the memory space in old raw integer representation.
- /// New `Attribute getMemorySpace()` method should be used instead.
- unsigned getMemorySpaceAsInt() const;
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
-inline bool BaseMemRefType::classof(Type type) {
- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool TensorType::classof(Type type) {
- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
/// Return a clone of this type with the given new element type and the same
/// shape as this type.
RankedTensorType clone(::mlir::Type elementType) {
- return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
/// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
return RankedTensorType::get(getShape(), getElementType());
}
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}]>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
- /// Return a clone of this type with the given new element type and the same
- /// shape as this type.
- MemRefType clone(::mlir::Type elementType) {
- return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
- }
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
- return input.getType().clone(shape);
+ return mlir::cast<TensorType>(input.getType().clone(shape));
}
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
- return inputType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
assert(!inputIsStatic || resultIsStatic);
// Create result type
- return inputType.clone(resultShape);
+ return mlir::cast<TensorType>(inputType.clone(resultShape));
}
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Andrei Golubev (andrey-golubev) ChangesExisting design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure. Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces. Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
- [ShapedTypeInterface], "::mlir::TensorType"> {
+ [TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+ TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+ bool hasRank() const { return true; }
+
TensorDescType clone(::mlir::Type elementType) {
- return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
- int getArrayLength() {
+ int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}
- bool getBoundaryCheck() {
+ bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ }];
+ let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return cloneWith(shape, elementType);
+ return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return cloneWith(shape, getElementType());
+ return $_type.cloneWith(shape, $_type.getElementType());
}
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified tensor types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::TensorType::Trait>();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified memref types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space in which data referred to by this memref resides.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ [deprecated] Returns the memory space in old raw integer representation.
+ New `Attribute getMemorySpace()` method should be used instead.
+ }],
+ "unsigned", "getMemorySpaceAsInt">,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this tensor type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this tensor type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<TensorType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this memref type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this memref type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<BaseMemRefType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a memref.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Returns the memory space in which data referred to by this memref resides.
- Attribute getMemorySpace() const;
-
- /// [deprecated] Returns the memory space in old raw integer representation.
- /// New `Attribute getMemorySpace()` method should be used instead.
- unsigned getMemorySpaceAsInt() const;
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
-inline bool BaseMemRefType::classof(Type type) {
- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool TensorType::classof(Type type) {
- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
/// Return a clone of this type with the given new element type and the same
/// shape as this type.
RankedTensorType clone(::mlir::Type elementType) {
- return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
/// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
return RankedTensorType::get(getShape(), getElementType());
}
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}]>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
- /// Return a clone of this type with the given new element type and the same
- /// shape as this type.
- MemRefType clone(::mlir::Type elementType) {
- return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
- }
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
- return input.getType().clone(shape);
+ return mlir::cast<TensorType>(input.getType().clone(shape));
}
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
- return inputType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
assert(!inputIsStatic || resultIsStatic);
// Create result type
- return inputType.clone(resultShape);
+ return mlir::cast<TensorType>(inputType.clone(resultShape));
}
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Andrei Golubev (andrey-golubev) ChangesExisting design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure. Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces. Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
- [ShapedTypeInterface], "::mlir::TensorType"> {
+ [TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+ TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+ bool hasRank() const { return true; }
+
TensorDescType clone(::mlir::Type elementType) {
- return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
- int getArrayLength() {
+ int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}
- bool getBoundaryCheck() {
+ bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ }];
+ let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return cloneWith(shape, elementType);
+ return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return cloneWith(shape, getElementType());
+ return $_type.cloneWith(shape, $_type.getElementType());
}
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified tensor types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::TensorType::Trait>();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified memref types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space in which data referred to by this memref resides.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ [deprecated] Returns the memory space in old raw integer representation.
+ New `Attribute getMemorySpace()` method should be used instead.
+ }],
+ "unsigned", "getMemorySpaceAsInt">,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this tensor type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this tensor type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<TensorType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this memref type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this memref type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<BaseMemRefType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a memref.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Returns the memory space in which data referred to by this memref resides.
- Attribute getMemorySpace() const;
-
- /// [deprecated] Returns the memory space in old raw integer representation.
- /// New `Attribute getMemorySpace()` method should be used instead.
- unsigned getMemorySpaceAsInt() const;
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
-inline bool BaseMemRefType::classof(Type type) {
- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool TensorType::classof(Type type) {
- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
/// Return a clone of this type with the given new element type and the same
/// shape as this type.
RankedTensorType clone(::mlir::Type elementType) {
- return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
/// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
return RankedTensorType::get(getShape(), getElementType());
}
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}]>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
- /// Return a clone of this type with the given new element type and the same
- /// shape as this type.
- MemRefType clone(::mlir::Type elementType) {
- return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
- }
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
- return input.getType().clone(shape);
+ return mlir::cast<TensorType>(input.getType().clone(shape));
}
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
- return inputType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
assert(!inputIsStatic || resultIsStatic);
// Create result type
- return inputType.clone(resultShape);
+ return mlir::cast<TensorType>(inputType.clone(resultShape));
}
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]
|
@River707 @joker-eph I'd also likely need some feedback from you guys! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether the tests are sufficient (please tell if you think there needs to be more!). Added the most obvious things for now as tensors / memrefs are incidentally tested "everywhere".
[deprecated] Returns the memory space in old raw integer representation. | ||
New `Attribute getMemorySpace()` method should be used instead. | ||
}], | ||
"unsigned", "getMemorySpaceAsInt">, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to reviewers: this seems to still exist, however it has been marked deprecated for a while.
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input, | |||
// 0D tensor. While such construct is not incorrect on its own, bufferization | |||
// cannot properly handle it at the moment, so we avoid it. | |||
SmallVector<int64_t> shape(input.getType().getRank(), 1); | |||
return input.getType().clone(shape); | |||
return mlir::cast<TensorType>(input.getType().clone(shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to reviewers: i see some types "shadow" clone
method from ShapedTypeInterface by providing own implementation with different return type. It didn't seem possible to do so for TensorType / BaseMemRefType that are now interfaces unfortunately. i guess, this could be a reasonable inconvenience?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's acceptable. We have a similar issue with Operation::clone
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup! It always felt a bit odd that class TensorType
and class BaseMemRefType
were defined in C++ instead of Tablegen. We didn't have interface inheritance from the beginning, maybe that's the reason why it was designed like that.
|
||
let extraSharedClassDeclaration = [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it necessary to split this into extraClassDeclaration
and extraSharedClassDeclaration
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember anymore (this patch is actually ~year old at this point :) - i never had time to brush it up).
I think this has to be like this because of reliance on $_type
(see https://mlir.llvm.org/docs/Interfaces/) to implement cloneWith
(so ShapedType::clone() would behind the scenes call the derived type's methods).
static bool isValidElementType(::mlir::Type type); | ||
}]; | ||
|
||
let extraClassOf = [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with extraClassOf
. What does it do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik, this is to generate {TensorType, BaseMemRefType}::classof()
. i believe this enables mlir::isa<TensorType>(myTensor)
machinery.
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input, | |||
// 0D tensor. While such construct is not incorrect on its own, bufferization | |||
// cannot properly handle it at the moment, so we avoid it. | |||
SmallVector<int64_t> shape(input.getType().getRank(), 1); | |||
return input.getType().clone(shape); | |||
return mlir::cast<TensorType>(input.getType().clone(shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify: Is the cast necessary because clone
is now calling the ShapedType
implementation, which returns a ShapedType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct
@@ -77,9 +77,9 @@ struct CastOpInterface | |||
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not | |||
// change. | |||
auto rankedResultType = cast<RankedTensorType>(castOp.getType()); | |||
return MemRefType::get( | |||
return llvm::cast<BaseMemRefType>(MemRefType::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the cast necessary here? (And in other places.) I expected that MemRefType
inherits from BaseMemRefType
(talking about the auto-generated C++ classes). Is that not the case?
(Same how you can assign a MemRefType
to a ShapedType
variable, I think...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this stems from BaseMemRefType / TensorType now being interfaces.
(tl;dr: C++ doesn't compile without it).
my hand-wavy explanation: the MemRefType inherits BaseMemRefType::Trait and not BaseMemRefType directly, so from pure C++ standpoint, we cannot implicitly upcast here. this still used to work some time back around LLVM 16-17 i believe but since then there was some other clean up that prevented casts from "X" to "X interface" afair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a bit odd because class TensorType
used to be defined as class TensorType : public Type, public ShapedType::Trait<TensorType>
and it was possible to assign a TensorType
to a ShapedType
variable. It looks like it is the same setup here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! i think this was due to explicit conversion operator (operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+ there's similar one for base memref as well).
i guess depending on what our long-term goal is, i may add similar thing for derived types.
actually, i just noticed i miss at least {TensorType, BaseMemRefType} -> ShapedType operators (but maybe this is no longer as important - again due to "indirect" inheritance).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this to the extra(Shared?)ClassDeclaration
section of MemRefType
, UnrankedMemRefType
, RankedTensorType
, UnrankedTensorType
? (Converting to the new interfaces.) Then it's really almost NFC for existing code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try! Somehow i didn't think of extra(Shared?)ClassDeclaration
but it might actually be rather nice if it works.
do you think we need operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
as well? (i mean, maybe just for the sake of keeping the source compatibility)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong opinion about ShapedType
. Sure, why not...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, turns out there's no free lunch here...
so even with the operators in place we get issues when doing: MemRefType
-> FailureOr<BaseMemRefType>
. I vaguely remember some corner of C++ standard saying something along the lines of "one cannot have two consecutive user-specified conversions in an expression".
which i guess is what happens here: there's a conversion from MemRefType to BaseMemRefType (user-specified) and then a call to ctor FailureOr (also user-specified)?
I guess for the case in question to work, we also need an extra FailureOr's ctor (SFINAE-guarded likely) that does roughly: template<typename U> FailureOr(const U& u) : FailureOr(T(u)) {}
but i'm not sure it's a good idea to extend this PR for that, i'd rather have it in the follow-up (it may turn out there's some other problem once this is done...).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so: for now, i just switched from mlir::cast<X>(...)
to X(...)
+ extended the tests to verify that X x = y;
works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Then there's nothing we can do about FailureOr<...>
etc.
Thank you! I believe you're right: interfaces appeared later and at that point memrefs and tensors were already fully functional. Same was done to ShapedTypeInterface - at some earlier point it was also a base class. Looking at, inline bool BaseMemRefType::classof(Type type) {
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool TensorType::classof(Type type) {
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
} kind of suggests both types ought to be interfaces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, but let's wait a 1-2 more days to give others a chance to take a look.
That deserves an RFC on Discourse IMO, this isn't a trivial change to the builtin infra. |
We didn't have ODS for types at the beginning, we didn't even have Type Interfaces either... Everything was C++! |
Fair enough. Created: https://discourse.llvm.org/t/rfc-changing-base-types-for-tensors-and-memrefs-from-c-base-classes-to-type-interfaces/85509 |
Superseded by an alternative solution done by #134220 |
Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.
Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.