Skip to content

Restore #140171 with to_memref -> to_buffer #140355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ void TensorDialect::initialize() {
>();
addInterfaces<TensorInlinerInterface>();
declarePromisedInterfaces<
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
ReshapeOp, SplatOp>();
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
Expand Down
129 changes: 129 additions & 0 deletions mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,134 @@ struct SplatOpInterface
}
};

/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
/// on subviews instead of memref.store.
struct ConcatOpInterface
: public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
tensor::ConcatOp> {

bool bufferizesToAllocation(Operation *op, Value value) const { return true; }

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto concatOp = cast<tensor::ConcatOp>(op);

// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, concatOp.getResult(), options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());

// TODO: Implement memory space for this op.
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
return op->emitError("memory space not implemented yet");

MemRefLayoutAttrInterface layout;
MemRefType memrefType =
MemRefType::get(concatOp.getResultType().getShape(),
concatOp.getResultType().getElementType(), layout);
Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
op->getLoc(), memrefType, *tensorAlloc);

// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
bool dynamicConcatDim = false;

SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes;

for (const auto &[dimIdx, dimSize] :
llvm::enumerate(tensorType.getShape())) {
if (dimSize == ShapedType::kDynamic) {
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
sizes.push_back(dimOp.getResult());
if (dimIdx == concatDim)
dynamicConcatDim = true;
} else {
sizes.push_back(rewriter.getIndexAttr(dimSize));
}
}

int64_t concatDimOffset = 0;
std::optional<Value> dynamicOffset;
std::optional<Value> dynamicSize;
if (dynamicConcatDim) {
// One or more operands have dynamic size, so we must accumulate the
// offset with arith ops.
dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
}

for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
if (failed(srcBuffer))
return failure();

// Each operand may have a different size along the concat dimension,
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);

if (dynamicConcatDim) {
offsets[concatDim] = dynamicOffset.value();
dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
.getResult();
sizes[concatDim] = dynamicSize.value();
} else {
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
}

// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
MemRefType subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
operandTensorType.getShape(), dstMemrefType, offsets, sizes,
strides);
Value subview = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);

// Copy the source buffer into the destination subview.
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();

if (dynamicConcatDim) {
dynamicOffset = rewriter.create<arith::AddIOp>(
loc, dynamicOffset.value(), dynamicSize.value());
} else {
concatDimOffset += operandConcatDimSize;
}
}

replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
return success();
}
};

} // namespace
} // namespace tensor
} // namespace mlir
Expand All @@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
Expand Down
91 changes: 91 additions & 0 deletions mlir/test/Dialect/Tensor/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {

// -----

// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
return %t : tensor<16xf32>
}

// -----

// CHECK-LABEL: func @tensor.concat_different_shapes(
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
%t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
return %t : tensor<8x9xf32>
}

// -----

// CHECK-LABEL: func @tensor.concat_dynamic(
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
return %t : tensor<8x?xf32>
}

// -----

// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
%t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %t : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
Expand Down