@@ -88,15 +88,14 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
88
88
.getResult (0 );
89
89
}
90
90
91
- // Broadcast the source value to all the outer dimensions of the result value.
92
- // If required, the element type is expanded using an arith.extsi operation.
93
- static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
94
- Location loc, Value source,
95
- Value result) {
91
+ // Construct the affine map that a linalg generic would use to broadcast the
92
+ // source tensor into the shape of the result tensor.
93
+ static AffineMap getBroadcastingMap (PatternRewriter &rewriter, Value source,
94
+ Value result) {
96
95
ShapedType resultTy = cast<ShapedType>(result.getType ());
97
96
ShapedType sourceTy = cast<ShapedType>(source.getType ());
98
- int64_t resultRank = resultTy.getRank ();
99
- int64_t sourceRank = sourceTy.getRank ();
97
+ const int64_t resultRank = resultTy.getRank ();
98
+ const int64_t sourceRank = sourceTy.getRank ();
100
99
101
100
// The source tensor is broadcast to all the outer dimensions of the
102
101
// result tensor.
@@ -115,14 +114,21 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
115
114
}
116
115
}
117
116
118
- // Creating maps for the input and output of the broacast-like generic op.
119
- SmallVector<AffineMap, 2 > indexingMaps = {
120
- // Broadcast the last dimension of the bias to all output dimensions.
121
- AffineMap::get (/* dimCount=*/ resultRank,
122
- /* symbolCount=*/ 0 , sourceDims, rewriter.getContext ()),
117
+ return AffineMap::get (/* dimCount=*/ resultRank,
118
+ /* symbolCount=*/ 0 , sourceDims, rewriter.getContext ());
119
+ }
123
120
124
- // Output indexing map.
125
- rewriter.getMultiDimIdentityMap (resultRank)};
121
+ // Broadcast the source value to all the outer dimensions of the result value.
122
+ // If required, the element type is expanded using an arith.extsi operation.
123
+ static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
124
+ Location loc, Value source,
125
+ Value result) {
126
+ ShapedType resultTy = cast<ShapedType>(result.getType ());
127
+ const int64_t resultRank = resultTy.getRank ();
128
+ // Creating maps for the input and output of the broacast-like generic op.
129
+ SmallVector<AffineMap, 2 > indexingMaps;
130
+ indexingMaps.push_back (getBroadcastingMap (rewriter, source, result));
131
+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
126
132
127
133
// Build the broadcast-like operation as a linalg.generic.
128
134
return rewriter
@@ -488,14 +494,6 @@ class DepthwiseConvConverter
488
494
weightShape[2 ], weightShape[3 ]},
489
495
resultETy);
490
496
491
- // Broadcast the initial value to the output tensor before convolving.
492
- SmallVector<AffineMap, 4 > indexingMaps;
493
- indexingMaps.push_back (AffineMap::get (
494
- /* dimCount=*/ resultRank, /* symbolCount=*/ 0 ,
495
- {rewriter.getAffineDimExpr (3 )}, rewriter.getContext ()));
496
- indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
497
- indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
498
-
499
497
auto resultZeroAttr = rewriter.getZeroAttr (resultETy);
500
498
Value emptyTensor = rewriter.create <tensor::EmptyOp>(
501
499
loc, linalgConvTy.getShape (), resultETy, filteredDims);
@@ -507,6 +505,13 @@ class DepthwiseConvConverter
507
505
508
506
Value biasEmptyTensor = rewriter.create <tensor::EmptyOp>(
509
507
loc, resultTy.getShape (), resultETy, filteredDims);
508
+
509
+ // Broadcast the initial value to the output tensor before convolving.
510
+ SmallVector<AffineMap, 4 > indexingMaps;
511
+ indexingMaps.push_back (getBroadcastingMap (rewriter, bias, biasEmptyTensor));
512
+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
513
+ indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
514
+
510
515
if (!isQuantized) {
511
516
Value conv = rewriter
512
517
.create <linalg::DepthwiseConv2DNhwcHwcmOp>(
0 commit comments