-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Restore #140171 with to_memref -> to_buffer #140355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…oncat op (llvm#140171) This restores the previously reverted commit with forward fixes
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Jeremy Kun (j2kun) Changes#140171 was reverted because an op's name as changed and I neglected to rebase before merging. Full diff: https://github.com/llvm/llvm-project/pull/140355.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index 8af087cbf0f61..e7d8f52d309c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
>();
addInterfaces<TensorInlinerInterface>();
declarePromisedInterfaces<
- bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
- EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
+ bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
+ DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
ReshapeOp, SplatOp>();
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c0e697292d2a0..6525e58d002a2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1048,6 +1048,134 @@ struct SplatOpInterface
}
};
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+ tensor::ConcatOp> {
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto concatOp = cast<tensor::ConcatOp>(op);
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+ rewriter, loc, concatOp.getResult(), options,
+ /*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ return op->emitError("memory space not implemented yet");
+
+ MemRefLayoutAttrInterface layout;
+ MemRefType memrefType =
+ MemRefType::get(concatOp.getResultType().getShape(),
+ concatOp.getResultType().getElementType(), layout);
+ Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
+ op->getLoc(), memrefType, *tensorAlloc);
+
+ // Extract the dimension for the concat op
+ uint64_t concatDim = concatOp.getDim();
+ bool dynamicConcatDim = false;
+
+ SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(tensorType.getRank(),
+ rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes;
+
+ for (const auto &[dimIdx, dimSize] :
+ llvm::enumerate(tensorType.getShape())) {
+ if (dimSize == ShapedType::kDynamic) {
+ auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
+ sizes.push_back(dimOp.getResult());
+ if (dimIdx == concatDim)
+ dynamicConcatDim = true;
+ } else {
+ sizes.push_back(rewriter.getIndexAttr(dimSize));
+ }
+ }
+
+ int64_t concatDimOffset = 0;
+ std::optional<Value> dynamicOffset;
+ std::optional<Value> dynamicSize;
+ if (dynamicConcatDim) {
+ // One or more operands have dynamic size, so we must accumulate the
+ // offset with arith ops.
+ dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ }
+
+ for (auto operand : concatOp.getInputs()) {
+ // Get the buffer for the operand.
+ FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+ if (failed(srcBuffer))
+ return failure();
+
+ // Each operand may have a different size along the concat dimension,
+ // so the offset on that axis must accumulate through the loop, and the
+ // size must change to the size of the current operand.
+ auto operandTensorType = cast<RankedTensorType>(operand.getType());
+ int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+
+ if (dynamicConcatDim) {
+ offsets[concatDim] = dynamicOffset.value();
+ dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
+ .getResult();
+ sizes[concatDim] = dynamicSize.value();
+ } else {
+ sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+ offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+ }
+
+ // Create a subview of the destination buffer.
+ auto dstMemrefType = cast<MemRefType>(memrefType);
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ operandTensorType.getShape(), dstMemrefType, offsets, sizes,
+ strides);
+ Value subview = rewriter.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
+
+ // Copy the source buffer into the destination subview.
+ if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
+ return failure();
+
+ if (dynamicConcatDim) {
+ dynamicOffset = rewriter.create<arith::AddIOp>(
+ loc, dynamicOffset.value(), dynamicSize.value());
+ } else {
+ concatDimOffset += operandConcatDimSize;
+ }
+ }
+
+ replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 567c4abea488e..e9c3ba7e3b970 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.concat(
+// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
+// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+ %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+ return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_different_shapes(
+// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
+ return %t : tensor<8x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_dynamic(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME: memref<8x?xf32>
+// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
+ return %t : tensor<8x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
+// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME: memref<?x?xf32>
+// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %t : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
|
|
#140171 was reverted because an op's name changed and I neglected to rebase before merging.