Skip to content

[FXML-1871] Implement folding for constant TOSA casts #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
bool enableIntCastFolding);
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===//
//===-- Passes.td - TOSA pass declarations -----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -22,6 +22,10 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
}];

let constructor = "createTosaLayerwiseConstantFoldPass()";
let options = [
Option<"enableIntCastFolding", "enable-cast-folding-int-input", "bool",
"true", "Enable folding for casts from integer types">
];
}

def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
Expand Down Expand Up @@ -56,7 +60,7 @@ def TosaOptionalDecompositions
: Pass<"tosa-optional-decompositions", "func::FuncOp"> {
let summary = "Applies Tosa operations optional decompositions";
let description = [{
Pass to apply the Tosa operations decompositions
Pass to apply the Tosa operations decompositions
exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
}];

Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ using DimensionType = ArrayRef<int64_t>;
/// Type for tensor offsets.
using OffsetType = size_t;

static constexpr llvm::RoundingMode tosaRoundingMode =
APFloat::rmNearestTiesToEven;

/// Transform a tensor with the given transformation function.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType);

/// Apply the given transformation function on the elements of the given
/// tensors. If the input tensors do not match \p targetType, broadcasting is
Expand Down Expand Up @@ -74,7 +79,7 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape,
OffsetType offset);

/// Function to compute the reciprocal.
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy);
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);

} // namespace tosa
} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldCommon.cpp
TosaFoldConstantCast.cpp
TosaFoldConstantPow.cpp
TosaFoldConstantReciprocal.cpp
TosaFoldConstantRSQRT.cpp
Expand Down
50 changes: 37 additions & 13 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,52 @@
using namespace mlir;
using namespace mlir::tosa;

namespace {
static constexpr llvm::RoundingMode reciprocalRoundingMode =
APFloat::rmNearestTiesToEven;
} // namespace

template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply) {
llvm::SmallVector<llvm::APFloat, 1> transformedValues;
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<llvm::APFloat>()) {
auto transformedVal = toApply(val, toTransform.getElementType());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
transformedValues.push_back(transformedVal);
}

auto inShape = toTransform.getType();
auto outTy = inShape.cloneWith(None, targetType);

// Replace the current tensor with one containing the computed values
auto newTensor =
DenseElementsAttr::get(toTransform.getType(), transformedValues);
auto newTensor = DenseElementsAttr::get(outTy, transformedValues);
return newTensor;
}

template DenseElementsAttr
mlir::tosa::applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
FloatType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APInt, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APInt &, FloatType)> &toApply,
FloatType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APFloat, APInt, IntegerType>(
const DenseElementsAttr &toTransform,
const std::function<APInt(const APFloat &, IntegerType)> &toApply,
IntegerType targetType);

template DenseElementsAttr
mlir::tosa::applyElementWise<APInt, APInt, IntegerType>(
const DenseElementsAttr &toTransform,
const std::function<APInt(const APInt &, IntegerType)> &toApply,
IntegerType targetType);

DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
Expand Down Expand Up @@ -182,10 +205,11 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
}

APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
recip.divide(floatVal, reciprocalRoundingMode);
recip.divide(floatVal, tosaRoundingMode);

return recip;
}
213 changes: 213 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
//===- TosaFoldConstantCast.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA cast operation on constant data
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/APSInt.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantCast : public OpRewritePattern<CastOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat convertIntToFloat(const APInt &toConvert,
FloatType targetType) {
APFloat res(targetType.getFloatSemantics());
res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode);
return res;
}

static APFloat convertFloatToFloat(const APFloat &toConvert,
FloatType targetType) {
APFloat res(toConvert);
bool didLosePrecision;
res.convert(targetType.getFloatSemantics(), tosaRoundingMode,
&didLosePrecision);
return res;
}

static APInt convertFloatToInt(const APFloat &toConvert,
IntegerType targetType) {
auto targetWidth = targetType.getIntOrFloatBitWidth();
// Converting NaN to an integer results in an unpredictable value. Pick 0.
if (toConvert.isNaN()) {
return APInt::getZero(targetWidth);
}

// Make sure to properly translate booleans
if (targetWidth == 1) {
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
}

// Use the built-in functionality of APFloats to convert to integers.
// The result of this conversion should be an integer which might still be
// outside of the target integer range.
auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics());
APSInt converted(std::max(floatSize, targetWidth), targetType.isUnsigned());
bool ignored = false;
toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven,
&ignored);
// Clip to allowed range.
if (targetWidth < floatSize) {
if (targetType.isUnsigned()) {
return converted.truncUSat(targetWidth);
}
return converted.truncSSat(targetWidth);
}
return converted;
}

static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) {
// Make sure to properly translate booleans
if (targetType.getWidth() == 1) {
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
}
if (targetType.isUnsigned()) {
return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth());
}
return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth());
}

static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location,
PatternRewriter &rewriter) {
// This is only relevant if the input values are float
if (!isa<FloatType>(elements.getElementType())) {
return;
}
// Check if it is an float to integer conversion
auto resultType = location.getOutput().getType();
if (!isa<IntegerType>(cast<TensorType>(resultType).getElementType())) {
return;
}

// Report encountered NaNs
auto checkNan = [](const APFloat &val) { return val.isNaN(); };
if (any_of(elements.getValues<APFloat>(), checkNan)) {
location->emitWarning(
"Float tensor is casted to integer and it contains NaN values. The "
"cast results in an unspecified value.");
}
}

LogicalResult matchAndRewrite(CastOp tosaCast,
PatternRewriter &rewriter) const override {
auto inputTensor = tosaCast.getInput();

// If the input tensor is not constant, we cannot fold it.
if (failed(notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast,
rewriter))) {
return failure();
}

auto fromType = inputTensor.getType().getElementType();
auto toType = tosaCast.getOutput().getType().getElementType();

DenseElementsAttr elements;
matchPattern(inputTensor, m_Constant(&elements));

// Issue a warning if we convert float -> int and NaNs are present; the
// result value is unspecified in that case
warnAboutNaNToIntCast(elements, tosaCast, rewriter);

// Only fold splat tensors and those used only once to avoid duplicating
// them.
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(elements)) {
return rewriter.notifyMatchFailure(tosaCast,
"Currently, casts will only be folded "
"if its input only has a single user");
}

// Report a match failure for unexpected types
if (!toType.isIntOrFloat() || !fromType.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
tosaCast, "Only casts from/to int/float are supported.");
}

auto isUnsigned = [](Type toCheck) {
return isa<IntegerType>(toCheck) &&
cast<IntegerType>(toCheck).isUnsigned();
};
auto typesToCheck = {toType, fromType};
if (llvm::any_of(typesToCheck, isUnsigned)) {
// TOSA casts currently don't support unsigned integers.
// To support them by here, one could use APSInt instead of APInts,
// however, this causes trouble with `getValues` which does not support
// APSInts currently.
return rewriter.notifyMatchFailure(
tosaCast, "Cast folding from/to unsigned integers is not supported.");
}

DenseElementsAttr res;
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
if (isa<FloatType>(fromType)) {
res = applyElementWise<APFloat, APInt, IntegerType>(
elements, &convertFloatToInt, intOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APInt, IntegerType>(
elements, &convertIntToInt, intOutTy);
}
} else {
assert(isa<FloatType>(toType));
auto floatOutTy = cast<FloatType>(toType);
if (isa<FloatType>(fromType)) {
res = applyElementWise<APFloat, APFloat, FloatType>(
elements, &convertFloatToFloat, floatOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APFloat, FloatType>(
elements, &convertIntToFloat, floatOutTy);
}
}

rewriter.replaceOpWithNewOp<ConstOp>(tosaCast, res.getType(), res);
return success();
}
};

struct TosaFoldConstantFloatCasts : TosaFoldConstantCast {

TosaFoldConstantFloatCasts(MLIRContext *ctx) : TosaFoldConstantCast(ctx) {}

LogicalResult matchAndRewrite(CastOp tosaCast,
PatternRewriter &rewriter) const override {
if (isa<IntegerType>(tosaCast.getInput().getType().getElementType())) {
return rewriter.notifyMatchFailure(
tosaCast, "Folding casts from int is currently disabled.");
}

return TosaFoldConstantCast::matchAndRewrite(tosaCast, rewriter);
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantCastPatterns(
MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding) {
if (enableIntCastFolding) {
patterns.add<TosaFoldConstantCast>(ctx);
} else {
patterns.add<TosaFoldConstantFloatCasts>(ctx);
}
}
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) {
static APFloat computeRSQRT(const APFloat &apFloatVal, FloatType floatTy) {
// The result for negative values (apart from zero) is always NaN
if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) {
return APFloat::getNaN(apFloatVal.getSemantics());
Expand Down Expand Up @@ -72,7 +72,9 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
}

// Create a new tensor with the updated values
auto newTensor = applyElementWise(inputValues, &computeRSQRT);
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeRSQRT,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
rewriter.replaceOpWithNewOp<ConstOp>(rsqrt, newTensor.getType(), newTensor);
Expand All @@ -84,6 +86,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
} // namespace

void mlir::tosa::populateTosaFoldConstantRSQRTPatterns(

MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantRSQRT>(ctx);
}
Loading