Skip to content

[mlir] Use arith max or min ops instead of cmp + select #82178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2024

Conversation

mlevesquedion
Copy link
Contributor

I believe the semantics should be the same, but this saves 1 op and simplifies the code.

For example, the following two instructions:

%2 = cmp sgt %0, %1
%3 = select %2, %0, %1

Are equivalent to:

%2 = maxsi %0 %1

I believe the semantics should be the same, but this saves 1 op and
simplifies the code.
@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-tosa

Author: mlevesquedion (mlevesquedion)

Changes

I believe the semantics should be the same, but this saves 1 op and simplifies the code.

For example, the following two instructions:

%2 = cmp sgt %0, %1
%3 = select %2, %0, %1

Are equivalent to:

%2 = maxsi %0 %1

Patch is 44.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82178.diff

13 Files Affected:

  • (modified) mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp (+8-9)
  • (modified) mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (+2-6)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-15)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-7)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+2-7)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+16-26)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-2)
  • (modified) mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir (+6-12)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+15-29)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+30-53)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+21-40)
  • (modified) mlir/test/Transforms/parametric-tiling.mlir (+4-8)
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 15ad6d8cdf629d..98cdfc63252711 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -34,12 +34,7 @@ using namespace mlir::affine;
 using namespace mlir::vector;
 
 /// Given a range of values, emit the code that reduces them with "min" or "max"
-/// depending on the provided comparison predicate.  The predicate defines which
-/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
-/// `cmpi` operation followed by the `select` operation:
-///
-///   %cond   = arith.cmpi "predicate" %v0, %v1
-///   %result = select %cond, %v0, %v1
+/// depending on the provided comparison predicate, sgt for max and slt for min.
 ///
 /// Multiple values are scanned in a linear sequence.  This creates a data
 /// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,17 @@ static Value buildMinMaxReductionSeq(Location loc,
                                      arith::CmpIPredicate predicate,
                                      ValueRange values, OpBuilder &builder) {
   assert(!values.empty() && "empty min/max chain");
+  assert(predicate == arith::CmpIPredicate::sgt ||
+         predicate == arith::CmpIPredicate::slt);
 
   auto valueIt = values.begin();
   Value value = *valueIt++;
   for (; valueIt != values.end(); ++valueIt) {
-    auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
-    value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
-                                            *valueIt);
+    if (predicate == arith::CmpIPredicate::sgt) {
+      value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+    } else {
+      value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+    }
   }
 
   return value;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a3e51aeed0735a..de649f730ee9d7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4f6dadfb37166..7eb32ebe3228fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
     auto zero = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementTy));
-    auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
-                                              args[0], zero);
     auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
-    return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
   }
 
   // tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::CeilOp
@@ -1000,9 +994,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
@@ -1010,9 +1002,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 607a603cca810f..3f39cbf03a9a80 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
             Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
 
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, dpos, zero);
-            Value offset =
-                rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+            Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
             return rewriter.create<arith::AddIOp>(loc, valid, offset)
                 ->getResult(0);
           };
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             // Determine how much padding was included.
             val = padFn(val, left, pad[i * 2]);
             val = padFn(val, right, pad[i * 2 + 1]);
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, val, one);
-            return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
+            return rewriter.create<arith::MaxSIOp>(loc, one, val);
           };
 
           // Compute the indices from either end.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 536c02feca1bd5..502d7e197a6f6b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
     // Insert newForOp before the terminator of `t`.
     auto b = OpBuilder::atBlockTerminator((t.getBody()));
     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
