Skip to content

Implement folding for the TOSA power operation #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantRSQRTPatterns(MLIRContext *ctx,
Expand Down
41 changes: 36 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,64 @@
#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H

#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/ArrayRef.h>
#include <functional>
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
#include <mlir/IR/PatternMatch.h>

namespace mlir {
namespace tosa {

// Transform a tensor with the given transformation function.
/// Type that represents tensor dimensions.
using DimensionType = ArrayRef<int64_t>;

/// Type for tensor offsets.
using OffsetType = size_t;

/// Transform a tensor with the given transformation function.
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);

/// Function that checks if arg is a dense TOSA constant float tensor
/// Apply the given transformation function on the elements of the given
/// tensors. If the input tensors do not match \p targetType, broadcasting is
/// applied.
DenseElementsAttr applyElementWise(
const DenseElementsAttr &, const DenseElementsAttr &, TensorType targetType,
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);

/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &);

/// Function that checks if arg is a dense TOSA constant tensor
/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &);

/// Function that checks if the contained type is float
/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
PatternRewriter &);

/// Function to compute the reciprocal
/// Compute the offset in \p shape which corresponds to the given \p index.
OffsetType indexToOffset(DimensionType shape, DimensionType index);

/// Compute the index into \p shape which corresponds to the given \p offset.
SmallVector<int64_t> offsetToIndex(DimensionType shape, OffsetType offset);

/// Given an \p index into \p desiredShape, compute the corresponding index into
/// \p toBeBroadcasted.
SmallVector<int64_t> getBroadcastedIndex(DimensionType desiredShape,
DimensionType toBeBroadcasted,
DimensionType index);
/// Given an \p offset into \p desiredShape, compute the corresponding offset
/// into \p toBeBroadcasted.
OffsetType getBroadcastedOffset(DimensionType desiredShape,
DimensionType toBeBroadcasted,
OffsetType offset);

/// Function to compute the reciprocal.
APFloat computeReciprocal(const APFloat &, Type);

} // namespace tosa
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldCommon.cpp
TosaFoldConstantPow.cpp
TosaFoldConstantReciprocal.cpp
TosaFoldConstantRSQRT.cpp
TosaFoldConstantTranspose.cpp
Expand Down
93 changes: 90 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include <algorithm>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/SmallVector.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Matchers.h>
Expand All @@ -34,16 +36,52 @@ DenseElementsAttr mlir::tosa::applyElementWise(
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<llvm::APFloat>()) {
auto recipVal = toApply(val, toTransform.getElementType());
transformedValues.push_back(recipVal);
auto transformedVal = toApply(val, toTransform.getElementType());
transformedValues.push_back(transformedVal);
}

// Replace the current tensor with one containing the computed reciprocals
// Replace the current tensor with one containing the computed values
auto newTensor =
DenseElementsAttr::get(toTransform.getType(), transformedValues);
return newTensor;
}

DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply) {
// Make sure to use the correct values in case broadcasting is required
SmallVector<APFloat> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
auto targetSize = 1;
auto targetShape = targetType.getShape();
for (const auto &dimSize : targetShape) {
targetSize *= dimSize;
}
transformedValues.reserve(targetSize);

// Apply the given function to each pair of values from the input tensors.
// Make sure to broadcast the offsets properly.
auto firstIt = first.getValues<APFloat>();
auto firstShape = first.getType().getShape();
auto secondIt = second.getValues<APFloat>();
auto secondShape = second.getType().getShape();
for (auto offset = 0; offset < targetSize; offset++) {
OffsetType offsetInTargetFirst =
getBroadcastedOffset(targetShape, firstShape, offset);
OffsetType offsetInTargetSecond =
getBroadcastedOffset(targetShape, secondShape, offset);
auto res =
toApply(firstIt[offsetInTargetFirst], secondIt[offsetInTargetSecond]);
transformedValues.push_back(res);
}

// Generate a tensor with the computed values.
auto newTensor = DenseElementsAttr::get(targetType, transformedValues);
return newTensor;
}

LogicalResult
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
Expand Down Expand Up @@ -91,6 +129,55 @@ LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue<TensorType> toCheck,
"TOSA spec only allows floats");
}

OffsetType mlir::tosa::indexToOffset(DimensionType shape, DimensionType index) {
OffsetType offset = 0;
for (size_t i = 0; i < shape.size(); i++) {
offset = offset * shape[i] + index[i];
}
return offset;
}

