Skip to content

Commit fb0aa3f

Browse files
committed
[mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
1 parent bbe40b9 commit fb0aa3f

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24672467
: APFloat::getInf(semantic, /*Negative=*/true);
24682468
return builder.getFloatAttr(resultType, identity);
24692469
}
2470+
case AtomicRMWKind::maxnumf: {
2471+
const llvm::fltSemantics &semantic =
2472+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2473+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2474+
return builder.getFloatAttr(resultType, identity);
2475+
}
24702476
case AtomicRMWKind::addf:
24712477
case AtomicRMWKind::addi:
24722478
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
24892495

24902496
return builder.getFloatAttr(resultType, identity);
24912497
}
2498+
case AtomicRMWKind::minnumf: {
2499+
const llvm::fltSemantics &semantic =
2500+
llvm::cast<FloatType>(resultType).getFloatSemantics();
2501+
APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2502+
return builder.getFloatAttr(resultType, identity);
2503+
}
24922504
case AtomicRMWKind::mins:
24932505
return builder.getIntegerAttr(
24942506
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
25182530
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
25192531
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
25202532
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2533+
.Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2534+
.Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
25212535
// Integer operations.
25222536
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
25232537
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })

0 commit comments

Comments
 (0)