Skip to content

[MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern #92426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,20 @@ struct TransferWriteNonPermutationLowering
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
struct TransferOpReduceRank
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
// TODO: support masked case.
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");

AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
Expand Down Expand Up @@ -369,9 +375,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}

SmallVector<int64_t> newShape(
Expand All @@ -393,9 +399,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
};

Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----


// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
// CHECK: func.func @transfer_read_reduce_rank_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this test different to @permutation_with_mask_xfer_read_scalable? Both seem to check for vector.transfer_read from memref into a scalable vector?

Copy link
Contributor Author

@nujaa nujaa Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable. It would hence also trigger TransferReadPermutationLowering which IIRC would fail without #91987 if I masked it with vector.mask. Changes were not completely independent, this one is a bit more unit test-like

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable.

Cool, right now that's very unclear, unless you happen to know what to look for. Such distinction should be captured either:

  • through comments, and/or
  • test function name.

(preferably both). That's currently not the case, so lets fix that. I'm suggesting the following plan-of-action:

  • add a big bold comment to document what's being tested here, something similar to what's on L94-L101
  • add no_transpose to the test functions names that you added.

I can take it from there.

I have two more high-level points. A bit tangential to this PR - more like TODOs for myself 😅

What test should be included in vector-transfer-permutation-lowering.mlir?

Btw, based on the file name and the patterns being tested here:

there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:

  • transform-vector.mlir
  • vector-transfer-to-vector-load-store.mlir

😱 :) I'm on a mission to tidy this up a bit (while adding tests for scalable vectors).

Note on populateVectorTransferPermutationMapLoweringPatterns

Separately, it's bit confusing that TransferOpReduceRank is added to populateVectorTransferPermutationMapLoweringPatterns:

void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadPermutationLowering, TransferWritePermutationLowering,
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}

I think that it would be easier to "categorise" tests if the granularity was lower. Another PR though. I can handle this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, it deserves refactoring and standardization. I could have done better on commenting and naming.

add a big bold comment to document what's being tested here, something similar to what's on L94-L101

Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.

add no_transpose to the test functions names that you added.

I think this quite goes against the idea of having a permutation specific file. as you mention there :

there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:

But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose answer works.

these patterns are also tested in:

I would not mind about transform-vector.mlir as it is a full pipeline test. The goal is to show and test a complete lowering which makes sense. A bit like the e-2-e examples you added for SME.
For vector-transfer-to-vector-load-store.mlir , I think this file deserves improvements too. (such as the repetition of same transform sequences. ) The concept is to combine it with lower_transfer to test the complete lowering of xfer ops. it is not very unit test like but it shows how to properly lower xfer ops which is a benefit. Mixed feelings. I think there should not be transfer_permutation_patterns specific tests here, but I find sensible to test this combination of patterns. We should maybe ask @ftynse .

it's bit confusing that TransferOpReduceRank is added to populateVectorTransferPermutationMapLoweringPatterns:

Indeed, but after all it is a lowering based on the permutation map of a TransferRead. Hence, I think than rather removing TransferOpReduceRank, the ideal solution is to rename TransferOpReduceRank to make clear it matches patterns in the permutation map or renaming populateVectorTransferPermutationMapLoweringPatterns to reflect it is not mandatory a permutation. This set of patterns is useful as a whole for the community as a lowering pass.

I would also add something. As @hanhanW mentioned in #93664 (review) we could get rid of /// ``` and standardize it. in Doxygen descriptions for improved readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in Doxygen descriptions for improved readability.

It is also helpful when I parse the comment in the code directly. I'd appreciate if someone can help make it consistent. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax is mostly for markdown to me, so I don't know why people started adding such style to comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose answer works.

That's spot-on and something that I missed. In fact, this is documented here:

/// Collect a set of transfer read/write lowering patterns that simplify the
/// permutation map (e.g., converting it to a minor identity map) by inserting
/// broadcasts and transposes. More specifically:
///
/// [TransferReadPermutationLowering]
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identity +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
///
/// [TransferWritePermutationLowering]
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
///
/// [TransferOpReduceRank]
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
void populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

With this in mind, please ignore my comments under this heading:

Note on populateVectorTransferPermutationMapLoweringPatterns

In fact, sounds like vector-transfer-permutation-lowering.mlir is the right file for all patterns under populateVectorTransferPermutationMapLoweringPatterns.

Going back to the problem at hand ... This

Makes sense to me, it deserves refactoring and standardization.

and this:

Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.

First step in this direction: #95529. Let me know if that makes sense and I will prepare more tests.

I could have done better on commenting and naming.

Not your fault, it's hard to navigate this ATM (this is also why it takes me ages to review things). Clearly our test coverage is not great either. In general, Vector needs some TLC 😅 I really appreciate you helping us here 🙏🏻

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First step in this direction: #95529. Let me know if that makes sense and I will prepare more tests.

Next one: #96033

%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
return %1 : vector<8x[4]x2x3xf32>
}

// Masked case not supported.
// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK-NOT: vector.broadcast
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
%res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
return %res : vector<8x[4]x2x3xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
transform.apply_patterns.vector.transfer_permutation_patterns
} : !transform.any_op
transform.yield
}
}
Loading