Skip to content

Commit ac66d87

Browse files
jacquesguanjacquesguan
jacquesguan
authored and
jacquesguan
committed
[mlir][Math] Add constant folder for RoundEvenOp.
This patch uses roundeven/roundevenf of libm to fold RoundEvenOp of constant. Differential Revision: https://reviews.llvm.org/D133344
1 parent 7359314 commit ac66d87

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ def Math_RoundEvenOp : Math_FloatUnaryOp<"roundeven"> {
769769
%a = math.roundeven %b : f64
770770
```
771771
}];
772+
let hasFolder = 1;
772773
}
773774

774775
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,24 @@ OpFoldResult math::TanhOp::fold(ArrayRef<Attribute> operands) {
415415
});
416416
}
417417

418+
//===----------------------------------------------------------------------===//
419+
// RoundEvenOp folder
420+
//===----------------------------------------------------------------------===//
421+
422+
OpFoldResult math::RoundEvenOp::fold(ArrayRef<Attribute> operands) {
423+
return constFoldUnaryOpConditional<FloatAttr>(
424+
operands, [](const APFloat &a) -> Optional<APFloat> {
425+
switch (a.getSizeInBits(a.getSemantics())) {
426+
case 64:
427+
return APFloat(roundeven(a.convertToDouble()));
428+
case 32:
429+
return APFloat(roundevenf(a.convertToFloat()));
430+
default:
431+
return {};
432+
}
433+
});
434+
}
435+
418436
/// Materialize an integer or floating point constant.
419437
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
420438
Attribute value, Type type,

mlir/test/Dialect/Math/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,21 @@ func.func @cos_fold_vec() -> (vector<4xf32>) {
375375
%0 = math.cos %v1 : vector<4xf32>
376376
return %0 : vector<4xf32>
377377
}
378+
379+
// CHECK-LABEL: @roundeven_fold
380+
// CHECK-NEXT: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
381+
// CHECK-NEXT: return %[[cst]]
382+
func.func @roundeven_fold() -> f32 {
383+
%c = arith.constant 1.5 : f32
384+
%r = math.roundeven %c : f32
385+
return %r : f32
386+
}
387+
388+
// CHECK-LABEL: @roundeven_fold_vec
389+
// CHECK-NEXT: %[[cst:.+]] = arith.constant dense<[0.000000e+00, -0.000000e+00, 2.000000e+00, -2.000000e+00]> : vector<4xf32>
390+
// CHECK-NEXT: return %[[cst]]
391+
func.func @roundeven_fold_vec() -> (vector<4xf32>) {
392+
%v1 = arith.constant dense<[0.5, -0.5, 1.5, -1.5]> : vector<4xf32>
393+
%0 = math.roundeven %v1 : vector<4xf32>
394+
return %0 : vector<4xf32>
395+
}

0 commit comments

Comments
 (0)