Skip to content

[mlir] Convert TensorType and BaseMemRefType to interfaces #133053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}

def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
[ShapedTypeInterface], "::mlir::TensorType"> {
[TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
Expand Down Expand Up @@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
Expand All @@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;

TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
bool hasRank() const { return true; }

TensorDescType clone(::mlir::Type elementType) {
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
return cloneWith(getShape(), elementType);
}

BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
Expand Down Expand Up @@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}

int getArrayLength() {
int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
Expand All @@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}

bool getBoundaryCheck() {
bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
Expand Down
68 changes: 64 additions & 4 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {

/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
}];

let extraSharedClassDeclaration = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to split this into extraClassDeclaration and extraSharedClassDeclaration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember anymore (this patch is actually ~year old at this point :) - i never had time to brush it up).
I think this has to be like this because of reliance on $_type (see https://mlir.llvm.org/docs/Interfaces/) to implement cloneWith (so ShapedType::clone() would behind the scenes call the derived type's methods).

/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
return cloneWith(shape, elementType);
return $_type.cloneWith(shape, elementType);
}

/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
return cloneWith(shape, getElementType());
return $_type.cloneWith(shape, $_type.getElementType());
}
}];

let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
Expand Down Expand Up @@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}

//===----------------------------------------------------------------------===//
// TensorTypeInterface
//===----------------------------------------------------------------------===//

def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This interface provides a shared interface type for ranked, unranked and any
user-specified tensor types.

This interface attaches the ShapedTypeInterface to act as a mixin to
provide many useful utility functions.
}];

let extraClassDeclaration = [{
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(::mlir::Type type);
}];

let extraClassOf = [{
return $_type.hasTrait<::mlir::TensorType::Trait>();
}];
}

//===----------------------------------------------------------------------===//
// BaseMemRefTypeInterface
//===----------------------------------------------------------------------===//

def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This interface provides a shared interface type for ranked, unranked and any
user-specified memref types.

This interface attaches the ShapedTypeInterface to act as a mixin to
provide many useful utility functions.
}];

let methods = [
InterfaceMethod<[{
Returns the memory space in which data referred to by this memref resides.
}],
"::mlir::Attribute", "getMemorySpace">,
InterfaceMethod<[{
[deprecated] Returns the memory space in old raw integer representation.
New `Attribute getMemorySpace()` method should be used instead.
}],
"unsigned", "getMemorySpaceAsInt">,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to reviewers: this seems to still exist, however it has been marked deprecated for a while.

];

let extraClassDeclaration = [{
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(::mlir::Type type);
}];

let extraClassOf = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with extraClassOf. What does it do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, this is to generate {TensorType, BaseMemRefType}::classof(). i believe this enables mlir::isa<TensorType>(myTensor) machinery.

return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
}];
}

#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
110 changes: 0 additions & 110 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//

/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived tensor types.
class TensorType : public Type, public ShapedType::Trait<TensorType> {
public:
using Type::Type;

/// Returns the element type of this tensor type.
Type getElementType() const;

/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;

/// Returns the shape of this tensor type.
ArrayRef<int64_t> getShape() const;

/// Clone this type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

// Make sure that base class overloads are visible.
using ShapedType::Trait<TensorType>::clone;

/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;

/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
RankedTensorType clone(ArrayRef<int64_t> shape) const;

/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};

//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//

/// This class provides a shared interface for ranked and unranked memref types.
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;

/// Returns the element type of this memref type.
Type getElementType() const;

/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;

/// Returns the shape of this memref type.
ArrayRef<int64_t> getShape() const;

/// Clone this type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;

/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;

/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
MemRefType clone(ArrayRef<int64_t> shape) const;

/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Returns the memory space in which data referred to by this memref resides.
Attribute getMemorySpace() const;

/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;

/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};

} // namespace mlir

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -390,21 +288,13 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//

inline bool BaseMemRefType::classof(Type type) {
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}

inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
type) ||
llvm::isa<MemRefElementTypeInterface>(type);
}

inline bool TensorType::classof(Type type) {
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}

//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
Expand Down
Loading