Skip to content

Commit 631e227

Browse files
committed
[mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
1 parent bbe40b9 commit 631e227

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24672467
: APFloat::getInf(semantic, /*Negative=*/true);
24682468
return builder.getFloatAttr(resultType, identity);
24692469
}
2470+
case AtomicRMWKind::maxnumf: {
2471+
const llvm::fltSemantics &semantic =
2472+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2473+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2474+
return builder.getFloatAttr(resultType, identity);
2475+
}
24702476
case AtomicRMWKind::addf:
24712477
case AtomicRMWKind::addi:
24722478
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24892495

24902496
return builder.getFloatAttr(resultType, identity);
24912497
}
2498+
case AtomicRMWKind::minnumf: {
2499+
const llvm::fltSemantics &semantic =
2500+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2501+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2502+
return builder.getFloatAttr(resultType, identity);
2503+
}
24922504
case AtomicRMWKind::mins:
24932505
return builder.getIntegerAttr(
24942506
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
25182530
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
25192531
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
25202532
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2533+
.Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2534+
.Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
25212535
// Integer operations.
25222536
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
25232537
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })

mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,50 @@ module attributes {transform.with_named_sequence} {
407407
transform.yield
408408
}
409409
}
410+
411+
// -----
412+
413+
// Checks we use nan as the neutral element for maxnumf op.
414+
func.func @generic_split_maxnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
415+
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
416+
affine_map<(d0) -> ()>],
417+
iterator_types = ["reduction"]}
418+
ins(%in : tensor<32xf32>)
419+
outs(%out : tensor<f32>) {
420+
^bb0(%arg1: f32, %arg2: f32):
421+
%y = arith.maxnumf %arg1, %arg2 : f32
422+
linalg.yield %y : f32
423+
} -> tensor<f32>
424+
return %r : tensor<f32>
425+
}
426+
427+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
428+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
429+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
430+
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
431+
// CHECK-LABEL: func @generic_split_maxnumf
432+
// The float value 0xFFC00000 that is filled into the init tensor represents NaN.
433+
// CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
434+
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
435+
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
436+
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
437+
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
438+
// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
439+
// CHECK: arith.maxnumf
440+
// CHECK: linalg.yield
441+
// CHECK: } -> tensor<4xf32>
442+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
443+
// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
444+
// CHECK: arith.maxnumf {{.*}}
445+
// CHECK: linalg.yield
446+
// CHECK: } -> tensor<f32>
447+
// CHECK: return %[[R]] : tensor<f32>
448+
449+
module attributes {transform.with_named_sequence} {
450+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
451+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
452+
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
453+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
454+
transform.yield
455+
}
456+
}

0 commit comments

Comments
 (0)