@@ -299,3 +299,80 @@ module attributes {transform.with_named_sequence} {
299
299
// CHECK-NOT: scf.for
300
300
// CHECK: transform.named_sequence
301
301
302
+ // -----
303
+
304
+ // Check avoiding generating unnecessary operations while collapsing trip-1 loops.
305
+ func.func @trip_one_loops (%arg0 : tensor <?x?xf32 >, %arg1 : index , %arg2 : index ) -> tensor <?x?xf32 > {
306
+ %c0 = arith.constant 0 : index
307
+ %c1 = arith.constant 1 : index
308
+ %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args (%iter0 = %arg0 ) -> tensor <?x?xf32 > {
309
+ %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args (%iter1 = %iter0 ) -> tensor <?x?xf32 > {
310
+ %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args (%iter2 = %iter1 ) -> tensor <?x?xf32 > {
311
+ %3 = scf.for %iv3 = %c0 to %c1 step %c1 iter_args (%iter3 = %iter2 ) -> tensor <?x?xf32 > {
312
+ %4 = scf.for %iv4 = %c0 to %arg2 step %c1 iter_args (%iter4 = %iter3 ) -> tensor <?x?xf32 > {
313
+ %5 = " some_use" (%iter4 , %iv0 , %iv1 , %iv2 , %iv3 , %iv4 )
314
+ : (tensor <?x?xf32 >, index , index , index , index , index ) -> (tensor <?x?xf32 >)
315
+ scf.yield %5 : tensor <?x?xf32 >
316
+ }
317
+ scf.yield %4 : tensor <?x?xf32 >
318
+ }
319
+ scf.yield %3 : tensor <?x?xf32 >
320
+ }
321
+ scf.yield %2 : tensor <?x?xf32 >
322
+ }
323
+ scf.yield %1 : tensor <?x?xf32 >
324
+ } {coalesce }
325
+ return %0 : tensor <?x?xf32 >
326
+ }
327
+ module attributes {transform.with_named_sequence } {
328
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
329
+ %0 = transform.structured.match ops {[" scf.for" ]} attributes {coalesce } in %arg1 : (!transform.any_op ) -> !transform.any_op
330
+ %1 = transform.cast %0 : !transform.any_op to !transform.op <" scf.for" >
331
+ %2 = transform.loop.coalesce %1 : (!transform.op <" scf.for" >) -> (!transform.op <" scf.for" >)
332
+ transform.yield
333
+ }
334
+ }
335
+ // CHECK-LABEL: func @trip_one_loops
336
+ // CHECK-SAME: , %[[ARG1:.+]]: index,
337
+ // CHECK-SAME: %[[ARG2:.+]]: index)
338
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
339
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
340
+ // CHECK: %[[UB:.+]] = arith.muli %[[ARG1]], %[[ARG2]]
341
+ // CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
342
+ // CHECK: %[[IV1:.+]] = arith.remsi %[[IV]], %[[ARG2]]
343
+ // CHECK: %[[IV2:.+]] = arith.divsi %[[IV]], %[[ARG2]]
344
+ // CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
345
+
346
+ // -----
347
+
348
+ // Check generating no instructions when all except one loops is non unit-trip.
349
+ func.func @all_outer_trip_one (%arg0 : tensor <?x?xf32 >, %arg1 : index ) -> tensor <?x?xf32 > {
350
+ %c0 = arith.constant 0 : index
351
+ %c1 = arith.constant 1 : index
352
+ %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args (%iter0 = %arg0 ) -> tensor <?x?xf32 > {
353
+ %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args (%iter1 = %iter0 ) -> tensor <?x?xf32 > {
354
+ %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args (%iter2 = %iter1 ) -> tensor <?x?xf32 > {
355
+ %3 = " some_use" (%iter2 , %iv0 , %iv1 , %iv2 )
356
+ : (tensor <?x?xf32 >, index , index , index ) -> (tensor <?x?xf32 >)
357
+ scf.yield %3 : tensor <?x?xf32 >
358
+ }
359
+ scf.yield %2 : tensor <?x?xf32 >
360
+ }
361
+ scf.yield %1 : tensor <?x?xf32 >
362
+ } {coalesce }
363
+ return %0 : tensor <?x?xf32 >
364
+ }
365
+ module attributes {transform.with_named_sequence } {
366
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
367
+ %0 = transform.structured.match ops {[" scf.for" ]} attributes {coalesce } in %arg1 : (!transform.any_op ) -> !transform.any_op
368
+ %1 = transform.cast %0 : !transform.any_op to !transform.op <" scf.for" >
369
+ %2 = transform.loop.coalesce %1 : (!transform.op <" scf.for" >) -> (!transform.op <" scf.for" >)
370
+ transform.yield
371
+ }
372
+ }
373
+ // CHECK-LABEL: func @all_outer_trip_one
374
+ // CHECK-SAME: , %[[ARG1:.+]]: index)
375
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
376
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
377
+ // CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[ARG1]] step %[[C1]]
378
+ // CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV]])
0 commit comments