Skip to content

Restore #140171 with to_memref -> to_buffer #140355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 18, 2025
Merged

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented May 17, 2025

#140171 was reverted because an op's name changed and I neglected to rebase before merging.

j2kun added 2 commits May 16, 2025 20:28
…oncat op (llvm#140171)

This restores the previously reverted commit with forward fixes
@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Jeremy Kun (j2kun)

Changes

#140171 was reverted because an op's name as changed and I neglected to rebase before merging.


Full diff: https://github.com/llvm/llvm-project/pull/140355.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+129)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+91)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index 8af087cbf0f61..e7d8f52d309c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
       >();
   addInterfaces<TensorInlinerInterface>();
   declarePromisedInterfaces<
-      bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
-      EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
+      bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
+      DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
       GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
       ReshapeOp, SplatOp>();
   declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c0e697292d2a0..6525e58d002a2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1048,6 +1048,134 @@ struct SplatOpInterface
   }
 };
 
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+    : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+                                                    tensor::ConcatOp> {
+
+  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+                                      const AnalysisState &state) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto concatOp = cast<tensor::ConcatOp>(op);
+
+    // Allocate memory.
+    Location loc = op->getLoc();
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+        rewriter, loc, concatOp.getResult(), options,
+        /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
+    auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+    // TODO: Implement memory space for this op.
+    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+      return op->emitError("memory space not implemented yet");
+
+    MemRefLayoutAttrInterface layout;
+    MemRefType memrefType =
+        MemRefType::get(concatOp.getResultType().getShape(),
+                        concatOp.getResultType().getElementType(), layout);
+    Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
+        op->getLoc(), memrefType, *tensorAlloc);
+
+    // Extract the dimension for the concat op
+    uint64_t concatDim = concatOp.getDim();
+    bool dynamicConcatDim = false;
+
+    SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(tensorType.getRank(),
+                                      rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes;
+
+    for (const auto &[dimIdx, dimSize] :
+         llvm::enumerate(tensorType.getShape())) {
+      if (dimSize == ShapedType::kDynamic) {
+        auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
+        sizes.push_back(dimOp.getResult());
+        if (dimIdx == concatDim)
+          dynamicConcatDim = true;
+      } else {
+        sizes.push_back(rewriter.getIndexAttr(dimSize));
+      }
+    }
+
+    int64_t concatDimOffset = 0;
+    std::optional<Value> dynamicOffset;
+    std::optional<Value> dynamicSize;
+    if (dynamicConcatDim) {
+      // One or more operands have dynamic size, so we must accumulate the
+      // offset with arith ops.
+      dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    }
+
+    for (auto operand : concatOp.getInputs()) {
+      // Get the buffer for the operand.
+      FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+      if (failed(srcBuffer))
+        return failure();
+
+      // Each operand may have a different size along the concat dimension,
+      // so the offset on that axis must accumulate through the loop, and the
+      // size must change to the size of the current operand.
+      auto operandTensorType = cast<RankedTensorType>(operand.getType());
+      int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+
+      if (dynamicConcatDim) {
+        offsets[concatDim] = dynamicOffset.value();
+        dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
+                          .getResult();
+        sizes[concatDim] = dynamicSize.value();
+      } else {
+        sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+        offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+      }
+
+      // Create a subview of the destination buffer.
+      auto dstMemrefType = cast<MemRefType>(memrefType);
+      MemRefType subviewMemRefType =
+          memref::SubViewOp::inferRankReducedResultType(
+              operandTensorType.getShape(), dstMemrefType, offsets, sizes,
+              strides);
+      Value subview = rewriter.create<memref::SubViewOp>(
+          loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
+
+      // Copy the source buffer into the destination subview.
+      if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
+        return failure();
+
+      if (dynamicConcatDim) {
+        dynamicOffset = rewriter.create<arith::AddIOp>(
+            loc, dynamicOffset.value(), dynamicSize.value());
+      } else {
+        concatDimOffset += operandConcatDimSize;
+      }
+    }
+
+    replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
+    return success();
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
     CastOp::attachInterface<CastOpInterface>(*ctx);
     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+    ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 567c4abea488e..e9c3ba7e3b970 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
 
 // -----
 
+// CHECK-LABEL:   func @tensor.concat(
+// CHECK-SAME:        %[[F:.*]]: tensor<8xf32>)
+// CHECK:           %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+  %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+  return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func @tensor.concat_different_shapes(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME:        %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+  %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
+  return %t : tensor<8x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func @tensor.concat_dynamic(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x?xf32>,
+// CHECK-SAME:        %[[G:.*]]: tensor<8x?xf32>
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME:                                    memref<8x?xf32>
+// CHECK-DAG:       %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
+  %t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
+  return %t : tensor<8x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func @tensor.concat_dynamic_nonconcat_dim(
+// CHECK-SAME:        %[[F:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:        %[[G:.*]]: tensor<?x?xf32>
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME:                                    memref<?x?xf32>
+// CHECK-DAG:       %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %t : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@j2kun j2kun merged commit 1bc0043 into llvm:main May 18, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants