Skip to content

[mlir][vector] Propagate vector.extract through elementwise ops #131462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,27 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.sink_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Patterns that remove redundant Vector Ops by re-ordering them with
e.g. elementwise Ops:
```
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
```
gets converted to:
```
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
At the moment, these patterns are limited to vector.broadcast and
vector.transpose.
}];

let assemblyFormat = "attr-dict";
}

#endif // VECTOR_TRANSFORM_OPS
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);

// TODO: As we now have a dedicated transform for
// `populateSinkVectorOpsPatterns` we can remove it from here.
vector::populateSinkVectorOpsPatterns(patterns);
}

Expand Down Expand Up @@ -204,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
}

void transform::ApplySinkVectorPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateSinkVectorOpsPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
61 changes: 59 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,63 @@ struct ReorderElementwiseOpsOnBroadcast final
}
};

/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
/// This may result in cleaner code when we extracting a single value
/// from multi-element vector and also to help canonicalize 1-element vectors to
/// scalars.
/// ```
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
/// ```
/// Gets converted to:
/// ```
/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
/// %2 = arith.addf %0, %1 : f32
/// ```
class ExtractOpFromElementwise final
: public OpRewritePattern<vector::ExtractOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
Operation *eltwise = op.getVector().getDefiningOp();

if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise))
return rewriter.notifyMatchFailure(op, "not an elementwise op");

if (eltwise->getNumResults() != 1)
return rewriter.notifyMatchFailure(op, "expected single result");

if (!eltwise->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");

if (!llvm::all_equal(eltwise->getOperandTypes()))
return rewriter.notifyMatchFailure(op, "operand types are different");

Type dstType = op.getType();

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(eltwise);

IRMapping mapping;
Location loc = eltwise->getLoc();
SmallVector<OpFoldResult> pos = op.getMixedPosition();
for (Value arg : eltwise->getOperands()) {
Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
mapping.map(arg, newArg);
}

Operation *newEltwise = rewriter.clone(*eltwise, mapping);
newEltwise->getResult(0).setType(dstType);

rewriter.replaceOp(op, newEltwise);
rewriter.eraseOp(eltwise);
return success();
}
};

// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
Expand Down Expand Up @@ -2111,8 +2168,8 @@ void mlir::vector::
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
benefit);
ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
patterns.getContext(), benefit);
}

void mlir::vector::populateChainedVectorReductionFoldingPatterns(
Expand Down
55 changes: 24 additions & 31 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16


// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 79 : index
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>

// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>

// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index,
// CHECK-SAME: %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
// CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index

// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
// CHECK: }

// -----
Expand All @@ -98,19 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
}

// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {

// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index

// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
// CHECK: }

// -----
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Vector/vector-sink-transform.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new file, could you try adding a new RUN line here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-sink.mlir. You will probably have to put the TD sequence into a separate file. Here's an example:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s

// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this
// file is also used in `vector-sink.mlir`.
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.sink_ops
} : !transform.any_op
transform.yield
}
}
78 changes: 78 additions & 0 deletions mlir/test/Dialect/Vector/vector-sink.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s

//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
Expand Down Expand Up @@ -423,3 +424,80 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
return %r : vector<6x[4]x2x3xf32>
}

// -----

Comment on lines +427 to +429
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a block comment documenting which pattern is being tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

//-----------------------------------------------------------------------------
// [Pattern: ExtractOpFromElementwise]
//-----------------------------------------------------------------------------

// CHECK-LABEL: @extract_elementwise_scalar
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
// CHECK: return %[[RES]] : f32
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}

// CHECK-LABEL: @extract_elementwise_arg_res_different_types
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>)
func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 {
// CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
// CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
// CHECK: return %[[RES]] : i64
%0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64>
%1 = vector.extract %0[1] : i64 from vector<4xi64>
return %1 : i64
}

// CHECK-LABEL: @extract_elementwise_vec
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
// CHECK: return %[[RES]] : vector<4xf32>
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
return %1 : vector<4xf32>
}

// CHECK-LABEL: @negative_extract_elementwise_no_single_use
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
// Do not propagate extract, as elementwise has other uses.
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1, %0 : f32, vector<4xf32>
}

// CHECK-LABEL: @negative_extract_elementwise_not_one_res
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
// Do not propagate extract, as elementwise has more than 1 result.
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
// CHECK: return %[[EXT]] : i32
%low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
%1 = vector.extract %low[1] : i32 from vector<4xi32>
return %1 : i32
}

// CHECK-LABEL: @negative_extract_not_elementwise
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
// `test.increment` is not an elemewise op.
// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
// CHECK: return %[[RES]] : i64
%0 = test.increment %arg0: vector<4xi64>
%1 = vector.extract %0[1] : i64 from vector<4xi64>
return %1 : i64
}