Skip to content

[AutoBump] Merge with fixes of 8388040f (Jan 23) (19) #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: bump_to_08195f31
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {

let arguments = (ins
Tosa_Tensor: $input,
I32Attr: $axis
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -287,7 +288,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {

Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad
Tosa_IntArrayAttr4:$pad,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -388,7 +390,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
I64Attr:$min_int,
I64Attr:$max_int,
Tosa_FloatAttr:$min_fp,
Tosa_FloatAttr:$max_fp
Tosa_FloatAttr:$max_fp,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -752,7 +755,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [

let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2
Tosa_Tensor:$input2,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand All @@ -777,7 +781,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [

let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2
Tosa_Tensor:$input2,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -1390,7 +1395,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {

let arguments = (ins
Tosa_Tensor:$input,
I32Attr:$axis
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -1430,7 +1436,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {

let arguments = (ins
Tosa_Tensor:$input,
I32Attr:$axis
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

let results = (outs
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.

// Supported regimes for tosa.resize.
def Tosa_ResizeTypeAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
"Supported resize/upsampling strategies">;

// Supported NaN propagation strategies.
def Tosa_NanPropagationAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
"Supported NaN propagation strategies">;

def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;

// Tensor to buffer types.
Expand Down
85 changes: 68 additions & 17 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,33 +658,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
};

// Attempts the following transformation:
//
// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
// tensor X the following identity holds:
//
// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
//
// subject to the following valid NaN propagation semantics:
// --------------------------------------------
// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
// |-------------|--------------|-------------|
// | PROPAGATE | PROPAGATE | PROPAGATE |
// | PROPAGATE | IGNORE | IGNORE |
// | IGNORE | PROPAGATE | INVALID |
// | IGNORE | IGNORE | IGNORE |
// |------------------------------------------|

struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;

// Helper structure to describe the range of a clamp operation.
template <typename T>
struct ClampRange {
ClampRange(const T &start, const T &end) : start(start), end(end) {}
T start;
T end;

// Helper function to determine if two Clamp ranges intersect.
bool intersects(const ClampRange<T> &otherRange) {
return start < otherRange.end && otherRange.start < end;
}
};

LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();

Operation *definingOp = input.getDefiningOp();
if (!definingOp)
// Check the input to the CLAMP op is itself a CLAMP.
auto clampOp =
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
if (!clampOp)
return failure();

if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
return failure();

auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
// Check we have intersecting ranges.
const auto opMinInt = op.getMinInt();
const auto opMaxInt = op.getMaxInt();
const auto clampOpMinInt = clampOp.getMinInt();
const auto clampOpMaxInt = clampOp.getMaxInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();

rewriter.replaceOpWithNewOp<ClampOp>(
op, {op->getLoc(), clampOp->getLoc()}, op.getType(),
clampOp.getInput(), rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));
return success();
}
const auto opMinFloat = op.getMinFp();
const auto opMaxFloat = op.getMaxFp();
const auto clampOpMinFloat = clampOp.getMinFp();
const auto clampOpMaxFloat = clampOp.getMaxFp();
ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();

return failure();
// Run the transformation.
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
const auto minInt = std::max(opMinInt, clampOpMinInt);
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, {op->getLoc(), clampOp->getLoc()}, op.getType(), clampOp.getInput(),
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
return success();
}
};

Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,58 @@ func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>,

// -----

// CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
%0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}

// -----

// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp
func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}

// -----

// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp
func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}

// -----

// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp
func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}

// -----

// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8>
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}

// -----

// CHECK-LABEL: @concat_fold
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: clamp_propagate
func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: clamp_ignore
func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: clamp_f16
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
Expand Down