-    Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
-                                         forOp.getUpperBound(), stepped);
-    Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
-                                         forOp.getUpperBound(), stepped);
+    Value ub =
+        b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
 
     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..4fc97115064f33 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
 
 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
                                  OpBuilder &rewriter) {
-  auto smallerThanMin =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
-  auto minOrArg =
-      rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
-  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
+  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
 }
 
 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 92608135d24b08..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -371,16 +371,14 @@ func.func @if_for() {
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
 // CHECK-NEXT:     %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:     %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:     %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
-// CHECK-NEXT:     %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
-// CHECK-NEXT:     %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
+// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT:     %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
 // CHECK-NEXT:     %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT:     %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
-// CHECK-NEXT:     %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
-// CHECK-NEXT:     %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
+// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT:     %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
 // CHECK-NEXT:     %[[c1_0:.*]] = arith.constant 1 : index
-// CHECK-NEXT:     for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
+// CHECK-NEXT:     for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
 // CHECK-NEXT:       call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
 
 #map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
 
-// Check that the "min" (cmpi slt + select) reduction sequence is emitted
+// Check that the "min" reduction sequence is emitted
 // correctly for an affine map with 7 results.
 
 // CHECK-LABEL: func @min_reduction_tree
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-NEXT:   %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
+// CHECK-NEXT:   %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT:   %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
+// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
 // CHECK-NEXT:     call @body(%{{.*}}) : (index) -> ()
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.minsi %[[first]], %[[second]]
   %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.maxsi %[[first]], %[[second]]
   %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index eb45112b117c0d..87d613986c7c3f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
-// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
 // CHECK:           %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
 // CHECK:           %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index cb3af973daee20..3b73c513b7955f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 51ebcad0797807..e64903671e599f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -263,16 +263,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
   // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
   // CHECK:   %[[PAD_START:.+]] = arith.constant 1
   // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
-  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
-  // CHECK:   %...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2024

@llvm/pr-subscribers-mlir

Author: mlevesquedion (mlevesquedion)

Changes

I believe the semantics should be the same, but this saves 1 op and simplifies the code.

For example, the following two instructions:

%2 = cmp sgt %0, %1
%3 = select %2, %0, %1

Are equivalent to:

%2 = maxsi %0 %1

Patch is 44.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82178.diff

13 Files Affected:

  • (modified) mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp (+8-9)
  • (modified) mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (+2-6)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-15)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-7)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+2-7)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+16-26)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-2)
  • (modified) mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir (+6-12)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+15-29)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+30-53)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+21-40)
  • (modified) mlir/test/Transforms/parametric-tiling.mlir (+4-8)
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 15ad6d8cdf629d..98cdfc63252711 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -34,12 +34,7 @@ using namespace mlir::affine;
 using namespace mlir::vector;
 
 /// Given a range of values, emit the code that reduces them with "min" or "max"
-/// depending on the provided comparison predicate.  The predicate defines which
-/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
-/// `cmpi` operation followed by the `select` operation:
-///
-///   %cond   = arith.cmpi "predicate" %v0, %v1
-///   %result = select %cond, %v0, %v1
+/// depending on the provided comparison predicate, sgt for max and slt for min.
 ///
 /// Multiple values are scanned in a linear sequence.  This creates a data
 /// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,17 @@ static Value buildMinMaxReductionSeq(Location loc,
                                      arith::CmpIPredicate predicate,
                                      ValueRange values, OpBuilder &builder) {
   assert(!values.empty() && "empty min/max chain");
+  assert(predicate == arith::CmpIPredicate::sgt ||
+         predicate == arith::CmpIPredicate::slt);
 
   auto valueIt = values.begin();
   Value value = *valueIt++;
   for (; valueIt != values.end(); ++valueIt) {
-    auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
-    value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
-                                            *valueIt);
+    if (predicate == arith::CmpIPredicate::sgt) {
+      value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+    } else {
+      value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+    }
   }
 
   return value;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a3e51aeed0735a..de649f730ee9d7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4f6dadfb37166..7eb32ebe3228fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
     auto zero = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementTy));
-    auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
-                                              args[0], zero);
     auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
-    return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
   }
 
   // tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::CeilOp
