@@ -1048,6 +1048,134 @@ struct SplatOpInterface
1048
1048
}
1049
1049
};
1050
1050
1051
+ // / Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052
+ // / filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053
+ // / on subviews instead of memref.store.
1054
+ struct ConcatOpInterface
1055
+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056
+ tensor::ConcatOp> {
1057
+
1058
+ bool bufferizesToAllocation (Operation *op, Value value) const { return true ; }
1059
+
1060
+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1061
+ const AnalysisState &state) const {
1062
+ return false ;
1063
+ }
1064
+
1065
+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
1066
+ const AnalysisState &state) const {
1067
+ return true ;
1068
+ }
1069
+
1070
+ AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
1071
+ const AnalysisState &state) const {
1072
+ return {};
1073
+ }
1074
+
1075
+ LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
1076
+ const BufferizationOptions &options) const {
1077
+ OpBuilder::InsertionGuard g (rewriter);
1078
+ auto concatOp = cast<tensor::ConcatOp>(op);
1079
+
1080
+ // Allocate memory.
1081
+ Location loc = op->getLoc ();
1082
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue (
1083
+ rewriter, loc, concatOp.getResult (), options,
1084
+ /* copy=*/ false );
1085
+ if (failed (tensorAlloc))
1086
+ return failure ();
1087
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType ());
1088
+
1089
+ // TODO: Implement memory space for this op.
1090
+ if (options.defaultMemorySpaceFn (tensorType) != Attribute ())
1091
+ return op->emitError (" memory space not implemented yet" );
1092
+
1093
+ MemRefLayoutAttrInterface layout;
1094
+ MemRefType memrefType =
1095
+ MemRefType::get (concatOp.getResultType ().getShape (),
1096
+ concatOp.getResultType ().getElementType (), layout);
1097
+ Value dstBuffer = rewriter.create <bufferization::ToBufferOp>(
1098
+ op->getLoc (), memrefType, *tensorAlloc);
1099
+
1100
+ // Extract the dimension for the concat op
1101
+ uint64_t concatDim = concatOp.getDim ();
1102
+ bool dynamicConcatDim = false ;
1103
+
1104
+ SmallVector<OpFoldResult> offsets (tensorType.getRank (),
1105
+ rewriter.getIndexAttr (0 ));
1106
+ SmallVector<OpFoldResult> strides (tensorType.getRank (),
1107
+ rewriter.getIndexAttr (1 ));
1108
+ SmallVector<OpFoldResult> sizes;
1109
+
1110
+ for (const auto &[dimIdx, dimSize] :
1111
+ llvm::enumerate (tensorType.getShape ())) {
1112
+ if (dimSize == ShapedType::kDynamic ) {
1113
+ auto dimOp = rewriter.create <memref::DimOp>(loc, dstBuffer, dimIdx);
1114
+ sizes.push_back (dimOp.getResult ());
1115
+ if (dimIdx == concatDim)
1116
+ dynamicConcatDim = true ;
1117
+ } else {
1118
+ sizes.push_back (rewriter.getIndexAttr (dimSize));
1119
+ }
1120
+ }
1121
+
1122
+ int64_t concatDimOffset = 0 ;
1123
+ std::optional<Value> dynamicOffset;
1124
+ std::optional<Value> dynamicSize;
1125
+ if (dynamicConcatDim) {
1126
+ // One or more operands have dynamic size, so we must accumulate the
1127
+ // offset with arith ops.
1128
+ dynamicOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1129
+ }
1130
+
1131
+ for (auto operand : concatOp.getInputs ()) {
1132
+ // Get the buffer for the operand.
1133
+ FailureOr<Value> srcBuffer = getBuffer (rewriter, operand, options);
1134
+ if (failed (srcBuffer))
1135
+ return failure ();
1136
+
1137
+ // Each operand may have a different size along the concat dimension,
1138
+ // so the offset on that axis must accumulate through the loop, and the
1139
+ // size must change to the size of the current operand.
1140
+ auto operandTensorType = cast<RankedTensorType>(operand.getType ());
1141
+ int64_t operandConcatDimSize = operandTensorType.getDimSize (concatDim);
1142
+
1143
+ if (dynamicConcatDim) {
1144
+ offsets[concatDim] = dynamicOffset.value ();
1145
+ dynamicSize = rewriter.create <memref::DimOp>(loc, *srcBuffer, concatDim)
1146
+ .getResult ();
1147
+ sizes[concatDim] = dynamicSize.value ();
1148
+ } else {
1149
+ sizes[concatDim] = rewriter.getIndexAttr (operandConcatDimSize);
1150
+ offsets[concatDim] = rewriter.getIndexAttr (concatDimOffset);
1151
+ }
1152
+
1153
+ // Create a subview of the destination buffer.
1154
+ auto dstMemrefType = cast<MemRefType>(memrefType);
1155
+ MemRefType subviewMemRefType =
1156
+ memref::SubViewOp::inferRankReducedResultType (
1157
+ operandTensorType.getShape (), dstMemrefType, offsets, sizes,
1158
+ strides);
1159
+ Value subview = rewriter.create <memref::SubViewOp>(
1160
+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161
+
1162
+ // Copy the source buffer into the destination subview.
1163
+ if (failed (options.createMemCpy (rewriter, loc, *srcBuffer, subview)))
1164
+ return failure ();
1165
+
1166
+ if (dynamicConcatDim) {
1167
+ dynamicOffset = rewriter.create <arith::AddIOp>(
1168
+ loc, dynamicOffset.value (), dynamicSize.value ());
1169
+ } else {
1170
+ concatDimOffset += operandConcatDimSize;
1171
+ }
1172
+ }
1173
+
1174
+ replaceOpWithBufferizedValues (rewriter, op, dstBuffer);
1175
+ return success ();
1176
+ }
1177
+ };
1178
+
1051
1179
} // namespace
1052
1180
} // namespace tensor
1053
1181
} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1057
1185
registry.addExtension (+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1058
1186
CastOp::attachInterface<CastOpInterface>(*ctx);
1059
1187
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188
+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1060
1189
DimOp::attachInterface<DimOpInterface>(*ctx);
1061
1190
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1062
1191
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
0 commit comments