Skip to content

Commit aa2a96a

Browse files
[mlir][TilingInterface] Move TilingInterface tests to use transform dialect ops. (#77204)
In the process a couple of test transform dialect ops are added just for testing. These operations are not intended to use as full flushed out of transformation ops, but are rather operations added for testing. A separate operation is added to `LinalgTransformOps.td` to convert a `TilingInterface` operation to loops using the `generateScalarImplementation` method implemented by the operation. Eventually this and other operations related to tiling using the `TilingInterface` need to move to a better place (i.e. out of `Linalg` dialect)
1 parent 3699811 commit aa2a96a

File tree

15 files changed

+856
-758
lines changed

15 files changed

+856
-758
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
293293
let results = (outs TransformHandleTypeInterface:$transformed,
294294
Variadic<TransformHandleTypeInterface>:$loops);
295295

296-
let hasCustomAssemblyFormat = 1;
296+
let assemblyFormat = [{
297+
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
298+
attr-dict `:` functional-type(operands, results)
299+
}];
297300
let hasVerifier = 1;
298301
}
299302

@@ -1269,6 +1272,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
12691272
}];
12701273
}
12711274

1275+
def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
1276+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1277+
TransformOpInterface, TransformEachOpTrait,
1278+
ReportTrackingListenerFailuresOpTrait]> {
1279+
let description = [{
1280+
For operations that implement the `TilingInterface`, and implement
1281+
the `generateScalarImplementation` method, lowers the operation to
1282+
loops. This operation does not return any handles.
1283+
}];
1284+
1285+
let arguments = (ins TransformHandleTypeInterface:$target);
1286+
let results = (outs);
1287+
1288+
let assemblyFormat = [{
1289+
$target attr-dict `:` type($target)
1290+
}];
1291+
1292+
let extraClassDeclaration = [{
1293+
::mlir::DiagnosedSilenceableFailure applyToOne(
1294+
::mlir::transform::TransformRewriter &rewriter,
1295+
::mlir::TilingInterface target,
1296+
::mlir::transform::ApplyToEachResultList &results,
1297+
::mlir::transform::TransformState &state);
1298+
}];
1299+
}
1300+
1301+
12721302
//===----------------------------------------------------------------------===//
12731303
// DecomposeInterfaceOp
12741304
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -492,38 +492,6 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
492492
: DiagnosedSilenceableFailure::success();
493493
}
494494

495-
ParseResult transform::FuseOp::parse(OpAsmParser &parser,
496-
OperationState &result) {
497-
OpAsmParser::UnresolvedOperand targetOperand;
498-
if (parser.parseOperand(targetOperand) ||
499-
parser.parseOptionalAttrDict(result.attributes))
500-
return failure();
501-
502-
FunctionType trailingType;
503-
SMLoc typeLoc;
504-
if (parser.getCurrentLocation(&typeLoc) ||
505-
parser.parseColonType(trailingType)) {
506-
return failure();
507-
}
508-
if (trailingType.getNumInputs() != 1)
509-
return parser.emitError(typeLoc) << "expected one input type";
510-
511-
result.addTypes(trailingType.getResults());
512-
if (parser.resolveOperand(targetOperand, trailingType.getInput(0),
513-
result.operands))
514-
return failure();
515-
return success();
516-
}
517-
518-
void transform::FuseOp::print(OpAsmPrinter &p) {
519-
p << ' ';
520-
p << getTarget();
521-
p.printOptionalAttrDict((*this)->getAttrs());
522-
p << " : ";
523-
p.printFunctionalType(TypeRange(getOperand().getType()),
524-
getResults().getTypes());
525-
}
526-
527495
LogicalResult transform::FuseOp::verify() {
528496
SmallVector<int64_t> permutation =
529497
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
@@ -2111,6 +2079,22 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
21112079
return DiagnosedSilenceableFailure::success();
21122080
}
21132081

2082+
//===----------------------------------------------------------------------===//
2083+
// ConvertToLoopsOp
2084+
//===----------------------------------------------------------------------===//
2085+
2086+
DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
2087+
transform::TransformRewriter &rewriter, TilingInterface target,
2088+
transform::ApplyToEachResultList &results,
2089+
transform::TransformState &state) {
2090+
rewriter.setInsertionPoint(target);
2091+
FailureOr<SmallVector<scf::ForOp>> loops =
2092+
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
2093+
if (failed(loops))
2094+
return emitDefaultDefiniteFailure(target);
2095+
return DiagnosedSilenceableFailure::success();
2096+
}
2097+
21142098
//===----------------------------------------------------------------------===//
21152099
// RewriteInDestinationPassingStyleOp
21162100
//===----------------------------------------------------------------------===//
@@ -2620,7 +2604,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
26202604
}
26212605

26222606
scf::SCFTilingOptions tilingOptions;
2623-
if (!tileSizes.empty()) {
2607+
if (tileSizes.empty()) {
2608+
tilingOptions.setTileSizeComputationFunction(
2609+
[](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
2610+
return {};
2611+
});
2612+
} else {
26242613
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
26252614
Operation *) {
26262615
SmallVector<OpFoldResult> sizes;

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
283283
// 1. Get the range of the loops that are represented by the operation.
284284
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
285285
size_t numLoops = iterationDomain.size();
286-
if (numLoops == 0) {
287-
return rewriter.notifyMatchFailure(
288-
op, "unable to tile op with no iteration domain");
289-
}
286+
290287
// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
291288
// skips tiling a particular dimension. This convention is significantly
292289
// simpler to handle instead of adjusting affine maps to account for missing

mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1-
// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s
22

33
func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
44
%arg2 : memref<?x?xf32>) {
55
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
66
outs(%arg2 : memref<?x?xf32>)
77
return
88
}
9+
10+
module attributes {transform.with_named_sequence} {
11+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
12+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
13+
: (!transform.any_op) -> !transform.any_op
14+
transform.structured.convert_to_loops %matmul : !transform.any_op
15+
transform.yield
16+
}
17+
}
918
// CHECK-LABEL: func @gemm
1019
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
1120
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
1221
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
1322
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
14-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1523
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
24+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1625
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
1726
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
1827
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
@@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
5160
}
5261
return
5362
}
63+
64+
module attributes {transform.with_named_sequence} {
65+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
66+
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
67+
: (!transform.any_op) -> !transform.any_op
68+
transform.structured.convert_to_loops %generic : !transform.any_op
69+
transform.yield
70+
}
71+
}
5472
// CHECK-LABEL: func @indexed_generic
5573
// CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32>
5674
// CHECK-SAME: %[[ARG1:.+]]: memref<300xi16>
@@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
87105
outs(%arg2 : memref<?x?x?x?xf32>)
88106
return
89107
}
90-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
91-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
108+
109+
module attributes {transform.with_named_sequence} {
110+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
111+
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
112+
: (!transform.any_op) -> !transform.any_op
113+
transform.structured.convert_to_loops %conv : !transform.any_op
114+
transform.yield
115+
}
116+
}
117+
118+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
119+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
92120
// CHECK: func @conv_strides_and_dilation(
93121
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
94122
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
@@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
111139
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
112140
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
113141
// CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
114-
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
115-
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
142+
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
143+
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
116144
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
117145
// CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
118146
// CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
@@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
131159
outs(%arg2 : memref<?x?x?x?xf32>)
132160
return
133161
}
134-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
135-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
162+
163+
module attributes {transform.with_named_sequence} {
164+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
165+
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
166+
: (!transform.any_op) -> !transform.any_op
167+
transform.structured.convert_to_loops %pool : !transform.any_op
168+
transform.yield
169+
}
170+
}
171+
172+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
173+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
136174
// CHECK: func @pool_strides_and_dilation
137175
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
138176
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
153191
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
154192
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
155193
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
156-
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
157-
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
194+
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
195+
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
158196
// CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
159197
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
160198
// CHECK: %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
@@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>,
172210
}
173211
return
174212
}
213+
214+
module attributes {transform.with_named_sequence} {
215+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
216+
%map = transform.structured.match ops{["linalg.map"]} in %arg1
217+
: (!transform.any_op) -> !transform.any_op
218+
transform.structured.convert_to_loops %map : !transform.any_op
219+
transform.yield
220+
}
221+
}
175222
// CHECK-LABEL: func.func @map(
176223
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
177224
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
@@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>,
195242
outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
196243
return
197244
}
245+
246+
module attributes {transform.with_named_sequence} {
247+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
248+
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
249+
: (!transform.any_op) -> !transform.any_op
250+
transform.structured.convert_to_loops %transpose : !transform.any_op
251+
transform.yield
252+
}
253+
}
198254
// CHECK-LABEL: func.func @transpose(
199255
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
200256
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)
@@ -223,6 +279,15 @@ func.func @reduce(%arg0: memref<16x32x64xf32>,
223279
}
224280
return
225281
}
282+
283+
module attributes {transform.with_named_sequence} {
284+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
285+
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
286+
: (!transform.any_op) -> !transform.any_op
287+
transform.structured.convert_to_loops %reduce : !transform.any_op
288+
transform.yield
289+
}
290+
}
226291
// CHECK-LABEL: func.func @reduce(
227292
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
228293
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>
@@ -251,6 +316,15 @@ func.func @broadcast(%input: memref<8x32xf32>,
251316
dimensions = [1]
252317
func.return
253318
}
319+
320+
module attributes {transform.with_named_sequence} {
321+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
322+
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
323+
: (!transform.any_op) -> !transform.any_op
324+
transform.structured.convert_to_loops %broadcast : !transform.any_op
325+
transform.yield
326+
}
327+
}
254328
// CHECK-LABEL: func.func @broadcast(
255329
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
256330
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>

0 commit comments

Comments
 (0)