Skip to content

[mlir][vector] Update CombineContractBroadcastMask #140050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

banach-space
Copy link
Contributor

This patch updates CombineContractBroadcastMask to inherit from
MaskableOpRewritePattern, enabling it to handle masked
vector.contract operations. The pattern rewrites:

  %a = vector.broadcast %a_bc
  %res vector.contract %a_bc, %b, ...

into:

  // Move the broadcast into vector.contract (by updating the indexing
  // maps)
  %res vector.contract %a, %b, ...

The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:

func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0 : vector<8x4xi32>,
    %arg1 : vector<8x4xi32>,
    %arg2 : vector<8x8xi32>,
    %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
  %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
  %result = vector.mask %mask {
    vector.contract {
      indexing_maps = [#map0, #map1, #map2],
      iterator_types = ["reduction", "parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
  } : vector<1x8x8x4xi1> -> vector<8x8xi32>

  return %result : vector<8x8xi32>
}

Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a vector.shape_cast:

func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0: vector<8x4xi32>,
    %arg1: vector<8x4xi32>,
    %arg2: vector<8x8xi32>,
    %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
  %res = vector.mask %mask_sc {
    vector.contract {
      indexing_maps = [#map, #map1, #map2],
      iterator_types = ["parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
  } : vector<8x8x4xi1> -> vector<8x8xi32>

  return %res : vector<8x8xi32>
}

While this isn't ideal — since it introduces a vector.shape_cast that
must be cleaned up later — it reflects the best we can do once the input
reaches CombineContractBroadcastMask. A more robust solution may
involve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.

LIMITATIONS

Currently, this pattern assumes:

  • Only leading dimensions are dropped in the mask.
  • All dropped dimensions must be unit-sized.

TODO: Check whether any other cases are possible.

This patch updates `CombineContractBroadcastMask` to inherit from
`MaskableOpRewritePattern`, enabling it to handle masked
`vector.contract` operations. The pattern rewrites:
```mlir
  %a = vector.broadcast %a_bc
  %res vector.contract %a_bc, %b, ...
```

into:
```mlir
  // Move the broadcast into vector.contract (by updating the indexing
  // maps)
  %res vector.contract %a, %b, ...
```

The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0 : vector<8x4xi32>,
    %arg1 : vector<8x4xi32>,
    %arg2 : vector<8x8xi32>,
    %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
  %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
  %result = vector.mask %mask {
    vector.contract {
      indexing_maps = [#map0, #map1, #map2],
      iterator_types = ["reduction", "parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
  } : vector<1x8x8x4xi1> -> vector<8x8xi32>

  return %result : vector<8x8xi32>
}
```

Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a `vector.shape_cast`:

```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0: vector<8x4xi32>,
    %arg1: vector<8x4xi32>,
    %arg2: vector<8x8xi32>,
    %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
  %res = vector.mask %mask_sc {
    vector.contract {
      indexing_maps = [#map, #map1, #map2],
      iterator_types = ["parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
  } : vector<8x8x4xi1> -> vector<8x8xi32>

  return %res : vector<8x8xi32>
}
```

While this isn't ideal — since it introduces a `vector.shape_cast` that
must be cleaned up later — it reflects the best we can do once the input
reaches `CombineContractBroadcastMask`. A more robust solution may
involve simplifying the input earlier. I am leaving that as  a TODO for
myself to explore this further. Posting this now to unblock downstream
work.

LIMITATIONS

Currently, this pattern assumes:
* Only leading dimensions are dropped in the mask.
* All dropped dimensions must be unit-sized.

TODO: Check whether any other cases are possible.
@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch updates CombineContractBroadcastMask to inherit from
MaskableOpRewritePattern, enabling it to handle masked
vector.contract operations. The pattern rewrites:

  %a = vector.broadcast %a_bc
  %res vector.contract %a_bc, %b, ...

into:

  // Move the broadcast into vector.contract (by updating the indexing
  // maps)
  %res vector.contract %a, %b, ...

The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:

func.func @<!-- -->contract_broadcast_unit_dim_reduction_masked(
    %arg0 : vector&lt;8x4xi32&gt;,
    %arg1 : vector&lt;8x4xi32&gt;,
    %arg2 : vector&lt;8x8xi32&gt;,
    %mask: vector&lt;1x8x8x4xi1&gt;) -&gt; vector&lt;8x8xi32&gt; {

  %0 = vector.broadcast %arg0 : vector&lt;8x4xi32&gt; to vector&lt;1x8x4xi32&gt;
  %1 = vector.broadcast %arg1 : vector&lt;8x4xi32&gt; to vector&lt;1x8x4xi32&gt;
  %result = vector.mask %mask {
    vector.contract {
      indexing_maps = [#map0, #map1, #map2],
      iterator_types = ["reduction", "parallel", "parallel", "reduction"],
      kind = #vector.kind&lt;add&gt;
    } %0, %1, %arg2 : vector&lt;1x8x4xi32&gt;, vector&lt;1x8x4xi32&gt; into vector&lt;8x8xi32&gt;
  } : vector&lt;1x8x8x4xi1&gt; -&gt; vector&lt;8x8xi32&gt;

  return %result : vector&lt;8x8xi32&gt;
}

Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a vector.shape_cast:

func.func @<!-- -->contract_broadcast_unit_dim_reduction_masked(
    %arg0: vector&lt;8x4xi32&gt;,
    %arg1: vector&lt;8x4xi32&gt;,
    %arg2: vector&lt;8x8xi32&gt;,
    %arg3: vector&lt;1x8x8x4xi1&gt;) -&gt; vector&lt;8x8xi32&gt; {

  %mask_sc = vector.shape_cast %arg3 : vector&lt;1x8x8x4xi1&gt; to vector&lt;8x8x4xi1&gt;
  %res = vector.mask %mask_sc {
    vector.contract {
      indexing_maps = [#map, #map1, #map2],
      iterator_types = ["parallel", "parallel", "reduction"],
      kind = #vector.kind&lt;add&gt;
    } %arg0, %arg1, %mask_sc : vector&lt;8x4xi32&gt;, vector&lt;8x4xi32&gt; into vector&lt;8x8xi32&gt;
  } : vector&lt;8x8x4xi1&gt; -&gt; vector&lt;8x8xi32&gt;

  return %res : vector&lt;8x8xi32&gt;
}

While this isn't ideal — since it introduces a vector.shape_cast that
must be cleaned up later — it reflects the best we can do once the input
reaches CombineContractBroadcastMask. A more robust solution may
involve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.

LIMITATIONS

Currently, this pattern assumes:

  • Only leading dimensions are dropped in the mask.
  • All dropped dimensions must be unit-sized.

TODO: Check whether any other cases are possible.


Full diff: https://github.com/llvm/llvm-project/pull/140050.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+157-94)
  • (modified) mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir (+75-2)
<!DOCTYPE html>
<html>
  <head>
    <meta http-equiv="Content-type" content="text/html; charset=utf-8">
    <meta http-equiv="Content-Security-Policy" content="default-src 'none'; base-uri 'self'; connect-src 'self'; form-action 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline'">
    <meta content="origin" name="referrer">
    <title>Rate limit &middot; GitHub</title>
    <meta name="viewport" content="width=device-width">
    <style type="text/css" media="screen">
      body {
        background-color: #f6f8fa;
        color: #24292e;
        font-family: -apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol;
        font-size: 14px;
        line-height: 1.5;
        margin: 0;
      }

      .container { margin: 50px auto; max-width: 600px; text-align: center; padding: 0 24px; }

      a { color: #0366d6; text-decoration: none; }
      a:hover { text-decoration: underline; }

      h1 { line-height: 60px; font-size: 48px; font-weight: 300; margin: 0px; text-shadow: 0 1px 0 #fff; }
      p { color: rgba(0, 0, 0, 0.5); margin: 20px 0 40px; }

      ul { list-style: none; margin: 25px 0; padding: 0; }
      li { display: table-cell; font-weight: bold; width: 1%; }

      .logo { display: inline-block; margin-top: 35px; }
      .logo-img-2x { display: none; }
      @media
      only screen and (-webkit-min-device-pixel-ratio: 2),
      only screen and (   min--moz-device-pixel-ratio: 2),
      only screen and (     -o-min-device-pixel-ratio: 2/1),
      only screen and (        min-device-pixel-ratio: 2),
      only screen and (                min-resolution: 192dpi),
      only screen and (                min-resolution: 2dppx) {
        .logo-img-1x { display: none; }
        .logo-img-2x { display: inline-block; }
      }

      #suggestions {
        margin-top: 35px;
        color: #ccc;
      }
      #suggestions a {
        color: #666666;
        font-weight: 200;
        font-size: 14px;
        margin: 0 10px;
      }

    </style>
  </head>
  <body>

    <div class="container">

      <h1>Whoa there!</h1>
      <p>You have exceeded a secondary rate limit.<br><br>
        Please wait a few minutes before you try again;<br>
        in some cases this may take up to an hour.
      </p>
      <div id="suggestions">
        <a href="https://support.github.com/contact">Contact Support</a> &mdash;
        <a href="https://githubstatus.com">GitHub Status</a> &mdash;
        <a href="https://twitter.com/githubstatus">@githubstatus</a>
      </div>

      <a href="/" class="logo logo-img-1x">
        <img width="32" height="32" title="" alt="" src="">
      </a>

      <a href="/" class="logo logo-img-2x">
        <img width="32" height="32" title="" alt="" src="">
      </a>
    </div>
  </body>
</html>

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a few nits and a question about scalable vector support.

* Add tests for scalable vectors
* Capitalize all LIT variables used for maps
* Fix punctuation
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just one final nit! LGTM

@dcaballe
Copy link
Contributor

A more robust solution may involve simplifying the input earlier.

It really looks like the pattern is trying to do an in-place simplification that should have happened before and adding some complexity to achieve that. Would it possible to just reject these cases in the pattern when the broadcasts to be folded are actually "no-ops"? This is pointing at "we need to remove unnecessary unit dims before calling this pattern" kind of requirement...

@banach-space
Copy link
Contributor Author

Hey Diego - your comments are spot on and very much aligned with what we've been thinking.

Would it possible to just reject these cases in the pattern when the broadcasts to be folded are actually "no-ops"? This is pointing at "we need to remove unnecessary unit dims before calling this pattern" kind of requirement...

In practice, quite a bit happens before we hit this pattern. My goal here was to minimally extend CombineContractBroadcastMask to unblock us, and then separately investigate a more principled solution.

For context, here’s the IR we're seeing - this is just the part that extracts arguments for the masked vector.contract:

  // MASK
  %13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
  %mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
 
  // LHS - %4 comes from an xfer Op
  %rhs = vector.extract %4[0, 0] : vector<2x[4]xi32> from vector<1x1x2x[4]xi32>

  // RHS - %2 comes from an xfer Op
  %10 = vector.extract %2[0, 0, 0] : vector<2x[4]x8xi8> from vector<1x1x1x2x[4]x8xi8>
  %11 = arith.extsi %10 : vector<2x[4]x8xi8> to vector<2x[4]x8xi32>
  %rhs = vector.broadcast %11 : vector<2x[4]x8xi32> to vector<1x2x[4]x8xi32>

My thinking is that introducing vector.shape_cast creates something that is easy to correct for (with a different pattern):

  // MASK
  %13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
  %mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
  %mask_sc = vector.shape_cast %mask  vector<1x2x[4]x8xi1> to  vector<2x[4]x8xi1>

Indeed, the code above could be simplified as:

  // MASK
  %13 = vector.create_mask %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>

…but I don’t think we want to be re-writing arbitrary vector.create_mask directly within CombineContractBroadcastMask, right?

I agree it would be better if the IR were cleaned up before reaching this point - I'll look into whether earlier patterns could be improved. So far, it looked a bit tricky, but there may be low-hanging fruit.

It really looks like the pattern is trying to do an in-place simplification that should have happened before and adding some complexity to achieve that.

Totally hear you - and I should mention: GitHub makes the diff look more invasive than it is. This patch mainly just wraps CombineContractBroadcastMask in MaskableOpRewritePattern to support masks. So while the diff appears large, the actual change is pretty minimal 😅

Let me know what you think - it’s possible I’m missing a simpler approach here.🤔

@hanhanW
Copy link
Contributor

hanhanW commented May 20, 2025

So while the diff appears large, the actual change is pretty minimal 😅

Here is my tip to review such change: you can use hide whitespace feature to make visualization better.

image

Swap masked and scalable tests (thanks Hanhan)
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for elaborating. Ok, let's move on!

FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) {
SmallVector<AffineMap> maps =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅 Since I am not touching this code, I'd rather leave it as is.

Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious... what is this doing and why are we using a Value *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing
Iterates over &lhs and &rhs (which are Values).

why are we using a Value *

On L327

    *operand = broadcast.getSource();

Apologies if I am stating the obvious, I wasn't sure what specifically you are asking about. Also, note that I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅

@banach-space banach-space merged commit e22508e into llvm:main May 27, 2025
11 checks passed
@banach-space banach-space deleted the andrzej/vector/update_CombineContractBroadcast branch May 27, 2025 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants