@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2467
2467
: APFloat::getInf (semantic, /* Negative=*/ true );
2468
2468
return builder.getFloatAttr (resultType, identity);
2469
2469
}
2470
+ case AtomicRMWKind::maxnumf: {
2471
+ const llvm::fltSemantics &semantic =
2472
+ llvm::cast<FloatType>(resultType).getFloatSemantics ();
2473
+ APFloat identity = APFloat::getNaN (semantic, /* Negative=*/ true );
2474
+ return builder.getFloatAttr (resultType, identity);
2475
+ }
2470
2476
case AtomicRMWKind::addf:
2471
2477
case AtomicRMWKind::addi:
2472
2478
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2489
2495
2490
2496
return builder.getFloatAttr (resultType, identity);
2491
2497
}
2498
+ case AtomicRMWKind::minnumf: {
2499
+ const llvm::fltSemantics &semantic =
2500
+ llvm::cast<FloatType>(resultType).getFloatSemantics ();
2501
+ APFloat identity = APFloat::getNaN (semantic, /* Negative=*/ false );
2502
+ return builder.getFloatAttr (resultType, identity);
2503
+ }
2492
2504
case AtomicRMWKind::mins:
2493
2505
return builder.getIntegerAttr (
2494
2506
resultType, APInt::getSignedMaxValue (
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2518
2530
.Case ([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2519
2531
.Case ([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2520
2532
.Case ([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2533
+ .Case ([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2534
+ .Case ([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2521
2535
// Integer operations.
2522
2536
.Case ([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2523
2537
.Case ([](arith::OrIOp op) { return AtomicRMWKind::ori; })
0 commit comments