Skip to content

[mlir][arith] Fix arith maxnumf/minnumf folder #114595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2024

Conversation

oowekyala
Copy link
Contributor

@oowekyala oowekyala commented Nov 1, 2024

Fix #114594

Context

IEEE754-2019 Sec 9.6 defines 2 minimum and 2 maximum operations. They are termed

  • maximum and maximumNumber
  • minimum and minimumNumber

In the arith dialect they are respectively named maximumf and maxnumf, minimumf and minnumf so I use these names.

These operations only differ in how they handle NaN values. For maximumf and minimumf, if any operand is NaN, then the result is NaN, ie, NaN is propagated. For maxnumf and minnumf, if any operand is NaN, then the other operand is returned, ie, NaN is absorbed. The following identities hold:

maximumf(x, NaN) = maximumf(NaN, x) = NaN
maxnumf(x, NaN) = maxnumf(NaN, x) = x

(and same for min).

Arith folders

In the following I am talking about the folders for the arith operations. The folders implement the following canonicalizations (op is one of maximumf, maxnumf, minimumf, minnumf):

  1. op(x, x) folds to x
  2. for op(x, y), if y folds to the neutral element of the op, then the op is folded to x.
    1. The neutral element of maximumf is -Infty
    2. The neutral element of minimumf is +Infty
    3. The neutral element of maxnumf and minnumf is NaN as shown above.
  3. for op(x, y), if both x and y fold to constants x' and y', then the op is folded and the result is calculated with a corresponding runtime function.

The folders are properly implemented for maximumf and minimumf, but the same implementations were copied for the respective maxnumf and minnumf functions. This means the neutral element of the second folder above is wrong:

  • maxnumf(x, -Infty) is folded to x, but that's wrong, because if x is NaN then -Infty should be the result
  • minnumf(x, +Infty) is folded to x, but same thing, the result should be +Infty when x is NaN.

This is fixed by using NaN as neutral element for the maxnumf and minnumf ops.1

Again because of copy paste mistake, the third pattern above is using llvm::maximum instead of llvm::maximumnum to calculate the result in case both arguments fold to a constant:

  • maxnumf(NaN, x') would have been folded to llvm::maximum(NaN, x') which is NaN, whereas the result should be x'.

This folder for minnumf already correctly uses llvm::minnum, but I fixed the one for maxnumf in this PR.

Footnotes

  1. this is by the way already correctly implemented in arith::getIdentityValueAttr

Copy link

github-actions bot commented Nov 1, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:arith labels Nov 1, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir-core

Author: Clément Fournier (oowekyala)

Changes

Fix #114594


Full diff: https://github.com/llvm/llvm-project/pull/114595.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/Matchers.h (+5)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+8-7)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+15-7)
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6fa5a47109d20d..d218206e50f8f1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
   }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_NaNFloat() {
+  return {[](const APFloat &value) { return value.isNaN(); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat float positive
 /// infinity.
 inline detail::constant_float_predicate_matcher m_PosInfFloat() {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..7734911e1e01a7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1014,13 +1014,14 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxnumf(x, -inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+  // maxnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
-  return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
+                                      [](const APFloat &a, const APFloat &b) {
+                                        return llvm::maximumnum(a, b);
+                                      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1100,8 +1101,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minnumf(x, +inf) -> x
-  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+  // minnumf(x, NaN) -> x
+  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..84f2b0f113a0c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
 // -----
 
 // CHECK-LABEL: @test_minnumf(
-func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[INF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
-  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %inf = arith.constant 0x7F800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.minnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----
 
 // CHECK-LABEL: @test_maxnumf(
-func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
-  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG:   %[[NINF:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
-  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  // CHECK-NEXT:  %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %[[Y]], %arg0
   %c0 = arith.constant 0.0 : f32
   %-inf = arith.constant 0xFF800000 : f32
+  %nan = arith.constant 0x7FC00000 : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %3 = arith.maxnumf %nan, %arg0 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an explanation that clarifies what was broken, how it was fixed, and reference the relevant spec. These min/max intrinsic can be very trickly and there are many of them...

@oowekyala
Copy link
Contributor Author

@kuhar I updated the PR description

Just finishing this and I noticed another edge case for more tedium...

There are in fact 2 different implementations for IEEE754-2019's maximumNumber and minimumNumber in LLVM's APFloat.h: maximumnum and maxnum (and same for minimum). They are not documented as such, but they differ in that one returns a quiet NaN if both operands are NaN (even if one is signalling), and the other returns one operand (which could be signalling).

I'm not sure which of those to use. The IEEE spec says the following:

  • maximumNumber(x, y) is x if x > y, y if y > x, and the number if one operand is a number and the other is a NaN. For this operation, +0 compares greater than −0. If x = y and signs are the same it is either x or y. If both operands are NaNs, a quiet NaN is returned, according to 6.2. If either operand is a signaling NaN, an invalid operation exception is signaled, but unless both operands are NaNs, the signaling NaN is otherwise ignored and not converted to a quiet NaN as stated in 6.2 for other operations.

The spec of the arith ops does not specify anything regarding the handling of signalling NaNs (and in fact does not say either that it follows IEEE754-2019 spec). Should the part about "signalling an invalid operation exception" be implemented in arith? And if so, how? Should we fold to a poison value? Or just prevent folding in this case and let lower-level dialects handle this?

If we don't care about signalling NaNs then I think we should use the APFloat functions that produce quiet NaNs consistently.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, we don't care about signaling NaNs but should otherwise distinguish between the 3 versions of max/min.

@joker-eph
Copy link
Collaborator

IIRC, we don't care about signaling NaNs but should otherwise distinguish between the 3 versions of max/min.

Wouldn't users like Flang care about signaling NaNs?

@kuhar
Copy link
Member

kuhar commented Nov 4, 2024

I meant this in the context of constants in the IR -- IIRC there's no way to represent a signaling NaN?

@kuhar
Copy link
Member

kuhar commented Nov 4, 2024

A relevant thread: https://discourse.llvm.org/t/semantics-of-nan/66729

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but it would be nice to have someone more familiar with float semantics confirm the signaling NaN behavior

@krzysz00
Copy link
Contributor

I think we can land this

@joker-eph joker-eph merged commit b6ab04c into llvm:main Nov 27, 2024
8 checks passed
Copy link

@oowekyala Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@oowekyala oowekyala deleted the fix-arith-folder branch November 28, 2024 09:59
CoTinker pushed a commit that referenced this pull request Jan 13, 2025
The decomposition of `linalg.softmax` uses `maxnumf`, but the identity
element that is used in the generated code is the one for `maximumf`.
They are not the same, as the identity for `maxnumf` is `NaN`, while the
one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf
from being folded.

Related to #114595, which fixed the folder for maxnumf.
kazutakahirata pushed a commit to kazutakahirata/llvm-project that referenced this pull request Jan 13, 2025
The decomposition of `linalg.softmax` uses `maxnumf`, but the identity
element that is used in the generated code is the one for `maximumf`.
They are not the same, as the identity for `maxnumf` is `NaN`, while the
one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf
from being folded.

Related to llvm#114595, which fixed the folder for maxnumf.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:arith mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][arith] Arith folder for maxnumf/minnumf has wrong NaN behaviour
5 participants