@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110
110
}));
111
111
}
112
112
113
- // Instantiate the tiled implementation of the operation.
113
+ // / Instantiate the tiled implementation of the operation.
114
114
FailureOr<TilingResult>
115
115
getTiledImplementation (Operation *op, OpBuilder &b,
116
116
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
132
132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133
133
}
134
134
135
- // Return the details of the output tile generated by the tiled
136
- // implementation.
135
+ void
136
+ getMappedOffsetAndSize (LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137
+ ArrayRef<OpFoldResult> offsets,
138
+ ArrayRef<OpFoldResult> sizes,
139
+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
140
+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141
+ unsigned numLoops = linalgOp.getNumLoops ();
142
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
143
+ mappedOffsets.resize (numLoops);
144
+ mappedSizes.resize (numLoops);
145
+ if (!indexingMap.isPermutation ()) {
146
+ SmallVector<Range> iterationDomain =
147
+ tilingInterfaceOp.getIterationDomain (b);
148
+ for (const auto &&[index , value] : llvm::enumerate (iterationDomain)) {
149
+ mappedOffsets[index ] = value.offset ;
150
+ mappedSizes[index ] = value.size ;
151
+ }
152
+ }
153
+ for (const auto &&[index , value] :
154
+ llvm::enumerate (indexingMap.getResults ())) {
155
+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition ();
156
+ mappedOffsets[dimPosition] = offsets[index ];
157
+ mappedSizes[dimPosition] = sizes[index ];
158
+ }
159
+ }
160
+
161
+ // / Return the details of the output tile generated by the tiled
162
+ // / implementation.
163
+ LogicalResult getIterationDomainTileFromOperandTile (
164
+ Operation *op, OpBuilder &b, unsigned operandNumber,
165
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168
+ auto linalgOp = cast<LinalgOp>(op);
169
+
170
+ // Check that the indexing map used for the operand is a projected
171
+ // permutation. This could be relaxed with a more general approach that can
172
+ // map the offsets and sizes from the operand to iteration space tiles
173
+ // (filling in full extent for dimensions not used to access the result).
174
+ AffineMap indexingMap =
175
+ linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
176
+ if (!indexingMap.isProjectedPermutation ()) {
177
+ return emitError (op->getLoc (),
178
+ " unhandled get iter domain position when operand is not "
179
+ " accessed using a permuted projection" );
180
+ }
181
+
182
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
183
+ iterDomainOffsets, iterDomainSizes);
184
+ return success ();
185
+ }
186
+
187
+ // / Return the details of the output tile generated by the tiled
188
+ // / implementation.
137
189
LogicalResult
138
190
getResultTilePosition (Operation *op, OpBuilder &b, unsigned resultNumber,
139
191
ArrayRef<OpFoldResult> offsets,
140
192
ArrayRef<OpFoldResult> sizes,
141
- SmallVector <OpFoldResult> &resultOffsets,
142
- SmallVector <OpFoldResult> &resultSizes) const {
193
+ SmallVectorImpl <OpFoldResult> &resultOffsets,
194
+ SmallVectorImpl <OpFoldResult> &resultSizes) const {
143
195
Location loc = op->getLoc ();
144
196
LinalgOp linalgOp = cast<LinalgOp>(op);
145
197
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
160
212
return success ();
161
213
}
162
214
215
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
216
+ Operation *op, OpBuilder &b, unsigned operandNumber,
217
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
220
+ if (failed (tilingInterfaceOp.getIterationDomainTileFromOperandTile (
221
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222
+ return emitError (
223
+ op->getLoc (),
224
+ " unable to obtain the iter domain position of the operation." );
225
+ }
226
+ return tilingInterfaceOp.getTiledImplementation (b, mappedOffsets,
227
+ mappedSizes);
228
+ }
229
+
163
230
FailureOr<TilingResult>
164
231
generateResultTileValue (Operation *op, OpBuilder &b, unsigned resultNumber,
165
232
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177
244
" unhandled tiled implementation generation when result is not "
178
245
" accessed using a permuted projection" );
179
246
}
180
-
181
- auto numLoops = linalgOp.getNumLoops ();
247
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
249
+ mappedOffsets, mappedSizes);
182
250
auto tilingInterfaceOp = cast<TilingInterface>(op);
183
- SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184
- iterationTileSizes (numLoops);
185
- if (!indexingMap.isPermutation ()) {
186
- SmallVector<Range> iterationDomain =
187
- tilingInterfaceOp.getIterationDomain (b);
188
- for (const auto &range : llvm::enumerate (iterationDomain)) {
189
- iterationTileOffsets[range.index ()] = range.value ().offset ;
190
- iterationTileSizes[range.index ()] = range.value ().size ;
191
- }
192
- }
193
- for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194
- unsigned dimPosition =
195
- cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197
- iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198
- }
199
-
200
251
FailureOr<TilingResult> tilingResult =
201
- tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202
- iterationTileSizes);
252
+ tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
253
+
254
+ if (failed (tilingResult))
255
+ return failure ();
256
+
203
257
if (tilingResult->tiledOps .size () != 1 )
204
258
return op->emitOpError (" failed to generate tiled implementation" );
205
259
0 commit comments