Skip to content

Commit 1bc0043

Browse files
authored
Restore #140171 with to_memref -> to_buffer (#140355)
#140171 was reverted because an op's name changed and I neglected to rebase before merging. --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent f2165b9 commit 1bc0043

File tree

3 files changed

+222
-2
lines changed

3 files changed

+222
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
4949
>();
5050
addInterfaces<TensorInlinerInterface>();
5151
declarePromisedInterfaces<
52-
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
53-
EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
52+
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
53+
DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
5454
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
5555
ReshapeOp, SplatOp>();
5656
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,134 @@ struct SplatOpInterface
10481048
}
10491049
};
10501050

1051+
/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052+
/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053+
/// on subviews instead of memref.store.
1054+
struct ConcatOpInterface
1055+
: public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056+
tensor::ConcatOp> {
1057+
1058+
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1059+
1060+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1061+
const AnalysisState &state) const {
1062+
return false;
1063+
}
1064+
1065+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1066+
const AnalysisState &state) const {
1067+
return true;
1068+
}
1069+
1070+
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1071+
const AnalysisState &state) const {
1072+
return {};
1073+
}
1074+
1075+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1076+
const BufferizationOptions &options) const {
1077+
OpBuilder::InsertionGuard g(rewriter);
1078+
auto concatOp = cast<tensor::ConcatOp>(op);
1079+
1080+
// Allocate memory.
1081+
Location loc = op->getLoc();
1082+
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1083+
rewriter, loc, concatOp.getResult(), options,
1084+
/*copy=*/false);
1085+
if (failed(tensorAlloc))
1086+
return failure();
1087+
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1088+
1089+
// TODO: Implement memory space for this op.
1090+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1091+
return op->emitError("memory space not implemented yet");
1092+
1093+
MemRefLayoutAttrInterface layout;
1094+
MemRefType memrefType =
1095+
MemRefType::get(concatOp.getResultType().getShape(),
1096+
concatOp.getResultType().getElementType(), layout);
1097+
Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
1098+
op->getLoc(), memrefType, *tensorAlloc);
1099+
1100+
// Extract the dimension for the concat op
1101+
uint64_t concatDim = concatOp.getDim();
1102+
bool dynamicConcatDim = false;
1103+
1104+
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1105+
rewriter.getIndexAttr(0));
1106+
SmallVector<OpFoldResult> strides(tensorType.getRank(),
1107+
rewriter.getIndexAttr(1));
1108+
SmallVector<OpFoldResult> sizes;
1109+
1110+
for (const auto &[dimIdx, dimSize] :
1111+
llvm::enumerate(tensorType.getShape())) {
1112+
if (dimSize == ShapedType::kDynamic) {
1113+
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1114+
sizes.push_back(dimOp.getResult());
1115+
if (dimIdx == concatDim)
1116+
dynamicConcatDim = true;
1117+
} else {
1118+
sizes.push_back(rewriter.getIndexAttr(dimSize));
1119+
}
1120+
}
1121+
1122+
int64_t concatDimOffset = 0;
1123+
std::optional<Value> dynamicOffset;
1124+
std::optional<Value> dynamicSize;
1125+
if (dynamicConcatDim) {
1126+
// One or more operands have dynamic size, so we must accumulate the
1127+
// offset with arith ops.
1128+
dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1129+
}
1130+
1131+
for (auto operand : concatOp.getInputs()) {
1132+
// Get the buffer for the operand.
1133+
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
1134+
if (failed(srcBuffer))
1135+
return failure();
1136+
1137+
// Each operand may have a different size along the concat dimension,
1138+
// so the offset on that axis must accumulate through the loop, and the
1139+
// size must change to the size of the current operand.
1140+
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1141+
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1142+
1143+
if (dynamicConcatDim) {
1144+
offsets[concatDim] = dynamicOffset.value();
1145+
dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1146+
.getResult();
1147+
sizes[concatDim] = dynamicSize.value();
1148+
} else {
1149+
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1150+
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1151+
}
1152+
1153+
// Create a subview of the destination buffer.
1154+
auto dstMemrefType = cast<MemRefType>(memrefType);
1155+
MemRefType subviewMemRefType =
1156+
memref::SubViewOp::inferRankReducedResultType(
1157+
operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1158+
strides);
1159+
Value subview = rewriter.create<memref::SubViewOp>(
1160+
loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161+
1162+
// Copy the source buffer into the destination subview.
1163+
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1164+
return failure();
1165+
1166+
if (dynamicConcatDim) {
1167+
dynamicOffset = rewriter.create<arith::AddIOp>(
1168+
loc, dynamicOffset.value(), dynamicSize.value());
1169+
} else {
1170+
concatDimOffset += operandConcatDimSize;
1171+
}
1172+
}
1173+
1174+
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1175+
return success();
1176+
}
1177+
};
1178+
10511179
} // namespace
10521180
} // namespace tensor
10531181
} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
10571185
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
10581186
CastOp::attachInterface<CastOpInterface>(*ctx);
10591187
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188+
ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
10601189
DimOp::attachInterface<DimOpInterface>(*ctx);
10611190
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
10621191
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
615615

616616
// -----
617617

618+
// CHECK-LABEL: func @tensor.concat(
619+
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620+
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
621+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
622+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
623+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
624+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
625+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
626+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
627+
// CHECK: return %[[RET]]
628+
// CHECK: }
629+
func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
630+
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
631+
return %t : tensor<16xf32>
632+
}
633+
634+
// -----
635+
636+
// CHECK-LABEL: func @tensor.concat_different_shapes(
637+
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
638+
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
639+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
640+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
641+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
642+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
643+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
644+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
645+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
646+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
647+
// CHECK: return %[[RET]]
648+
// CHECK: }
649+
func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
650+
%t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
651+
return %t : tensor<8x9xf32>
652+
}
653+
654+
// -----
655+
656+
// CHECK-LABEL: func @tensor.concat_dynamic(
657+
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
658+
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
659+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
660+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
661+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
662+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
663+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
664+
// CHECK: %[[ALLOC:.*]] = memref.alloc
665+
// CHECK-SAME: memref<8x?xf32>
666+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
667+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
668+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
669+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
670+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
671+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
672+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
673+
// CHECK: return %[[RET]]
674+
// CHECK: }
675+
func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
676+
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
677+
return %t : tensor<8x?xf32>
678+
}
679+
680+
// -----
681+
682+
// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
683+
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
684+
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
685+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
686+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
687+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
688+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
689+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
690+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
691+
// CHECK: %[[ALLOC:.*]] = memref.alloc
692+
// CHECK-SAME: memref<?x?xf32>
693+
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
694+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
695+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
696+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
697+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
698+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
699+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
700+
// CHECK: return %[[RET]]
701+
// CHECK: }
702+
func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
703+
%t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
704+
return %t : tensor<?x?xf32>
705+
}
706+
707+
// -----
708+
618709
// CHECK-LABEL: func @tensor.splat_dynamic(
619710
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
620711
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
 (0)