Skip to content

Commit 9a42732

Browse files
author
git apple-llvm automerger
committed
Merge commit 'f385f6c93b33' from llvm.org/main into next
2 parents 40dbce7 + f385f6c commit 9a42732

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
837837
// of which lane is responsible for which element is captured strictly
838838
// by shape information on the warp op, and thus requires materializing
839839
// the permutation in IR.
840-
if (!read.getPermutationMap().isMinorIdentity())
840+
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
841841
return failure();
842842
VectorType maskType =
843843
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
13511351

13521352
// -----
13531353

1354+
func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
1355+
%f0 = arith.constant 0.000000e+00 : f32
1356+
%c0 = arith.constant 0 : index
1357+
%r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
1358+
%mask = "mask_def_0"() : () -> (vector<128xi1>)
1359+
%0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
1360+
vector.yield %0 : vector<128xf32>
1361+
}
1362+
return %r : vector<2xf32>
1363+
}
1364+
1365+
// CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
1366+
// CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
1367+
// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
1368+
// CHECK-PROP-SAME: %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
1369+
// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
1370+
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
1371+
// CHECK-PROP: %[[M0:.*]] = "mask_def_0"
1372+
// CHECK-PROP: vector.yield %[[M0]] : vector<128xi1>
1373+
// CHECK-PROP: }
1374+
// CHECK-PROP: %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
1375+
// CHECK-PROP: vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
1376+
// CHECK-PROP-SAME: permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
1377+
1378+
// -----
1379+
13541380
func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
13551381
%f0 = arith.constant 0.000000e+00 : f32
13561382
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)