-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs #78137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Aviad Cohen (AviadCo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/78137.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3257ecd9d91f11..d8fc960563bf29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1114,6 +1114,13 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
let results = (outs
Tosa_Tensor:$output
);
+
+ let extraClassDeclaration = [{
+ /// Computes reciprocal on a float element (input must be from float type).
+ static llvm::APFloat computeFloatElemOne(const llvm::APFloat &floatVal, FloatType floatTy);
+ }];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..fb3cd378f2c84b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -25,6 +26,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1036,3 +1038,20 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
getOperation()->setOperands(concatOperands);
return getResult();
}
+
+OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
+ auto input = adaptor.getInput1();
+
+ auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
+ // Fold splat inputs only.
+ if (!inputAttr || !inputAttr.isSplat())
+ return {};
+
+ auto shapeType = llvm::cast<ShapedType>(getType());
+ if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
+ auto floatVal = inputAttr.getSplatValue<APFloat>();
+ return DenseElementsAttr::get(shapeType, computeFloatElemOne(floatVal, floatType));
+ }
+
+ return {};
+}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..a2af9ef0c069f2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1778,6 +1779,14 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+APFloat tosa::ReciprocalOp::computeFloatElemOne(const APFloat &floatVal, FloatType floatTy) {
+ auto recipAttr = FloatAttr::get(floatTy, 1.0);
+ APFloat recip = recipAttr.getValue();
+ recip.divide(floatVal, llvm::APFloat::rmNearestTiesToEven);
+
+ return recip;
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index d35e911ebe63c4..6208b38900ebad 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -30,10 +31,6 @@ using namespace mlir::tosa;
namespace {
-/// Rounding mode to be used on floating point operations that require rounding.
-static constexpr llvm::RoundingMode tosaRoundingMode =
- llvm::APFloat::rmNearestTiesToEven;
-
/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
@@ -249,14 +246,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
using OpRewritePattern::OpRewritePattern;
- static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
- auto recipAttr = FloatAttr::get(floatTy, 1.0);
- APFloat recip = recipAttr.getValue();
- recip.divide(floatVal, tosaRoundingMode);
-
- return recip;
- }
-
LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();
@@ -281,7 +270,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
- inputValues, &computeReciprocal,
+ inputValues, &ReciprocalOp::computeFloatElemOne,
cast<FloatType>(inputValues.getElementType()));
// Replace the use of the reciprocal with the transformed tensor
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..9fc864463d95bf 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
using namespace mlir::tosa;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..de9d13b1453232 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -613,3 +613,16 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_reciprocal
+func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
+ // CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
+ // CHECK: }
+ %0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
+ %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
+ %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+ return %2 : tensor<3x600x1200xf32>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
488ef72
to
a2ed9b7
Compare
a2ed9b7
to
6a20e5a
Compare
6a20e5a
to
215f7f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, answered your comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me.
No description provided.