SmallVector<int64_t> mlir::tosa::offsetToIndex(DimensionType shape,
OffsetType offset) {
auto rank = shape.size();
// The rank of the index will be equal to the rank of the shape
SmallVector<int64_t> resultIndex;
resultIndex.reserve(rank);
// Compute all the index values from the last to the first one, reverse the
// vector afterwards as there is no convenient push_front.
for (int32_t i = rank - 1; i >= 0; i--) {
resultIndex.push_back(offset % shape[i]);
offset /= shape[i];
}
std::reverse(resultIndex.begin(), resultIndex.end());
return resultIndex;
}

SmallVector<int64_t>
mlir::tosa::getBroadcastedIndex(DimensionType desiredShape,
DimensionType toBeBroadcasted,
DimensionType index) {
SmallVector<int64_t> broadCasted;
broadCasted.reserve(desiredShape.size());
for (size_t i = 0; i < desiredShape.size(); i++) {
auto toInsert = 0;
if (toBeBroadcasted[i] == desiredShape[i]) {
toInsert = index[i];
}
broadCasted.push_back(toInsert);
}
return broadCasted;
}

OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
DimensionType toBeBroadcasted,
OffsetType offset) {
auto indexInTarget = offsetToIndex(desiredShape, offset);
auto indexBroadcasted =
getBroadcastedIndex(desiredShape, toBeBroadcasted, indexInTarget);
return indexToOffset(toBeBroadcasted, indexBroadcasted);
}

APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
Expand Down
110 changes: 110 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//===- TosaFoldConstantPow.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA Pow operation on constant data
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <cmath>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/FloatingPointMode.h>
#include <llvm/ADT/SmallVector.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantPow : public OpRewritePattern<PowOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computePower(const APFloat &base, const APFloat &exp) {
// Propagate NaN
if (base.isNaN() || exp.isNaN()) {
return APFloat::getNaN(base.getSemantics());
}
// TOSA defines 0.0**0.0 as NaN
if (base.isZero() && exp.isZero()) {
return APFloat::getNaN(base.getSemantics());
}
// In case the value is negative, the exponent needs to be an integer
if (base.isNegative() && !base.isZero()) {
if (!exp.isInteger()) {
return APFloat::getNaN(base.getSemantics());
}
}

// Actually compute base**exp. Special cases for [-]infinity and [-]0 are
// already handled in accordance with the TOSA spec.
auto powFloat = std::pow(base.convertToFloat(), exp.convertToFloat());
auto res = APFloat(powFloat);

bool lostPrecision;
res.convert(base.getSemantics(), APFloat::rmNearestTiesToEven,
&lostPrecision);
return res;
}

LogicalResult matchAndRewrite(PowOp powOp,
PatternRewriter &rewriter) const override {
auto baseOp = powOp.getInput1();
auto expOp = powOp.getInput2();

// Check if both tensors are constant
auto baseIsConstCheck =
notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter);
if (failed(baseIsConstCheck)) {
return baseIsConstCheck;
}
auto expIsConstCheck =
notifyIfNotConstantFloatTosaTensor(expOp, powOp, rewriter);
if (failed(expIsConstCheck)) {
return expIsConstCheck;
}

// Extract the tensor values
DenseElementsAttr baseValues;
matchPattern(baseOp, m_Constant(&baseValues));

DenseElementsAttr expValues;
matchPattern(expOp, m_Constant(&expValues));

// If both tensors are splat, we don't care for the number of users
if (!isa<SplatElementsAttr>(baseValues) ||
!isa<SplatElementsAttr>(expValues)) {
// Make sure that at least one of the constant input tensors can be
// replaced (i.e. only has a single user)
if (!baseOp.hasOneUse() && !expOp.hasOneUse()) {
return rewriter.notifyMatchFailure(
powOp, "Currently, pows will only be folded if at least one input "
"tensor only has a single user");
}
}

auto newTensor =
applyElementWise(baseValues, expValues, powOp.getType(), &computePower);
rewriter.replaceOpWithNewOp<ConstOp>(powOp, newTensor.getType(), newTensor);

return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantPowPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantPow>(ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
RewritePatternSet patterns(ctx);
auto func = getOperation();

mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
Expand Down
Loading