Skip to content

Commit d9d175d

Browse files
authored
Merge pull request #19 from Xilinx/tina.tosacastfolding
[FXML-1871] Implement folding for constant TOSA casts
2 parents dddeab1 + 12b5f8d commit d9d175d

File tree

11 files changed

+545
-20
lines changed

11 files changed

+545
-20
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
2929
RewritePatternSet &patterns);
3030
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3131
RewritePatternSet &patterns);
32+
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
33+
RewritePatternSet &patterns,
34+
bool enableIntCastFolding);
3235
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
3336
RewritePatternSet &patterns);
3437
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===//
1+
//===-- Passes.td - TOSA pass declarations -----------------*- tablegen -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -22,6 +22,10 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
2222
}];
2323

2424
let constructor = "createTosaLayerwiseConstantFoldPass()";
25+
let options = [
26+
Option<"enableIntCastFolding", "enable-cast-folding-int-input", "bool",
27+
"true", "Enable folding for casts from integer types">
28+
];
2529
}
2630

2731
def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
@@ -56,7 +60,7 @@ def TosaOptionalDecompositions
5660
: Pass<"tosa-optional-decompositions", "func::FuncOp"> {
5761
let summary = "Applies Tosa operations optional decompositions";
5862
let description = [{
59-
Pass to apply the Tosa operations decompositions
63+
Pass to apply the Tosa operations decompositions
6064
exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
6165
}];
6266

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ using DimensionType = ArrayRef<int64_t>;
2727
/// Type for tensor offsets.
2828
using OffsetType = size_t;
2929

30+
static constexpr llvm::RoundingMode tosaRoundingMode =
31+
APFloat::rmNearestTiesToEven;
32+
3033
/// Transform a tensor with the given transformation function.
34+
template <class SrcValType, class TargetValType, class TargetType>
3135
DenseElementsAttr applyElementWise(
3236
const DenseElementsAttr &toTransform,
33-
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
37+
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
38+
TargetType targetType);
3439

3540
/// Apply the given transformation function on the elements of the given
3641
/// tensors. If the input tensors do not match \p targetType, broadcasting is
@@ -74,7 +79,7 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape,
7479
OffsetType offset);
7580

7681
/// Function to compute the reciprocal.
77-
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy);
82+
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);
7883

7984
} // namespace tosa
8085
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
33
TosaDecomposeConv2D.cpp
44
TosaDecomposeDepthwise.cpp
55
TosaFoldCommon.cpp
6+
TosaFoldConstantCast.cpp
67
TosaFoldConstantPow.cpp
78
TosaFoldConstantReciprocal.cpp
89
TosaFoldConstantRSQRT.cpp

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,52 @@
2323
using namespace mlir;
2424
using namespace mlir::tosa;
2525

