Skip to content

Commit c4bf949

Browse files
authored
[mlir][linalg] Implement TilingInterface for winograd operators (#96184)
In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operations. Before converting winograd operations into nested loops with matrix multiply, tile the input of conv2d into the supported size first. Add a transform operation structured.decompose_winograd_op to decompose winograd operations. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operations.
1 parent 2fe59d5 commit c4bf949

File tree

8 files changed

+1330
-40
lines changed

8 files changed

+1330
-40
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157-
def Linalg_WinogradFilterTransformOp :
158-
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[AllElementTypesMatch<["filter", "output"]>,
159+
DeclareOpInterfaceMethods<TilingInterface,
160+
["getIterationDomain",
161+
"getLoopIteratorTypes",
162+
"getResultTilePosition",
163+
"getTiledImplementation"]>]> {
159164
let summary = "Winograd filter transform operator";
160165
let description = [{
161166
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
190195
`outs` `(` $output `:` type($output) `)`
191196
`->` type($result)
192197
}];
198+
let extraClassDeclaration = [{
199+
ShapedType getFilterOperandType() {
200+
return cast<ShapedType>(getFilter().getType());
201+
}
202+
ShapedType getOutputOperandType() {
203+
return cast<ShapedType>(getOutput().getType());
204+
}
205+
int64_t getFilterOperandRank() {
206+
return getFilterOperandType().getRank();
207+
}
208+
int64_t getOutputOperandRank() {
209+
return getOutputOperandType().getRank();
210+
}
211+
int64_t getFilterFDim() {
212+
return 0;
213+
}
214+
int64_t getFilterHDim() {
215+
return 1;
216+
}
217+
int64_t getFilterWDim() {
218+
return 2;
219+
}
220+
int64_t getFilterCDim() {
221+
return 3;
222+
}
223+
}];
193224
let hasVerifier = 1;
194225
}
195226

196-
def Linalg_WinogradInputTransformOp :
197-
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
227+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228+
[AllElementTypesMatch<["input", "output"]>,
229+
DeclareOpInterfaceMethods<TilingInterface,
230+
["getIterationDomain",
231+
"getLoopIteratorTypes",
232+
"getResultTilePosition",
233+
"getTiledImplementation"]>]> {
198234
let summary = "Winograd input transform operator";
199235
let description = [{
200236
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
229265
`outs` `(` $output `:` type($output) `)`
230266
`->` type($result)
231267
}];
268+
let extraClassDeclaration = [{
269+
ShapedType getInputOperandType() {
270+
return cast<ShapedType>(getInput().getType());
271+
}
272+
ShapedType getOutputOperandType() {
273+
return cast<ShapedType>(getOutput().getType());
274+
}
275+
int64_t getInputOperandRank() {
276+
return getInputOperandType().getRank();
277+
}
278+
int64_t getOutputOperandRank() {
279+
return getOutputOperandType().getRank();
280+
}
281+
int64_t getInputNDim() {
282+
return 0;
283+
}
284+
int64_t getInputHDim() {
285+
return 1;
286+
}
287+
int64_t getInputWDim() {
288+
return 2;
289+
}
290+
int64_t getInputCDim() {
291+
return 3;
292+
}
293+
int64_t getOutputAlphaHDim() {
294+
return 0;
295+
}
296+
int64_t getOutputAlphaWDim() {
297+
return 1;
298+
}
299+
int64_t getOutputTileHDim() {
300+
return 2;
301+
}
302+
int64_t getOutputTileWDim() {
303+
return 3;
304+
}
305+
int64_t getOutputNDim() {
306+
return 4;
307+
}
308+
int64_t getOutputCDim() {
309+
return 5;
310+
}
311+
}];
232312
let hasVerifier = 1;
233313
}
234314

235-
def Linalg_WinogradOutputTransformOp :
236-
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
315+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316+
[AllElementTypesMatch<["value", "output"]>,
317+
DeclareOpInterfaceMethods<TilingInterface,
318+
["getIterationDomain",
319+
"getLoopIteratorTypes",
320+
"getResultTilePosition",
321+
"getTiledImplementation"]>]> {
237322
let summary = "Winograd output transform operator";
238323
let description = [{
239324
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
268353
`outs` `(` $output `:` type($output) `)`
269354
`->` type($result)
270355
}];
356+
let extraClassDeclaration = [{
357+
ShapedType getValueOperandType() {
358+
return cast<ShapedType>(getValue().getType());
359+
}
360+
ShapedType getOutputOperandType() {
361+
return cast<ShapedType>(getOutput().getType());
362+
}
363+
int64_t getValueOperandRank() {
364+
return getValueOperandType().getRank();
365+
}
366+
int64_t getOutputOperandRank() {
367+
return getOutputOperandType().getRank();
368+
}
369+
int64_t getValueAlphaHDim() {
370+
return 0;
371+
}
372+
int64_t getValueAlphaWDim() {
373+
return 1;
374+
}
375+
int64_t getValueTileHDim() {
376+
return 2;
377+
}
378+
int64_t getValueTileWDim() {
379+
return 3;
380+
}
381+
int64_t getValueNDim() {
382+
return 4;
383+
}
384+
int64_t getValueFDim() {
385+
return 5;
386+
}
387+
int64_t getOutputNDim() {
388+
return 0;
389+
}
390+
int64_t getOutputHDim() {
391+
return 1;
392+
}
393+
int64_t getOutputWDim() {
394+
return 2;
395+
}
396+
int64_t getOutputFDim() {
397+
return 3;
398+
}
399+
}];
271400
let hasVerifier = 1;
272401
}
273402

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
26972697
}];
26982698
}
26992699

2700+
def DecomposeWinogradOp : Op<Transform_Dialect,
2701+
"structured.decompose_winograd_op",
2702+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2703+
TransformOpInterface, TransformEachOpTrait,
2704+
ReportTrackingListenerFailuresOpTrait]> {
2705+
let description = [{
2706+
Decompose winograd operations. It will convert filter, input and output
2707+
transform operations into a combination of scf, tensor, and linalg
2708+
equivalent operations. Before applying this transform operations, users
2709+
need to tile winograd transform operations into supported sizes.
2710+
2711+
#### Return modes:
2712+
2713+
This operation fails if `target` is unsupported. Otherwise, the operation
2714+
succeeds and returns a handle of the sequence that replaces the original
2715+
operations.
2716+
}];
2717+
2718+
let arguments = (ins TransformHandleTypeInterface:$target);
2719+
let results = (outs TransformHandleTypeInterface:$transformed);
2720+
2721+
let assemblyFormat =
2722+
"$target attr-dict `:` functional-type($target, results)";
2723+
2724+
let builders = [
2725+
OpBuilder<(ins "Value":$target)>
2726+
];
2727+
2728+
let extraClassDeclaration = [{
2729+
::mlir::DiagnosedSilenceableFailure applyToOne(
2730+
::mlir::transform::TransformRewriter &rewriter,
2731+
::mlir::Operation *target,
2732+
::mlir::transform::ApplyToEachResultList &results,
2733+
::mlir::transform::TransformState &state);
2734+
}];
2735+
}
2736+
27002737
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,63 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13161316
linalg::Conv2DNhwcFhwcOp op, int64_t m,
13171317
int64_t r);
13181318

1319+
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1320+
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1321+
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1322+
/// the rewriting, we get
1323+
///
1324+
/// scf.for %f = lo_f to hi_f step 1
1325+
/// scf.for %c = lo_c to hi_c step 1
1326+
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1327+
/// %ret = linalg.matmul G, %extracted
1328+
/// %ret = linalg.matmul %ret, GT
1329+
/// %inserted = insert %ret into filter<h x w x c x f>
1330+
FailureOr<Operation *>
1331+
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1332+
linalg::WinogradFilterTransformOp op);
1333+
1334+
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1335+
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1336+
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1337+
/// and tileW. After the rewriting, we get
1338+
///
1339+
/// scf.for %h = 0 to tileH step 1
1340+
/// scf.for %w = 0 to tileW step 1
1341+
/// scf.for %n = 0 to N step 1
1342+
/// scf.for %c = 0 to C step 1
1343+
/// %extracted = extract %extracted<alphaH x alphaW> from
1344+
/// %input<N x H x W x C>
1345+
/// at [%n, (%h x m), (%w x m), %c]
1346+
/// %ret = linalg.matmul BT, %extracted
1347+
/// %ret = linalg.matmul %ret, B
1348+
/// %inserted = insert %ret<alphaH x alphaW> into
1349+
/// %output<alphaH x alphaW x tileH x tileW x N x C>
1350+
/// at [0, 0, %h, %w, %n, %c]
1351+
FailureOr<Operation *>
1352+
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1353+
linalg::WinogradInputTransformOp op);
1354+
1355+
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1356+
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1357+
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1358+
/// and tileW. After the transformation, we get
1359+
///
1360+
/// scf.for %h = 0 to tileH step 1
1361+
/// scf.for %w = 0 to tileW step 1
1362+
/// scf.for %n = 0 to N step 1
1363+
/// scf.for %f = 0 to F step 1
1364+
/// %extracted = extract %extracted<alphaH x alphaW> from
1365+
/// %input<alphaH x alphaW x tileH x tileW x N x F>
1366+
/// at [0, 0, %h, %w, %n, %f]
1367+
/// %ret = linalg.matmul AT, %extracted
1368+
/// %ret = linalg.matmul %ret, A
1369+
/// %inserted = insert %ret<alphaH x alphaW> into
1370+
/// output<N x H x W x F>
1371+
/// at [%n, (%h x m), (%w x m), %f]
1372+
FailureOr<Operation *>
1373+
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1374+
linalg::WinogradOutputTransformOp op);
1375+
13191376
//===----------------------------------------------------------------------===//
13201377
// Rewrite patterns wrapping transformations.
13211378
// TODO: every single such pattern should be a close to noop wrapper around a

0 commit comments

Comments
 (0)