-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Update CombineContractBroadcastMask
#140050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Update CombineContractBroadcastMask
#140050
Conversation
This patch updates `CombineContractBroadcastMask` to inherit from `MaskableOpRewritePattern`, enabling it to handle masked `vector.contract` operations. The pattern rewrites: ```mlir %a = vector.broadcast %a_bc %res vector.contract %a_bc, %b, ... ``` into: ```mlir // Move the broadcast into vector.contract (by updating the indexing // maps) %res vector.contract %a, %b, ... ``` The main challenge is supporting cases where the pattern drops a leading unit dimension. For example: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> %result = vector.mask %mask { vector.contract { indexing_maps = [#map0, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add> } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> } : vector<1x8x8x4xi1> -> vector<8x8xi32> return %result : vector<8x8xi32> } ``` Here, the leading unit dimension is dropped. To handle this, the mask is cast to the correct shape using a `vector.shape_cast`: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0: vector<8x4xi32>, %arg1: vector<8x4xi32>, %arg2: vector<8x8xi32>, %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1> %res = vector.mask %mask_sc { vector.contract { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add> } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> } : vector<8x8x4xi1> -> vector<8x8xi32> return %res : vector<8x8xi32> } ``` While this isn't ideal — since it introduces a `vector.shape_cast` that must be cleaned up later — it reflects the best we can do once the input reaches `CombineContractBroadcastMask`. A more robust solution may involve simplifying the input earlier. I am leaving that as a TODO for myself to explore this further. Posting this now to unblock downstream work. LIMITATIONS Currently, this pattern assumes: * Only leading dimensions are dropped in the mask. * All dropped dimensions must be unit-sized. TODO: Check whether any other cases are possible.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThis patch updates %a = vector.broadcast %a_bc
%res vector.contract %a_bc, %b, ... into: // Move the broadcast into vector.contract (by updating the indexing
// maps)
%res vector.contract %a, %b, ... The main challenge is supporting cases where the pattern drops a leading func.func @<!-- -->contract_broadcast_unit_dim_reduction_masked(
%arg0 : vector<8x4xi32>,
%arg1 : vector<8x4xi32>,
%arg2 : vector<8x8xi32>,
%mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
%result = vector.mask %mask {
vector.contract {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
return %result : vector<8x8xi32>
} Here, the leading unit dimension is dropped. To handle this, the mask is func.func @<!-- -->contract_broadcast_unit_dim_reduction_masked(
%arg0: vector<8x4xi32>,
%arg1: vector<8x4xi32>,
%arg2: vector<8x8xi32>,
%arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
%res = vector.mask %mask_sc {
vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
} : vector<8x8x4xi1> -> vector<8x8xi32>
return %res : vector<8x8xi32>
} While this isn't ideal — since it introduces a LIMITATIONS Currently, this pattern assumes:
TODO: Check whether any other cases are possible. Full diff: https://github.com/llvm/llvm-project/pull/140050.diff 2 Files Affected:
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
<meta http-equiv="Content-Security-Policy" content="default-src 'none'; base-uri 'self'; connect-src 'self'; form-action 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline'">
<meta content="origin" name="referrer">
<title>Rate limit · GitHub</title>
<meta name="viewport" content="width=device-width">
<style type="text/css" media="screen">
body {
background-color: #f6f8fa;
color: #24292e;
font-family: -apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol;
font-size: 14px;
line-height: 1.5;
margin: 0;
}
.container { margin: 50px auto; max-width: 600px; text-align: center; padding: 0 24px; }
a { color: #0366d6; text-decoration: none; }
a:hover { text-decoration: underline; }
h1 { line-height: 60px; font-size: 48px; font-weight: 300; margin: 0px; text-shadow: 0 1px 0 #fff; }
p { color: rgba(0, 0, 0, 0.5); margin: 20px 0 40px; }
ul { list-style: none; margin: 25px 0; padding: 0; }
li { display: table-cell; font-weight: bold; width: 1%; }
.logo { display: inline-block; margin-top: 35px; }
.logo-img-2x { display: none; }
@media
only screen and (-webkit-min-device-pixel-ratio: 2),
only screen and ( min--moz-device-pixel-ratio: 2),
only screen and ( -o-min-device-pixel-ratio: 2/1),
only screen and ( min-device-pixel-ratio: 2),
only screen and ( min-resolution: 192dpi),
only screen and ( min-resolution: 2dppx) {
.logo-img-1x { display: none; }
.logo-img-2x { display: inline-block; }
}
#suggestions {
margin-top: 35px;
color: #ccc;
}
#suggestions a {
color: #666666;
font-weight: 200;
font-size: 14px;
margin: 0 10px;
}
</style>
</head>
<body>
<div class="container">
<h1>Whoa there!</h1>
<p>You have exceeded a secondary rate limit.<br><br>
Please wait a few minutes before you try again;<br>
in some cases this may take up to an hour.
</p>
<div id="suggestions">
<a href="https://support.github.com/contact">Contact Support</a> —
<a href="https://githubstatus.com">GitHub Status</a> —
<a href="https://twitter.com/githubstatus">@githubstatus</a>
</div>
<a href="/" class="logo logo-img-1x">
<img width="32" height="32" title="" alt="" src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAyRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADw/eHBhY2tldCBiZWdpbj0i77u/IiBpZD0iVzVNME1wQ2VoaUh6cmVTek5UY3prYzlkIj8+IDx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IkFkb2JlIFhNUCBDb3JlIDUuMy1jMDExIDY2LjE0NTY2MSwgMjAxMi8wMi8wNi0xNDo1NjoyNyAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczp4bXBNTT0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wL21tLyIgeG1sbnM6c3RSZWY9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZVJlZiMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIENTNiAoTWFjaW50b3NoKSIgeG1wTU06SW5zdGFuY2VJRD0ieG1wLmlpZDpFMTZCRDY3REIzRjAxMUUyQUQzREIxQzRENUFFNUM5NiIgeG1wTU06RG9jdW1lbnRJRD0ieG1wLmRpZDpFMTZCRDY3RUIzRjAxMUUyQUQzREIxQzRENUFFNUM5NiI+IDx4bXBNTTpEZXJpdmVkRnJvbSBzdFJlZjppbnN0YW5jZUlEPSJ4bXAuaWlkOkUxNkJENjdCQjNGMDExRTJBRDNEQjFDNEQ1QUU1Qzk2IiBzdFJlZjpkb2N1bWVudElEPSJ4bXAuZGlkOkUxNkJENjdDQjNGMDExRTJBRDNEQjFDNEQ1QUU1Qzk2Ii8+IDwvcmRmOkRlc2NyaXB0aW9uPiA8L3JkZjpSREY+IDwveDp4bXBtZXRhPiA8P3hwYWNrZXQgZW5kPSJyIj8+SM9MCAAAA+5JREFUeNrEV11Ik1EY3s4+ddOp29Q5b0opCgKFsoKoi5Kg6CIhuwi6zLJLoYLopq4qsKKgi4i6CYIoU/q5iDAKs6syoS76IRWtyJ+p7cdt7sf1PGOD+e0c3dygAx/67ZzzPM95/877GYdHRg3ZjMXFxepQKNS6sLCwJxqNNuFpiMfjVs4ZjUa/pmmjeD6VlJS8NpvNT4QQ7mxwjSsJiEQim/1+/9lgMHgIr5ohuxG1WCw9Vqv1clFR0dCqBODElV6v90ogEDjGdYbVjXhpaendioqK07CIR7ZAqE49PT09BPL2PMgTByQGsYiZlQD4uMXtdr+JxWINhgINYhGT2MsKgMrm2dnZXgRXhaHAg5jEJodUAHxux4LudHJE9RdEdA+i3Juz7bGHe4mhE9FNrgwBCLirMFV9Okh5eflFh8PR5nK5nDabrR2BNJlKO0T35+Li4n4+/J+/JQCxhmu5h3uJoXNHPbmWZAHMshWB8l5/ipqammaAf0zPDDx1ONV3vurdidqwAQL+pEc8sLcAe1CCvQ3YHxIW8Pl85xSWNC1hADDIv0rIE/o4J0k3kww4xSlwIhcq3EFFOm7KN/hUGOQkt0CFa5WpNJlMvxBEz/IVQAxg/ZRZl9wiHA63yDYieM7DnLP5CiAGsC7I5sgtYKJGWe2A8seFqgFJrJjEPY1Cn3pJ8/9W1e5VWsFDTEmFrBcoDhZJEQkXuhICMyKpjhahqN21hRYATKfUOlDmkygrR4o4C0VOLGJKrOITKB4jijzdXygBKixyC5TDQdnk/Pz8qRw6oOWGlsTKGOQW6OH6FBWsyePxdOXLTgxiyebILZCjz+GLgMIKnXNzc49YMlcRdHXcSwxFVgTInQhC9G33UhNoJLuqq6t345p9y3eUy8OTk5PjAHuI9uo4b07FBaOhsu0A4Unc+T1TU1Nj3KsSSE5yJ65jqF2DDd8QqWYmAZrIM2VlZTdnZmb6AbpdV9V6ec9znf5Q7HjYumdRE0JOp3MjitO4SFa+cZz8Umqe3TCbSLvdfkR/kWDdNQl5InuTcysOcpFT35ZrbBxx4p3JAHlZVVW1D/634VRt+FvLBgK/v5LV9WS+10xMTEwtRw7XvqOL+e2Q8V3AYIOIAXQ26/heWVnZCVfcyKHg2CBgTpmPmjYM8l24GyaUHyaIh7XwfR9ErE8qHoDfn2LTNAVC0HX6MFcBIP8Bi+6F6cdW/DICkANRfx99fEYFQ7Nph5i/uQiA214gno7K+guhaiKg9gC62+M8eR7XsBsYJ4ilam60Fb7r7uAj8wFyuwM1oIOWgfmDy6RXEEQzJMPe23DXrVS7rtyD3Df8z/FPgAEAzWU5Ku59ZAUAAAAASUVORK5CYII=">
</a>
<a href="/" class="logo logo-img-2x">
<img width="32" height="32" title="" alt="" src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAyRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADw/eHBhY2tldCBiZWdpbj0i77u/IiBpZD0iVzVNME1wQ2VoaUh6cmVTek5UY3prYzlkIj8+IDx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IkFkb2JlIFhNUCBDb3JlIDUuMy1jMDExIDY2LjE0NTY2MSwgMjAxMi8wMi8wNi0xNDo1NjoyNyAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczp4bXBNTT0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wL21tLyIgeG1sbnM6c3RSZWY9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZVJlZiMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIENTNiAoTWFjaW50b3NoKSIgeG1wTU06SW5zdGFuY2VJRD0ieG1wLmlpZDpEQUM1QkUxRUI0MUMxMUUyQUQzREIxQzRENUFFNUM5NiIgeG1wTU06RG9jdW1lbnRJRD0ieG1wLmRpZDpEQUM1QkUxRkI0MUMxMUUyQUQzREIxQzRENUFFNUM5NiI+IDx4bXBNTTpEZXJpdmVkRnJvbSBzdFJlZjppbnN0YW5jZUlEPSJ4bXAuaWlkOkUxNkJENjdGQjNGMDExRTJBRDNEQjFDNEQ1QUU1Qzk2IiBzdFJlZjpkb2N1bWVudElEPSJ4bXAuZGlkOkUxNkJENjgwQjNGMDExRTJBRDNEQjFDNEQ1QUU1Qzk2Ii8+IDwvcmRmOkRlc2NyaXB0aW9uPiA8L3JkZjpSREY+IDwveDp4bXBtZXRhPiA8P3hwYWNrZXQgZW5kPSJyIj8+hfPRaQAAB6lJREFUeNrsW2mME2UYbodtt+2222u35QheoCCYGBQligIJgkZJNPzgigoaTEj8AdFEMfADfyABkgWiiWcieK4S+QOiHAYUj2hMNKgYlEujpNttu9vttbvdw+chU1K6M535pt3ubHCSyezR+b73eb73+t7vrfXsufOW4bz6+vom9/b23ovnNNw34b5xYGAgODg46Mbt4mesVmsWd1qSpHhdXd2fuP/Afcput5/A88xwymcdBgLqenp6FuRyuWV4zu/v759QyWBjxoz5t76+/gun09mK5xFyakoCAPSaTCazNpvNPoYVbh6O1YKGRF0u13sNDQ27QMzfpiAAKj0lnU6/gBVfAZW2WWpwwVzy0IgP3G73FpjI6REhAGA9qVRqA1b9mVoBVyIC2tDi8Xg24+dUzQiAbS/s7Ox8G2o/3mKCC+Zw0efzPQEfcVjYrARX3dbV1bUtHo8fMgt42f+Mp0yUTVQbdWsAHVsikdiHkHaPxcQXQufXgUBgMRxme9U0AAxfH4vFvjM7eF6UkbJS5qoQwEQGA57Ac5JllFyUVZZ5ckUEgMVxsK2jlSYzI+QXJsiyjzNEAJyJAzb/KQa41jJKL8pODMQiTEAymXw5n8/P0IjD3bh7Rgog59aanxiIRTVvV/oj0tnHca/WMrVwODwB3raTGxzkBg/gnZVapFV62Wy2n5AO70HM/5wbJ0QnXyQSaVPDIuNZzY0V3ntHMwxiwHA0Gj2Np7ecIBDgaDAYXKCQJM1DhrgJ3nhulcPbl8j4NmHe46X/g60fwbz3aewjkqFQaAqebWU1AOqyQwt8Id6qEHMc97zu7u7FGGsn7HAiVuosVw7P35C1nccdgSCxop1dHeZswmfHMnxBo6ZTk+jN8dl/vF7vWofDsa+MLN9oEUBMxOb3+1eoEsBVw6Zmua49r8YmhAKDiEPcMwBsxMiqQ+ixzPFxZyqRpXARG/YOr1ObFJ0gUskXBbamcR1OKmMUvDxHRAu8/LmY3jFLMUpFqz9HxG65smYJdyKyECOxDiEAe/p1gjF2oonivZAsxVgl2daa4EQWCW6J55qFAFFZiJWYLxNQy2qOSUzGRsyXCUDIeliwAHEO4WSlWQBRFoZakXcKmCXmyXAKs0Ve9vl8q42WoIYpJU4hV3hKcNs8m9gl7p/xQ73eF5kB4j5mNrWmTJRNwAzqiV1CxjVTZCIkEq+Z1bZFZSN2CenmVAFVy4Plz8xKAGWjjAKFk6lCBMDR/MJjLLMSQNm43xAiQKTaA+9/wewhDjL+JVI1kkTSSOTcKbMTwPqESAot6dn6Fr1gHwVJju6IRuyiByPuUUBAg5DGkAgBmxlvdgIEK9gDkohdY/BJo4CAG0R8miRSsGABkgVQs4KXu098IgUXSSRsFAoKZiVAVDY2WUiiPTjYRi41KwGisrGsLtlsth8Fiwnz2fBkQvWfRtlE3iF2yW63/yCacXZ1dW02GwGyTFaRd4idJnCKHRaCxYRHoG5LTKT6SyiToP1fJHbmAYPYRR0UnZQtMnA6s0zg+GZBlt0Gdo7EPHgpE3Q6nZ8YyLhc8Xj8MJh/aKTAY+5FPAKHLE7RdwuYJZmNwzyCMkBCYyKROJBMJl9B/PXXCjjmCmDOVzH3fiPpObEWGqoKe4EBl8v1hlqsdLvd23mkxHM9pc9kMpmno9HoeTii7ewbHEZPPx1ztLS1tV3AnGuMjiNjvbQFuHw6zDo5By7dTPAQNBgMLrRarTkSls1mnwT7uwp9virx9QzbW/HuV/j5d/b+6jniKlllP8lkeONJDk+dq9GsQTnC4fB1heO0K47Hwe7WdDr9nAKgXwOBwHI+C45Htj1d6sd429TUNEcmUdc+PRaLHcvn87dXW4ugzdsaGxufL94NFv9zi1J7GVbhlvb2dnaJ3SVrxfc+n2+NTsZ7/H7/Mr3g5XdSIHyJSH1PZ+7fToyl2+ErqilgZ4NaLYB9goVGaHjR93Hv1ZrU4XDsFT20kH3PObzbWk0CgG1jacVIUnAQb9F+VexyLMzkpcLv0IJV7AHQIOCAUYHx7v5qgScmYHtTqSAyZLEJTK22Bie4iq3xsqpm4SAf9Hq9a2DnJ4uLK3SEULcdRvp3i3zHySqpficxEdsQc1NrlYXXvR+O7qASSezXB+h1SuUomgg9LL8BUoV4749EIolKh+EiqWmqVEZlDgHks2pxHw7xTqUQw9J5NcAXOK10AGIoZ6Zli6JY6Z1Q461KoZ4NiKLHarW+KDsxlDUPHZ5zPQZqUVDPJsTqb5n9malbpAh8C2XXDLl62+WZIDFRUlNVOiwencnNU3aQEkL+cDMSoLvZo2fQB7AJssNAuFuvorlDVVkkg2I87+jo2K2QAVphDrfyViK5VqtO34OkaxXCp+7drdDBCAdubm6eidX+2WwqT5komwh4YQLk+H4aE93h8Xg2gvHekQZOGSgLZTLyDTLJ4Lx9/KZWKBSainT4Iy3FqQBfnUZR42PKQFksBr9QKVXCPusD3OiA/RkQ5kP8qV/Jl1WywAp/6+dcmPM2zL1UrUahe4JqfnWWKXIul3uUbfP8njAFLW1OFr3gdFtZ72cNH+PtQT7/brW+NXqJAHh0y9V8/U/A1U7AfwIMAD7mS3pCbuWJAAAAAElFTkSuQmCC">
</a>
</div>
</body>
</html>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a few nits and a question about scalable vector support.
* Add tests for scalable vectors * Capitalize all LIT variables used for maps * Fix punctuation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just one final nit! LGTM
It really looks like the pattern is trying to do an in-place simplification that should have happened before and adding some complexity to achieve that. Would it possible to just reject these cases in the pattern when the broadcasts to be folded are actually "no-ops"? This is pointing at "we need to remove unnecessary unit dims before calling this pattern" kind of requirement... |
Hey Diego - your comments are spot on and very much aligned with what we've been thinking.
In practice, quite a bit happens before we hit this pattern. My goal here was to minimally extend For context, here’s the IR we're seeing - this is just the part that extracts arguments for the masked // MASK
%13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
%mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
// LHS - %4 comes from an xfer Op
%rhs = vector.extract %4[0, 0] : vector<2x[4]xi32> from vector<1x1x2x[4]xi32>
// RHS - %2 comes from an xfer Op
%10 = vector.extract %2[0, 0, 0] : vector<2x[4]x8xi8> from vector<1x1x1x2x[4]x8xi8>
%11 = arith.extsi %10 : vector<2x[4]x8xi8> to vector<2x[4]x8xi32>
%rhs = vector.broadcast %11 : vector<2x[4]x8xi32> to vector<1x2x[4]x8xi32> My thinking is that introducing // MASK
%13 = vector.create_mask %c1, %c1, %c1, %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1>
%mask = vector.extract %13[0, 0] : vector<1x2x[4]x8xi1> from vector<1x1x1x2x[4]x8xi1>
%mask_sc = vector.shape_cast %mask vector<1x2x[4]x8xi1> to vector<2x[4]x8xi1> Indeed, the code above could be simplified as: // MASK
%13 = vector.create_mask %c2, %dim, %c8 : vector<1x1x1x2x[4]x8xi1> …but I don’t think we want to be re-writing arbitrary I agree it would be better if the IR were cleaned up before reaching this point - I'll look into whether earlier patterns could be improved. So far, it looked a bit tricky, but there may be low-hanging fruit.
Totally hear you - and I should mention: GitHub makes the diff look more invasive than it is. This patch mainly just wraps Let me know what you think - it’s possible I’m missing a simpler approach here.🤔 |
Swap masked and scalable tests (thanks Hanhan)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for elaborating. Ok, let's move on!
FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp, | ||
MaskingOpInterface maskingOp, | ||
PatternRewriter &rewriter) { | ||
SmallVector<AffineMap> maps = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅 Since I am not touching this code, I'd rather leave it as is.
Value rhs = contractOp.getRhs(); | ||
size_t index = 0; | ||
bool changed = false; | ||
for (Value *operand : {&lhs, &rhs}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious... what is this doing and why are we using a Value *
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this doing
Iterates over&lhs
and&rhs
(which areValue
s).
why are we using a Value *
On L327
*operand = broadcast.getSource();
Apologies if I am stating the obvious, I wasn't sure what specifically you are asking about. Also, note that I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅
This patch updates
CombineContractBroadcastMask
to inherit fromMaskableOpRewritePattern
, enabling it to handle maskedvector.contract
operations. The pattern rewrites:into:
The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a
vector.shape_cast
:While this isn't ideal — since it introduces a
vector.shape_cast
thatmust be cleaned up later — it reflects the best we can do once the input
reaches
CombineContractBroadcastMask
. A more robust solution mayinvolve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.
LIMITATIONS
Currently, this pattern assumes:
TODO: Check whether any other cases are possible.