Skip to content

Commit 8a57d82

Browse files
[mlir] Add Scalar Broadcast TOSA Depthwise Conv (#110806)
Support broadcasting of depthwise conv2d bias in tosa->linalg named lowering in the case that bias is a rank-1 tensor with exactly 1 element. In this case TOSA specifies the value should first be broadcast across the bias dimension and then across the result tensor. Add `lit` tests for depthwise conv2d with scalar bias and for conv3d which was already supported but missing coverage. Signed-off-by: Jack Frankland <[email protected]>
1 parent 56736c7 commit 8a57d82

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,14 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
8888
.getResult(0);
8989
}
9090

91-
// Broadcast the source value to all the outer dimensions of the result value.
92-
// If required, the element type is expanded using an arith.extsi operation.
93-
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
94-
Location loc, Value source,
95-
Value result) {
91+
// Construct the affine map that a linalg generic would use to broadcast the
92+
// source tensor into the shape of the result tensor.
93+
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
94+
Value result) {
9695
ShapedType resultTy = cast<ShapedType>(result.getType());
9796
ShapedType sourceTy = cast<ShapedType>(source.getType());
98-
int64_t resultRank = resultTy.getRank();
99-
int64_t sourceRank = sourceTy.getRank();
97+
const int64_t resultRank = resultTy.getRank();
98+
const int64_t sourceRank = sourceTy.getRank();
10099

101100
// The source tensor is broadcast to all the outer dimensions of the
102101
// result tensor.
@@ -115,14 +114,21 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
115114
}
116115
}
117116

118-
// Creating maps for the input and output of the broacast-like generic op.
119-
SmallVector<AffineMap, 2> indexingMaps = {
120-
// Broadcast the last dimension of the bias to all output dimensions.
121-
AffineMap::get(/*dimCount=*/resultRank,
122-
/*symbolCount=*/0, sourceDims, rewriter.getContext()),
117+
return AffineMap::get(/*dimCount=*/resultRank,
118+
/*symbolCount=*/0, sourceDims, rewriter.getContext());
119+
}
123120

124-
// Output indexing map.
125-
rewriter.getMultiDimIdentityMap(resultRank)};
121+
// Broadcast the source value to all the outer dimensions of the result value.
122+
// If required, the element type is expanded using an arith.extsi operation.
123+
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
124+
Location loc, Value source,
125+
Value result) {
126+
ShapedType resultTy = cast<ShapedType>(result.getType());
127+
const int64_t resultRank = resultTy.getRank();
128+
// Creating maps for the input and output of the broacast-like generic op.
129+
SmallVector<AffineMap, 2> indexingMaps;
130+
indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
131+
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
126132

127133
// Build the broadcast-like operation as a linalg.generic.
128134
return rewriter
@@ -488,14 +494,6 @@ class DepthwiseConvConverter
488494
weightShape[2], weightShape[3]},
489495
resultETy);
490496

491-
// Broadcast the initial value to the output tensor before convolving.
492-
SmallVector<AffineMap, 4> indexingMaps;
493-
indexingMaps.push_back(AffineMap::get(
494-
/*dimCount=*/resultRank, /*symbolCount=*/0,
495-
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
496-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
497-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
498-
499497
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
500498
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
501499
loc, linalgConvTy.getShape(), resultETy, filteredDims);
@@ -507,6 +505,13 @@ class DepthwiseConvConverter
507505

508506
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
509507
loc, resultTy.getShape(), resultETy, filteredDims);
508+
509+
// Broadcast the initial value to the output tensor before convolving.
510+
SmallVector<AffineMap, 4> indexingMaps;
511+
indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
512+
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513+
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514+
510515
if (!isQuantized) {
511516
Value conv = rewriter
512517
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,22 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf
702702

703703
// -----
704704

705+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (0)>
706+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
707+
708+
// CHECK-LABEL: @depthwise_conv_scalar_bias
709+
func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
710+
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %{{.*}} : tensor<1xf32>, tensor<1x5x5x33xf32>) outs(%{{.*}} : tensor<1x5x5x33xf32>) {
711+
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %{{.*}}: f32):
712+
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
713+
// CHECK: linalg.yield [[ADD]] : f32
714+
// CHECK: } -> tensor<1x5x5x33xf32>
715+
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
716+
return
717+
}
718+
719+
// -----
720+
705721
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
706722
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
707723

@@ -840,6 +856,20 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
840856

841857
// -----
842858

859+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (0)>
860+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
861+
862+
// CHECK-LABEL: @conv3d_scalar_bias_f32
863+
func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
864+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
865+
// CHECK: %[[BROADCAST:.+]] = linalg.generic
866+
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
867+
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
868+
return
869+
}
870+
871+
// -----
872+
843873
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
844874
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
845875

0 commit comments

Comments
 (0)