Skip to content

Commit 2e6deb1

Browse files
authored
[LoopInterchange] Fix overflow in cost calculation (#111807)
If the iteration count is really large, e.g. UINT_MAX, then the cost calculation can overflows and trigger an assert. So saturate the cost to INT_MAX if this is the case by using InstructionCost as a type which already supports this kind of overflow handling. This fixes #104761
1 parent 5cfa8ba commit 2e6deb1

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

llvm/include/llvm/Analysis/LoopCacheAnalysis.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "llvm/Analysis/LoopAnalysisManager.h"
1818
#include "llvm/IR/PassManager.h"
19+
#include "llvm/Support/InstructionCost.h"
1920
#include <optional>
2021

2122
namespace llvm {
@@ -31,7 +32,7 @@ class ScalarEvolution;
3132
class SCEV;
3233
class TargetTransformInfo;
3334

34-
using CacheCostTy = int64_t;
35+
using CacheCostTy = InstructionCost;
3536
using LoopVectorTy = SmallVector<Loop *, 8>;
3637

3738
/// Represents a memory reference as a base pointer and a set of indexing
@@ -192,8 +193,6 @@ class CacheCost {
192193
using LoopCacheCostTy = std::pair<const Loop *, CacheCostTy>;
193194

194195
public:
195-
static CacheCostTy constexpr InvalidCost = -1;
196-
197196
/// Construct a CacheCost object for the loop nest described by \p Loops.
198197
/// The optional parameter \p TRT can be used to specify the max. distance
199198
/// between array elements accessed in a loop so that the elements are

llvm/lib/Analysis/LoopCacheAnalysis.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
328328
const SCEV *TripCount =
329329
computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
330330
Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
331+
// For the multiplication result to fit, request a type twice as wide.
332+
WiderType = WiderType->getExtendedType();
331333
RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
332334
SE.getNoopOrZeroExtend(TripCount, WiderType));
333335
}
@@ -338,14 +340,18 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
338340
assert(RefCost && "Expecting a valid RefCost");
339341

340342
// Attempt to fold RefCost into a constant.
343+
// CacheCostTy is a signed integer, but the tripcount value can be large
344+
// and may not fit, so saturate/limit the value to the maximum signed
345+
// integer value.
341346
if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
342-
return ConstantCost->getValue()->getZExtValue();
347+
return ConstantCost->getValue()->getLimitedValue(
348+
std::numeric_limits<int64_t>::max());
343349

344350
LLVM_DEBUG(dbgs().indent(4)
345351
<< "RefCost is not a constant! Setting to RefCost=InvalidCost "
346352
"(invalid value).\n");
347353

348-
return CacheCost::InvalidCost;
354+
return CacheCostTy::getInvalid();
349355
}
350356

351357
bool IndexedReference::tryDelinearizeFixedSize(
@@ -696,7 +702,7 @@ CacheCostTy
696702
CacheCost::computeLoopCacheCost(const Loop &L,
697703
const ReferenceGroupsTy &RefGroups) const {
698704
if (!L.isLoopSimplifyForm())
699-
return InvalidCost;
705+
return CacheCostTy::getInvalid();
700706

701707
LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName()
702708
<< "' as innermost loop.\n");
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: opt < %s -passes='print<loop-cache-cost>' -disable-output 2>&1 | FileCheck %s
2+
3+
; For a loop with a very large iteration count, make sure the cost
4+
; calculation does not overflow:
5+
;
6+
; void a(int b) {
7+
; for (int c;; c += b)
8+
; for (long d = 0; d < -3ULL; d += 2ULL)
9+
; A[c][d][d] = 0;
10+
; }
11+
12+
; CHECK: Loop 'outer.loop' has cost = 9223372036854775807
13+
; CHECK: Loop 'inner.loop' has cost = 9223372036854775807
14+
15+
@A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16
16+
17+
define void @foo(i32 noundef %b) {
18+
entry:
19+
%0 = sext i32 %b to i64
20+
br label %outer.loop
21+
22+
outer.loop:
23+
%indvars.iv = phi i64 [ %indvars.iv.next, %outer.loop.cleanup ], [ 0, %entry ]
24+
br label %inner.loop
25+
26+
outer.loop.cleanup:
27+
%indvars.iv.next = add nsw i64 %indvars.iv, %0
28+
br label %outer.loop
29+
30+
inner.loop:
31+
%inner.iv = phi i64 [ 0, %outer.loop ], [ %add, %inner.loop ]
32+
%arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %inner.iv, i64 %inner.iv
33+
store i32 0, ptr %arrayidx3, align 4
34+
%add = add nuw i64 %inner.iv, 2
35+
%cmp = icmp ult i64 %inner.iv, -5
36+
br i1 %cmp, label %inner.loop, label %outer.loop.cleanup
37+
}

0 commit comments

Comments
 (0)