Skip to content

[mlir] Convert TensorType and BaseMemRefType to interfaces #133053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

andrey-golubev
Copy link
Contributor

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.

Existing design assumes "TensorType" is only a built-in (un)ranked
tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This
means that the generic logic operating on "tensors" and "memrefs" is
limited to just built-ins, no compatible user types allowed. For
instance, this becomes important in one-shot bufferization when
converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the
footsteps of ShapedType (see 676bfb2).
As with ShapedType, "tensor" and "memref" seem to always aspire to be
interfaces.
@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-tensor

Author: Andrei Golubev (andrey-golubev)

Changes

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.


Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+7-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+64-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (-110)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+40-18)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+15-12)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+7)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+30-91)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+46)
  • (modified) mlir/unittests/IR/InterfaceTest.cpp (+34)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
 }
 
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
-        [ShapedTypeInterface], "::mlir::TensorType"> {
+        [TensorTypeInterface]> {
   let summary = "TensorDesc describing regions of interested data.";
   let description = [{
     TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
     using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
 
+    TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+    bool hasRank() const { return true; }
+
     TensorDescType clone(::mlir::Type elementType) {
-      return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return MemorySpace::Global;
     }
 
-    int getArrayLength() {
+    int getArrayLength() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return 1;
     }
 
-    bool getBoundaryCheck() {
+    bool getBoundaryCheck() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
+  }];
 
+  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return cloneWith(shape, elementType);
+      return $_type.cloneWith(shape, elementType);
     }
 
     /// Return a clone of this type with the given new shape. The returned type
     /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return cloneWith(shape, getElementType());
+      return $_type.cloneWith(shape, $_type.getElementType());
     }
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new element type. The
     /// returned type is ranked if and only if this type is ranked. In that
     /// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified tensor types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a tensor.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::TensorType::Trait>();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified memref types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space in which data referred to by this memref resides.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      [deprecated] Returns the memory space in old raw integer representation.
+      New `Attribute getMemorySpace()` method should be used instead.
+    }],
+    "unsigned", "getMemorySpaceAsInt">,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a memref.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+  }];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this tensor type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this tensor type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<TensorType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this memref type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this memref type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                           Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<BaseMemRefType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a memref.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Returns the memory space in which data referred to by this memref resides.
-  Attribute getMemorySpace() const;
-
-  /// [deprecated] Returns the memory space in old raw integer representation.
-  /// New `Attribute getMemorySpace()` method should be used instead.
-  unsigned getMemorySpaceAsInt() const;
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
-inline bool BaseMemRefType::classof(Type type) {
-  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
 inline bool BaseMemRefType::isValidElementType(Type type) {
   return type.isIntOrIndexOrFloat() ||
          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool TensorType::classof(Type type) {
-  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<MemRefType>::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Return "true" if the last dimension has a static unit stride. Also
     /// return "true" for types with no strides.
     bool isLastDimUnitStride();
+
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<RankedTensorType>::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     /// Return a clone of this type with the given new element type and the same
     /// shape as this type.
     RankedTensorType clone(::mlir::Type elementType) {
-      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     /// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
       return RankedTensorType::get(getShape(), getElementType());
     }
 
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
     /// Return a clone of this type with the given new encoding and the same
     /// shape and element type as this type.
     RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<UnrankedMemRefType>::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
-    /// Return a clone of this type with the given new element type and the same
-    /// shape as this type.
-    MemRefType clone(::mlir::Type elementType) {
-      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
-    }
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<UnrankedTensorType>::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
 
     ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
   // 0D tensor. While such construct is not incorrect on its own, bufferization
   // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
-  return input.getType().clone(shape);
+  return mlir::cast<TensorType>(input.getType().clone(shape));
 }
 
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
   // with just '{}', as it will invoke the incorrect overload.
   if (newShape.empty())
-    return inputType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
 
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return inputType.clone(resultShape);
+  return mlir::cast<TensorType>(inputType.clone(resultShape));
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Andrei Golubev (andrey-golubev)

Changes

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.


Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+7-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+64-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (-110)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+40-18)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+15-12)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+7)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+30-91)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+46)
  • (modified) mlir/unittests/IR/InterfaceTest.cpp (+34)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
 }
 
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
-        [ShapedTypeInterface], "::mlir::TensorType"> {
+        [TensorTypeInterface]> {
   let summary = "TensorDesc describing regions of interested data.";
   let description = [{
     TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
     using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
 
+    TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+    bool hasRank() const { return true; }
+
     TensorDescType clone(::mlir::Type elementType) {
-      return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return MemorySpace::Global;
     }
 
-    int getArrayLength() {
+    int getArrayLength() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return 1;
     }
 
-    bool getBoundaryCheck() {
+    bool getBoundaryCheck() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
+  }];
 
+  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return cloneWith(shape, elementType);
+      return $_type.cloneWith(shape, elementType);
     }
 
     /// Return a clone of this type with the given new shape. The returned type
     /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return cloneWith(shape, getElementType());
+      return $_type.cloneWith(shape, $_type.getElementType());
     }
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new element type. The
     /// returned type is ranked if and only if this type is ranked. In that
     /// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified tensor types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a tensor.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::TensorType::Trait>();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified memref types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space in which data referred to by this memref resides.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      [deprecated] Returns the memory space in old raw integer representation.
+      New `Attribute getMemorySpace()` method should be used instead.
+    }],
+    "unsigned", "getMemorySpaceAsInt">,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a memref.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+  }];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this tensor type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this tensor type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<TensorType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this memref type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this memref type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                           Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<BaseMemRefType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a memref.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Returns the memory space in which data referred to by this memref resides.
-  Attribute getMemorySpace() const;
-
-  /// [deprecated] Returns the memory space in old raw integer representation.
-  /// New `Attribute getMemorySpace()` method should be used instead.
-  unsigned getMemorySpaceAsInt() const;
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
-inline bool BaseMemRefType::classof(Type type) {
-  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
 inline bool BaseMemRefType::isValidElementType(Type type) {
   return type.isIntOrIndexOrFloat() ||
          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool TensorType::classof(Type type) {
-  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<MemRefType>::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Return "true" if the last dimension has a static unit stride. Also
     /// return "true" for types with no strides.
     bool isLastDimUnitStride();
+
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<RankedTensorType>::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     /// Return a clone of this type with the given new element type and the same
     /// shape as this type.
     RankedTensorType clone(::mlir::Type elementType) {
-      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     /// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
       return RankedTensorType::get(getShape(), getElementType());
     }
 
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
     /// Return a clone of this type with the given new encoding and the same
     /// shape and element type as this type.
     RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<UnrankedMemRefType>::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
-    /// Return a clone of this type with the given new element type and the same
-    /// shape as this type.
-    MemRefType clone(::mlir::Type elementType) {
-      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
-    }
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<UnrankedTensorType>::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
 
     ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
   // 0D tensor. While such construct is not incorrect on its own, bufferization
   // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
-  return input.getType().clone(shape);
+  return mlir::cast<TensorType>(input.getType().clone(shape));
 }
 
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
   // with just '{}', as it will invoke the incorrect overload.
   if (newShape.empty())
-    return inputType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
 
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return inputType.clone(resultShape);
+  return mlir::cast<TensorType>(inputType.clone(resultShape));
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir

Author: Andrei Golubev (andrey-golubev)

Changes

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.


Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+7-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+64-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (-110)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+40-18)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+15-12)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+7)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+30-91)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+46)
  • (modified) mlir/unittests/IR/InterfaceTest.cpp (+34)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
 }
 
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
-        [ShapedTypeInterface], "::mlir::TensorType"> {
+        [TensorTypeInterface]> {
   let summary = "TensorDesc describing regions of interested data.";
   let description = [{
     TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
     using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
 
+    TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+    bool hasRank() const { return true; }
+
     TensorDescType clone(::mlir::Type elementType) {
-      return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return MemorySpace::Global;
     }
 
-    int getArrayLength() {
+    int getArrayLength() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return 1;
     }
 
-    bool getBoundaryCheck() {
+    bool getBoundaryCheck() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
+  }];
 
+  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return cloneWith(shape, elementType);
+      return $_type.cloneWith(shape, elementType);
     }
 
     /// Return a clone of this type with the given new shape. The returned type
     /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return cloneWith(shape, getElementType());
+      return $_type.cloneWith(shape, $_type.getElementType());
     }
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new element type. The
     /// returned type is ranked if and only if this type is ranked. In that
     /// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified tensor types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a tensor.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::TensorType::Trait>();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified memref types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space in which data referred to by this memref resides.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      [deprecated] Returns the memory space in old raw integer representation.
+      New `Attribute getMemorySpace()` method should be used instead.
+    }],
+    "unsigned", "getMemorySpaceAsInt">,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a memref.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+  }];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this tensor type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this tensor type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<TensorType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this memref type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this memref type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                           Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<BaseMemRefType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a memref.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Returns the memory space in which data referred to by this memref resides.
-  Attribute getMemorySpace() const;
-
-  /// [deprecated] Returns the memory space in old raw integer representation.
-  /// New `Attribute getMemorySpace()` method should be used instead.
-  unsigned getMemorySpaceAsInt() const;
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
-inline bool BaseMemRefType::classof(Type type) {
-  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
 inline bool BaseMemRefType::isValidElementType(Type type) {
   return type.isIntOrIndexOrFloat() ||
          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool TensorType::classof(Type type) {
-  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<MemRefType>::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Return "true" if the last dimension has a static unit stride. Also
     /// return "true" for types with no strides.
     bool isLastDimUnitStride();
+
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<RankedTensorType>::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     /// Return a clone of this type with the given new element type and the same
     /// shape as this type.
     RankedTensorType clone(::mlir::Type elementType) {
-      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     /// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
       return RankedTensorType::get(getShape(), getElementType());
     }
 
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
     /// Return a clone of this type with the given new encoding and the same
     /// shape and element type as this type.
     RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<UnrankedMemRefType>::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
-    /// Return a clone of this type with the given new element type and the same
-    /// shape as this type.
-    MemRefType clone(::mlir::Type elementType) {
-      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
-    }
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<UnrankedTensorType>::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
 
     ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
   // 0D tensor. While such construct is not incorrect on its own, bufferization
   // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
-  return input.getType().clone(shape);
+  return mlir::cast<TensorType>(input.getType().clone(shape));
 }
 
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
   // with just '{}', as it will invoke the incorrect overload.
   if (newShape.empty())
-    return inputType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
 
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return inputType.clone(resultShape);
+  return mlir::cast<TensorType>(inputType.clone(resultShape));
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]

@andrey-golubev
Copy link
Contributor Author

@River707 @joker-eph I'd also likely need some feedback from you guys!

Copy link
Contributor Author

@andrey-golubev andrey-golubev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether the tests are sufficient (please tell if you think there needs to be more!). Added the most obvious things for now as tensors / memrefs are incidentally tested "everywhere".

[deprecated] Returns the memory space in old raw integer representation.
New `Attribute getMemorySpace()` method should be used instead.
}],
"unsigned", "getMemorySpaceAsInt">,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to reviewers: this seems to still exist, however it has been marked deprecated for a while.

@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
return input.getType().clone(shape);
return mlir::cast<TensorType>(input.getType().clone(shape));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to reviewers: i see some types "shadow" clone method from ShapedTypeInterface by providing own implementation with different return type. It didn't seem possible to do so for TensorType / BaseMemRefType that are now interfaces unfortunately. i guess, this could be a reasonable inconvenience?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's acceptable. We have a similar issue with Operation::clone.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup! It always felt a bit odd that class TensorType and class BaseMemRefType were defined in C++ instead of Tablegen. We didn't have interface inheritance from the beginning, maybe that's the reason why it was designed like that.


let extraSharedClassDeclaration = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to split this into extraClassDeclaration and extraSharedClassDeclaration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember anymore (this patch is actually ~year old at this point :) - i never had time to brush it up).
I think this has to be like this because of reliance on $_type (see https://mlir.llvm.org/docs/Interfaces/) to implement cloneWith (so ShapedType::clone() would behind the scenes call the derived type's methods).

static bool isValidElementType(::mlir::Type type);
}];

let extraClassOf = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with extraClassOf. What does it do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, this is to generate {TensorType, BaseMemRefType}::classof(). i believe this enables mlir::isa<TensorType>(myTensor) machinery.

@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
return input.getType().clone(shape);
return mlir::cast<TensorType>(input.getType().clone(shape));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: Is the cast necessary because clone is now calling the ShapedType implementation, which returns a ShapedType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct

