@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110
110
}));
111
111
}
112
112
113
- // Instantiate the tiled implementation of the operation.
113
+ // / Instantiate the tiled implementation of the operation.
114
114
FailureOr<TilingResult>
115
115
getTiledImplementation (Operation *op, OpBuilder &b,
116
116
ArrayRef<OpFoldResult> offsets,
@@ -132,8 +132,63 @@ struct LinalgOpTilingInterface
132
132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133
133
}
134
134
135
- // Return the details of the output tile generated by the tiled
136
- // implementation.
135
+ // / Utility to fetch the offsets and sizes when applied as per the indexing
136
+ // / map of the linalg op. This helps in fusing the linalg op as a consumer of
137
+ // / a given slice op.
138
+ void
139
+ getMappedOffsetAndSize (LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
140
+ ArrayRef<OpFoldResult> offsets,
141
+ ArrayRef<OpFoldResult> sizes,
142
+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
143
+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
144
+ unsigned numLoops = linalgOp.getNumLoops ();
145
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
146
+ mappedOffsets.resize (numLoops);
147
+ mappedSizes.resize (numLoops);
148
+ if (!indexingMap.isPermutation ()) {
149
+ SmallVector<Range> iterationDomain =
150
+ tilingInterfaceOp.getIterationDomain (b);
151
+ for (const auto &&[index , value] : llvm::enumerate (iterationDomain)) {
152
+ mappedOffsets[index ] = value.offset ;
153
+ mappedSizes[index ] = value.size ;
154
+ }
155
+ }
156
+ for (const auto &&[index , value] :
157
+ llvm::enumerate (indexingMap.getResults ())) {
158
+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition ();
159
+ mappedOffsets[dimPosition] = offsets[index ];
160
+ mappedSizes[dimPosition] = sizes[index ];
161
+ }
162
+ }
163
+
164
+ // / Method to return the position of the result tile computed by the tiled
165
+ // / operation.
166
+ LogicalResult getIterationDomainTileFromOperandTile (
167
+ Operation *op, OpBuilder &b, unsigned operandNumber,
168
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
169
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
170
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
171
+ auto linalgOp = cast<LinalgOp>(op);
172
+
173
+ // Check that the indexing map used for the operand is a projected
174
+ // permutation. This could be relaxed with a more general approach that can
175
+ // map the offsets and sizes from the operand to iteration space tiles
176
+ // (filling in full extent for dimensions not used to access the result).
177
+ AffineMap indexingMap =
178
+ linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
179
+ if (!indexingMap.isProjectedPermutation ()) {
180
+ return op->emitError ()
181
+ << " unhandled get iter domain position when operand is not "
182
+ " accessed using a permuted projection" ;
183
+ }
184
+
185
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
186
+ iterDomainOffsets, iterDomainSizes);
187
+ return success ();
188
+ }
189
+
190
+ // / Return the details of the output tile generated by the tiled
191
+ // / implementation.
137
192
LogicalResult
138
193
getResultTilePosition (Operation *op, OpBuilder &b, unsigned resultNumber,
139
194
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +232,16 @@ struct LinalgOpTilingInterface
177
232
" unhandled tiled implementation generation when result is not "
178
233
" accessed using a permuted projection" );
179
234
}
180
-
181
- auto numLoops = linalgOp.getNumLoops ();
235
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
236
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
237
+ mappedOffsets, mappedSizes);
182
238
auto tilingInterfaceOp = cast<TilingInterface>(op);
183
- SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184
- iterationTileSizes (numLoops);
185
- if (!indexingMap.isPermutation ()) {
186
- SmallVector<Range> iterationDomain =
187
- tilingInterfaceOp.getIterationDomain (b);
188
- for (const auto &range : llvm::enumerate (iterationDomain)) {
189
- iterationTileOffsets[range.index ()] = range.value ().offset ;
190
- iterationTileSizes[range.index ()] = range.value ().size ;
191
- }
192
- }
193
- for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194
- unsigned dimPosition =
195
- cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197
- iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198
- }
199
-
200
239
FailureOr<TilingResult> tilingResult =
201
- tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202
- iterationTileSizes);
240
+ tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
241
+
242
+ if (failed (tilingResult))
243
+ return failure ();
244
+
203
245
if (tilingResult->tiledOps .size () != 1 )
204
246
return op->emitOpError (" failed to generate tiled implementation" );
205
247
@@ -208,6 +250,20 @@ struct LinalgOpTilingInterface
208
250
SmallVector<Value>{tilingResult->tiledValues [resultNumber]}};
209
251
}
210
252
253
+ // / Method to generate the tiled implementation of an operation from the tile
254
+ // / of the operand.
255
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
256
+ Operation *op, OpBuilder &b, unsigned operandNumber,
257
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
258
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
259
+ if (failed (getIterationDomainTileFromOperandTile (
260
+ op, b, operandNumber, offsets, sizes, mappedOffsets,
261
+ mappedSizes))) {
262
+ return failure ();
263
+ }
264
+ return getTiledImplementation (op, b, mappedOffsets, mappedSizes);
265
+ }
266
+
211
267
LogicalResult generateScalarImplementation (Operation *op, OpBuilder &builder,
212
268
Location loc,
213
269
ValueRange ivs) const {
0 commit comments