-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][arith] Fix arith maxnumf/minnumf folder #114595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Clément Fournier (oowekyala) ChangesFix #114594 Full diff: https://github.com/llvm/llvm-project/pull/114595.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..d218206e50f8f1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
}};
}
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+ return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat float positive
/// infinity.
inline detail::constant_float_predicate_matcher m_PosInfFloat() {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..7734911e1e01a7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1014,13 +1014,14 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // maxnumf(x, -inf) -> x
- if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+ // maxnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
- return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+ return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) {
+ return llvm::maximumnum(a, b);
+ });
}
//===----------------------------------------------------------------------===//
@@ -1100,8 +1101,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- // minnumf(x, +inf) -> x
- if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+ // minnumf(x, NaN) -> x
+ if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..84f2b0f113a0c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// -----
// CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[INF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.minnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
- // CHECK-DAG: %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-DAG: %[[NINF:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
- // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ // CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+ // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
+ %nan = arith.constant 0x7FC00000 : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
- return %0, %1, %2 : f32, f32, f32
+ %3 = arith.maxnumf %nan, %arg0 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an explanation that clarifies what was broken, how it was fixed, and reference the relevant spec. These min/max intrinsic can be very trickly and there are many of them...
@kuhar I updated the PR description Just finishing this and I noticed another edge case for more tedium... There are in fact 2 different implementations for IEEE754-2019's I'm not sure which of those to use. The IEEE spec says the following:
The spec of the arith ops does not specify anything regarding the handling of signalling NaNs (and in fact does not say either that it follows IEEE754-2019 spec). Should the part about "signalling an invalid operation exception" be implemented in arith? And if so, how? Should we fold to a poison value? Or just prevent folding in this case and let lower-level dialects handle this? If we don't care about signalling NaNs then I think we should use the APFloat functions that produce quiet NaNs consistently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, we don't care about signaling NaNs but should otherwise distinguish between the 3 versions of max/min.
Wouldn't users like Flang care about signaling NaNs? |
I meant this in the context of constants in the IR -- IIRC there's no way to represent a signaling NaN? |
A relevant thread: https://discourse.llvm.org/t/semantics-of-nan/66729 |
a821964
to
4d4a8b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but it would be nice to have someone more familiar with float semantics confirm the signaling NaN behavior
I think we can land this |
@oowekyala Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
The decomposition of `linalg.softmax` uses `maxnumf`, but the identity element that is used in the generated code is the one for `maximumf`. They are not the same, as the identity for `maxnumf` is `NaN`, while the one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf from being folded. Related to #114595, which fixed the folder for maxnumf.
The decomposition of `linalg.softmax` uses `maxnumf`, but the identity element that is used in the generated code is the one for `maximumf`. They are not the same, as the identity for `maxnumf` is `NaN`, while the one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf from being folded. Related to llvm#114595, which fixed the folder for maxnumf.
Fix #114594
Context
IEEE754-2019 Sec 9.6 defines 2 minimum and 2 maximum operations. They are termed
maximum
andmaximumNumber
minimum
andminimumNumber
In the arith dialect they are respectively named
maximumf
andmaxnumf
,minimumf
andminnumf
so I use these names.These operations only differ in how they handle NaN values. For
maximumf
andminimumf
, if any operand is NaN, then the result is NaN, ie, NaN is propagated. Formaxnumf
andminnumf
, if any operand is NaN, then the other operand is returned, ie, NaN is absorbed. The following identities hold:(and same for min).
Arith folders
In the following I am talking about the folders for the arith operations. The folders implement the following canonicalizations (
op
is one of maximumf, maxnumf, minimumf, minnumf):op(x, x)
folds tox
op(x, y)
, ify
folds to the neutral element of theop
, then theop
is folded tox
.maximumf
is -Inftyminimumf
is +Inftymaxnumf
andminnumf
is NaN as shown above.op(x, y)
, if bothx
andy
fold to constantsx'
andy'
, then theop
is folded and the result is calculated with a corresponding runtime function.The folders are properly implemented for
maximumf
andminimumf
, but the same implementations were copied for the respectivemaxnumf
andminnumf
functions. This means the neutral element of the second folder above is wrong:maxnumf(x, -Infty)
is folded tox
, but that's wrong, because ifx
is NaN then -Infty should be the resultminnumf(x, +Infty)
is folded tox
, but same thing, the result should be +Infty whenx
is NaN.This is fixed by using
NaN
as neutral element for themaxnumf
andminnumf
ops.1Again because of copy paste mistake, the third pattern above is using
llvm::maximum
instead ofllvm::maximumnum
to calculate the result in case both arguments fold to a constant:maxnumf(NaN, x')
would have been folded tollvm::maximum(NaN, x')
which isNaN
, whereas the result should bex'
.This folder for
minnumf
already correctly usesllvm::minnum
, but I fixed the one formaxnumf
in this PR.Footnotes
this is by the way already correctly implemented in
arith::getIdentityValueAttr
↩