-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add ElementwiseToOuterproduct #93664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) Changes1D multi-reduction are lowered to arith which can prevent some optimisation. I propose quote @MacDue
This can be rewritten as:
CC @banach-space , @dcaballe . Full diff: https://github.com/llvm/llvm-project/pull/93664.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa..ac55433fadb2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns that fold elementwise op on vectors to the vector
+/// dialect.
+void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..e1da09fba73a7 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.elementwise_to_vector",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns that fold elementwise op on vectors to the vector
+ dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.reduction_to_contract",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..6e13749a66415 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}
+void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateElementwiseToVectorOpsPatterns(patterns);
+}
+
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..d7ccfc4986068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
unsigned maxNumElementsToExtract = 0;
};
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
+template <typename MulOpType>
+struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+ using OpRewritePattern<MulOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOpType mulOp,
+ PatternRewriter &rewriter) const override {
+ auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
+ if (!VT)
+ return failure();
+ if (VT.getRank() != 2)
+ return failure();
+
+ auto canonicalize = [&](Value OperandA,
+ Value OperandB) -> vector::OuterProductOp {
+ vector::TransposeOp transposedLhs =
+ dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+ if (!transposedLhs)
+ return vector::OuterProductOp();
+ // Fail unless this is a true 2-D matrix transpose.
+ ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
+ return vector::OuterProductOp();
+
+ // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
+ // generating shape_casts/broadcasts which do not belong in this pattern.
+ vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
+ transposedLhs.getVector().getDefiningOp());
+ if (!broadcastedLhs ||
+ !broadcastedLhs.computeBroadcastedUnitDims().empty())
+ return vector::OuterProductOp();
+ // Avoid broadcast f32 or vector<f32> -> ResType
+ auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
+ if (!srcVT || srcVT.getRank() != 1)
+ return vector::OuterProductOp();
+
+ vector::BroadcastOp broadcastedRhs =
+ dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
+ if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+ return vector::OuterProductOp();
+
+ return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
+ mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
+ Value(), vector::CombiningKind::ADD);
+ };
+ Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+ vector::OuterProductOp outerP = canonicalize(a, b);
+ // Handle commutativity, the transposed op is the outerproduct LHS.
+ outerP = outerP ? outerP : canonicalize(b, a);
+ return outerP ? success() : failure();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
maxNumElementsToExtract, benefit);
}
+void mlir::vector::populateElementwiseToVectorOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
+ ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 75b29e22b4d2c..c170486f6ce27 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @ewise_outerproduct
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
+func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+ %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+ %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+ return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
+// CHECK: return %[[RES]] : vector<16x16xf32>
+func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
+ %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
+ %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+ %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
+ %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
+ return %mul: vector<16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.elementwise_to_vector
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This looks nice. cc @hanhanW and @qedawkins if this is useful for us to use in IREE. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch! I think it is useful for the output of vectorization. We could have a complicated linalg op (broadcast + transpose + reduction), and it recovers the information we want at n-D vector level.
ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); | ||
if (permutation[0] != 1 || permutation[1] != 0) | ||
return vector::OuterProductOp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that it is weird to have transpose with 0D or 1D vector types, but it is still a valid input IR. So I'd suggest to check permutation.size() == 2
as well. Otherwise, we could crash in accessing permutation[0] and permutation[1].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check rank == 2 a bit earlier.
if (resType.getRank() != 2)
return failure();
Co-authored-by: Han-Chung Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Hugo!!
Naming is hard, but let me try ... To me, FoldArithToVecOouterProd
would make more sense than ElementwiseToOuterproduct
. I just feel that "Elementwise" is a very broad category and this new pattern is not really matching OpTrait::Elementwise
.
More comments inline.
if (VT.getRank() != 2) | ||
return failure(); | ||
|
||
auto canonicalize = [&](Value OperandA, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I think that long lambdas (more than 2-3 lines) obscure readability. Why not create a helper hook outside this method? Others might disagree (looking at you @MacDue :) ).
- In any case, please document what this lambda does and use less generic name. "canonicalize" is a very powerful term and based on past experience I feel that it can mean different things depending on context. Also, we don't really want people hitting this when searching for "canonicalize" in MLIR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, it was quite inspired by CanonicalizeContract
pattern working the same way. Lambdas allow not to pass Location and rewriter. Which I find nice. Renamed and commented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lambdas allow not to pass Location and rewriter.
FYI, there is ImplicitLocOpBuilder
which allows you just passing a single builder (but not location). I don't find it useful, so I usually don't suggest it in my reviews.
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) | ||
/// into vector.outerproduct(A, B) such as : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Keep it shorter. In particular, "series" might suggest that more than one occurance of mulf(tr(broadcast(A)), broadcast(B))
is required:
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) | |
/// into vector.outerproduct(A, B) such as : | |
/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, B)`. Example: |
// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes | ||
// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, | ||
// CHECK-SAME: %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> { | ||
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32> | ||
// CHECK: return %[[RES]] : vector<8x4xf32> | ||
func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> { | ||
%lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32> | ||
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32> | ||
%rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32> | ||
%mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32> | ||
return %mul: vector<8x4xf32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, this test doesn't really add much - why not use "different" sizes in the two tests above?
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { | ||
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> | ||
// CHECK: return %[[RES]] : vector<[4]x[4]xi32> | ||
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to capture all unique features of a test in the corresponding function name.
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { | |
func.func @ewise_outerproduct_trans_lhs_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logics looks good to me, just few nits about coding style. Thanks!
if (!srcType || srcType.getRank() == 2) | ||
return false; | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about collapsing them to return srcType && srcType.getRank() == 2
?
vector::TransposeOp transposedLhs = | ||
dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use auto because the casting already spells the type; we can use operandA.getDefiningOp<>(vector::TransposeOp)
.
https://www.llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
Don’t “almost always” use auto, but do use auto with initializers like cast(...) or other places where the type is already obvious from the context.
if (VT.getRank() != 2) | ||
return failure(); | ||
|
||
auto canonicalize = [&](Value OperandA, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lambdas allow not to pass Location and rewriter.
FYI, there is ImplicitLocOpBuilder
which allows you just passing a single builder (but not location). I don't find it useful, so I usually don't suggest it in my reviews.
1D multi-reduction are lowered to arith which can prevent some optimisations. I propose `ElementwiseToOuterproduct` matching a series of ops to generate `vector.outerproduct`. As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse other elementwiseOps to vector dialect. Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24. quote @MacDue ``` %lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32> %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32> %rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32> %mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32> ``` Can be rewritten as: ``` %mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32> ``` --------- Co-authored-by: Han-Chung Wang <[email protected]>
1D multi-reduction are lowered to arith which can prevent some optimisation. I propose
ElementwiseToOuterproduct
matching a series of ops to generatevector.outerproduct
.As part of some
ElementwiseToVectorOpsPatterns
, it could allow to fuse other elementwiseOps to vector dialect.Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.
quote @MacDue
This can be rewritten as:
CC @banach-space , @dcaballe .