Skip to content

[FXML-1871] Implement folding for constant TOSA casts #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ using DimensionType = ArrayRef<int64_t>;
/// Type for tensor offsets.
using OffsetType = size_t;

static constexpr llvm::RoundingMode tosaRoundingMode =
APFloat::rmNearestTiesToEven;

/// Transform a tensor with the given transformation function.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType);

/// Apply the given transformation function on the elements of the given
/// tensors. If the input tensors do not match \p targetType, broadcasting is
Expand Down Expand Up @@ -74,7 +79,7 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape,
OffsetType offset);

/// Function to compute the reciprocal.
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy);
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);

} // namespace tosa
} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldCommon.cpp
TosaFoldConstantCast.cpp
TosaFoldConstantPow.cpp
TosaFoldConstantReciprocal.cpp
TosaFoldConstantRSQRT.cpp
Expand Down
50 changes: 37 additions & 13 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,52 @@
using namespace mlir;
using namespace mlir::tosa;

namespace {
static constexpr llvm::RoundingMode reciprocalRoundingMode =
APFloat::rmNearestTiesToEven;
} // namespace

template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply) {
llvm::SmallVector<llvm::APFloat, 1> transformedValues;
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<llvm::APFloat>()) {
auto transformedVal = toApply(val, toTransform.getElementType());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
transformedValues.push_back(transformedVal);
}

auto inShape = toTransform.getType();
auto outTy = inShape.cloneWith(None, targetType);

// Replace the current tensor with one containing the computed values
auto newTensor =
DenseElementsAttr::get(toTransform.getType(), transformedValues);
auto newTensor = DenseElementsAttr::get(outTy, transformedValues);
return newTensor;
}

template DenseElementsAttr
mlir::tosa::applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
FloatType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APInt, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APInt &, FloatType)> &toApply,
FloatType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APFloat, APInt, IntegerType>(
const DenseElementsAttr &toTransform,
const std::function<APInt(const APFloat &, IntegerType)> &toApply,
IntegerType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APInt, APInt, IntegerType>(
const DenseElementsAttr &toTransform,
const std::function<APInt(const APInt &, IntegerType)> &toApply,
IntegerType targetType);

DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
Expand Down Expand Up @@ -182,10 +205,11 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
}

APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
recip.divide(floatVal, reciprocalRoundingMode);
recip.divide(floatVal, tosaRoundingMode);

return recip;
}
194 changes: 194 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//===- TosaFoldConstantCast.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA cast operation on constant data
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/APSInt.h>
#include <llvm/Support/Debug.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantCast : public OpRewritePattern<CastOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat convertIntToFloat(const APInt &toConvert,
FloatType targetType) {
APFloat res(targetType.getFloatSemantics());
res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode);
return res;
}

static APFloat convertFloatToFloat(const APFloat &toConvert,
FloatType targetType) {
APFloat res(toConvert);
bool didLosePrecision;
res.convert(targetType.getFloatSemantics(), tosaRoundingMode,
&didLosePrecision);
return res;
}

static APInt convertFloatToInt(const APFloat &toConvert,
IntegerType targetType) {
auto targetWidth = targetType.getIntOrFloatBitWidth();
// Converting NaN to an integer results in an unpredictable value. Pick 0.
if (toConvert.isNaN()) {
return APInt::getZero(targetWidth);
}

// Make sure to properly translate booleans
if (targetWidth == 1) {
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
}

// Use the built-in functionality of APFloats to convert to integers.
// The result of this conversion should be an integer which might still be
// outside of the target integer range.
auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics());
APSInt converted(std::max(floatSize, targetWidth), targetType.isUnsigned());
bool ignored = false;
toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven,
&ignored);
// Clip to allowed range.
if (targetWidth < floatSize) {
if (targetType.isUnsigned()) {
return converted.truncUSat(targetWidth);
}
return converted.truncSSat(targetWidth);
}
return converted;
}

static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) {
// Make sure to properly translate booleans
if (targetType.getWidth() == 1) {
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
}
if (targetType.isUnsigned()) {
return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth());
}
return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth());
}

static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location,
PatternRewriter &rewriter) {
// This is only relevant if the input values are float
if (!isa<FloatType>(elements.getElementType())) {
return;
}
// Check if it is an float to integer conversion
auto resultType = location.getOutput().getType();
if (!isa<IntegerType>(cast<TensorType>(resultType).getElementType())) {
return;
}

// Report encountered NaNs
auto checkNan = [](const APFloat &val) { return val.isNaN(); };
if (any_of(elements.getValues<APFloat>(), checkNan)) {
location->emitWarning(
"Float tensor is casted to integer and it contains NaN values. The "
"cast results in an unspecified value.");
}
}

LogicalResult matchAndRewrite(CastOp tosaCast,
PatternRewriter &rewriter) const override {
auto inputTensor = tosaCast.getInput();

// If the input tensor is not constant, we cannot fold it.
if (failed(notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast,
rewriter))) {
return failure();
}

auto fromType = inputTensor.getType().getElementType();
auto toType = tosaCast.getOutput().getType().getElementType();

DenseElementsAttr elements;
matchPattern(inputTensor, m_Constant(&elements));

// Issue a warning if we convert float -> int and NaNs are present; the
// result value is unspecified in that case
warnAboutNaNToIntCast(elements, tosaCast, rewriter);

// Only fold splat tensors and those used only once to avoid duplicating
// them.
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(elements)) {
return rewriter.notifyMatchFailure(tosaCast,
"Currently, casts will only be folded "
"if its input only has a single user");
}

// Report a match failure for unexpected types
if (!toType.isIntOrFloat() || !fromType.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
tosaCast, "Only casts from/to int/float are supported.");
}

auto isUnsigned = [](Type toCheck) {
return isa<IntegerType>(toCheck) &&
cast<IntegerType>(toCheck).isUnsigned();
};
auto typesToCheck = {toType, fromType};
if (llvm::any_of(typesToCheck, isUnsigned)) {
// TOSA casts currently don't support unsigned integers.
// To support them by here, one could use APSInt instead of APInts,
// however, this causes trouble with `getValues` which does not support
// APSInts currently.
return rewriter.notifyMatchFailure(
tosaCast, "Cast folding from/to unsigned integers is not supported.");
}

DenseElementsAttr res;
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
if (isa<FloatType>(fromType)) {
res = applyElementWise<APFloat, APInt, IntegerType>(
elements, &convertFloatToInt, intOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APInt, IntegerType>(
elements, &convertIntToInt, intOutTy);
}
} else {
assert(isa<FloatType>(toType));
auto floatOutTy = cast<FloatType>(toType);
if (isa<FloatType>(fromType)) {
res = applyElementWise<APFloat, APFloat, FloatType>(
elements, &convertFloatToFloat, floatOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APFloat, FloatType>(
elements, &convertIntToFloat, floatOutTy);
}
}

rewriter.replaceOpWithNewOp<ConstOp>(tosaCast, res.getType(), res);
return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantCastPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantCast>(ctx);
}
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) {
static APFloat computeRSQRT(const APFloat &apFloatVal, FloatType floatTy) {
// The result for negative values (apart from zero) is always NaN
if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) {
return APFloat::getNaN(apFloatVal.getSemantics());
Expand Down Expand Up @@ -72,7 +72,9 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
}

// Create a new tensor with the updated values
auto newTensor = applyElementWise(inputValues, &computeRSQRT);
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeRSQRT,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
rewriter.replaceOpWithNewOp<ConstOp>(rsqrt, newTensor.getType(), newTensor);
Expand All @@ -84,6 +86,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
} // namespace

void mlir::tosa::populateTosaFoldConstantRSQRTPatterns(

MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantRSQRT>(ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
}

// Create a new tensor with the updated values
auto newTensor = applyElementWise(inputValues, &computeReciprocal);
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeReciprocal,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
RewritePatternSet patterns(ctx);
auto func = getOperation();

mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Expand Down
Loading