Skip to content

Commit b6ab04c

Browse files
authored
[mlir][arith] Fix arith maxnumf/minnumf folder (#114595)
Fix #114594 #### Context [IEEE754-2019](https://ieeexplore.ieee.org/document/8766229) Sec 9.6 defines 2 minimum and 2 maximum operations. They are termed - `maximum` and `maximumNumber` - `minimum` and `minimumNumber` In the arith dialect they are respectively named `maximumf` and `maxnumf`, `minimumf` and `minnumf` so I use these names. These operations only differ in how they handle NaN values. For `maximumf` and `minimumf`, if any operand is NaN, then the result is NaN, ie, NaN is propagated. For `maxnumf` and `minnumf`, if any operand is NaN, then the other operand is returned, ie, NaN is absorbed. The following identities hold: ``` maximumf(x, NaN) = maximumf(NaN, x) = NaN maxnumf(x, NaN) = maxnumf(NaN, x) = x ``` (and same for min). #### Arith folders In the following I am talking about the folders for the arith operations. The folders implement the following canonicalizations (`op` is one of maximumf, maxnumf, minimumf, minnumf): 1. `op(x, x)` folds to `x` 2. for `op(x, y)`, if `y` folds to the neutral element of the `op`, then the `op` is folded to `x`. 1. The neutral element of `maximumf` is -Infty 2. The neutral element of `minimumf` is +Infty 3. The neutral element of `maxnumf` and `minnumf` is NaN as shown above. 3. for `op(x, y)`, if both `x` and `y` fold to constants `x'` and `y'`, then the `op` is folded and the result is calculated with a corresponding runtime function. The folders are properly implemented for `maximumf` and `minimumf`, but the same implementations were copied for the respective `maxnumf` and `minnumf` functions. This means the neutral element of the second folder above is wrong: - `maxnumf(x, -Infty)` is folded to `x`, but that's wrong, because if `x` is NaN then -Infty should be the result - `minnumf(x, +Infty)` is folded to `x`, but same thing, the result should be +Infty when `x` is NaN. This is fixed by using `NaN` as neutral element for the `maxnumf` and `minnumf` ops.[^1] Again because of copy paste mistake, the third pattern above is using `llvm::maximum` instead of `llvm::maximumnum` to calculate the result in case both arguments fold to a constant: - `maxnumf(NaN, x')` would have been folded to `llvm::maximum(NaN, x')` which is `NaN`, whereas the result should be `x'`. This folder for `minnumf` already correctly uses `llvm::minnum`, but I fixed the one for `maxnumf` in this PR. [^1]: this is by the way already correctly implemented in [`arith::getIdentityValueAttr`](https://github.com/oowekyala/llvm-project/blob/a821964e0320d1e35514ced149ec10ec06d7131a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L2493-L2498)
1 parent 8c1bd97 commit b6ab04c

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
417417
}};
418418
}
419419

420+
/// Matches a constant scalar / vector splat / tensor splat float ones.
421+
inline detail::constant_float_predicate_matcher m_NaNFloat() {
422+
return {[](const APFloat &value) { return value.isNaN(); }};
423+
}
424+
420425
/// Matches a constant scalar / vector splat / tensor splat float positive
421426
/// infinity.
422427
inline detail::constant_float_predicate_matcher m_PosInfFloat() {

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,13 +1022,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
10221022
if (getLhs() == getRhs())
10231023
return getRhs();
10241024

1025-
// maxnumf(x, -inf) -> x
1026-
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1025+
// maxnumf(x, NaN) -> x
1026+
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
10271027
return getLhs();
10281028

1029-
return constFoldBinaryOp<FloatAttr>(
1030-
adaptor.getOperands(),
1031-
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029+
return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
10321030
}
10331031

10341032
//===----------------------------------------------------------------------===//
@@ -1108,8 +1106,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
11081106
if (getLhs() == getRhs())
11091107
return getRhs();
11101108

1111-
// minnumf(x, +inf) -> x
1112-
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1109+
// minnumf(x, NaN) -> x
1110+
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
11131111
return getLhs();
11141112

11151113
return constFoldBinaryOp<FloatAttr>(

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,31 +1917,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
19171917
// -----
19181918

19191919
// CHECK-LABEL: @test_minnumf(
1920-
func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
1920+
func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
19211921
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
1922+
// CHECK-DAG: %[[INF:.+]] = arith.constant
19221923
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
1923-
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1924+
// CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
1925+
// CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
19241926
%c0 = arith.constant 0.0 : f32
19251927
%inf = arith.constant 0x7F800000 : f32
1928+
%nan = arith.constant 0x7FC00000 : f32
19261929
%0 = arith.minnumf %c0, %arg0 : f32
19271930
%1 = arith.minnumf %arg0, %arg0 : f32
19281931
%2 = arith.minnumf %inf, %arg0 : f32
1929-
return %0, %1, %2 : f32, f32, f32
1932+
%3 = arith.minnumf %nan, %arg0 : f32
1933+
return %0, %1, %2, %3 : f32, f32, f32, f32
19301934
}
19311935

19321936
// -----
19331937

19341938
// CHECK-LABEL: @test_maxnumf(
1935-
func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
1936-
// CHECK-DAG: %[[C0:.+]] = arith.constant
1939+
func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
1940+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
1941+
// CHECK-DAG: %[[NINF:.+]] = arith.constant
19371942
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
1938-
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1943+
// CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
1944+
// CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
19391945
%c0 = arith.constant 0.0 : f32
19401946
%-inf = arith.constant 0xFF800000 : f32
1947+
%nan = arith.constant 0x7FC00000 : f32
19411948
%0 = arith.maxnumf %c0, %arg0 : f32
19421949
%1 = arith.maxnumf %arg0, %arg0 : f32
19431950
%2 = arith.maxnumf %-inf, %arg0 : f32
1944-
return %0, %1, %2 : f32, f32, f32
1951+
%3 = arith.maxnumf %nan, %arg0 : f32
1952+
return %0, %1, %2, %3 : f32, f32, f32, f32
19451953
}
19461954

19471955
// -----

0 commit comments

Comments
 (0)