-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Math] add canonicalize-f32-promotion pass #92482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
crazydemo
wants to merge
6
commits into
llvm:main
from
crazydemo:zhangyan/canonicalize_f32_promotion
Closed
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c4dd5ad
add canonicalize-f32-promotion pass
02be4d6
add branch case
07ca29d
use single walk rather than greedy rewrite
224e714
add canonical option in legalize-to-f32
bf4d202
remove single canonicalize pass
cbbfdb3
Merge branch 'main' into zhangyan/canonicalize_f32_promotion
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements removing redundant extf/truncf pairs inserted from | ||
// LegalizeToF32. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/Math/Transforms/Passes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::math { | ||
#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION | ||
#include "mlir/Dialect/Math/Transforms/Passes.h.inc" | ||
} // namespace mlir::math | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
struct CanonicalizeF32PromotionRewritePattern final | ||
: OpRewritePattern<arith::ExtFOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(arith::ExtFOp op, | ||
PatternRewriter &rewriter) const final { | ||
if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) { | ||
if (auto truncinput = innertruncop.getOperand()) { | ||
auto outter_type = op.getType(); | ||
auto intermediate_type = innertruncop.getType(); | ||
auto inner_type = truncinput.getType(); | ||
if (outter_type.isa<ShapedType>()) { | ||
outter_type = op.getType().cast<ShapedType>().getElementType(); | ||
intermediate_type = | ||
innertruncop.getType().cast<ShapedType>().getElementType(); | ||
inner_type = truncinput.getType().cast<ShapedType>().getElementType(); | ||
} | ||
if (outter_type.isF32() && | ||
(intermediate_type.isF16() || intermediate_type.isBF16()) && | ||
inner_type.isF32()) { | ||
rewriter.replaceOp(op, {truncinput}); | ||
} | ||
} else | ||
return failure(); | ||
} else | ||
return failure(); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct MathCanonicalizeF32Promotion final | ||
: math::impl::MathCanonicalizeF32PromotionBase< | ||
MathCanonicalizeF32Promotion> { | ||
using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase; | ||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext()); | ||
FrozenRewritePatternSet patternSet(std::move(patterns)); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) | ||
joker-eph marked this conversation as resolved.
Show resolved
Hide resolved
|
||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s | ||
|
||
// CHECK-LABEL: @sequences | ||
// CHECK-SAME: ([[ARG0:%.+]]: bf16) | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF]] : bf16 | ||
func.func @sequences(%arg0: bf16) -> bf16 { | ||
%0 = math.absf %arg0 : bf16 | ||
%1 = math.sin %0 : bf16 | ||
return %1 : bf16 | ||
} | ||
|
||
// CHECK-LABEL: @eliminatecastoncastf16 | ||
// CHECK: return [[arg0:%.+]] : f32 | ||
func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { | ||
%0 = arith.truncf %arg0 : f32 to f16 | ||
%1 = arith.extf %0 : f16 to f32 | ||
return %1 : f32 | ||
} | ||
|
||
// CHECK-LABEL: @eliminatecastoncastbf16 | ||
// CHECK: return [[arg0:%.+]] : f32 | ||
func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { | ||
%0 = arith.truncf %arg0 : f32 to bf16 | ||
%1 = arith.extf %0 : bf16 to f32 | ||
return %1 : f32 | ||
} | ||
|
||
// CHECK-LABEL: @bf16_sin_vector | ||
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> | ||
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { | ||
%0 = math.absf %arg0 : vector<32x32x32xbf16> | ||
%1 = math.sin %0 : vector<32x32x32xbf16> | ||
return %1 : vector<32x32x32xbf16> | ||
} | ||
|
||
// CHECK-LABEL: @f16_sin_vector | ||
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] | ||
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> | ||
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { | ||
%0 = math.absf %arg0 : vector<32x32x32xf16> | ||
%1 = math.sin %0 : vector<32x32x32xf16> | ||
return %1 : vector<32x32x32xf16> | ||
} | ||
|
||
// CHECK-LABEL: @bf16_branch_vector | ||
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) | ||
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] | ||
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] | ||
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] | ||
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] | ||
// CHECK: [[COS:%.+]] = math.cos [[ABSF]] | ||
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]] | ||
// CHECK: [[ADDF:%.+]] = arith.addf | ||
// CHECK: return [[ADDF]] : vector<32x32x32xbf16> | ||
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { | ||
%0 = math.absf %arg0 : vector<32x32x32xbf16> | ||
%1 = math.sin %0 : vector<32x32x32xbf16> | ||
%2 = math.cos %0 : vector<32x32x32xbf16> | ||
%3 = arith.addf %1, %2 : vector<32x32x32xbf16> | ||
return %3 : vector<32x32x32xbf16> | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would these be valid patterns as canonicalization when fast-math is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass is independent from fast-math. If the fast-math is enabled, and we have such redundant truncf/extf pairs, they will still be removed from the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you missed my point: I was asking about THE canonicalization patterns, not a separate pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think they're valid under fastmath. See discussion at 3cf8535 for a particular case of this in a different context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the context of how that ultimately got resolved, what I'd want to see is updating the math legalizer to strip out trunc/ext pairs that it creates at time of creation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate? I don't follow.
Assuming you meant to send a link to #88486 (comment) instead, there you wrote:
But numerical differences is actually in scope for fast-math...
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I did mean to send a link to the commit - the discussion moved there.
What I mean to say is that
can't be simplified to
return %arg0
, even under fastmath, because that is an explicit, user-specified desire to lose precision.However, rewrites (like
arith-emulate-unsupported-floats
, which probably should also get this intermediate preservation treatment) which would introduce such a colliding truncf/extf pair are allowed to not do that and keep a higher intermediate precision. This is allowed without fastmath.These are, to my current understanding, the C, and thus the LLVM, semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we didn't have this wider ecosystem that knows what exactly the fastmath flags mean, I'd buy that you could do this rewrite everywhere under
contract
... but I'm pretty sure that flag doesn't allow eliminating arbitrary trunc/ext pairs from inputUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but I don't quite see why fast-math does not allow this in LLVM IR?
I don't really buy the "user-specified desired" because patterns of IR are happening in LLVM after quite a churn of optimization (possibly only through inlining, etc.).
If the user really want something, they shouldn't decorate these with fast math! (which is the point of expressing the express opposite desire than what you're saying actually)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main thing with LLVM IR is hat
fptrunc
andfpext
can't carry fastmath flags, last I checked.Now, if they could, I'd absolutely permit
to fold to
So, from an MLIR perspective, if we give extend and truncate the ability to fastmath (if they don't already have it) it'd make a lot of sense to
contract
on their truncates and extensions