@@ -407,3 +407,95 @@ module attributes {transform.with_named_sequence} {
407
407
transform.yield
408
408
}
409
409
}
410
+
411
+ // -----
412
+ // Checks we use nan as the neutral element for maxnumf op.
413
+ func.func @generic_split_maxnumf (%in: tensor <32 xf32 >, %out: tensor <f32 >) -> tensor <f32 > {
414
+ %r = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>,
415
+ affine_map <(d0 ) -> ()>],
416
+ iterator_types = [" reduction" ]}
417
+ ins (%in : tensor <32 xf32 >)
418
+ outs (%out : tensor <f32 >) {
419
+ ^bb0 (%arg1: f32 , %arg2: f32 ):
420
+ %y = arith.maxnumf %arg1 , %arg2 : f32
421
+ linalg.yield %y : f32
422
+ } -> tensor <f32 >
423
+ return %r : tensor <f32 >
424
+ }
425
+
426
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
427
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
428
+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
429
+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
430
+ // CHECK-LABEL: func @generic_split_maxnumf
431
+ // The float value 0xFFC00000 that is filled into the init tensor represents negative NaN.
432
+ // CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
433
+ // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
434
+ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
435
+ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
436
+ // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
437
+ // CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
438
+ // CHECK: arith.maxnumf
439
+ // CHECK: linalg.yield
440
+ // CHECK: } -> tensor<4xf32>
441
+ // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
442
+ // CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
443
+ // CHECK: arith.maxnumf {{.*}}
444
+ // CHECK: linalg.yield
445
+ // CHECK: } -> tensor<f32>
446
+ // CHECK: return %[[R]] : tensor<f32>
447
+
448
+ module attributes {transform.with_named_sequence } {
449
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
450
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
451
+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4 , insert_split_dimension = 0 , inner_parallel }
452
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
453
+ transform.yield
454
+ }
455
+ }
456
+
457
+ // -----
458
+ // Checks we use nan as the neutral element for minnumf op.
459
+ func.func @generic_split_minnumf (%in: tensor <32 xf32 >, %out: tensor <f32 >) -> tensor <f32 > {
460
+ %r = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>,
461
+ affine_map <(d0 ) -> ()>],
462
+ iterator_types = [" reduction" ]}
463
+ ins (%in : tensor <32 xf32 >)
464
+ outs (%out : tensor <f32 >) {
465
+ ^bb0 (%arg1: f32 , %arg2: f32 ):
466
+ %y = arith.minnumf %arg1 , %arg2 : f32
467
+ linalg.yield %y : f32
468
+ } -> tensor <f32 >
469
+ return %r : tensor <f32 >
470
+ }
471
+
472
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
473
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
474
+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
475
+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
476
+ // CHECK-LABEL: func @generic_split_minnumf
477
+ // The float value 0x7FC00000 that is filled into the init tensor represents positive NaN.
478
+ // CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32
479
+ // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
480
+ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
481
+ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
482
+ // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
483
+ // CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
484
+ // CHECK: arith.minnumf
485
+ // CHECK: linalg.yield
486
+ // CHECK: } -> tensor<4xf32>
487
+ // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
488
+ // CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
489
+ // CHECK: arith.minnumf {{.*}}
490
+ // CHECK: linalg.yield
491
+ // CHECK: } -> tensor<f32>
492
+ // CHECK: return %[[R]] : tensor<f32>
493
+
494
+ module attributes {transform.with_named_sequence } {
495
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
496
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
497
+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4 , insert_split_dimension = 0 , inner_parallel }
498
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
499
+ transform.yield
500
+ }
501
+ }
0 commit comments