Skip to content

Commit 71051de

Browse files
authored
[MemCpyOpt] Fix infinite loop in memset+memcpy fold (llvm#98638)
For the case where the memcpy size is zero, this transform is a complex no-op. This can lead to an infinite loop when the size is zero in a way that BasicAA understands, because it can still understand that dst and dst + src_size are MustAlias. I've tried to mitigate this before using the isZeroSize() check, but we can hit cases where InstSimplify doesn't understand that the size is zero, but BasicAA does. As such, this bites the bullet and adds an explicit isKnownNonZero() check to guard against no-op transforms. Fixes llvm#98610.
1 parent 9ad72df commit 71051de

File tree

5 files changed

+115
-28
lines changed

5 files changed

+115
-28
lines changed

llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,15 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
12961296
if (!BAA.isMustAlias(MemSet->getDest(), MemCpy->getDest()))
12971297
return false;
12981298

1299+
// Don't perform the transform if src_size may be zero. In that case, the
1300+
// transform is essentially a complex no-op and may lead to an infinite
1301+
// loop if BasicAA is smart enough to understand that dst and dst + src_size
1302+
// are still MustAlias after the transform.
1303+
Value *SrcSize = MemCpy->getLength();
1304+
if (!isKnownNonZero(SrcSize,
1305+
SimplifyQuery(MemCpy->getDataLayout(), DT, AC, MemCpy)))
1306+
return false;
1307+
12991308
// Check that src and dst of the memcpy aren't the same. While memcpy
13001309
// operands cannot partially overlap, exact equality is allowed.
13011310
if (isModSet(BAA.getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy))))
@@ -1312,7 +1321,6 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
13121321
// Use the same i8* dest as the memcpy, killing the memset dest if different.
13131322
Value *Dest = MemCpy->getRawDest();
13141323
Value *DestSize = MemSet->getLength();
1315-
Value *SrcSize = MemCpy->getLength();
13161324

13171325
if (mayBeVisibleThroughUnwinding(Dest, MemSet, MemCpy))
13181326
return false;
@@ -1726,8 +1734,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
17261734
return true;
17271735
}
17281736

1729-
// If the size is zero, remove the memcpy. This also prevents infinite loops
1730-
// in processMemSetMemCpyDependence, which is a no-op for zero-length memcpys.
1737+
// If the size is zero, remove the memcpy.
17311738
if (isZeroSize(M->getLength())) {
17321739
++BBI;
17331740
eraseInstruction(M);

llvm/test/Transforms/MemCpyOpt/memcpy-zero-size.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,18 @@ define void @pr64886(i64 %len, ptr noalias %p) {
3434
call void @llvm.memcpy.p0.p0.i64(ptr inttoptr (i64 -1 to ptr), ptr %p, i64 poison, i1 false)
3535
ret void
3636
}
37+
38+
define void @pr98610(ptr %p, ptr noalias %p2) {
39+
; CHECK-LABEL: @pr98610(
40+
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 1, i1 false)
41+
; CHECK-NEXT: [[ZERO_EXT:%.*]] = zext i32 0 to i64
42+
; CHECK-NEXT: [[MUL:%.*]] = mul i64 [[ZERO_EXT]], 1
43+
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr [[P]], ptr [[P2:%.*]], i64 [[MUL]], i1 false)
44+
; CHECK-NEXT: ret void
45+
;
46+
call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 1, i1 false)
47+
%zero.ext = zext i32 0 to i64
48+
%mul = mul i64 %zero.ext, 1
49+
call void @llvm.memcpy.p0.p0.i64(ptr %p, ptr %p2, i64 %mul, i1 false)
50+
ret void
51+
}

llvm/test/Transforms/MemCpyOpt/memset-memcpy-dbgloc.ll

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@ declare void @llvm.memset.p0.i64(ptr nocapture, i8, i64, i1)
88
declare void @llvm.memcpy.p0.p0.i64(ptr nocapture, ptr nocapture readonly, i64, i1)
99

1010
define void @test_constant(i64 %src_size, ptr %dst, i64 %dst_size, i8 %c) !dbg !5 {
11-
; CHECK-LABEL: @test_constant(
12-
; CHECK-NEXT: [[TMP1:%.*]] = icmp ule i64 [[DST_SIZE:%.*]], [[SRC_SIZE:%.*]], !dbg [[DBG11:![0-9]+]]
11+
; CHECK-LABEL: define void @test_constant(
12+
; CHECK-SAME: i64 [[SRC_SIZE:%.*]], ptr [[DST:%.*]], i64 [[DST_SIZE:%.*]], i8 [[C:%.*]]) !dbg [[DBG5:![0-9]+]] {
13+
; CHECK-NEXT: [[NON_ZERO:%.*]] = icmp ne i64 [[SRC_SIZE]], 0
14+
; CHECK-NEXT: call void @llvm.assume(i1 [[NON_ZERO]])
15+
; CHECK-NEXT: [[TMP1:%.*]] = icmp ule i64 [[DST_SIZE]], [[SRC_SIZE]], !dbg [[DBG11:![0-9]+]]
1316
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 [[DST_SIZE]], [[SRC_SIZE]], !dbg [[DBG11]]
1417
; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP1]], i64 0, i64 [[TMP2]], !dbg [[DBG11]]
15-
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[DST:%.*]], i64 [[SRC_SIZE]], !dbg [[DBG11]]
16-
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr align 1 [[TMP4]], i8 [[C:%.*]], i64 [[TMP3]], i1 false), !dbg [[DBG11]]
18+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[DST]], i64 [[SRC_SIZE]], !dbg [[DBG11]]
19+
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr align 1 [[TMP4]], i8 [[C]], i64 [[TMP3]], i1 false), !dbg [[DBG11]]
1720
; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr [[DST]], ptr @C, i64 [[SRC_SIZE]], i1 false), !dbg [[DBG12:![0-9]+]]
1821
; CHECK-NEXT: ret void, !dbg [[DBG13:![0-9]+]]
1922
;
23+
%non.zero = icmp ne i64 %src_size, 0
24+
call void @llvm.assume(i1 %non.zero)
2025
call void @llvm.memset.p0.i64(ptr %dst, i8 %c, i64 %dst_size, i1 false), !dbg !11
2126
call void @llvm.memcpy.p0.p0.i64(ptr %dst, ptr @C, i64 %src_size, i1 false), !dbg !12
2227
ret void, !dbg !13

0 commit comments

Comments
 (0)