Skip to content

[mlir][vector] Add ElementwiseToOuterproduct #93664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 21, 2024
Merged

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 29, 2024

1D multi-reduction are lowered to arith which can prevent some optimisation. I propose ElementwiseToOuterproduct matching a series of ops to generate vector.outerproduct.
As part of some ElementwiseToVectorOpsPatterns, it could allow to fuse other elementwiseOps to vector dialect.
Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.

quote @MacDue

%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>

This can be rewritten as:

%mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>

CC @banach-space , @dcaballe .

@nujaa nujaa force-pushed the hugo.ewiseToOuterP branch from 0c7cc3a to dab57eb Compare May 29, 2024 10:41
@nujaa nujaa marked this pull request as ready for review May 29, 2024 13:24
@llvmbot
Copy link
Member

llvmbot commented May 29, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

1D multi-reduction are lowered to arith which can prevent some optimisation. I propose ElementwiseToOuterproduct matching a series of ops to generate vector.outerproduct.
As part of some ElementwiseToVectorOpsPatterns, it could allow to fuse other elementwiseOps to vector dialect.
Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.

quote @MacDue

%lhsBcast = vector.broadcast %lhsCast : vector&lt;[4]xf32&gt; to vector&lt;[4]x[4]xf32&gt;
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector&lt;[4]x[4]xf32&gt; to vector&lt;[4]x[4]xf32&gt;
%rhsBcast = vector.broadcast %rhs : vector&lt;[4]xf32&gt; to vector&lt;[4]x[4]xf32&gt;
%mul = arith.mulf %lhsT, %rhsBcast : vector&lt;[4]x[4]xf32&gt;

This can be rewritten as:

%mul = vector.outerproduct $lhs, $rhs : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;

CC @banach-space , @dcaballe .


Full diff: https://github.com/llvm/llvm-project/pull/93664.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4)
  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+11)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+75)
  • (modified) mlir/test/Dialect/Vector/transform-vector.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa..ac55433fadb2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
 /// into vector contract for the backends with native support.
 void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns that fold elementwise op on vectors to the vector
+/// dialect.
+void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..e1da09fba73a7 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.elementwise_to_vector",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns that fold elementwise op on vectors to the vector 
+    dialect.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.reduction_to_contract",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..6e13749a66415 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
   vector::populateFoldArithExtensionPatterns(patterns);
 }
 
+void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateElementwiseToVectorOpsPatterns(patterns);
+}
+
 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..d7ccfc4986068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
+template <typename MulOpType>
+struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+  using OpRewritePattern<MulOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOpType mulOp,
+                                PatternRewriter &rewriter) const override {
+    auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
+    if (!VT)
+      return failure();
+    if (VT.getRank() != 2)
+      return failure();
+
+    auto canonicalize = [&](Value OperandA,
+                            Value OperandB) -> vector::OuterProductOp {
+      vector::TransposeOp transposedLhs =
+          dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+      if (!transposedLhs)
+        return vector::OuterProductOp();
+      // Fail unless this is a true 2-D matrix transpose.
+      ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+      if (permutation[0] != 1 || permutation[1] != 0)
+        return vector::OuterProductOp();
+
+      // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
+      // generating shape_casts/broadcasts which do not belong in this pattern.
+      vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
+          transposedLhs.getVector().getDefiningOp());
+      if (!broadcastedLhs ||
+          !broadcastedLhs.computeBroadcastedUnitDims().empty())
+        return vector::OuterProductOp();
+      // Avoid broadcast f32 or vector<f32> -> ResType
+      auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
+      if (!srcVT || srcVT.getRank() != 1)
+        return vector::OuterProductOp();
+
+      vector::BroadcastOp broadcastedRhs =
+          dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
+      if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+        return vector::OuterProductOp();
+
+      return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
+          mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
+          Value(), vector::CombiningKind::ADD);
+    };
+    Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+    vector::OuterProductOp outerP = canonicalize(a, b);
+    // Handle commutativity, the transposed op is the outerproduct LHS.
+    outerP = outerP ? outerP : canonicalize(b, a);
+    return outerP ? success() : failure();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
                                          maxNumElementsToExtract, benefit);
 }
 
+void mlir::vector::populateElementwiseToVectorOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
+               ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 75b29e22b4d2c..c170486f6ce27 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @ewise_outerproduct
+//  CHECK-SAME:   %[[LHS:.*]]: vector<[4]xi32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+//       CHECK:     return %[[RES]] : vector<[4]x[4]xi32>
+func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+  %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+  return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+//  CHECK-SAME:   %[[LHS:.*]]: vector<16xf32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
+//       CHECK:     return %[[RES]] : vector<16x16xf32>
+func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
+  %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
+  %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
+  %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
+  return %mul: vector<16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.elementwise_to_vector
+    } : !transform.any_op
+    transform.yield
+  }
+}

