-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf #93278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-arith Author: donald chen (cxy-1993) ChangesFor maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN. Full diff: https://github.com/llvm/llvm-project/pull/93278.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
: APFloat::getInf(semantic, /*Negative=*/true);
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::maxnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::minnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+ .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little bit confused -- the change seems to only be related to atomic operations but the PR description doesn't mention this. Aren't there any other places that rely on neutral elements for arith ops? There are no tests either.
This PR primarily supports the neutral element for MaxNumFOp and MinNumFOp, and adds initial values for the corresponding atomicKind without modifying the atomic operations. The current application scenario for the neutral element is to provide an initial value when tile the reduction axis in a linalg op. I did not add new tests because the implementation itself is straightforward and simple. Moreover, adding a linalg reduce test is quite far from this function, and the linalg reduce also doesn't test all the arithmetic operations for reduction. I'm not sure if you would accept this. If you insist on adding tests, I can add a linalg reduce unit test for this patch. |
Thanks for the context. Having unit tests would help here IMO. Right now it's not clear to me how to exercise this code and step through it. |
I have added tests according to your review feedback. Please help me review this code again when you have time, thanks. @kuhar |
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for adding the tests.
I do not have write access. Could you please help me merge this patch @kuhar |
Sure, no problem. LLVM strongly prefers to have a public email address added to your github instead of the default noreply github address. This is so that folks can reach you if there are issues with your commits. Could you add one to your github profile? |
Thanks for the reminder. I had previously set my email to private with no action. It's now visible. |
For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN.