Skip to content

[MLIR][Math] add canonicalize-f32-promotion pass #92482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
that is an operation frequently implemented at low precisions.
}];
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
let options = [
Option<"useCanonicalizeF32Promotion", "use-canonicalize-f32-promotion", "bool",
/*default=*/"true",
"Eliminate the redundant truncf/extf pairs to improve performance,"
"while may introduce numerical difference as the f32->bf16 rounding is"
"eliminated.">
];

}

#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::math {
#define GEN_PASS_DEF_MATHLEGALIZETOF32
Expand All @@ -37,6 +38,8 @@ struct LegalizeToF32RewritePattern final : ConversionPattern {

struct LegalizeToF32Pass final
: mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
LegalizeToF32Pass() = default;
LegalizeToF32Pass(const mlir::math::MathLegalizeToF32Options &options) {}
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -97,6 +100,29 @@ void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
patterns.getContext());
}

struct CanonicalizeF32PromotionRewritePattern final
: OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
if (auto truncinput = innertruncop.getOperand()) {
auto outterTy = getElementTypeOrSelf(op.getType());
auto intermediateTy = getElementTypeOrSelf(innertruncop.getType());
auto innerTy = getElementTypeOrSelf(truncinput.getType());
if (outterTy.isF32() &&
(intermediateTy.isF16() || intermediateTy.isBF16()) &&
innerTy.isF32()) {
rewriter.replaceOp(op, {truncinput});
}
} else
return failure();
} else
return failure();
return success();
}
};

void LegalizeToF32Pass::runOnOperation() {
Operation *op = getOperation();
MLIRContext &ctx = getContext();
Expand All @@ -109,4 +135,14 @@ void LegalizeToF32Pass::runOnOperation() {
math::populateLegalizeToF32Patterns(patterns, typeConverter);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();

if (useCanonicalizeF32Promotion) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this approach.

How about going up to matchAndRewrite() and doing this:

SmallVector<Value> extendedWidthOperands(operands);
for (auto [extended, original] : llvm::zip_equal(extendedWidthOperands, op->getOperands()) {
  // match trunc/ext pair. The inelegant version is.
  if (auto short = extended.getDefiningOp<arith::TruncFOp>()) {
    auto maybeOriginal = extended.getIn().getDefiningOp<arith;:ExtFOp>());
    if (maybeOriginal && maybeOriginal.getIn() == original)
      extended = original;
  }
  convertOpResultTypes(..., extendedWidthOperands, ...);

Now, you don't need a pass option, and all you're doing is "if this is the extension of the truncation of my original argument, use that original argument instead".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The pass option is for users who concerns about the numerical difference. With the option, they can easily switch on / off the optimization.
  2. I really appreciate your one stage approach, which directly modifies the matchAndRewrite() to determine whether to insert extf / truncf at the time of the current op is hit. However, if users creates extf / truncf explicitly in the IR, then the op pairs cannot be optimized in this way. I think the two stage approach can handle such case way more easily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose the one stage approach because it doesn't optimize explicit truncf / extf pairs

Explicitly rewriting away all truncf/extf pairs shouldn't be hiding in a type legalization. The legalization can, using the one stage approach, refrain from creating such pairs to improve numerical precision, but it should not eliminate existing ones.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can apply the one stage approach in legalization pass, and create another pass for something like graph simplification use. @ZhennanQin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 May I know what's the difference between the existing truncf / extf and auto-generated ones? Why we can only eliminate the truncf / extf generated from legalize-to-f32, but not from any other passes? Would you please provide a use scenario?

RewritePatternSet cano_patterns(&getContext());
cano_patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
FrozenRewritePatternSet cano_patternSet(std::move(cano_patterns));
op->walk([cano_patternSet](arith::ExtFOp extop) {
if (failed(applyOpPatternsAndFold({extop}, cano_patternSet)))
extop->emitError("fail to do implicit rounding removement");
});
}
}
84 changes: 71 additions & 13 deletions mlir/test/Dialect/Math/legalize-to-f32.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32=use-canonicalize-f32-promotion=true | FileCheck %s

// CHECK-LABEL: @sin
// CHECK-SAME: ([[ARG0:%.+]]: f16)
Expand Down Expand Up @@ -70,16 +70,74 @@ func.func @fastmath(%arg0: f16) -> f16 {
}

// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: f16)
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF1]] : f16
func.func @sequences(%arg0: f16) -> f16 {
%0 = math.absf %arg0 : f16
%1 = math.sin %0 : f16
return %1 : f16
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : bf16
func.func @sequences(%arg0: bf16) -> bf16 {
%0 = math.absf %arg0 : bf16
%1 = math.sin %0 : bf16
return %1 : bf16
}

// CHECK-LABEL: @eliminatecastoncastf16
// CHECK: return [[arg0:%.+]] : f32
func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
%0 = arith.truncf %arg0 : f32 to f16
%1 = arith.extf %0 : f16 to f32
return %1 : f32
}

// CHECK-LABEL: @eliminatecastoncastbf16
// CHECK: return [[arg0:%.+]] : f32
func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
%0 = arith.truncf %arg0 : f32 to bf16
%1 = arith.extf %0 : bf16 to f32
return %1 : f32
}

// CHECK-LABEL: @bf16_sin_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
return %1 : vector<32x32x32xbf16>
}

// CHECK-LABEL: @f16_sin_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
%0 = math.absf %arg0 : vector<32x32x32xf16>
%1 = math.sin %0 : vector<32x32x32xf16>
return %1 : vector<32x32x32xf16>
}

// CHECK-LABEL: @bf16_branch_vector
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
// CHECK: [[ADDF:%.+]] = arith.addf
// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
%2 = math.cos %0 : vector<32x32x32xbf16>
%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
return %3 : vector<32x32x32xbf16>
}
Loading