Skip to content

Commit f385f6c

Browse files
authored
[mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads (llvm#73539)
The primary difficulty with distribution of masked transfers is when the permutation map permutes the vector, in which case the distribution logic needs to make sure the correct mask elements end up with the distributed transfer. This is only tricky when the permutation map has a permutation in it, so we can relax the condition for distribution.
1 parent 0f18984 commit f385f6c

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
837837
// of which lane is responsible for which element is captured strictly
838838
// by shape information on the warp op, and thus requires materializing
839839
// the permutation in IR.
840-
if (!read.getPermutationMap().isMinorIdentity())
840+
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
841841
return failure();
842842
VectorType maskType =
843843
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
13511351

13521352
// -----
13531353

1354+
func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
1355+
%f0 = arith.constant 0.000000e+00 : f32
1356+
%c0 = arith.constant 0 : index
1357+
%r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
1358+
%mask = "mask_def_0"() : () -> (vector<128xi1>)
1359+
%0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
1360+
vector.yield %0 : vector<128xf32>
1361+
}
1362+
return %r : vector<2xf32>
1363+
}
1364+
1365+
// CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
1366+
// CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
1367+
// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
1368+
// CHECK-PROP-SAME: %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
1369+
// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
1370+
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
1371+
// CHECK-PROP: %[[M0:.*]] = "mask_def_0"
1372+
// CHECK-PROP: vector.yield %[[M0]] : vector<128xi1>
1373+
// CHECK-PROP: }
1374+
// CHECK-PROP: %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
1375+
// CHECK-PROP: vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
1376+
// CHECK-PROP-SAME: permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
1377+
1378+
// -----
1379+
13541380
func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
13551381
%f0 = arith.constant 0.000000e+00 : f32
13561382
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)