Skip to content

Commit 0981dca

Browse files
authored
[mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (#93278)
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
1 parent 14dc97d commit 0981dca

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24672467
: APFloat::getInf(semantic, /*Negative=*/true);
24682468
return builder.getFloatAttr(resultType, identity);
24692469
}
2470+
case AtomicRMWKind::maxnumf: {
2471+
const llvm::fltSemantics &semantic =
2472+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2473+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2474+
return builder.getFloatAttr(resultType, identity);
2475+
}
24702476
case AtomicRMWKind::addf:
24712477
case AtomicRMWKind::addi:
24722478
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24892495

24902496
return builder.getFloatAttr(resultType, identity);
24912497
}
2498+
case AtomicRMWKind::minnumf: {
2499+
const llvm::fltSemantics &semantic =
2500+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2501+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2502+
return builder.getFloatAttr(resultType, identity);
2503+
}
24922504
case AtomicRMWKind::mins:
24932505
return builder.getIntegerAttr(
24942506
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
25182530
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
25192531
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
25202532
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2533+
.Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2534+
.Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
25212535
// Integer operations.
25222536
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
25232537
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })

mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,95 @@ module attributes {transform.with_named_sequence} {
407407
transform.yield
408408
}
409409
}
410+
411+
// -----
412+
// Checks we use nan as the neutral element for maxnumf op.
413+
func.func @generic_split_maxnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
414+
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
415+
affine_map<(d0) -> ()>],
416+
iterator_types = ["reduction"]}
417+
ins(%in : tensor<32xf32>)
418+
outs(%out : tensor<f32>) {
419+
^bb0(%arg1: f32, %arg2: f32):
420+
%y = arith.maxnumf %arg1, %arg2 : f32
421+
linalg.yield %y : f32
422+
} -> tensor<f32>
423+
return %r : tensor<f32>
424+
}
425+
426+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
427+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
428+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
429+
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
430+
// CHECK-LABEL: func @generic_split_maxnumf
431+
// The float value 0xFFC00000 that is filled into the init tensor represents negative NaN.
432+
// CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
433+
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
434+
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
435+
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
436+
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
437+
// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
438+
// CHECK: arith.maxnumf
439+
// CHECK: linalg.yield
440+
// CHECK: } -> tensor<4xf32>
441+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
442+
// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
443+
// CHECK: arith.maxnumf {{.*}}
444+
// CHECK: linalg.yield
445+
// CHECK: } -> tensor<f32>
446+
// CHECK: return %[[R]] : tensor<f32>
447+
448+
module attributes {transform.with_named_sequence} {
449+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
450+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
451+
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
452+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
453+
transform.yield
454+
}
455+
}
456+
457+
// -----
458+
// Checks we use nan as the neutral element for minnumf op.
459+
func.func @generic_split_minnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
460+
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
461+
affine_map<(d0) -> ()>],
462+
iterator_types = ["reduction"]}
463+
ins(%in : tensor<32xf32>)
464+
outs(%out : tensor<f32>) {
465+
^bb0(%arg1: f32, %arg2: f32):
466+
%y = arith.minnumf %arg1, %arg2 : f32
467+
linalg.yield %y : f32
468+
} -> tensor<f32>
469+
return %r : tensor<f32>
470+
}
471+
472+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
473+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
474+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
475+
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
476+
// CHECK-LABEL: func @generic_split_minnumf
477+
// The float value 0x7FC00000 that is filled into the init tensor represents positive NaN.
478+
// CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32
479+
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
480+
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
481+
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
482+
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
483+
// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
484+
// CHECK: arith.minnumf
485+
// CHECK: linalg.yield
486+
// CHECK: } -> tensor<4xf32>
487+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
488+
// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
489+
// CHECK: arith.minnumf {{.*}}
490+
// CHECK: linalg.yield
491+
// CHECK: } -> tensor<f32>
492+
// CHECK: return %[[R]] : tensor<f32>
493+
494+
module attributes {transform.with_named_sequence} {
495+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
496+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
497+
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
498+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
499+
transform.yield
500+
}
501+
}

0 commit comments

Comments
 (0)