Skip to content

Commit 8388040

Browse files
[mlir][tosa] Add NaN Propagation Mode Support (#121951)
The TOSA-V1.0 specification adds "nan propagation" modes as attributes for several operators. Adjust the ODS definitions of the relevant operations to include this attribute. The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is set by default. MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX support this attribute. Signed-off-by: Jack Frankland <[email protected]> Co-authored-by: TatWai Chong <[email protected]>
1 parent 08195f3 commit 8388040

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4242

4343
let arguments = (ins
4444
Tosa_Tensor: $input,
45-
I32Attr: $axis
45+
I32Attr: $axis,
46+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
4647
);
4748

4849
let results = (outs
@@ -287,7 +288,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
287288

288289
Tosa_IntArrayAttr2:$kernel,
289290
Tosa_IntArrayAttr2:$stride,
290-
Tosa_IntArrayAttr4:$pad
291+
Tosa_IntArrayAttr4:$pad,
292+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
291293
);
292294

293295
let results = (outs
@@ -388,7 +390,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
388390
I64Attr:$min_int,
389391
I64Attr:$max_int,
390392
Tosa_FloatAttr:$min_fp,
391-
Tosa_FloatAttr:$max_fp
393+
Tosa_FloatAttr:$max_fp,
394+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
392395
);
393396

394397
let results = (outs
@@ -752,7 +755,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
752755

753756
let arguments = (ins
754757
Tosa_Tensor:$input1,
755-
Tosa_Tensor:$input2
758+
Tosa_Tensor:$input2,
759+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
756760
);
757761

758762
let results = (outs
@@ -775,7 +779,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
775779

776780
let arguments = (ins
777781
Tosa_Tensor:$input1,
778-
Tosa_Tensor:$input2
782+
Tosa_Tensor:$input2,
783+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
779784
);
780785

781786
let results = (outs
@@ -1382,7 +1387,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13821387

13831388
let arguments = (ins
13841389
Tosa_Tensor:$input,
1385-
I32Attr:$axis
1390+
I32Attr:$axis,
1391+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
13861392
);
13871393

13881394
let results = (outs
@@ -1417,7 +1423,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
14171423

14181424
let arguments = (ins
14191425
Tosa_Tensor:$input,
1420-
I32Attr:$axis
1426+
I32Attr:$axis,
1427+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
14211428
);
14221429

14231430
let results = (outs

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
205205
//===----------------------------------------------------------------------===//
206206
// Iterable attributes.
207207
//===----------------------------------------------------------------------===//
208+
// Defined in `section 3. Enumerations` of the TOSA specification.
209+
208210
// Supported regimes for tosa.resize.
209211
def Tosa_ResizeTypeAttr : StringBasedAttr<
210212
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
211213
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
212214
"Supported resize/upsampling strategies">;
213215

216+
// Supported NaN propagation strategies.
217+
def Tosa_NanPropagationAttr : StringBasedAttr<
218+
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
219+
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
220+
"Supported NaN propagation strategies">;
221+
214222
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
215223

216224
// Tensor to buffer types.

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
339339
}
340340
};
341341

342+
// Attempts the following transformation:
343+
//
344+
// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
345+
// tensor X the following identity holds:
346+
//
347+
// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
348+
//
349+
// subject to the following valid NaN propagation semantics:
350+
// --------------------------------------------
351+
// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
352+
// |-------------|--------------|-------------|
353+
// | PROPAGATE | PROPAGATE | PROPAGATE |
354+
// | PROPAGATE | IGNORE | IGNORE |
355+
// | IGNORE | PROPAGATE | INVALID |
356+
// | IGNORE | IGNORE | IGNORE |
357+
// |------------------------------------------|
358+
342359
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
343360
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
344361

362+
// Helper structure to describe the range of a clamp operation.
363+
template <typename T>
364+
struct ClampRange {
365+
ClampRange(const T &start, const T &end) : start(start), end(end) {}
366+
T start;
367+
T end;
368+
369+
// Helper function to determine if two Clamp ranges intersect.
370+
bool intersects(const ClampRange<T> &otherRange) {
371+
return start < otherRange.end && otherRange.start < end;
372+
}
373+
};
374+
345375
LogicalResult matchAndRewrite(tosa::ClampOp op,
346376
PatternRewriter &rewriter) const override {
347-
Value input = op.getInput();
348-
349-
Operation *definingOp = input.getDefiningOp();
350-
if (!definingOp)
377+
// Check the input to the CLAMP op is itself a CLAMP.
378+
auto clampOp =
379+
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
380+
if (!clampOp)
351381
return failure();
352382

353-
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354-
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355-
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
383+
// Check we have a valid NaN propagation combination.
384+
const auto opNanMode = op.getNanMode();
385+
const auto clampNanMode = clampOp.getNanMode();
386+
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
387+
return failure();
356388

357-
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358-
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
389+
// Check we have intersecting ranges.
390+
const auto opMinInt = op.getMinInt();
391+
const auto opMaxInt = op.getMaxInt();
392+
const auto clampOpMinInt = clampOp.getMinInt();
393+
const auto clampOpMaxInt = clampOp.getMaxInt();
394+
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
395+
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
396+
if (!opRangeIntRange.intersects(clampRangeIntRange))
397+
return failure();
359398

360-
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361-
op, op.getType(), clampOp.getInput(),
362-
rewriter.getI64IntegerAttr(minInt),
363-
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364-
rewriter.getF32FloatAttr(maxFp));
365-
return success();
366-
}
399+
const auto opMinFloat = op.getMinFp();
400+
const auto opMaxFloat = op.getMaxFp();
401+
const auto clampOpMinFloat = clampOp.getMinFp();
402+
const auto clampOpMaxFloat = clampOp.getMaxFp();
403+
ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
404+
ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
405+
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
406+
return failure();
367407

368-
return failure();
408+
// Run the transformation.
409+
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
410+
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
411+
const auto minInt = std::max(opMinInt, clampOpMinInt);
412+
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
413+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
414+
op, op.getType(), clampOp.getInput(),
415+
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
416+
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
417+
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
418+
: opNanMode));
419+
return success();
369420
}
370421
};
371422

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,58 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
138138

139139
// -----
140140

141+
// CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
142+
func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
143+
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
144+
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
145+
%0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
146+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
147+
return %1 : tensor<4xi8>
148+
}
149+
150+
// -----
151+
152+
// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp
153+
func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
154+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
155+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
156+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
157+
return %1 : tensor<4xi8>
158+
}
159+
160+
// -----
161+
162+
// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp
163+
func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
164+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
165+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
166+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
167+
return %1 : tensor<4xi8>
168+
}
169+
170+
// -----
171+
172+
// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp
173+
func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
174+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
175+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
176+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
177+
return %1 : tensor<4xi8>
178+
}
179+
180+
// -----
181+
182+
// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
183+
func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
184+
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8>
185+
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
186+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
187+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
188+
return %1 : tensor<4xi8>
189+
}
190+
191+
// -----
192+
141193
// CHECK-LABEL: @concat_fold
142194
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
143195
// CHECK: return %arg0

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
180180
return %0 : tensor<13x21x3xf32>
181181
}
182182

183+
// -----
184+
// CHECK-LABEL: clamp_propagate
185+
func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
186+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
187+
return %0 : tensor<13x21x3xf32>
188+
}
189+
190+
// -----
191+
// CHECK-LABEL: clamp_ignore
192+
func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
193+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
194+
return %0 : tensor<13x21x3xf32>
195+
}
196+
183197
// -----
184198
// CHECK-LABEL: clamp_f16
185199
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {

0 commit comments

Comments
 (0)