Skip to content

Commit d89a0a6

Browse files
authored
[mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs (#78137)
1 parent e8af89e commit d89a0a6

File tree

5 files changed

+60
-16
lines changed

5 files changed

+60
-16
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,17 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
11141114
let results = (outs
11151115
Tosa_Tensor:$output
11161116
);
1117+
1118+
let extraClassDeclaration = [{
1119+
/// Return the reciprocal result on the operand.
1120+
static inline APFloat calcOneElement(const APFloat &operand) {
1121+
APFloat recip = APFloat(operand.getSemantics(), 1);
1122+
recip.divide(operand, APFloat::rmNearestTiesToEven);
1123+
return recip;
1124+
}
1125+
}];
1126+
1127+
let hasFolder = 1;
11171128
}
11181129

11191130
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1818
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/DialectImplementation.h"
2223
#include "mlir/IR/Matchers.h"
@@ -25,6 +26,7 @@
2526
#include "mlir/Transforms/InliningUtils.h"
2627
#include "mlir/Transforms/RegionUtils.h"
2728
#include "llvm/ADT/APFloat.h"
29+
#include "llvm/ADT/APInt.h"
2830
#include "llvm/ADT/DenseMap.h"
2931
#include "llvm/ADT/TypeSwitch.h"
3032

@@ -1036,3 +1038,21 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
10361038
getOperation()->setOperands(concatOperands);
10371039
return getResult();
10381040
}
1041+
1042+
OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1043+
auto input = adaptor.getInput1();
1044+
1045+
auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1046+
// Fold splat inputs only.
1047+
if (!inputAttr || !inputAttr.isSplat())
1048+
return {};
1049+
1050+
auto shapeType = llvm::cast<ShapedType>(getType());
1051+
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1052+
auto floatVal = inputAttr.getSplatValue<APFloat>();
1053+
return DenseElementsAttr::get(shapeType,
1054+
ReciprocalOp::calcOneElement(floatVal));
1055+
}
1056+
1057+
return {};
1058+
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/TypeUtilities.h"
2626
#include "mlir/Interfaces/InferTypeOpInterface.h"
2727
#include "mlir/Transforms/InliningUtils.h"
28+
#include "llvm/ADT/APFloat.h"
2829
#include "llvm/ADT/DenseMap.h"
2930
#include "llvm/ADT/TypeSwitch.h"
3031

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ using namespace mlir::tosa;
3030

3131
namespace {
3232

33-
/// Rounding mode to be used on floating point operations that require rounding.
34-
static constexpr llvm::RoundingMode tosaRoundingMode =
35-
llvm::APFloat::rmNearestTiesToEven;
36-
3733
/// Apply the given transformation \p toApply to every element of the tensor to
3834
/// be transformed \p toTransform.
3935
///
@@ -44,14 +40,14 @@ static constexpr llvm::RoundingMode tosaRoundingMode =
4440
template <class SrcValType, class TargetValType, class TargetType>
4541
DenseElementsAttr applyElementWise(
4642
const DenseElementsAttr &toTransform,
47-
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
43+
const std::function<TargetValType(const SrcValType &)> &toApply,
4844
TargetType targetType) {
4945
SmallVector<TargetValType> transformedValues;
5046
// We already know the amount of values we will insert, reserve space for
5147
// all of them to avoid dynamic resizing
5248
transformedValues.reserve(toTransform.getNumElements());
5349
for (auto val : toTransform.getValues<SrcValType>()) {
54-
auto transformedVal = toApply(val, targetType);
50+
auto transformedVal = toApply(val);
5551
transformedValues.push_back(transformedVal);
5652
}
5753

@@ -64,7 +60,7 @@ DenseElementsAttr applyElementWise(
6460

6561
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
6662
const DenseElementsAttr &toTransform,
67-
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
63+
const std::function<APFloat(const APFloat &)> &toApply,
6864
FloatType targetType);
6965

7066
/// Function that checks if the type contained in \p toCheck is float.
@@ -249,14 +245,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
249245

250246
using OpRewritePattern::OpRewritePattern;
251247

252-
static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
253-
auto recipAttr = FloatAttr::get(floatTy, 1.0);
254-
APFloat recip = recipAttr.getValue();
255-
recip.divide(floatVal, tosaRoundingMode);
256-
257-
return recip;
258-
}
259-
260248
LogicalResult matchAndRewrite(ReciprocalOp recip,
261249
PatternRewriter &rewriter) const override {
262250
auto inputTensor = recip.getInput1();
@@ -281,7 +269,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
281269

282270
// Create a new tensor with the updated values
283271
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
284-
inputValues, &computeReciprocal,
272+
inputValues, &ReciprocalOp::calcOneElement,
285273
cast<FloatType>(inputValues.getElementType()));
286274

287275
// Replace the use of the reciprocal with the transformed tensor

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,27 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
613613
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
614614
return %1 : tensor<i32>
615615
}
616+
617+
// -----
618+
619+
// CHECK-LABEL: @fold_reciprocal
620+
func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
621+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
622+
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
623+
// CHECK: }
624+
%0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
625+
%1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
626+
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
627+
return %2 : tensor<3x600x1200xf32>
628+
}
629+
630+
// -----
631+
632+
// CHECK-LABEL: @do_not_fold_reciprocal_int
633+
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
634+
// CHECK: tosa.reciprocal
635+
%0 = "tosa.const"(){ value = dense<11>: tensor<i32> }: () -> tensor<i32>
636+
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
637+
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
638+
return %2 : tensor<3x600x1200xi32>
639+
}

0 commit comments

Comments
 (0)