@@ -407,3 +407,50 @@ module attributes {transform.with_named_sequence} {
407
407
transform.yield
408
408
}
409
409
}
410
+
411
+ // -----
412
+
413
+ // Checks we use nan as the neutral element for maxnumf op.
414
+ func.func @generic_split_maxnumf (%in: tensor <32 xf32 >, %out: tensor <f32 >) -> tensor <f32 > {
415
+ %r = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>,
416
+ affine_map <(d0 ) -> ()>],
417
+ iterator_types = [" reduction" ]}
418
+ ins (%in : tensor <32 xf32 >)
419
+ outs (%out : tensor <f32 >) {
420
+ ^bb0 (%arg1: f32 , %arg2: f32 ):
421
+ %y = arith.maxnumf %arg1 , %arg2 : f32
422
+ linalg.yield %y : f32
423
+ } -> tensor <f32 >
424
+ return %r : tensor <f32 >
425
+ }
426
+
427
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
428
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
429
+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
430
+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
431
+ // CHECK-LABEL: func @generic_split_maxnumf
432
+ // The float value 0xFFC00000 that is filled into the init tensor represents NaN.
433
+ // CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
434
+ // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
435
+ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
436
+ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
437
+ // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
438
+ // CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
439
+ // CHECK: arith.maxnumf
440
+ // CHECK: linalg.yield
441
+ // CHECK: } -> tensor<4xf32>
442
+ // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
443
+ // CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
444
+ // CHECK: arith.maxnumf {{.*}}
445
+ // CHECK: linalg.yield
446
+ // CHECK: } -> tensor<f32>
447
+ // CHECK: return %[[R]] : tensor<f32>
448
+
449
+ module attributes {transform.with_named_sequence } {
450
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
451
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
452
+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4 , insert_split_dimension = 0 , inner_parallel }
453
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
454
+ transform.yield
455
+ }
456
+ }
0 commit comments