Skip to content

[MLIR][Math] add canonicalize-f32-promotion pass #92482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
Expand Down
47 changes: 47 additions & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,51 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}

def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
let summary = "Eliminate redundant truncf/extf pairs";
let description = [{
`legalize-to-f32` pass does f32 promotion for every op belonging to the
illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
ops.

This pass is to eliminate the redundant truncf/extf pairs to improve
performance.

However, this pass may introduce numerical difference as the `f32->bf16` rounding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these be valid patterns as canonicalization when fast-math is enabled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is independent from fast-math. If the fast-math is enabled, and we have such redundant truncf/extf pairs, they will still be removed from the IR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed my point: I was asking about THE canonicalization patterns, not a separate pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they're valid under fastmath. See discussion at 3cf8535 for a particular case of this in a different context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the context of how that ultimately got resolved, what I'd want to see is updating the math legalizer to strip out trunc/ext pairs that it creates at time of creation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they're valid under fastmath.

Can you elaborate? I don't follow.
Assuming you meant to send a link to #88486 (comment) instead, there you wrote:

This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)

But numerical differences is actually in scope for fast-math...

Copy link
Contributor

@krzysz00 krzysz00 May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I did mean to send a link to the commit - the discussion moved there.

What I mean to say is that

%y = arith.truncf %arg0 : float to half
%z =arith.extf %y : half to float
return %z : float

can't be simplified to return %arg0, even under fastmath, because that is an explicit, user-specified desire to lose precision.

However, rewrites (like arith-emulate-unsupported-floats, which probably should also get this intermediate preservation treatment) which would introduce such a colliding truncf/extf pair are allowed to not do that and keep a higher intermediate precision. This is allowed without fastmath.

These are, to my current understanding, the C, and thus the LLVM, semantics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we didn't have this wider ecosystem that knows what exactly the fastmath flags mean, I'd buy that you could do this rewrite everywhere under contract ... but I'm pretty sure that flag doesn't allow eliminating arbitrary trunc/ext pairs from input

Copy link
Collaborator

@joker-eph joker-eph May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't be simplified to return %arg0, even under fastmath, because that is an explicit, user-specified desire to lose precision.

Sorry but I don't quite see why fast-math does not allow this in LLVM IR?
I don't really buy the "user-specified desired" because patterns of IR are happening in LLVM after quite a churn of optimization (possibly only through inlining, etc.).
If the user really want something, they shouldn't decorate these with fast math! (which is the point of expressing the express opposite desire than what you're saying actually)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main thing with LLVM IR is hat fptrunc and fpext can't carry fastmath flags, last I checked.

Now, if they could, I'd absolutely permit

func.func @f(%arg0: f32) -> f32 {
    %0 = arith.truncf contract %arg0 : f32 to f16
    %1 = arith.extf contract %0 : f16 to f32
    return %1 : f32
}

to fold to

func.func @f(%arg0: f32) -> f32 {
    return %arg0 : f32
}

So, from an MLIR perspective, if we give extend and truncate the ability to fastmath (if they don't already have it) it'd make a lot of sense to

  1. Add that folding and
  2. Change the various unsupported float emulation/legalization passes to (optionally - or maybe by default) stick a contract on their truncates and extensions

is eliminated.

Example:

```mlir
// the initial func
func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
%0 = math.absf %arg0 : vector<32xbf16>
%1 = math.sin %0 : vector<32xbf16>
return %1 : vector<32xbf16>
}
// after legalize-to-f32
func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
%0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
%1 = math.absf %0 : vector<32xf32>
%2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
%3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
%4 = math.sin %3 : vector<32xf32>
%5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
return %5 : vector<32xbf16>
}
// after canonicalize-f32-promotion
func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
%0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
%1 = math.absf %0 : vector<32xf32>
%2 = math.sin %1 : vector<32xf32>
%3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
return %3 : vector<32xbf16>
}
```

}];
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}

#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp

Expand Down
72 changes: 72 additions & 0 deletions mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements removing redundant extf/truncf pairs inserted from
// LegalizeToF32.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::math {
#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace mlir::math

using namespace mlir;

namespace {

struct CanonicalizeF32PromotionRewritePattern final
: OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
if (auto truncinput = innertruncop.getOperand()) {
auto outter_type = op.getType();
auto intermediate_type = innertruncop.getType();
auto inner_type = truncinput.getType();
if (outter_type.isa<ShapedType>()) {
outter_type = op.getType().cast<ShapedType>().getElementType();
intermediate_type =
innertruncop.getType().cast<ShapedType>().getElementType();
inner_type = truncinput.getType().cast<ShapedType>().getElementType();
}
if (outter_type.isF32() &&
(intermediate_type.isF16() || intermediate_type.isBF16()) &&
inner_type.isF32()) {
rewriter.replaceOp(op, {truncinput});
}
} else
return failure();
} else
return failure();
return success();
}
};

struct MathCanonicalizeF32Promotion final
: math::impl::MathCanonicalizeF32PromotionBase<
MathCanonicalizeF32Promotion> {
using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};

} // namespace
74 changes: 74 additions & 0 deletions mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s

// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : bf16
func.func @sequences(%arg0: bf16) -> bf16 {
%0 = math.absf %arg0 : bf16
%1 = math.sin %0 : bf16
return %1 : bf16
}

// CHECK-LABEL: @eliminatecastoncastf16
// CHECK: return [[arg0:%.+]] : f32
func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
%0 = arith.truncf %arg0 : f32 to f16
%1 = arith.extf %0 : f16 to f32
return %1 : f32
}

// CHECK-LABEL: @eliminatecastoncastbf16
// CHECK: return [[arg0:%.+]] : f32
func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
%0 = arith.truncf %arg0 : f32 to bf16
%1 = arith.extf %0 : bf16 to f32
return %1 : f32
}

// CHECK-LABEL: @bf16_sin_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
return %1 : vector<32x32x32xbf16>
}

// CHECK-LABEL: @f16_sin_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
%0 = math.absf %arg0 : vector<32x32x32xf16>
%1 = math.sin %0 : vector<32x32x32xf16>
return %1 : vector<32x32x32xf16>
}

// CHECK-LABEL: @bf16_branch_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
// CHECK: [[ADDF:%.+]] = arith.addf
// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
%2 = math.cos %0 : vector<32x32x32xbf16>
%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
return %3 : vector<32x32x32xbf16>
}
Loading