Skip to content

Commit fbdd47b

Browse files
pashu123AlexisPerry
authored andcommitted
[mlir][linalg] Fix numerical issue with softmax (llvm#96090)
For more info: iree-org/iree#17670 (comment)
1 parent 839a61c commit fbdd47b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2714,8 +2714,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27142714
Value neutralForMaxFInit =
27152715
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
27162716
.result();
2717-
Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2718-
reductionDim);
2717+
Value max =
2718+
reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
27192719

27202720
// Step 2: Subtract max from input and exponentiate.
27212721
Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);

mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
215215
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
216216
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
217217
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
218-
// CHECK: %[[D8:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
218+
// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
219219
// CHECK: linalg.yield %[[D8]] : f32
220220
// CHECK: } -> tensor<2x16xf32>
221221
// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =

0 commit comments

Comments
 (0)