@@ -1000,9 +994,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
@@ -1010,9 +1002,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 607a603cca810f..3f39cbf03a9a80 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
             Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
 
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, dpos, zero);
-            Value offset =
-                rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+            Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
             return rewriter.create<arith::AddIOp>(loc, valid, offset)
                 ->getResult(0);
           };
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             // Determine how much padding was included.
             val = padFn(val, left, pad[i * 2]);
             val = padFn(val, right, pad[i * 2 + 1]);
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, val, one);
-            return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
+            return rewriter.create<arith::MaxSIOp>(loc, one, val);
           };
 
           // Compute the indices from either end.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 536c02feca1bd5..502d7e197a6f6b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
     // Insert newForOp before the terminator of `t`.
     auto b = OpBuilder::atBlockTerminator((t.getBody()));
     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
-    Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
-                                         forOp.getUpperBound(), stepped);
-    Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
-                                         forOp.getUpperBound(), stepped);
+    Value ub =
+        b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
 
     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..4fc97115064f33 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
 
 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
                                  OpBuilder &rewriter) {
-  auto smallerThanMin =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
-  auto minOrArg =
-      rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
-  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
+  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
 }
 
 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 92608135d24b08..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -371,16 +371,14 @@ func.func @if_for() {
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
 // CHECK-NEXT:     %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:     %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:     %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
-// CHECK-NEXT:     %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
-// CHECK-NEXT:     %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
+// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT:     %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
 // CHECK-NEXT:     %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT:     %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
-// CHECK-NEXT:     %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
-// CHECK-NEXT:     %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
+// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT:     %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
 // CHECK-NEXT:     %[[c1_0:.*]] = arith.constant 1 : index
-// CHECK-NEXT:     for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
+// CHECK-NEXT:     for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
 // CHECK-NEXT:       call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
 
 #map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
 
-// Check that the "min" (cmpi slt + select) reduction sequence is emitted
+// Check that the "min" reduction sequence is emitted
 // correctly for an affine map with 7 results.
 
 // CHECK-LABEL: func @min_reduction_tree
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-NEXT:   %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
+// CHECK-NEXT:   %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT:   %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
+// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
 // CHECK-NEXT:     call @body(%{{.*}}) : (index) -> ()
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.minsi %[[first]], %[[second]]
   %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.maxsi %[[first]], %[[second]]
   %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index eb45112b117c0d..87d613986c7c3f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
-// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
 // CHECK:           %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
 // CHECK:           %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index cb3af973daee20..3b73c513b7955f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 51ebcad0797807..e64903671e599f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -263,16 +263,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
   // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
   // CHECK:   %[[PAD_START:.+]] = arith.constant 1
   // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
-  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
-  // CHECK:   %...
[truncated]

@makslevental
Copy link
Contributor

@RoboTux @banach-space @matthias-springer @Mogball @ThomasRaoux I tagged you because you're (roughly) the most recent editors (according to git blame) feel to remove/ignore if you're not up for reviewing.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm having a dejavu, though... as if somebody else had already tried something similar in the Affine to Standard lowering and we realized that there was some semantic differences/corner cases... Could you please do a search on Phabricator to be sure that there isn't anything similar?

@mlevesquedion
Copy link
Contributor Author

LGTM. I'm having a dejavu, though... as if somebody else had already tried something similar in the Affine to Standard lowering and we realized that there was some semantic differences/corner cases... Could you please do a search on Phabricator to be sure that there isn't anything similar?

Thanks for taking the time to review! I'm actually not sure how to search on Phabricator. When I search "llvm phabricator" on Google I land on this page: https://reviews.llvm.org/, which does not seem to provide a way to search old changes. Do you know how to search on Phabricator? I asked on Discourse: https://discourse.llvm.org/t/how-to-search-old-pull-requests-on-phabricator/77133. (I also tried a search on GitHub and got no hits.)

Alternatively, are there additional checks/tests that I could add/do to validate this change?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking. I also gave it a try and couldn't find anything related to affine min/max so I could be related to other ops... Let's go with this.

@mlevesquedion
Copy link
Contributor Author

Thanks for the reviews, folks! I've addressed the review feedback, can you merge this for me? (I don't have commit access.)

@dcaballe dcaballe merged commit d4fd202 into llvm:main Feb 21, 2024
@mlevesquedion mlevesquedion deleted the use-max-min-instead-of-cmp-select branch February 21, 2024 20:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants