Skip to content

Commit 00eaff3

Browse files
[mlir][bufferization] Add tensor-like and buffer-like interfaces (#134220)
Current one-shot bufferization infrastructure operates on top of TensorType and BaseMemRefType. These are non-extensible base classes of the respective builtins: tensor and memref. Thus, the infrastructure is bound to work only with builtin tensor/memref types. At the same time, there are customization points that allow one to provide custom logic to control the bufferization behavior. This patch introduces new type interfaces: tensor-like and buffer-like that aim to supersede TensorType/BaseMemRefType within the bufferization dialect and allow custom tensors / memrefs to be used. Additionally, these new type interfaces are attached to the respective builtin types so that the switch is seamless. Note that this patch does very minimal initial work, it does NOT refactor bufferization infrastructure. See https://discourse.llvm.org/t/rfc-changing-base-types-for-tensors-and-memrefs-from-c-base-classes-to-type-interfaces/85509
1 parent 96e3876 commit 00eaff3

File tree

13 files changed

+290
-5
lines changed

13 files changed

+290
-5
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
11+
12+
//===----------------------------------------------------------------------===//
13+
// Bufferization Type Interfaces
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
17+
18+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the definition file for type interfaces used in Bufferization.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef BUFFERIZATION_TYPE_INTERFACES
14+
#define BUFFERIZATION_TYPE_INTERFACES
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def Bufferization_TensorLikeTypeInterface
19+
: TypeInterface<"TensorLikeType"> {
20+
let cppNamespace = "::mlir::bufferization";
21+
let description = [{
22+
Indicates that this type is a tensor type (similarly to a MLIR builtin
23+
tensor) for bufferization purposes.
24+
25+
The interface currently has no methods as it is used by types to opt into
26+
being supported by the bufferization procedures.
27+
}];
28+
}
29+
30+
def Bufferization_BufferLikeTypeInterface
31+
: TypeInterface<"BufferLikeType"> {
32+
let cppNamespace = "::mlir::bufferization";
33+
let description = [{
34+
Indicates that this type is a buffer type (similarly to a MLIR builtin
35+
memref) for bufferization purposes.
36+
37+
The interface currently has no methods as it is used by types to opt into
38+
being supported by the bufferization procedures.
39+
}];
40+
}
41+
42+
#endif // BUFFERIZATION_TYPE_INTERFACES

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
1010
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
1111
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
1212
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
13+
14+
set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
15+
mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
16+
mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
17+
add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
18+
add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
471471
Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place",
472472
"Number of out-of-place tensor OpOperands">,
473473
];
474+
475+
let dependentDialects = [
476+
"bufferization::BufferizationDialect", "memref::MemRefDialect"
477+
];
474478
}
475479

476480
def PromoteBuffersToStackPass

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/IR/BuiltinTypes.h"
1416
#include "mlir/Interfaces/FunctionInterfaces.h"
1517
#include "mlir/Transforms/InliningUtils.h"
1618

@@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
5153
return true;
5254
}
5355
};
56+
57+
template <typename Tensor>
58+
struct BuiltinTensorExternalModel
59+
: TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
60+
Tensor> {};
61+
62+
template <typename MemRef>
63+
struct BuiltinMemRefExternalModel
64+
: BufferLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
65+
MemRef> {};
5466
} // namespace
5567

5668
//===----------------------------------------------------------------------===//
@@ -63,6 +75,20 @@ void mlir::bufferization::BufferizationDialect::initialize() {
6375
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
6476
>();
6577
addInterfaces<BufferizationInlinerInterface>();
78+
79+
// Note: Unlike with other external models, declaring bufferization's
80+
// "promised interfaces" in builtins for TensorLike and BufferLike type
81+
// interfaces is not possible (due to builtins being independent of
82+
// bufferization). Thus, the compromise is to attach these interfaces directly
83+
// during dialect initialization.
84+
RankedTensorType::attachInterface<
85+
BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
86+
UnrankedTensorType::attachInterface<
87+
BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
88+
MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
89+
*getContext());
90+
UnrankedMemRefType::attachInterface<
91+
BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
6692
}
6793

6894
LogicalResult BufferizationDialect::verifyRegionArgAttribute(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,6 @@ struct OneShotBufferizePass
5757
OneShotBufferizePass> {
5858
using Base::Base;
5959

60-
void getDependentDialects(DialectRegistry &registry) const override {
61-
registry
62-
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
63-
}
64-
6560
void runOnOperation() override {
6661
OneShotBufferizationOptions opt;
6762
if (!options) {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -test-tensorlike-bufferlike -split-input-file | FileCheck %s
2+
3+
// CHECK: func.func @builtin_unranked
4+
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}}
5+
func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>)
6+
{
7+
%0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32>
8+
return %0 : memref<*xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK: func.func @builtin_ranked
14+
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}}
15+
func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>)
16+
{
17+
%0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32>
18+
return %0 : memref<42xf32>
19+
}
20+
21+
// -----
22+
23+
// CHECK: func.func @custom_tensor
24+
// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}}
25+
func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> ()
26+
{
27+
return
28+
}
29+
30+
// -----
31+
32+
// CHECK: func.func @custom_memref
33+
// CHECK-SAME: {found = {operand_0 = "is_buffer_like"}}
34+
func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> ()
35+
{
36+
return
37+
}

mlir/test/lib/Dialect/Bufferization/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRBufferizationTestPasses
33
TestTensorCopyInsertion.cpp
4+
TestTensorLikeAndBufferLike.cpp
45

56
EXCLUDE_FROM_LIBMLIR
67
)
@@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC
910
MLIRBufferizationTransforms
1011
MLIRIR
1112
MLIRPass
13+
MLIRTestDialect
1214
)
15+
16+
target_include_directories(MLIRBufferizationTestPasses
17+
PRIVATE
18+
${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
19+
${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
20+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "TestDialect.h"
10+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/IR/Attributes.h"
14+
#include "mlir/IR/BuiltinAttributes.h"
15+
#include "mlir/Pass/Pass.h"
16+
17+
#include <string>
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
std::string getImplementationStatus(Type type) {
23+
if (isa<bufferization::TensorLikeType>(type)) {
24+
return "is_tensor_like";
25+
}
26+
if (isa<bufferization::BufferLikeType>(type)) {
27+
return "is_buffer_like";
28+
}
29+
return {};
30+
}
31+
32+
DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) {
33+
llvm::SmallVector<NamedAttribute> attributes;
34+
35+
const auto funcType = funcOp.getFunctionType();
36+
for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) {
37+
const auto status = getImplementationStatus(inputType);
38+
if (status.empty()) {
39+
continue;
40+
}
41+
42+
attributes.push_back(
43+
NamedAttribute(StringAttr::get(funcOp.getContext(),
44+
"operand_" + std::to_string(index)),
45+
StringAttr::get(funcOp.getContext(), status)));
46+
}
47+
48+
for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) {
49+
const auto status = getImplementationStatus(resultType);
50+
if (status.empty()) {
51+
continue;
52+
}
53+
54+
attributes.push_back(NamedAttribute(
55+
StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)),
56+
StringAttr::get(funcOp.getContext(), status)));
57+
}
58+
59+
return mlir::DictionaryAttr::get(funcOp.getContext(), attributes);
60+
}
61+
62+
/// This pass tests whether specified types implement TensorLike and (or)
63+
/// BufferLike type interfaces defined in bufferization.
64+
///
65+
/// The pass analyses operation signature. When the aforementioned interface
66+
/// implementation found, an attribute is added to the operation, signifying the
67+
/// associated operand / result.
68+
struct TestTensorLikeAndBufferLikePass
69+
: public PassWrapper<TestTensorLikeAndBufferLikePass,
70+
OperationPass<ModuleOp>> {
71+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass)
72+
73+
void getDependentDialects(DialectRegistry &registry) const override {
74+
registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
75+
}
76+
StringRef getArgument() const final { return "test-tensorlike-bufferlike"; }
77+
StringRef getDescription() const final {
78+
return "Module pass to test custom types that implement TensorLike / "
79+
"BufferLike interfaces";
80+
}
81+
82+
void runOnOperation() override {
83+
auto op = getOperation();
84+
85+
op.walk([](func::FuncOp funcOp) {
86+
const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp);
87+
if (!dict.empty()) {
88+
funcOp->setAttr("found", dict);
89+
}
90+
});
91+
}
92+
};
93+
} // namespace
94+
95+
namespace mlir::test {
96+
void registerTestTensorLikeAndBufferLikePass() {
97+
PassRegistration<TestTensorLikeAndBufferLikePass>();
98+
}
99+
} // namespace mlir::test

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
9393
MLIRTransformUtils
9494
MLIRTransforms
9595
MLIRValueBoundsOpInterface
96+
MLIRBufferizationDialect
9697
)
9798

9899
add_mlir_translation_library(MLIRTestFromLLVMIRTranslation

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "TestAttrDefs.td"
1919
include "TestInterfaces.td"
2020
include "mlir/IR/BuiltinTypes.td"
2121
include "mlir/Interfaces/DataLayoutInterfaces.td"
22+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
2223

2324
// All of the types will extend this class.
2425
class Test_Type<string name, list<Trait> traits = []>
@@ -403,4 +404,49 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
403404
let mnemonic = "op_asm_type_interface";
404405
}
405406

407+
def TestTensorType : Test_Type<"TestTensor",
408+
[Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> {
409+
let mnemonic = "test_tensor";
410+
let parameters = (ins
411+
ArrayRefParameter<"int64_t">:$shape,
412+
"mlir::Type":$elementType
413+
);
414+
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
415+
416+
let extraClassDeclaration = [{
417+
// ShapedTypeInterface:
418+
bool hasRank() const {
419+
return true;
420+
}
421+
test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
422+
mlir::Type elementType) const {
423+
return test::TestTensorType::get(
424+
getContext(), shape.value_or(getShape()), elementType);
425+
}
426+
}];
427+
}
428+
429+
def TestMemrefType : Test_Type<"TestMemref",
430+
[Bufferization_BufferLikeTypeInterface, ShapedTypeInterface]> {
431+
let mnemonic = "test_memref";
432+
let parameters = (ins
433+
ArrayRefParameter<"int64_t">:$shape,
434+
"mlir::Type":$elementType,
435+
DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
436+
);
437+
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
438+
439+
let extraClassDeclaration = [{
440+
// ShapedTypeInterface:
441+
bool hasRank() const {
442+
return true;
443+
}
444+
test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
445+
mlir::Type elementType) const {
446+
return test::TestMemrefType::get(
447+
getContext(), shape.value_or(getShape()), elementType, getMemSpace());
448+
}
449+
}];
450+
}
451+
406452
#endif // TEST_TYPEDEFS

mlir/test/lib/Dialect/Test/TestTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <tuple>
1919

2020
#include "TestTraits.h"
21+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2122
#include "mlir/IR/Diagnostics.h"
2223
#include "mlir/IR/Dialect.h"
2324
#include "mlir/IR/DialectImplementation.h"

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ void registerTestSPIRVCPURunnerPipeline();
150150
void registerTestSPIRVFuncSignatureConversion();
151151
void registerTestSPIRVVectorUnrolling();
152152
void registerTestTensorCopyInsertionPass();
153+
void registerTestTensorLikeAndBufferLikePass();
153154
void registerTestTensorTransforms();
154155
void registerTestTopologicalSortAnalysisPass();
155156
void registerTestTransformDialectEraseSchedulePass();
@@ -293,6 +294,7 @@ void registerTestPasses() {
293294
mlir::test::registerTestSPIRVFuncSignatureConversion();
294295
mlir::test::registerTestSPIRVVectorUnrolling();
295296
mlir::test::registerTestTensorCopyInsertionPass();
297+
mlir::test::registerTestTensorLikeAndBufferLikePass();
296298
mlir::test::registerTestTensorTransforms();
297299
mlir::test::registerTestTopologicalSortAnalysisPass();
298300
mlir::test::registerTestTransformDialectEraseSchedulePass();

0 commit comments

Comments
 (0)