Skip to content

Commit ac16b30

Browse files
committed
make patterns standalone
1 parent 0b8c7b1 commit ac16b30

File tree

8 files changed

+134
-85
lines changed

8 files changed

+134
-85
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
453453
let assemblyFormat = "attr-dict";
454454
}
455455

456+
def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
457+
"apply_patterns.vector.propagate_extract",
458+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459+
let description = [{
460+
Collect a set of patterns for propagating `vector.extract` through the
461+
vector ops.
462+
}];
463+
464+
let assemblyFormat = "attr-dict";
465+
}
466+
456467
#endif // VECTOR_TRANSFORM_OPS

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
409409
const TypeConverter &typeConverter, RewritePatternSet &patterns,
410410
ConversionTarget &target, unsigned targetBitWidth);
411411

412+
/// Populates patterns for propagating `vector.extract` through the vector ops.
413+
void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
414+
412415
} // namespace vector
413416
} // namespace mlir
414417

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,47 +2237,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22372237
}
22382238
};
22392239

2240-
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
2241-
class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
2242-
public:
2243-
using OpRewritePattern::OpRewritePattern;
2244-
2245-
LogicalResult matchAndRewrite(ExtractOp op,
2246-
PatternRewriter &rewriter) const override {
2247-
Operation *eltwise = op.getVector().getDefiningOp();
2248-
2249-
// Elementwise op with single result and `extract` is single user.
2250-
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
2251-
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
2252-
return failure();
2253-
2254-
// Arguments and result types must match.
2255-
if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
2256-
eltwise->getResultTypes())))
2257-
return failure();
2258-
2259-
Type dstType = op.getType();
2260-
2261-
OpBuilder::InsertionGuard g(rewriter);
2262-
rewriter.setInsertionPoint(eltwise);
2263-
2264-
IRMapping mapping;
2265-
Location loc = eltwise->getLoc();
2266-
for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
2267-
Value newArg =
2268-
rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
2269-
mapping.map(arg, newArg);
2270-
}
2271-
2272-
Operation *newEltwise = rewriter.clone(*eltwise, mapping);
2273-
newEltwise->getResult(0).setType(dstType);
2274-
2275-
rewriter.replaceOp(op, newEltwise);
2276-
rewriter.eraseOp(eltwise);
2277-
return success();
2278-
}
2279-
};
2280-
22812240
// Folds extract(shape_cast(..)) into shape_cast when the total element count
22822241
// does not change.
22832242
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2350,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23502309

23512310
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23522311
MLIRContext *context) {
2353-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2354-
ExtractOpFromElemetwise>(context);
2312+
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
23552313
results.add(foldExtractFromShapeCastToShapeCast);
23562314
results.add(foldExtractFromFromElements);
23572315
}

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
204204
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
205205
}
206206

207+
void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
208+
RewritePatternSet &patterns) {
209+
vector::populateVectorPropagateExtractsPatterns(patterns);
210+
}
211+
207212
//===----------------------------------------------------------------------===//
208213
// Transform op registration
209214
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
2424
VectorTransforms.cpp
2525
VectorUnroll.cpp
2626
VectorMaskElimination.cpp
27+
VectorPropagateExtract.cpp
2728

2829
ADDITIONAL_HEADER_DIRS
2930
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns for vector.extract propagation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
19+
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
20+
class ExtractOpFromElementwise final
21+
: public OpRewritePattern<vector::ExtractOp> {
22+
public:
23+
using OpRewritePattern::OpRewritePattern;
24+
25+
LogicalResult matchAndRewrite(vector::ExtractOp op,
26+
PatternRewriter &rewriter) const override {
27+
Operation *eltwise = op.getVector().getDefiningOp();
28+
29+
// Elementwise op with single result and `extract` is single user.
30+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
31+
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
32+
return failure();
33+
34+
// Arguments and result types must match.
35+
if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
36+
eltwise->getResultTypes())))
37+
return failure();
38+
39+
Type dstType = op.getType();
40+
41+
OpBuilder::InsertionGuard g(rewriter);
42+
rewriter.setInsertionPoint(eltwise);
43+
44+
IRMapping mapping;
45+
Location loc = eltwise->getLoc();
46+
for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
47+
Value newArg =
48+
rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
49+
mapping.map(arg, newArg);
50+
}
51+
52+
Operation *newEltwise = rewriter.clone(*eltwise, mapping);
53+
newEltwise->getResult(0).setType(dstType);
54+
55+
rewriter.replaceOp(op, newEltwise);
56+
rewriter.eraseOp(eltwise);
57+
return success();
58+
}
59+
};
60+
61+
} // namespace
62+
63+
void mlir::vector::populateVectorPropagateExtractsPatterns(
64+
RewritePatternSet &patterns) {
65+
patterns.add<ExtractOpFromElementwise>(patterns.getContext());
66+
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -244,48 +244,6 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
244244

245245
// -----
246246

247-
// CHECK-LABEL: @extract_elementwise
248-
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
249-
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
250-
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
251-
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
252-
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
253-
// CHECK: return %[[RES]] : f32
254-
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
255-
%1 = vector.extract %0[1] : f32 from vector<4xf32>
256-
return %1 : f32
257-
}
258-
259-
// -----
260-
261-
// CHECK-LABEL: @extract_vec_elementwise
262-
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
263-
func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
264-
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
265-
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
266-
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
267-
// CHECK: return %[[RES]] : vector<4xf32>
268-
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
269-
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
270-
return %1 : vector<4xf32>
271-
}
272-
273-
// -----
274-
275-
// CHECK-LABEL: @extract_elementwise_use
276-
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
277-
func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
278-
// Dop not propagate extract, as elementwise has other uses
279-
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
280-
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
281-
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
282-
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
283-
%1 = vector.extract %0[1] : f32 from vector<4xf32>
284-
return %1, %0 : f32, vector<4xf32>
285-
}
286-
287-
// -----
288-
289247
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
290248
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
291249
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @extract_elementwise
4+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
5+
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
6+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
7+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
8+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
9+
// CHECK: return %[[RES]] : f32
10+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
11+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
12+
return %1 : f32
13+
}
14+
15+
// CHECK-LABEL: @extract_vec_elementwise
16+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
17+
func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
18+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
19+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
20+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
21+
// CHECK: return %[[RES]] : vector<4xf32>
22+
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
23+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
24+
return %1 : vector<4xf32>
25+
}
26+
27+
// CHECK-LABEL: @extract_elementwise_use
28+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
29+
func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
30+
// Do not propagate extract, as elementwise has other uses.
31+
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
32+
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
33+
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
34+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
35+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
36+
return %1, %0 : f32, vector<4xf32>
37+
}
38+
39+
module attributes {transform.with_named_sequence} {
40+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
41+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
42+
transform.apply_patterns to %func {
43+
transform.apply_patterns.vector.propagate_extract
44+
} : !transform.any_op
45+
transform.yield
46+
}
47+
}

0 commit comments

Comments
 (0)