Skip to content

Commit b323864

Browse files
Max191ScottTodd
authored andcommitted
Revert "[mlir][linalg] Fix numerical issue with softmax (llvm#96090)"
This reverts commit fa06668.
1 parent bbd4af5 commit b323864

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2719,8 +2719,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27192719
Value neutralForMaxFInit =
27202720
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
27212721
.result();
2722-
Value max =
2723-
reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2722+
Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2723+
reductionDim);
27242724

27252725
// Step 2: Subtract max from input and exponentiate.
27262726
Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);

mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
215215
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
216216
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
217217
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
218-
// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
218+
// CHECK: %[[D8:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
219219
// CHECK: linalg.yield %[[D8]] : f32
220220
// CHECK: } -> tensor<2x16xf32>
221221
// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =

0 commit comments

Comments
 (0)