26-
namespace {
27-
static constexpr llvm::RoundingMode reciprocalRoundingMode =
28-
APFloat::rmNearestTiesToEven;
29-
} // namespace
30-
26+
template <class SrcValType, class TargetValType, class TargetType>
3127
DenseElementsAttr mlir::tosa::applyElementWise(
3228
const DenseElementsAttr &toTransform,
33-
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply) {
34-
llvm::SmallVector<llvm::APFloat, 1> transformedValues;
29+
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
30+
TargetType targetType) {
31+
SmallVector<TargetValType> transformedValues;
3532
// We already know the amount of values we will insert, reserve space for
3633
// all of them to avoid dynamic resizing
3734
transformedValues.reserve(toTransform.getNumElements());
38-
for (auto val : toTransform.getValues<llvm::APFloat>()) {
39-
auto transformedVal = toApply(val, toTransform.getElementType());
35+
for (auto val : toTransform.getValues<SrcValType>()) {
36+
auto transformedVal = toApply(val, targetType);
4037
transformedValues.push_back(transformedVal);
4138
}
4239

40+
auto inShape = toTransform.getType();
41+
auto outTy = inShape.cloneWith(None, targetType);
42+
4343
// Replace the current tensor with one containing the computed values
44-
auto newTensor =
45-
DenseElementsAttr::get(toTransform.getType(), transformedValues);
44+
auto newTensor = DenseElementsAttr::get(outTy, transformedValues);
4645
return newTensor;
4746
}
4847

48+
template DenseElementsAttr
49+
mlir::tosa::applyElementWise<APFloat, APFloat, FloatType>(
50+
const DenseElementsAttr &toTransform,
51+
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
52+
FloatType targetType);
53+
54+
template DenseElementsAttr
55+
mlir::tosa::applyElementWise<APInt, APFloat, FloatType>(
56+
const DenseElementsAttr &toTransform,
57+
const std::function<APFloat(const APInt &, FloatType)> &toApply,
58+
FloatType targetType);
59+
60+
template DenseElementsAttr
61+
mlir::tosa::applyElementWise<APFloat, APInt, IntegerType>(
62+
const DenseElementsAttr &toTransform,
63+
const std::function<APInt(const APFloat &, IntegerType)> &toApply,
64+
IntegerType targetType);
65+
66+
template DenseElementsAttr
67+
mlir::tosa::applyElementWise<APInt, APInt, IntegerType>(
68+
const DenseElementsAttr &toTransform,
69+
const std::function<APInt(const APInt &, IntegerType)> &toApply,
70+
IntegerType targetType);
71+
4972
DenseElementsAttr mlir::tosa::applyElementWise(
5073
const DenseElementsAttr &first, const DenseElementsAttr &second,
5174
TensorType targetType,
@@ -182,10 +205,11 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
182205
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
183206
}
184207

185-
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
208+
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
209+
FloatType floatTy) {
186210
auto recipAttr = FloatAttr::get(floatTy, 1.0);
187211
APFloat recip = recipAttr.getValue();
188-
recip.divide(floatVal, reciprocalRoundingMode);
212+
recip.divide(floatVal, tosaRoundingMode);
189213

190214
return recip;
191215
}
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
//===- TosaFoldConstantCast.cpp -------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA cast operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APFloat.h>
19+
#include <llvm/ADT/APInt.h>
20+
#include <llvm/ADT/APSInt.h>
21+
#include <mlir/IR/BuiltinTypes.h>
22+
#include <mlir/IR/MLIRContext.h>
23+
#include <mlir/Support/LogicalResult.h>
24+
25+
using namespace mlir;
26+
using namespace mlir::tosa;
27+
28+
namespace {
29+
30+
struct TosaFoldConstantCast : public OpRewritePattern<CastOp> {
31+
32+
using OpRewritePattern::OpRewritePattern;
33+
34+
static APFloat convertIntToFloat(const APInt &toConvert,
35+
FloatType targetType) {
36+
APFloat res(targetType.getFloatSemantics());
37+
res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode);
38+
return res;
39+
}
40+
41+
static APFloat convertFloatToFloat(const APFloat &toConvert,
42+
FloatType targetType) {
43+
APFloat res(toConvert);
44+
bool didLosePrecision;
45+
res.convert(targetType.getFloatSemantics(), tosaRoundingMode,
46+
&didLosePrecision);
47+
return res;
48+
}
49+
50+
static APInt convertFloatToInt(const APFloat &toConvert,
51+
IntegerType targetType) {
52+
auto targetWidth = targetType.getIntOrFloatBitWidth();
53+
// Converting NaN to an integer results in an unpredictable value. Pick 0.
54+
if (toConvert.isNaN()) {
55+
return APInt::getZero(targetWidth);
56+
}
57+
58+
// Make sure to properly translate booleans
59+
if (targetWidth == 1) {
60+
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
61+
}
62+
63+
// Use the built-in functionality of APFloats to convert to integers.
64+
// The result of this conversion should be an integer which might still be
65+
// outside of the target integer range.
66+
auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics());
67+
APSInt converted(std::max(floatSize, targetWidth), targetType.isUnsigned());
68+
bool ignored = false;
69+
toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven,
70+
&ignored);
71+
// Clip to allowed range.
72+
if (targetWidth < floatSize) {
73+
if (targetType.isUnsigned()) {
74+
return converted.truncUSat(targetWidth);
75+
}
76+
return converted.truncSSat(targetWidth);
77+
}
78+
return converted;
79+
}
80+
81+
static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) {
82+
// Make sure to properly translate booleans
83+
if (targetType.getWidth() == 1) {
84+
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
85+
}
86+
if (targetType.isUnsigned()) {
87+
return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth());
88+
}
89+
return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth());
90+
}
91+
92+
static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location,
93+
PatternRewriter &rewriter) {
94+
// This is only relevant if the input values are float
95+
if (!isa<FloatType>(elements.getElementType())) {
96+
return;
97+
}
98+
// Check if it is an float to integer conversion
99+
auto resultType = location.getOutput().getType();
100+
if (!isa<IntegerType>(cast<TensorType>(resultType).getElementType())) {
101+
return;
102+
}
103+
104+
// Report encountered NaNs
105+
auto checkNan = [](const APFloat &val) { return val.isNaN(); };
106+
if (any_of(elements.getValues<APFloat>(), checkNan)) {
107+
location->emitWarning(
108+
"Float tensor is casted to integer and it contains NaN values. The "
109+
"cast results in an unspecified value.");
110+
}
111+
}
112+
113+
LogicalResult matchAndRewrite(CastOp tosaCast,
114+
PatternRewriter &rewriter) const override {
115+
auto inputTensor = tosaCast.getInput();
116+
117+
// If the input tensor is not constant, we cannot fold it.
118+
if (failed(notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast,
119+
rewriter))) {
120+
return failure();
121+
}
122+
123+
auto fromType = inputTensor.getType().getElementType();
124+
auto toType = tosaCast.getOutput().getType().getElementType();
125+
126+
DenseElementsAttr elements;
127+
matchPattern(inputTensor, m_Constant(&elements));
128+
129+
// Issue a warning if we convert float -> int and NaNs are present; the
130+
// result value is unspecified in that case
131+
warnAboutNaNToIntCast(elements, tosaCast, rewriter);
132+
133+
// Only fold splat tensors and those used only once to avoid duplicating
134+
// them.
135+
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(elements)) {
136+
return rewriter.notifyMatchFailure(tosaCast,
137+
"Currently, casts will only be folded "
138+
"if its input only has a single user");
139+
}
140+
141+
// Report a match failure for unexpected types
142+
if (!toType.isIntOrFloat() || !fromType.isIntOrFloat()) {
143+
return rewriter.notifyMatchFailure(
144+
tosaCast, "Only casts from/to int/float are supported.");
145+
}
146+
147+
auto isUnsigned = [](Type toCheck) {
148+
return isa<IntegerType>(toCheck) &&
149+
cast<IntegerType>(toCheck).isUnsigned();
150+
};
151+
auto typesToCheck = {toType, fromType};
152+
if (llvm::any_of(typesToCheck, isUnsigned)) {
153+
// TOSA casts currently don't support unsigned integers.
154+
// To support them by here, one could use APSInt instead of APInts,
155+
// however, this causes trouble with `getValues` which does not support
156+
// APSInts currently.
157+
return rewriter.notifyMatchFailure(
158+
tosaCast, "Cast folding from/to unsigned integers is not supported.");
159+
}
160+
161+
DenseElementsAttr res;
162+
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
163+
if (isa<FloatType>(fromType)) {
164+
res = applyElementWise<APFloat, APInt, IntegerType>(
165+
elements, &convertFloatToInt, intOutTy);
166+
} else {
167+
assert(isa<IntegerType>(fromType));
168+
res = applyElementWise<APInt, APInt, IntegerType>(
169+
elements, &convertIntToInt, intOutTy);
170+
}
171+
} else {
172+
assert(isa<FloatType>(toType));
173+
auto floatOutTy = cast<FloatType>(toType);
174+
if (isa<FloatType>(fromType)) {
175+
res = applyElementWise<APFloat, APFloat, FloatType>(
176+
elements, &convertFloatToFloat, floatOutTy);
177+
} else {
178+
assert(isa<IntegerType>(fromType));
179+
res = applyElementWise<APInt, APFloat, FloatType>(
180+
elements, &convertIntToFloat, floatOutTy);
181+
}
182+
}
183+
184+
rewriter.replaceOpWithNewOp<ConstOp>(tosaCast, res.getType(), res);
185+
return success();
186+
}
187+
};
188+
189+
struct TosaFoldConstantFloatCasts : TosaFoldConstantCast {
190+
191+
TosaFoldConstantFloatCasts(MLIRContext *ctx) : TosaFoldConstantCast(ctx) {}
192+
193+
LogicalResult matchAndRewrite(CastOp tosaCast,
194+
PatternRewriter &rewriter) const override {
195+
if (isa<IntegerType>(tosaCast.getInput().getType().getElementType())) {
196+
return rewriter.notifyMatchFailure(
197+
tosaCast, "Folding casts from int is currently disabled.");
198+
}
199+
200+
return TosaFoldConstantCast::matchAndRewrite(tosaCast, rewriter);
201+
}
202+
};
203+
204+
} // namespace
205+
206+
void mlir::tosa::populateTosaFoldConstantCastPatterns(
207+
MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding) {
208+
if (enableIntCastFolding) {
209+
patterns.add<TosaFoldConstantCast>(ctx);
210+
} else {
211+
patterns.add<TosaFoldConstantFloatCasts>(ctx);
212+
}
213+
}

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
3232

3333
using OpRewritePattern::OpRewritePattern;
3434

35-
static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) {
35+
static APFloat computeRSQRT(const APFloat &apFloatVal, FloatType floatTy) {
3636
// The result for negative values (apart from zero) is always NaN
3737
if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) {
3838
return APFloat::getNaN(apFloatVal.getSemantics());
@@ -72,7 +72,9 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
7272
}
7373

7474
// Create a new tensor with the updated values
75-
auto newTensor = applyElementWise(inputValues, &computeRSQRT);
75+
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
76+
inputValues, &computeRSQRT,
77+
cast<FloatType>(inputValues.getElementType()));
7678

7779
// Replace the use of the reciprocal with the transformed tensor
7880
rewriter.replaceOpWithNewOp<ConstOp>(rsqrt, newTensor.getType(), newTensor);
@@ -84,6 +86,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
8486
} // namespace
8587

8688
void mlir::tosa::populateTosaFoldConstantRSQRTPatterns(
89+
8790
MLIRContext *ctx, RewritePatternSet &patterns) {
8891
patterns.add<TosaFoldConstantRSQRT>(ctx);
8992
}

0 commit comments

Comments
 (0)