@@ -374,38 +374,27 @@ module attributes {transform.with_named_sequence} {
374
374
375
375
// -----
376
376
377
- #map = affine_map <(d0 ) -> (d0 * 128 )>
378
377
module {
379
- func.func @fuse_tilable_consumer_nested_scf_loop (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
378
+ func.func @fuse_add_consumer_into_nested_scf_for (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
380
379
%c0 = arith.constant 0 : index
381
380
%c64 = arith.constant 64 : index
382
- %c128 = arith.constant 128 : index
381
+ %c256 = arith.constant 256 : index
383
382
%cst = arith.constant 0.000000e+00 : f32
384
383
%dest0 = tensor.empty () : tensor <256 x256 xf32 >
385
384
%dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
386
- %1 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %dest1 ) -> tensor <256 x256 xf32 > {
387
- %iv0 = affine.apply #map (%arg3 )
388
- %iv1 = affine.apply #map (%arg4 )
389
- %extracted_slice_1 = tensor.extract_slice %arg5 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <128 x128 xf32 >
390
- %extracted_slice_2 = tensor.extract_slice %arg0 [%iv0 , 0 ] [128 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <128 x512 xf32 >
391
- %extracted_slice_3 = tensor.extract_slice %arg1 [0 , %iv1 ] [512 , 128 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x128 xf32 >
392
- %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args (%arg7 = %extracted_slice_1 ) -> (tensor <128 x128 xf32 >) {
393
- %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args (%arg9 = %arg7 ) -> (tensor <128 x128 xf32 >) {
394
- %extracted_slice_4 = tensor.extract_slice %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <128 x128 xf32 > to tensor <64 x64 xf32 >
395
- %extracted_slice_5 = tensor.extract_slice %extracted_slice_2 [%arg6 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <128 x512 xf32 > to tensor <64 x512 xf32 >
396
- %extracted_slice_6 = tensor.extract_slice %extracted_slice_3 [0 , %arg8 ] [512 , 64 ] [1 , 1 ] : tensor <512 x128 xf32 > to tensor <512 x64 xf32 >
397
- %4 = linalg.matmul ins (%extracted_slice_5 , %extracted_slice_6 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice_4 : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
398
- %insert_slice = tensor.insert_slice %4 into %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <128 x128 xf32 >
399
- scf.yield %insert_slice : tensor <128 x128 xf32 >
400
- }
401
- scf.yield %3 : tensor <128 x128 xf32 >
402
- }
403
- scf.forall.in_parallel {
404
- tensor.parallel_insert_slice %2 into %arg5 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <128 x128 xf32 > into tensor <256 x256 xf32 >
385
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %dest1 ) -> (tensor <256 x256 xf32 >) {
386
+ %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args (%arg6 = %arg4 ) -> (tensor <256 x256 xf32 >) {
387
+ %extracted_slice_1 = tensor.extract_slice %arg6 [%arg3 , %arg5 ] [64 , 64 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x64 xf32 >
388
+ %extracted_slice_2 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <64 x512 xf32 >
389
+ %extracted_slice_3 = tensor.extract_slice %arg1 [0 , %arg5 ] [512 , 64 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x64 xf32 >
390
+ %3 = linalg.matmul ins (%extracted_slice_2 , %extracted_slice_3 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice_1 : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
391
+ %insert_slice = tensor.insert_slice %3 into %arg6 [%arg3 , %arg5 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <256 x256 xf32 >
392
+ scf.yield %insert_slice : tensor <256 x256 xf32 >
405
393
}
394
+ scf.yield %2 : tensor <256 x256 xf32 >
406
395
}
407
- %5 = linalg.add ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
408
- return %5 : tensor <256 x256 xf32 >
396
+ %4 = linalg.add ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
397
+ return %4 : tensor <256 x256 xf32 >
409
398
}
410
399
}
411
400
@@ -418,49 +407,33 @@ module attributes {transform.with_named_sequence} {
418
407
transform.yield
419
408
}
420
409
}
421
- // CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
422
- // CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
410
+ // CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
423
411
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
424
412
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
425
413
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
426
414
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
427
415
// CHECK: %[[dest1:.*]] = linalg.fill
428
416
// CHECK-SAME: outs(%[[dest0]] :
429
- // CHECK: %[[FINAL_RESULT :.*]]:2 = scf.forall ( %[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
430
- // CHECK-SAME: shared_outs (%[[FIRST_OUT_ARG0 :.*]] = %[[dest1]], %[[SECOND_OUT_ARG0 :.*]] = %[[dest0]])
417
+ // CHECK: %[[LOOP_RESULT1 :.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
418
+ // CHECK-SAME: iter_args (%[[FIRST_OUT_ARG1 :.*]] = %[[dest1]], %[[SECOND_OUT_ARG1 :.*]] = %[[dest0]])
431
419
// CHECK-SAME: {
432
- // CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
433
- // CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
434
- // CHECK: %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
435
- // CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
436
- // CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
437
- // CHECK: %[[ADD_OPERAND2_SLICE0:.*]] = tensor.extract_slice %[[ARG2]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
438
- // CHECK: %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
439
- // CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
440
- // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
441
- // CHECK-SAME: {
442
- // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
443
- // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
444
- // CHECK-SAME: {
445
- // CHECK: %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
446
- // CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
447
- // CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
420
+ // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
421
+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
422
+ // CHECK-SAME: {
423
+ // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
424
+ // CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
425
+ // CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
448
426
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
449
- // CHECK-SAME: outs(%[[MAT_OUT_SLICE1 ]] :
450
- // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3 ]], %[[IV4 ]]] [64, 64] [1, 1]
451
- // CHECK: %[[ADD_OPERAND2_SLICE1 :.*]] = tensor.extract_slice %[[ADD_OPERAND2_SLICE0 ]][%[[IV3 ]], %[[IV4 ]]] [64, 64] [1, 1]
452
- // CHECK: %[[ADD_OUT_SLICE1 :.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3 ]], %[[IV4 ]]] [64, 64] [1, 1]
427
+ // CHECK-SAME: outs(%[[MAT_OUT_SLICE ]] :
428
+ // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1 ]], %[[IV2 ]]] [64, 64] [1, 1]
429
+ // CHECK: %[[ADD_OPERAND2_SLICE :.*]] = tensor.extract_slice %[[ARG2 ]][%[[IV1 ]], %[[IV2 ]]] [64, 64] [1, 1]
430
+ // CHECK: %[[ADD_OUT_SLICE :.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1 ]], %[[IV2 ]]] [64, 64] [1, 1]
453
431
// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
454
- // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE1 ]] :
455
- // CHECK-SAME: outs(%[[ADD_OUT_SLICE1 ]] :
456
- // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3 ]], %[[IV4 ]]] [64, 64] [1, 1]
432
+ // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE ]] :
433
+ // CHECK-SAME: outs(%[[ADD_OUT_SLICE ]] :
434
+ // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1 ]], %[[IV2 ]]] [64, 64] [1, 1]
457
435
// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
458
- // CHECK: }
459
- // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
460
- // CHECK: }
461
- // CHECK: scf.forall.in_parallel {
462
- // CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
463
- // CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
464
- // CHECK: }
436
+ // CHECK: }
437
+ // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
465
438
// CHECK: }
466
- // CHECK: return %[[FINAL_RESULT ]]#1 :
439
+ // CHECK: return %[[LOOP_RESULT1 ]]#1 :
0 commit comments