Copy link

github-actions bot commented May 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@MaheshRavishankar
Copy link
Contributor

This looks nice. cc @hanhanW and @qedawkins if this is useful for us to use in IREE.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch! I think it is useful for the output of vectorization. We could have a complicated linalg op (broadcast + transpose + reduction), and it recovers the information we want at n-D vector level.

Comment on lines 1847 to 1849
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
if (permutation[0] != 1 || permutation[1] != 0)
return vector::OuterProductOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that it is weird to have transpose with 0D or 1D vector types, but it is still a valid input IR. So I'd suggest to check permutation.size() == 2 as well. Otherwise, we could crash in accessing permutation[0] and permutation[1].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check rank == 2 a bit earlier.

if (resType.getRank() != 2)
      return failure();

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Hugo!!

Naming is hard, but let me try ... To me, FoldArithToVecOouterProd would make more sense than ElementwiseToOuterproduct. I just feel that "Elementwise" is a very broad category and this new pattern is not really matching OpTrait::Elementwise.

More comments inline.

if (VT.getRank() != 2)
return failure();

auto canonicalize = [&](Value OperandA,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think that long lambdas (more than 2-3 lines) obscure readability. Why not create a helper hook outside this method? Others might disagree (looking at you @MacDue :) ).
  2. In any case, please document what this lambda does and use less generic name. "canonicalize" is a very powerful term and based on past experience I feel that it can mean different things depending on context. Also, we don't really want people hitting this when searching for "canonicalize" in MLIR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, it was quite inspired by CanonicalizeContract pattern working the same way. Lambdas allow not to pass Location and rewriter. Which I find nice. Renamed and commented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas allow not to pass Location and rewriter.

FYI, there is ImplicitLocOpBuilder which allows you just passing a single builder (but not location). I don't find it useful, so I usually don't suggest it in my reviews.

https://github.com/llvm/llvm-project/blob/0ec567c370df86893a22bf59d2716f6e553ca63b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h#L20-L23C7

Comment on lines 1798 to 1799
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
/// into vector.outerproduct(A, B) such as :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Keep it shorter. In particular, "series" might suggest that more than one occurance of mulf(tr(broadcast(A)), broadcast(B)) is required:

Suggested change
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
/// into vector.outerproduct(A, B) such as :
/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, B)`. Example:

Comment on lines 124 to 135
// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes
// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>,
// CHECK-SAME: %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32>
// CHECK: return %[[RES]] : vector<8x4xf32>
func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> {
%lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
%rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32>
return %mul: vector<8x4xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, this test doesn't really add much - why not use "different" sizes in the two tests above?

// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to capture all unique features of a test in the corresponding function name.

Suggested change
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
func.func @ewise_outerproduct_trans_lhs_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logics looks good to me, just few nits about coding style. Thanks!

Comment on lines 1826 to 1828
if (!srcType || srcType.getRank() == 2)
return false;
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about collapsing them to return srcType && srcType.getRank() == 2?

Comment on lines 1844 to 1845
vector::TransposeOp transposedLhs =
dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use auto because the casting already spells the type; we can use operandA.getDefiningOp<>(vector::TransposeOp).

https://www.llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

Don’t “almost always” use auto, but do use auto with initializers like cast(...) or other places where the type is already obvious from the context.

if (VT.getRank() != 2)
return failure();

auto canonicalize = [&](Value OperandA,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas allow not to pass Location and rewriter.

FYI, there is ImplicitLocOpBuilder which allows you just passing a single builder (but not location). I don't find it useful, so I usually don't suggest it in my reviews.

https://github.com/llvm/llvm-project/blob/0ec567c370df86893a22bf59d2716f6e553ca63b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h#L20-L23C7

@nujaa nujaa merged commit 9f0aa05 into llvm:main Jun 21, 2024
7 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
1D multi-reduction are lowered to arith which can prevent some
optimisations. I propose `ElementwiseToOuterproduct` matching a series of
ops to generate `vector.outerproduct`.
As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse
other elementwiseOps to vector dialect.
Originally discussed
https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.

quote @MacDue
```
%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>
```

Can be rewritten as:

```
%mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
```

---------

Co-authored-by: Han-Chung Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants