-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern #91987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Hugo Trachino (nujaa) ChangesImplements Split of #90835 Full diff: https://github.com/llvm/llvm-project/pull/91987.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..7f5703b635068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_read inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
- transposePerm);
- return success();
+ return rewriter
+ .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+ .getResult();
}
};
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -207,11 +217,11 @@ struct TransferWritePermutationLowering
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- op.getMask(), newInBoundsAttr);
-
- return success();
+ return rewriter
+ .create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
+ .getResult();
}
};
@@ -231,14 +241,19 @@ struct TransferWritePermutationLowering
/// vector<1x8x16xf32>
/// ```
struct TransferWriteNonPermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- newMask, newInBoundsAttr);
- return success();
+ return rewriter
+ .create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
+ .getResult();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..a53e2a9e50ba2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,52 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
return
}
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+// CHECK: return %[[RES]]
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+ %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+ return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
+// CHECK-SAME: -> tensor<?x?x?x?xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+// CHECK: return %[[R]] : tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+ return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+ %arg0 : tensor<?x?x?x?xf32>,
+ %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+ %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+ // CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
+
///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------
@@ -101,6 +147,37 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
return %1 : vector<8x[4]x2xf32>
}
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
+// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+ call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+ return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+
+ %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+ {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+ : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+ return %1 : vector<8x[4]x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I really like how the code is gradually becoming self-documenting :)
LGTM!
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
6de7022
to
b62a36d
Compare
74894d7
to
c11da46
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Hugo, sorry for the delay with this.
Having read this again, I am realising that I forgot about MemRef
semantics when implementing MaskableOpRewritePattern
- thanks for fixing that! I think that it would be good to capture that with some additional comments - see my suggestions inline. It would also be good to updated the summary accordingly (something along the lines:
Updates MaskableOpRewritePattern so that it works correctly with MemRefs.
Feel free to re-use and/or re-write.
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); | ||
if (newWrite.hasPureTensorSemantics()) | ||
return newWrite.getResult(); | ||
// In memref case, MaskableOpRewritePattern cannot replaceOp with result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// In memref case, MaskableOpRewritePattern cannot replaceOp with result. | |
// In the memref case there's no return value. Use empty value to signal success. |
if (rootOp->getNumResults() == 0 || *newOp == Value()) | ||
rewriter.eraseOp(rootOp); | ||
else | ||
rewriter.replaceOp(rootOp, *newOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the only case that we are testing now is when "there's no return value" and "newOp is Value()". Hence I'm suggesting to replace ||
with &&
.
if (rootOp->getNumResults() == 0 || *newOp == Value()) | |
rewriter.eraseOp(rootOp); | |
else | |
rewriter.replaceOp(rootOp, *newOp); | |
// In the memref case there won't be a return value to replace. Instead, use an empty value to signal success. | |
if (rootOp->getNumResults() == 0 && *newOp == Value()) | |
rewriter.eraseOp(rootOp); | |
else | |
rewriter.replaceOp(rootOp, *newOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for late answer, I have been thinking about it while implementing and did not come up with a solution I liked. With the weekend fresh mind, Here is my point returning Value()
means it did NOT fail. aka code updates happened but no value to give e.g. memref case.
if we split cases :
if *newOp == Value()
| if NumResult == 0 // simple case
| | rewriter.eraseOp(rootOp);
| else
| | // We have to replace something with a value with Value() so there might be uses of rootOp in the rest
| | // of the program if we try to erase it. So I suggest to raise an error.
| | raise Error();
if pattern returns a value:
| if NumResult == 1 // simple case
| | rewriter.replaceOp(rootOp, *newOp);
| else // We created ops with a value which should replace something without a value. We can't use it in the program. It will most likely be DCE-ed.
| | rewriter.eraseOp(rootOp);
Which can then be reduced to
if (failed(newOp))
return failure();
if NumResult == 0
rewriter.eraseOp(rootOp);
else
assert(*newOp != Value() && "Can't replace an op use with Value()");
rewriter.replaceOp(rootOp, *newOp);
return success()
As an additionnal point, technically, matchAndRewriteMaskableOp
could return a ValueRange as replaceOp takes a ValueRange as input. replaceOp will assert rootOp->getNumResults() != newOp.size()
. And will allow to handle cases where ops have multiple results. But I suggest as part of a separate patch.
Thanks for your comments. I interpreted your comments slightly differently. Feel free to debate or merge if you are satisfied. |
We are on the same page here, thanks for seeing this through! One thing that this discussion makes me question - should |
TransferWritePermutationLowering
,TransferReadPermutationLowering
andTransferWriteNonPermutationLowering
as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside avector::MaskOp
Value()
as a return value formatchAndRewriteMaskableOp
now represents successful rewriting without value to replace the original op.Split of #90835