Skip to content

[mlir][vector][affine] Allow --affine-super-vectorize to vectorize maxnumf/minnumf (2) #138730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

oowekyala
Copy link
Contributor

arith.minnumf and maxnumf were imcompletely supported in the vectorization flow of the affine dialect. There was a todo in the code that is now fixed. In particular, this change allows the decomposed linalg.softmax operator, which uses maxnumf, to be vectorized properly.

Note: was opened originally as #118981, I closed that one by accident

@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Clément Fournier (oowekyala)

Changes

arith.minnumf and maxnumf were imcompletely supported in the vectorization flow of the affine dialect. There was a todo in the code that is now fixed. In particular, this change allows the decomposed linalg.softmax operator, which uses maxnumf, to be vectorized properly.

Note: was opened originally as #118981, I closed that one by accident


Full diff: https://github.com/llvm/llvm-project/pull/138730.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp (+2)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+11-5)
  • (modified) mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir (+54)
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 6f79665c2bb60..6e12d3604a262 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -65,6 +65,8 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
               [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
           .Case(
               [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
+          .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
+          .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
           .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
           .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
           .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 65f85444e70db..0c9dc8ad4e6bf 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3949,8 +3949,9 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
   case arith::AtomicRMWKind::muli:
     return isa<IntegerType>(resultType);
   case arith::AtomicRMWKind::maximumf:
-    return isa<FloatType>(resultType);
   case arith::AtomicRMWKind::minimumf:
+  case arith::AtomicRMWKind::maxnumf:
+  case arith::AtomicRMWKind::minnumf:
     return isa<FloatType>(resultType);
   case arith::AtomicRMWKind::maxs: {
     auto intType = llvm::dyn_cast<IntegerType>(resultType);
@@ -3972,9 +3973,8 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
     return isa<IntegerType>(resultType);
   case arith::AtomicRMWKind::andi:
     return isa<IntegerType>(resultType);
-  default:
-    return false;
   }
+  llvm_unreachable("exhaustive switch");
 }
 
 LogicalResult AffineParallelOp::verify() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..74f6ae94c5cf4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -690,12 +690,18 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   case arith::AtomicRMWKind::ori:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::OR, vector);
-  // TODO: Add remaining reduction operations.
-  default:
-    (void)emitOptionalError(loc, "Reduction operation type not supported");
-    break;
+  case arith::AtomicRMWKind::maxnumf:
+    return builder.create<vector::ReductionOp>(vector.getLoc(),
+                                               CombiningKind::MAXNUMF, vector);
+  case arith::AtomicRMWKind::minnumf:
+    return builder.create<vector::ReductionOp>(vector.getLoc(),
+                                               CombiningKind::MINNUMF, vector);
+  case arith::AtomicRMWKind::assign:
+    (void)emitOptionalError(loc,
+                            "Reduction operation type not supported (assign)");
+    return nullptr;
   }
-  return nullptr;
+  llvm_unreachable("exhaustive switch");
 }
 
 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
index b616632a6fe24..00323d2853997 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
@@ -83,6 +83,60 @@ func.func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>
 
 // -----
 
+func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0x7FC00000 : f32
+ affine.for %i = 0 to 256 {
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+     %min = arith.minnumf %red_iter, %ld : f32
+     affine.yield %min : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: @vecdim_reduction_minnumf
+// CHECK:       affine.for %{{.*}} = 0 to 256 {
+// CHECK:         %[[vmax:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32>
+// CHECK:         %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xf32>) {
+// CHECK:           %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32>
+// CHECK:           %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK:           affine.yield %[[min]] : vector<128xf32>
+// CHECK:         }
+// CHECK:         %[[final_min:.*]] = vector.reduction <minnumf>, %[[vred:.*]] : vector<128xf32> into f32
+// CHECK:         affine.store %[[final_min]], %{{.*}} : memref<256xf32>
+// CHECK:       }
+
+// -----
+
+func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0xFFC00000 : f32
+ affine.for %i = 0 to 256 {
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+     %max = arith.maxnumf %red_iter, %ld : f32
+     affine.yield %max : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: @vecdim_reduction_maxnumf
+// CHECK:       affine.for %{{.*}} = 0 to 256 {
+// CHECK:         %[[vmin:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32>
+// CHECK:         %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xf32>) {
+// CHECK:           %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32>
+// CHECK:           %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK:           affine.yield %[[max]] : vector<128xf32>
+// CHECK:         }
+// CHECK:         %[[final_max:.*]] = vector.reduction <maxnumf>, %[[vred:.*]] : vector<128xf32> into f32
+// CHECK:         affine.store %[[final_max]], %{{.*}} : memref<256xf32>
+// CHECK:       }
+
+// -----
+
 func.func @vecdim_reduction_minsi(%in: memref<256x512xi32>, %out: memref<256xi32>) {
  %cst = arith.constant 2147483647 : i32
  affine.for %i = 0 to 256 {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants