Skip to content

Commit b5e47d2

Browse files
authored
[mlir][vector] Add extra check on distribute types to avoid crashes (#102952)
This PR addresses the issue detailed in iree-org/iree#17948. The problem occurs when distributed types are set to NULL, leading to compilation crashes. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent abc1acf commit b5e47d2

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
16891689
}
16901690
});
16911691

1692+
if (llvm::is_contained(distTypes, Type{}))
1693+
return failure();
1694+
16921695
SmallVector<size_t> newRetIndices;
16931696
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
16941697
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,40 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
620620

621621
// -----
622622

623+
// CHECK-PROP-LABEL: func @warp_distribute(
624+
// CHECK-PROP-SAME: %[[ID:[a-zA-Z0-9]+]]
625+
// CHECK-PROP-SAME: %[[SRC:[a-zA-Z0-9]+]]
626+
// CHECK-PROP-SAME: %[[DEST:[a-zA-Z0-9]+]]
627+
// CHECK-PROP: vector.warp_execute_on_lane_0(%[[ID]])[32]
628+
// CHECK-PROP-NEXT: "some_def"() : () -> vector<4096xf32>
629+
// CHECK-PROP-NEXT: %{{.*}} = vector.reduction
630+
// CHECK-PROP: %[[DEF:.*]] = arith.divf %{{.*}}, %{{.*}} : vector<1xf32>
631+
// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
632+
// CHECK-PROP: scf.for
633+
// CHECK-PROP: %{{.*}} = arith.subf %{{.*}}, %[[DEF]] : vector<1xf32>
634+
func.func @warp_distribute(%arg0: index, %src: memref<128xf32>, %dest: memref<128xf32>){
635+
%cst = arith.constant 0.000000e+00 : f32
636+
%c0 = arith.constant 0 : index
637+
%c1 = arith.constant 1 : index
638+
%c128 = arith.constant 128 : index
639+
%f0 = arith.constant 0.000000e+00 : f32
640+
vector.warp_execute_on_lane_0(%arg0)[32]{
641+
%cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
642+
%0 = "some_def"() : () -> (vector<4096xf32>)
643+
%1 = vector.reduction <add>, %0, %cst : vector<4096xf32> into f32
644+
%2 = vector.broadcast %1 : f32 to vector<1xf32>
645+
%3 = arith.divf %2, %cst_1 : vector<1xf32>
646+
scf.for %arg1 = %c0 to %c128 step %c1 {
647+
%4 = vector.transfer_read %src[%arg1], %f0 {in_bounds = [true]} : memref<128xf32>, vector<1xf32>
648+
%5 = arith.subf %4, %3 : vector<1xf32>
649+
vector.transfer_write %5, %dest[%arg1] : vector<1xf32>, memref<128xf32>
650+
}
651+
}
652+
return
653+
}
654+
655+
// -----
656+
623657
func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
624658
%c0 = arith.constant 0: index
625659
%f0 = arith.constant 0.0: f32

0 commit comments

Comments
 (0)