Skip to content

Commit d33bad6

Browse files
authored
[mlir][vector] Add patterns to simplify chained reductions (#73048)
Chained reductions get created during vector unrolling. These patterns simplify them into a series of adds followed by a final reductions. This is preferred on GPU targets like SPIR-V/Vulkan where vector reduction gets lowered into subgroup operations that are generally more expensive than simple vector additions. For now, only the `add` combining kind is handled.
1 parent 8c02b34 commit d33bad6

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,25 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
147147
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
148148
PatternBenefit benefit = 1);
149149

150+
/// Patterns that fold chained vector reductions. These patterns assume that
151+
/// elementwise operations (e.g., `arith.addf` with vector operands) are
152+
/// cheaper than vector reduction.
153+
/// Note that these patterns change the order of reduction which may not always
154+
/// produce bit-identical results on some floating point inputs.
155+
///
156+
/// Example:
157+
/// ```
158+
/// %a = vector.reduction <add> %x, %acc
159+
/// %b = vector.reduction <add> %y, %a
160+
/// ```
161+
/// is transformed into:
162+
/// ```
163+
/// %a = arith.addf %x, %y
164+
/// %b = vector.reduction <add> %a, %acc
165+
/// ```
166+
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
167+
PatternBenefit benefit = 1);
168+
150169
/// Populate `patterns` with the following patterns.
151170
///
152171
/// [DecomposeDifferentRankInsertStridedSlice]

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,98 @@ struct FoldArithExtIntoContractionOp
14021402
}
14031403
};
14041404

1405+
/// Pattern to fold chained reduction to a series of vector additions and a
1406+
/// final reduction. This form should require fewer subgroup operations.
1407+
///
1408+
/// ```mlir
1409+
/// %a = vector.reduction <add> %x, %acc
1410+
/// %b = vector.reduction <add> %y, %a
1411+
/// ==>
1412+
/// %a = arith.addf %x, %y
1413+
/// %b = vector.reduction <add> %a, %acc
1414+
/// ```
1415+
struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1416+
using OpRewritePattern::OpRewritePattern;
1417+
1418+
LogicalResult matchAndRewrite(vector::ReductionOp op,
1419+
PatternRewriter &rewriter) const override {
1420+
// TODO: Handle other combining kinds.
1421+
if (op.getKind() != vector::CombiningKind::ADD)
1422+
return failure();
1423+
1424+
// Accumulator is optional.
1425+
Value acc = op.getAcc();
1426+
if (!acc)
1427+
return failure();
1428+
1429+
if (!acc.getType().isIntOrFloat())
1430+
return failure();
1431+
1432+
auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1433+
if (!parentReduction)
1434+
return failure();
1435+
1436+
Location loc = op.getLoc();
1437+
Value vAdd;
1438+
if (isa<IntegerType>(acc.getType())) {
1439+
vAdd = rewriter.createOrFold<arith::AddIOp>(
1440+
loc, parentReduction.getVector(), op.getVector());
1441+
} else {
1442+
vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1443+
op.getVector());
1444+
}
1445+
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1446+
parentReduction.getAcc());
1447+
return success();
1448+
}
1449+
};
1450+
1451+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
1452+
/// It's enough for there to be one initial zero value, so we can eliminate the
1453+
/// extra ones that feed into `vector.reduction <add>`. These get created by the
1454+
/// `ChainedReduction` pattern.
1455+
///
1456+
/// ```mlir
1457+
/// %a = arith.addf %x, %zero
1458+
/// %b = arith.addf %a, %y
1459+
/// %c = vector.reduction <add> %b, %acc
1460+
/// ==>
1461+
/// %b = arith.addf %a, %y
1462+
/// %c = vector.reduction <add> %b, %acc
1463+
/// ```
1464+
struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1465+
using OpRewritePattern::OpRewritePattern;
1466+
1467+
LogicalResult matchAndRewrite(vector::ReductionOp op,
1468+
PatternRewriter &rewriter) const override {
1469+
// TODO: Handle other reduction kinds and their identity values.
1470+
if (op.getKind() != vector::CombiningKind::ADD)
1471+
return failure();
1472+
1473+
Type elemType = op.getSourceVectorType().getElementType();
1474+
// The integer case should be handled by `arith.addi` folders, only check
1475+
// for floats here.
1476+
if (!isa<FloatType>(elemType))
1477+
return failure();
1478+
1479+
auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1480+
if (!vAdd)
1481+
return failure();
1482+
auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1483+
if (!addLhs)
1484+
return failure();
1485+
1486+
if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1487+
return failure();
1488+
1489+
auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1490+
vAdd.getRhs());
1491+
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1492+
op.getAcc());
1493+
return success();
1494+
}
1495+
};
1496+
14051497
} // namespace
14061498

14071499
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1467,6 +1559,13 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
14671559
patterns.getContext(), benefit);
14681560
}
14691561

1562+
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
1563+
RewritePatternSet &patterns, PatternBenefit benefit) {
1564+
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
1565+
patterns.add<ReduceRedundantZero>(patterns.getContext(),
1566+
PatternBenefit(benefit.getBenefit() + 1));
1567+
}
1568+
14701569
//===----------------------------------------------------------------------===//
14711570
// TableGen'd enum attribute definitions
14721571
//===----------------------------------------------------------------------===//
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// RUN: mlir-opt %s --test-vector-chained-reduction-folding-patterns | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @reduce_1x_fp32(
4+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>) -> f32 {
5+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
6+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
7+
// CHECK-NEXT: return %[[RES]] : f32
8+
func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
9+
%cst0 = arith.constant 0.0 : f32
10+
%0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
11+
return %0 : f32
12+
}
13+
14+
// CHECK-LABEL: func.func @reduce_2x_fp32(
15+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
16+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
17+
// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
18+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
19+
// CHECK-NEXT: return %[[RES]] : f32
20+
func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
21+
%cst0 = arith.constant 0.0 : f32
22+
%0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
23+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
24+
return %1 : f32
25+
}
26+
27+
// CHECK-LABEL: func.func @reduce_2x_no_acc_fp32(
28+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
29+
// CHECK: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
30+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
31+
// CHECK-NEXT: return %[[RES]] : f32
32+
func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
33+
%0 = vector.reduction <add>, %arg0 : vector<8xf32> into f32
34+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
35+
return %1 : f32
36+
}
37+
38+
// CHECK-LABEL: func.func @reduce_2x_zero_add_fp32(
39+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
40+
// CHECK: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
41+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
42+
// CHECK-NEXT: return %[[RES]] : f32
43+
func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
44+
%cst0 = arith.constant dense<0.0> : vector<8xf32>
45+
%x = arith.addf %arg0, %cst0 : vector<8xf32>
46+
%0 = vector.reduction <add>, %x : vector<8xf32> into f32
47+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
48+
return %1 : f32
49+
}
50+
51+
// CHECK-LABEL: func.func @reduce_3x_fp32(
52+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>,
53+
// CHECK-SAME: %[[ARG2:.+]]: vector<8xf32>) -> f32 {
54+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
55+
// CHECK-DAG: %[[ADD0:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : vector<8xf32>
56+
// CHECK-DAG: %[[ADD1:.+]] = arith.addf %[[ARG0]], %[[ADD0]] : vector<8xf32>
57+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
58+
// CHECK-NEXT: return %[[RES]] : f32
59+
func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
60+
%arg2: vector<8xf32>) -> f32 {
61+
%cst0 = arith.constant 0.0 : f32
62+
%0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
63+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
64+
%2 = vector.reduction <add>, %arg2, %1 : vector<8xf32> into f32
65+
return %2 : f32
66+
}
67+
68+
// CHECK-LABEL: func.func @reduce_1x_i32(
69+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>) -> i32 {
70+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
71+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
72+
// CHECK-NEXT: return %[[RES]] : i32
73+
func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
74+
%cst0 = arith.constant 0 : i32
75+
%0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
76+
return %0 : i32
77+
}
78+
79+
// CHECK-LABEL: func.func @reduce_2x_i32(
80+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
81+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
82+
// CHECK-DAG: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
83+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
84+
// CHECK-NEXT: return %[[RES]] : i32
85+
func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
86+
%cst0 = arith.constant 0 : i32
87+
%0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
88+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
89+
return %1 : i32
90+
}
91+
92+
// CHECK-LABEL: func.func @reduce_2x_no_acc_i32(
93+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
94+
// CHECK: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
95+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
96+
// CHECK-NEXT: return %[[RES]] : i32
97+
func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
98+
%0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
99+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
100+
return %1 : i32
101+
}
102+
103+
// CHECK-LABEL: func.func @reduce_2x_zero_add_i32(
104+
// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
105+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
106+
// CHECK-DAG: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
107+
// CHECK-NEXT: %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
108+
// CHECK-NEXT: return %[[RES]] : i32
109+
func.func @reduce_2x_zero_add_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
110+
%cst0 = arith.constant 0 : i32
111+
%cstV = arith.constant dense<0> : vector<8xi32>
112+
%x = arith.addi %arg0, %cstV : vector<8xi32>
113+
%0 = vector.reduction <add>, %x, %cst0 : vector<8xi32> into i32
114+
%1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
115+
return %1 : i32
116+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns
420420
}
421421
};
422422

423+
struct TestVectorChainedReductionFoldingPatterns
424+
: public PassWrapper<TestVectorChainedReductionFoldingPatterns,
425+
OperationPass<func::FuncOp>> {
426+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
427+
TestVectorChainedReductionFoldingPatterns)
428+
429+
StringRef getArgument() const final {
430+
return "test-vector-chained-reduction-folding-patterns";
431+
}
432+
StringRef getDescription() const final {
433+
return "Test patterns to fold chained vector reductions";
434+
}
435+
void runOnOperation() override {
436+
RewritePatternSet patterns(&getContext());
437+
populateChainedVectorReductionFoldingPatterns(patterns);
438+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
439+
}
440+
};
441+
423442
struct TestFlattenVectorTransferPatterns
424443
: public PassWrapper<TestFlattenVectorTransferPatterns,
425444
OperationPass<func::FuncOp>> {
@@ -773,6 +792,8 @@ void registerTestVectorLowerings() {
773792

774793
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
775794

795+
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
796+
776797
PassRegistration<TestFlattenVectorTransferPatterns>();
777798

778799
PassRegistration<TestVectorScanLowering>();

0 commit comments

Comments
 (0)