@@ -77,9 +77,9 @@ struct CastOpInterface
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
return llvm::cast<BaseMemRefType>(MemRefType::get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the cast necessary here? (And in other places.) I expected that MemRefType inherits from BaseMemRefType (talking about the auto-generated C++ classes). Is that not the case?

(Same how you can assign a MemRefType to a ShapedType variable, I think...)

Copy link
Contributor Author

@andrey-golubev andrey-golubev Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this stems from BaseMemRefType / TensorType now being interfaces.

(tl;dr: C++ doesn't compile without it).
my hand-wavy explanation: the MemRefType inherits BaseMemRefType::Trait and not BaseMemRefType directly, so from pure C++ standpoint, we cannot implicitly upcast here. this still used to work some time back around LLVM 16-17 i believe but since then there was some other clean up that prevented casts from "X" to "X interface" afair.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit odd because class TensorType used to be defined as class TensorType : public Type, public ShapedType::Trait<TensorType> and it was possible to assign a TensorType to a ShapedType variable. It looks like it is the same setup here.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! i think this was due to explicit conversion operator (operator ShapedType() const { return llvm::cast<ShapedType>(*this); } + there's similar one for base memref as well).

i guess depending on what our long-term goal is, i may add similar thing for derived types.

actually, i just noticed i miss at least {TensorType, BaseMemRefType} -> ShapedType operators (but maybe this is no longer as important - again due to "indirect" inheritance).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this to the extra(Shared?)ClassDeclaration section of MemRefType, UnrankedMemRefType, RankedTensorType, UnrankedTensorType? (Converting to the new interfaces.) Then it's really almost NFC for existing code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try! Somehow i didn't think of extra(Shared?)ClassDeclaration but it might actually be rather nice if it works.
do you think we need operator ShapedType() const { return llvm::cast<ShapedType>(*this); } as well? (i mean, maybe just for the sake of keeping the source compatibility)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinion about ShapedType. Sure, why not...

Copy link
Contributor Author

@andrey-golubev andrey-golubev Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, turns out there's no free lunch here...
so even with the operators in place we get issues when doing: MemRefType -> FailureOr<BaseMemRefType>. I vaguely remember some corner of C++ standard saying something along the lines of "one cannot have two consecutive user-specified conversions in an expression".
which i guess is what happens here: there's a conversion from MemRefType to BaseMemRefType (user-specified) and then a call to ctor FailureOr (also user-specified)?

I guess for the case in question to work, we also need an extra FailureOr's ctor (SFINAE-guarded likely) that does roughly: template<typename U> FailureOr(const U& u) : FailureOr(T(u)) {} but i'm not sure it's a good idea to extend this PR for that, i'd rather have it in the follow-up (it may turn out there's some other problem once this is done...).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so: for now, i just switched from mlir::cast<X>(...) to X(...) + extended the tests to verify that X x = y; works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Then there's nothing we can do about FailureOr<...> etc.

@andrey-golubev
Copy link
Contributor Author

Nice cleanup! It always felt a bit odd that class TensorType and class BaseMemRefType were defined in C++ instead of Tablegen. We didn't have interface inheritance from the beginning, maybe that's the reason why it was designed like that.

Thank you! I believe you're right: interfaces appeared later and at that point memrefs and tensors were already fully functional. Same was done to ShapedTypeInterface - at some earlier point it was also a base class. Looking at,

inline bool BaseMemRefType::classof(Type type) {
  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool TensorType::classof(Type type) {
  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}

kind of suggests both types ought to be interfaces.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, but let's wait a 1-2 more days to give others a chance to take a look.

@joker-eph
Copy link
Collaborator

That deserves an RFC on Discourse IMO, this isn't a trivial change to the builtin infra.

@joker-eph
Copy link
Collaborator

It always felt a bit odd that class TensorType and class BaseMemRefType were defined in C++ instead of Tablegen. We didn't have interface inheritance from the beginning, maybe that's the reason why it was designed like that.

We didn't have ODS for types at the beginning, we didn't even have Type Interfaces either... Everything was C++!

@andrey-golubev
Copy link
Contributor Author

That deserves an RFC on Discourse IMO, this isn't a trivial change to the builtin infra.

Fair enough. Created: https://discourse.llvm.org/t/rfc-changing-base-types-for-tensors-and-memrefs-from-c-base-classes-to-type-interfaces/85509

@andrey-golubev
Copy link
Contributor Author

Superseded by an alternative solution done by #134220

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants