Skip to content

[mlir][bufferization] Use TensorLike, BufferLike type interfaces #136736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>

#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"

namespace mlir {
class OpBuilder;
Expand Down Expand Up @@ -259,18 +260,18 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: tensor type, memory space, func op, bufferization options
/// TensorLike -> BufferLike type converter.
/// Parameters: tensor like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// TensorLike -> BufferLike type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
using UnknownTypeConverterFn = std::function<BufferLikeType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
std::function<std::optional<Attribute>(TensorLikeType t)>;

BufferizationOptions();

Expand Down Expand Up @@ -360,7 +361,7 @@ struct BufferizationOptions {
// Returning std::nullopt will cause bufferization to fail (useful to indicate
// failure to determine memory space for a tensor type).
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
[](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };

/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
Expand Down Expand Up @@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options);

/// Return the buffer type for a given Value (tensor) after bufferization
Expand All @@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);

Expand Down Expand Up @@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
/*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
Expand Down
17 changes: 10 additions & 7 deletions mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);

FailureOr<BaseMemRefType> getBufferType(
FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);

Expand Down Expand Up @@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
UnitAttr:$restrict, UnitAttr:$writable);
let results = (outs AnyTensor:$result);
let results = (outs Bufferization_TensorLikeTypeInterface:$result);

let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
Expand All @@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [

bool isWritable(Value value, const AnalysisState &state);

FailureOr<BaseMemRefType> getBufferType(
FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
return ::llvm::cast<BufferLikeType>(getMemref().getType());
}
}];

Expand All @@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//

// TODO: rename to "to_buffer"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side-note: I would prefer to do it in a separate patch (either before or after this one) since renaming would be (almost?) NFC

def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
BufferizableOpInterface,
SameOperandsAndResultShape,
Expand All @@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
the returned buffer) will not be written to.
}];

let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
UnitAttr:$read_only);
let results = (outs Bufferization_BufferLikeTypeInterface:$memref);

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/IR/Attributes.h" // mlir::Attribute
#include "mlir/IR/Types.h"

#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];

let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the memory space in which data referred to by this buffer resides.
}],
/*retType=*/"::mlir::Attribute",
/*methodName=*/"getMemorySpace"
>,
];
}

#endif // BUFFERIZATION_TYPE_INTERFACES
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
Expand All @@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// operand types of all forwarded values. If these are all the same type,
// take that type. Otherwise, take only the memory space and fall back to a
// buffer type with a fully dynamic layout map.
BaseMemRefType bufferType;
BufferLikeType bufferType;
auto tensorType = cast<TensorType>(value.getType());
for (OpOperand *opOperand :
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
Expand All @@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
continue;

// Compute the bufferized type of the forwarded operand.
BaseMemRefType callerType;
if (auto memrefType =
dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
BufferLikeType callerType;
if (auto bufferLikeType =
dyn_cast<BufferLikeType>(opOperand->get().getType())) {
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
callerType = bufferLikeType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
invocationStack);
if (failed(maybeCallerType))
Expand All @@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
assert(mlir::isa<BaseMemRefType>(bufferType) &&
mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);

if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
assert(bufferType.hasRank() && callerType.hasRank() &&
assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
"expected ranked memrefs");
assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
rankedTensorType.getShape()}) &&
"expected same shape");
assert(
llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
rankedTensorType.getShape()}) &&
"expected same shape");
} else {
assert(!bufferType.hasRank() && !callerType.hasRank() &&
assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG
Expand All @@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
return op->emitOpError("incoming operands of block argument have "
"inconsistent memory spaces");

bufferType = getMemRefTypeWithFullyDynamicLayout(
tensorType, bufferType.getMemorySpace());
bufferType =
mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
tensorType, bufferType.getMemorySpace()));
}

if (!bufferType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
auto type = dyn_cast<TensorLikeType>(constantOp.getType());

// Only ranked tensors are supported.
if (!type)
Expand Down Expand Up @@ -176,7 +176,7 @@ struct SelectOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
Expand All @@ -195,10 +195,11 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
return getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
memrefType.getMemorySpace());
return mlir::cast<bufferization::BufferLikeType>(
getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
memrefType.getMemorySpace()));
}
};

Expand Down
Loading
Loading