Skip to content

[mlir][bufferization] Add tensor-like and buffer-like interfaces #134220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

#include "mlir/IR/BuiltinTypeInterfaces.h" // for ShapedTypeInterface

//===----------------------------------------------------------------------===//
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"

#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for type interfaces used in Bufferization.
//
//===----------------------------------------------------------------------===//

#ifndef BUFFERIZATION_TYPE_INTERFACES
#define BUFFERIZATION_TYPE_INTERFACES

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"

def Bufferization_TensorLikeTypeInterface
: TypeInterface<"TensorLikeType", [ShapedTypeInterface]> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this attaches ShapedTypeInterface but TensorType (base class of ranked tensor) also attaches ShapedTypeInterface. is there any risk that we can run into trouble due to:

ranked tensor type -> TensorType -> **ShapedTypeInterface**
                  |-> TensorLikeTypeInterface -> **ShapedTypeInterface**

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to inherit the ShapedTypeInterface? If not, let's keep this simpler.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at the code, TensorLike doesn't need it (or I haven't seen it). MemRefLike does need it: there are .hasRank() and .getShape() API usages (albeit the ones I've seen are in asserts)

(also, there are memref.getMemorySpace() usages - this is an "API" of BaseMemRefType but I guess I could introduce it later once switch to these new type interfaces happens.)

overall, my motivation is to have ShapedTypeInterface APIs available to avoid boilerplate of cast<ShapedTypeInterface>(tensorLike).getBlah(). however, since TensorLike doesn't seem to need it, maybe i can drop it in that one at least? (but then it's going to be "tensor like" - not shaped type, "memref like" - shaped type which is kind of dumb given that memref is "created" from tensor).

it feels like - in MLIR - tensor and memref are both shaped types "by design" but as we don't have a generic type interfaces for those, this design has to kind of leak into bufferization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which functions require these? The main places to look at are in BufferizableOpInterface.cpp. There are functions such as bufferization::getMemRefType. These will probably have to become interface methods on TensorTypeInterface. Apart from that, I don't think the bufferization driver itself really needs anything from the shape type interface.

So I would recommend to go without ShapeTypeInterface for now. Are there any places where you'd have to insert explicit casts because of this today?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I only saw the shaped type APIs around here (note NDEBUG):

#ifndef NDEBUG
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
assert(bufferType.hasRank() && callerType.hasRank() &&
"expected ranked memrefs");
assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
rankedTensorType.getShape()}) &&
"expected same shape");
} else {
assert(!bufferType.hasRank() && !callerType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG

Are there any places where you'd have to insert explicit casts because of this today?

I don't think I can see such places (well, maybe in the user code but then those places are likely to cast down to the actual type so kind of not an issue).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to need a function such as MemRefTypeInterface::toTensorType(). It should be possible to reimplement these asserts based on that function. We then use operator== instead of checking shape, rank, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack. then let's proceed without enforcing shaped type interface. worst case, can always be added later once the bulk of the code is migrated and new cases are discovered "lazily".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ShapedTypeInterface propagation.

let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that the type that attaches this interface can be treated as a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could be shorted to: Indicates that this type is a tensor type for bufferization purposes. Same for memref.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied. albeit i left the piece about "similarly to a MLIR builtin X" to have a clearer reference to builtin tensor / memref.

tensor type (similarly to a MLIR builtin tensor) during bufferization.

Implementing this interface means that the type also implements
ShapedTypeInterface.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}

def Bufferization_MemRefLikeTypeInterface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ask for one more change? The term used in the bufferization is "buffer", not memref. Can we change this to BufferLikeTypeInterface? E.g., we have getBufferType in BufferizableOpInterface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have thought of this as well but wasn't sure it's better than memref. If you prefer "buffer", sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

: TypeInterface<"MemRefLikeType", [ShapedTypeInterface]> {
let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that the type that attaches this interface can be treated as a
memref type (similarly to a MLIR builtin memref) during bufferization.

Implementing this interface means that the type also implements
ShapedTypeInterface.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}

#endif // BUFFERIZATION_TYPE_INTERFACES
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
return true;
}
};

template <typename Tensor>
struct BuiltinTensorExternalModel
: TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
Tensor> {};

template <typename MemRef>
struct BuiltinMemRefExternalModel
: MemRefLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
MemRef> {};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -63,6 +75,15 @@ void mlir::bufferization::BufferizationDialect::initialize() {
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
>();
addInterfaces<BufferizationInlinerInterface>();

RankedTensorType::attachInterface<
BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
UnrankedTensorType::attachInterface<
BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
*getContext());
UnrankedMemRefType::attachInterface<
BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
}

LogicalResult BufferizationDialect::verifyRegionArgAttribute(
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRTransformUtils
MLIRTransforms
MLIRValueBoundsOpInterface
MLIRBufferizationDialect
)

add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"

// All of the types will extend this class.
class Test_Type<string name, list<Trait> traits = []>
Expand Down Expand Up @@ -403,4 +404,47 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
let mnemonic = "op_asm_type_interface";
}

def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterface]> {
let mnemonic = "test_tensor";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";

let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestTensorType::get(
getContext(), shape.value_or(getShape()), elementType);
}
}];
}

def TestMemrefType : Test_Type<"TestMemref", [Bufferization_MemRefLikeTypeInterface]> {
let mnemonic = "test_memref";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType,
DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";

let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestMemrefType::get(
getContext(), shape.value_or(getShape()), elementType, getMemSpace());
}
}];
}

#endif // TEST_TYPEDEFS
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <tuple>

#include "TestTraits.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down
77 changes: 77 additions & 0 deletions mlir/unittests/IR/InterfaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i haven't seen bufferization-specific unit tests, so ended up adding all tests here. these seem to test "TestDialect".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As suggested by the path, this tests libIR capabilities. We rarely use this style of unit tests. The way to test this would be to write a test-only pass that, e.g., looks at function signatures and adds an attribute to the function indicating whether the types of the function results implement certain interface. It can then live in mlir/test/lib/Dialect/Bufferization. There is already a precedent.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointer! i'll add it a bit later (seems to be quite an endeavour).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now re-done according to the suggestion. thanks!

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -84,3 +86,78 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}

TEST(InterfaceTest, TestBuiltinTensorIsTensorLikeType) {
MLIRContext context;
// Note: attaches external model to builtins
context.loadDialect<bufferization::BufferizationDialect>();

auto builtinRankedTensor = mlir::RankedTensorType::get(
{1, 2, 3}, mlir::IntegerType::get(&context, 32));
EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinRankedTensor));
EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedTensor));

auto builtinUnrankedTensor =
mlir::UnrankedTensorType::get(mlir::IntegerType::get(&context, 32));
EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedTensor));
EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedTensor));
}

TEST(InterfaceTest, TestCustomTensorIsTensorLikeType) {
MLIRContext context;
context.loadDialect<test::TestDialect>();

auto customTensorType = test::TestTensorType::get(
&context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customTensorType));

auto customCloneType = customTensorType.cloneWith(
ArrayRef<int64_t>{3, 4, 5}, customTensorType.getElementType());
EXPECT_EQ(customTensorType.getElementType(),
customCloneType.getElementType());
EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customCloneType));
EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));

// user-specified conversions
bufferization::TensorLikeType baseCopy = customTensorType;
std::ignore = baseCopy;
}

TEST(InterfaceTest, TestBuiltinMemrefIsMemRefLikeType) {
MLIRContext context;
// Note: attaches external model to builtins
context.loadDialect<bufferization::BufferizationDialect>();

auto builtinRankedMemref =
mlir::MemRefType::get({1, 2, 3}, mlir::IntegerType::get(&context, 32));
EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedMemref));
EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinRankedMemref));

auto builtinUnrankedMemref = mlir::UnrankedMemRefType::get(
mlir::IntegerType::get(&context, 32), nullptr);
EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedMemref));
EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedMemref));
}

TEST(InterfaceTest, TestCustomMemrefIsMemRefLikeType) {
MLIRContext context;
context.loadDialect<test::TestDialect>();

auto customMemrefType = test::TestMemrefType::get(
&context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
mlir::StringAttr::get(&context, "some_memspace"));
EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customMemrefType));

auto customCloneType = customMemrefType.cloneWith(
ArrayRef<int64_t>{3, 4, 5}, customMemrefType.getElementType());
EXPECT_EQ(customMemrefType.getElementType(),
customCloneType.getElementType());
EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customCloneType));
EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
EXPECT_EQ(customMemrefType.getMemSpace(),
mlir::cast<test::TestMemrefType>(customCloneType).getMemSpace());

// user-specified conversions
bufferization::MemRefLikeType baseCopy = customMemrefType;
std::ignore = baseCopy;
}