@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
132
132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133
133
}
134
134
135
+ void getMappedOffsetAndSize (Operation *op, OpBuilder &b,
136
+ AffineMap indexingMap,
137
+ ArrayRef<OpFoldResult> offsets,
138
+ ArrayRef<OpFoldResult> sizes,
139
+ SmallVector<OpFoldResult> &mappedOffsets,
140
+ SmallVector<OpFoldResult> &mappedSizes) const {
141
+ auto linalgOp = cast<LinalgOp>(op);
142
+ auto numLoops = linalgOp.getNumLoops ();
143
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
144
+ mappedOffsets.resize (numLoops);
145
+ mappedSizes.resize (numLoops);
146
+ if (!indexingMap.isPermutation ()) {
147
+ SmallVector<Range> iterationDomain =
148
+ tilingInterfaceOp.getIterationDomain (b);
149
+ for (const auto &range : llvm::enumerate (iterationDomain)) {
150
+ mappedOffsets[range.index ()] = range.value ().offset ;
151
+ mappedSizes[range.index ()] = range.value ().size ;
152
+ }
153
+ }
154
+ for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
155
+ unsigned dimPosition =
156
+ cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
157
+ mappedOffsets[dimPosition] = offsets[resultExpr.index ()];
158
+ mappedSizes[dimPosition] = sizes[resultExpr.index ()];
159
+ }
160
+ }
161
+
162
+ // Return the details of the output tile generated by the tiled
163
+ // implementation.
164
+ LogicalResult getIterDomainTilePositionFromOperandPosition (
165
+ Operation *op, OpBuilder &b, unsigned operandNumber,
166
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
167
+ SmallVector<OpFoldResult> &iterDomainOffsets,
168
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
169
+ auto linalgOp = cast<LinalgOp>(op);
170
+
171
+ // Check that the indexing map used for the operand is a projected
172
+ // permutation. This could be relaxed with a more general approach that can
173
+ // map the offsets and sizes from the operand to iteration space tiles
174
+ // (filling in full extent for dimensions not used to access the result).
175
+ AffineMap indexingMap =
176
+ linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
177
+ if (!indexingMap.isProjectedPermutation ()) {
178
+ return op->emitOpError (
179
+ " unhandled get iter domain position when operand is not "
180
+ " accessed using a permuted projection" );
181
+ }
182
+
183
+ getMappedOffsetAndSize (op, b, indexingMap, offsets, sizes,
184
+ iterDomainOffsets, iterDomainSizes);
185
+ return success ();
186
+ }
187
+
135
188
// Return the details of the output tile generated by the tiled
136
189
// implementation.
137
190
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
160
213
return success ();
161
214
}
162
215
216
+ FailureOr<TilingResult> getTiledImplementationFromOperandPosition (
217
+ Operation *op, OpBuilder &b, unsigned operandNumber,
218
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
219
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
220
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
221
+ if (failed (tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition (
222
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
223
+ return op->emitOpError (
224
+ " unable to obtain the iter domain position of the operation." );
225
+ }
226
+ return tilingInterfaceOp.getTiledImplementation (b, mappedOffsets,
227
+ mappedSizes);
228
+ }
229
+
163
230
FailureOr<TilingResult>
164
231
generateResultTileValue (Operation *op, OpBuilder &b, unsigned resultNumber,
165
232
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177
244
" unhandled tiled implementation generation when result is not "
178
245
" accessed using a permuted projection" );
179
246
}
180
-
181
- auto numLoops = linalgOp.getNumLoops ();
247
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248
+ getMappedOffsetAndSize (op, b, indexingMap, offsets, sizes, mappedOffsets,
249
+ mappedSizes);
182
250
auto tilingInterfaceOp = cast<TilingInterface>(op);
183
- SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184
- iterationTileSizes (numLoops);
185
- if (!indexingMap.isPermutation ()) {
186
- SmallVector<Range> iterationDomain =
187
- tilingInterfaceOp.getIterationDomain (b);
188
- for (const auto &range : llvm::enumerate (iterationDomain)) {
189
- iterationTileOffsets[range.index ()] = range.value ().offset ;
190
- iterationTileSizes[range.index ()] = range.value ().size ;
191
- }
192
- }
193
- for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194
- unsigned dimPosition =
195
- cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197
- iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198
- }
199
-
200
251
FailureOr<TilingResult> tilingResult =
201
- tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202
- iterationTileSizes);
252
+ tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
253
+
254
+ if (failed (tilingResult))
255
+ return failure ();
256
+
203
257
if (tilingResult->tiledOps .size () != 1 )
204
258
return op->emitOpError (" failed to generate tiled implementation" );
205
259
0 commit comments