Skip to content

Commit d4fd202

Browse files
author
mlevesquedion
authored
[mlir] Use arith max or min ops instead of cmp + select (#82178)
I believe the semantics should be the same, but this saves 1 op and simplifies the code. For example, the following two instructions: ``` %2 = cmp sgt %0, %1 %3 = select %2, %0, %1 ``` Are equivalent to: ``` %2 = maxsi %0 %1 ```
1 parent cb1fed3 commit d4fd202

File tree

13 files changed

+113
-218
lines changed

13 files changed

+113
-218
lines changed

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@ using namespace mlir::affine;
3434
using namespace mlir::vector;
3535

3636
/// Given a range of values, emit the code that reduces them with "min" or "max"
37-
/// depending on the provided comparison predicate. The predicate defines which
38-
/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
39-
/// `cmpi` operation followed by the `select` operation:
40-
///
41-
/// %cond = arith.cmpi "predicate" %v0, %v1
42-
/// %result = select %cond, %v0, %v1
37+
/// depending on the provided comparison predicate, sgt for max and slt for min.
4338
///
4439
/// Multiple values are scanned in a linear sequence. This creates a data
4540
/// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,16 @@ static Value buildMinMaxReductionSeq(Location loc,
4843
arith::CmpIPredicate predicate,
4944
ValueRange values, OpBuilder &builder) {
5045
assert(!values.empty() && "empty min/max chain");
46+
assert(predicate == arith::CmpIPredicate::sgt ||
47+
predicate == arith::CmpIPredicate::slt);
5148

5249
auto valueIt = values.begin();
5350
Value value = *valueIt++;
5451
for (; valueIt != values.end(); ++valueIt) {
55-
auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
56-
value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
57-
*valueIt);
52+
if (predicate == arith::CmpIPredicate::sgt)
53+
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
54+
else
55+
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
5856
}
5957

6058
return value;

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
147147
// Find the maximum rank
148148
Value maxRank = ranks.front();
149149
for (Value v : llvm::drop_begin(ranks, 1)) {
150-
Value rankIsGreater =
151-
lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
152-
maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
150+
maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
153151
}
154152

155153
// Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
262260
// Find the maximum rank
263261
Value maxRank = ranks.front();
264262
for (Value v : llvm::drop_begin(ranks, 1)) {
265-
Value rankIsGreater =
266-
lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
267-
maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
263+
maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
268264
}
269265

270266
// Calculate the difference of ranks and the maximum rank for later offsets.

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
6161
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
6262
auto zero = rewriter.create<arith::ConstantOp>(
6363
loc, rewriter.getZeroAttr(elementTy));
64-
auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
65-
args[0], zero);
6664
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
67-
return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
65+
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
6866
}
6967

7068
// tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
348346
}
349347

350348
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
351-
auto predicate = rewriter.create<arith::CmpIOp>(
352-
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
353-
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
349+
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
354350
}
355351

356352
// tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
359355
}
360356

361357
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
362-
auto predicate = rewriter.create<arith::CmpIOp>(
363-
loc, arith::CmpIPredicate::slt, args[0], args[1]);
364-
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
358+
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
365359
}
366360

367361
// tosa::CeilOp
@@ -1000,19 +994,15 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1000994
}
1001995

1002996
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1003-
auto predicate = rewriter.create<arith::CmpIOp>(
1004-
loc, arith::CmpIPredicate::slt, args[0], args[1]);
1005-
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
997+
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1006998
}
1007999

10081000
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
10091001
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
10101002
}
10111003

10121004
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1013-
auto predicate = rewriter.create<arith::CmpIOp>(
1014-
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
1015-
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
1005+
return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
10161006
}
10171007

10181008
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
845845
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
846846
Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
847847

848-
Value cmp = rewriter.create<arith::CmpIOp>(
849-
loc, arith::CmpIPredicate::slt, dpos, zero);
850-
Value offset =
851-
rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
848+
Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
852849
return rewriter.create<arith::AddIOp>(loc, valid, offset)
853850
->getResult(0);
854851
};
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
868865
// Determine how much padding was included.
869866
val = padFn(val, left, pad[i * 2]);
870867
val = padFn(val, right, pad[i * 2 + 1]);
871-
Value cmp = rewriter.create<arith::CmpIOp>(
872-
loc, arith::CmpIPredicate::slt, val, one);
873-
return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
868+
return rewriter.create<arith::MaxSIOp>(loc, one, val);
874869
};
875870

876871
// Compute the indices from either end.

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
791791
// Insert newForOp before the terminator of `t`.
792792
auto b = OpBuilder::atBlockTerminator((t.getBody()));
793793
Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
794-
Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
795-
forOp.getUpperBound(), stepped);
796-
Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
797-
forOp.getUpperBound(), stepped);
794+
Value ub =
795+
b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
798796

799797
// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
800798
auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
3939

4040
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
4141
OpBuilder &rewriter) {
42-
auto smallerThanMin =
43-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
44-
auto minOrArg =
45-
rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
46-
auto largerThanMax =
47-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
48-
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
42+
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
43+
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
4944
}
5045

5146
bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,14 @@ func.func @if_for() {
371371
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
372372
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
373373
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
374-
// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
375-
// CHECK-NEXT: %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
376-
// CHECK-NEXT: %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
377-
// CHECK-NEXT: %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
374+
// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
375+
// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
376+
// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
378377
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
379-
// CHECK-NEXT: %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
380-
// CHECK-NEXT: %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
381-
// CHECK-NEXT: %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
378+
// CHECK-NEXT: %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
379+
// CHECK-NEXT: %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
382380
// CHECK-NEXT: %[[c1_0:.*]] = arith.constant 1 : index
383-
// CHECK-NEXT: for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
381+
// CHECK-NEXT: for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
384382
// CHECK-NEXT: call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
385383
// CHECK-NEXT: }
386384
// CHECK-NEXT: }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
397395

398396
#map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
399397

400-
// Check that the "min" (cmpi slt + select) reduction sequence is emitted
398+
// Check that the "min" reduction sequence is emitted
401399
// correctly for an affine map with 7 results.
402400

403401
// CHECK-LABEL: func @min_reduction_tree
404402
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
405-
// CHECK-NEXT: %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
406-
// CHECK-NEXT: %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
407-
// CHECK-NEXT: %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
408-
// CHECK-NEXT: %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
409-
// CHECK-NEXT: %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
410-
// CHECK-NEXT: %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
411-
// CHECK-NEXT: %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
412-
// CHECK-NEXT: %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
413-
// CHECK-NEXT: %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
414-
// CHECK-NEXT: %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
415-
// CHECK-NEXT: %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
416-
// CHECK-NEXT: %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
403+
// CHECK-NEXT: %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
404+
// CHECK-NEXT: %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
405+
// CHECK-NEXT: %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
406+
// CHECK-NEXT: %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
407+
// CHECK-NEXT: %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
408+
// CHECK-NEXT: %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
417409
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
418-
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
410+
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
419411
// CHECK-NEXT: call @body(%{{.*}}) : (index) -> ()
420412
// CHECK-NEXT: }
421413
// CHECK-NEXT: return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
690682
// CHECK: %[[Cm2:.*]] = arith.constant -1
691683
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
692684
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
693-
// CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
694-
// CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
685+
// CHECK: arith.minsi %[[first]], %[[second]]
695686
%0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
696687
return %0 : index
697688
}
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
705696
// CHECK: %[[Cm2:.*]] = arith.constant -1
706697
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
707698
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
708-
// CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
709-
// CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
699+
// CHECK: arith.maxsi %[[first]], %[[second]]
710700
%0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
711701
return %0 : index
712702
}

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
554554
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
555555
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
556556
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
557-
// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
558-
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
557+
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
559558
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
560559
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
561560
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
377377
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
378378
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
379379
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
380-
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
381-
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
382-
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
383-
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
380+
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
381+
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
384382
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
385383
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
386384
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
467465
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
468466
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
469467
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
470-
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
471-
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
472-
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
473-
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
468+
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
469+
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
474470
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
475471
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
476472
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
559555
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
560556
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
561557
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
562-
// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
563-
// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
564-
// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
565-
// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
558+
// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
559+
// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
566560
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
567561
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
568562
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index

0 commit comments

Comments
 (0)