1
- // RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
1
+ // RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s
2
2
3
3
func.func @gemm (%arg0 : memref <?x?xf32 >, %arg1 : memref <?x?xf32 >,
4
4
%arg2 : memref <?x?xf32 >) {
5
5
linalg.matmul ins (%arg0 , %arg1 : memref <?x?xf32 >, memref <?x?xf32 >)
6
6
outs (%arg2 : memref <?x?xf32 >)
7
7
return
8
8
}
9
+
10
+ module attributes {transform.with_named_sequence } {
11
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
12
+ %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %arg1
13
+ : (!transform.any_op ) -> !transform.any_op
14
+ transform.structured.convert_to_loops %matmul : !transform.any_op
15
+ transform.yield
16
+ }
17
+ }
9
18
// CHECK-LABEL: func @gemm
10
19
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
11
20
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
12
21
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
13
22
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
14
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
15
23
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
24
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
16
25
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
17
26
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
18
27
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
@@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
51
60
}
52
61
return
53
62
}
63
+
64
+ module attributes {transform.with_named_sequence } {
65
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
66
+ %generic = transform.structured.match ops {[" linalg.generic" ]} in %arg1
67
+ : (!transform.any_op ) -> !transform.any_op
68
+ transform.structured.convert_to_loops %generic : !transform.any_op
69
+ transform.yield
70
+ }
71
+ }
54
72
// CHECK-LABEL: func @indexed_generic
55
73
// CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32>
56
74
// CHECK-SAME: %[[ARG1:.+]]: memref<300xi16>
@@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
87
105
outs (%arg2 : memref <?x?x?x?xf32 >)
88
106
return
89
107
}
90
- // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
91
- // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
108
+
109
+ module attributes {transform.with_named_sequence } {
110
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
111
+ %conv = transform.structured.match ops {[" linalg.conv_2d_nhwc_hwcf" ]} in %arg1
112
+ : (!transform.any_op ) -> !transform.any_op
113
+ transform.structured.convert_to_loops %conv : !transform.any_op
114
+ transform.yield
115
+ }
116
+ }
117
+
118
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
119
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
92
120
// CHECK: func @conv_strides_and_dilation(
93
121
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
94
122
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
@@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
111
139
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
112
140
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
113
141
// CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
114
- // CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[ IV1]], %[[IV2]], %[[IV3]], %[[ IV4]], %[[IV5]], %[[IV6 ]])
115
- // CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[ IV2]], %[[IV3]], %[[IV4]], %[[ IV5]], %[[IV6 ]])
142
+ // CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
143
+ // CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
116
144
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
117
145
// CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
118
146
// CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
@@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
131
159
outs (%arg2 : memref <?x?x?x?xf32 >)
132
160
return
133
161
}
134
- // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
135
- // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
162
+
163
+ module attributes {transform.with_named_sequence } {
164
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
165
+ %pool = transform.structured.match ops {[" linalg.pooling_nhwc_max" ]} in %arg1
166
+ : (!transform.any_op ) -> !transform.any_op
167
+ transform.structured.convert_to_loops %pool : !transform.any_op
168
+ transform.yield
169
+ }
170
+ }
171
+
172
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
173
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
136
174
// CHECK: func @pool_strides_and_dilation
137
175
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
138
176
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
153
191
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
154
192
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
155
193
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
156
- // CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[ IV1]], %[[IV2]], %[[IV3]], %[[ IV4]], %[[IV5 ]])
157
- // CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[ IV2]], %[[IV3]], %[[IV4 ]], %[[IV5]])
194
+ // CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
195
+ // CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
158
196
// CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
159
197
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
160
198
// CHECK: %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
@@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>,
172
210
}
173
211
return
174
212
}
213
+
214
+ module attributes {transform.with_named_sequence } {
215
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
216
+ %map = transform.structured.match ops {[" linalg.map" ]} in %arg1
217
+ : (!transform.any_op ) -> !transform.any_op
218
+ transform.structured.convert_to_loops %map : !transform.any_op
219
+ transform.yield
220
+ }
221
+ }
175
222
// CHECK-LABEL: func.func @map(
176
223
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
177
224
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
@@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>,
195
242
outs (%arg1 : memref <32 x64 x16 xf32 >) permutation = [1 , 2 , 0 ]
196
243
return
197
244
}
245
+
246
+ module attributes {transform.with_named_sequence } {
247
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
248
+ %transpose = transform.structured.match ops {[" linalg.transpose" ]} in %arg1
249
+ : (!transform.any_op ) -> !transform.any_op
250
+ transform.structured.convert_to_loops %transpose : !transform.any_op
251
+ transform.yield
252
+ }
253
+ }
198
254
// CHECK-LABEL: func.func @transpose(
199
255
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
200
256
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)
@@ -223,6 +279,15 @@ func.func @reduce(%arg0: memref<16x32x64xf32>,
223
279
}
224
280
return
225
281
}
282
+
283
+ module attributes {transform.with_named_sequence } {
284
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
285
+ %reduce = transform.structured.match ops {[" linalg.reduce" ]} in %arg1
286
+ : (!transform.any_op ) -> !transform.any_op
287
+ transform.structured.convert_to_loops %reduce : !transform.any_op
288
+ transform.yield
289
+ }
290
+ }
226
291
// CHECK-LABEL: func.func @reduce(
227
292
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
228
293
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>
@@ -251,6 +316,15 @@ func.func @broadcast(%input: memref<8x32xf32>,
251
316
dimensions = [1 ]
252
317
func.return
253
318
}
319
+
320
+ module attributes {transform.with_named_sequence } {
321
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
322
+ %broadcast = transform.structured.match ops {[" linalg.broadcast" ]} in %arg1
323
+ : (!transform.any_op ) -> !transform.any_op
324
+ transform.structured.convert_to_loops %broadcast : !transform.any_op
325
+ transform.yield
326
+ }
327
+ }
254
328
// CHECK-LABEL: func.func @broadcast(
255
329
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
256
330
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>
0 commit comments