Skip to content

Commit 77d4358

Browse files
committed
MathExtras: avoid unnecessarily widening types
Several multi-argument functions unnecessarily widen types beyond the argument types. Template'ize the functions, and use std::common_type_t to avoid this, hence optimizing the functions. While at it, address usage issues raised in #95087. One of the requirements of this patch is to add overflow checks, and one caller in LoopVectorize and one in AMDGPUBaseInfo is manually widened.
1 parent ebb5385 commit 77d4358

File tree

7 files changed

+117
-39
lines changed

7 files changed

+117
-39
lines changed

llvm/include/llvm/Support/MathExtras.h

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@
2323
#include <type_traits>
2424

2525
namespace llvm {
26+
/// Some template parameter helpers to optimize for bitwidth, for functions that
27+
/// take multiple arguments.
28+
29+
// We can't verify signedness, since callers rely on implicit coercions to
30+
// signed/unsigned.
31+
template <typename T, typename U>
32+
using enableif_int =
33+
std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U>>;
34+
35+
// Use std::common_type_t to widen only up to the widest argument.
36+
template <typename T, typename U, typename = enableif_int<T, U>>
37+
using common_uint =
38+
std::common_type_t<std::make_unsigned_t<T>, std::make_unsigned_t<U>>;
39+
template <typename T, typename U, typename = enableif_int<T, U>>
40+
using common_sint =
41+
std::common_type_t<std::make_signed_t<T>, std::make_signed_t<U>>;
2642

2743
/// Mathematical constants.
2844
namespace numbers {
@@ -346,7 +362,8 @@ inline unsigned Log2_64_Ceil(uint64_t Value) {
346362

347363
/// A and B are either alignments or offsets. Return the minimum alignment that
348364
/// may be assumed after adding the two together.
349-
constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
365+
template <typename U, typename V, typename T = common_uint<U, V>>
366+
constexpr T MinAlign(U A, V B) {
350367
// The largest power of 2 that divides both A and B.
351368
//
352369
// Replace "-Value" by "1+~Value" in the following commented code to avoid
@@ -375,7 +392,7 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
375392
return UINT64_C(1) << Log2_64_Ceil(A);
376393
}
377394

378-
/// Returns the next integer (mod 2**64) that is greater than or equal to
395+
/// Returns the next integer (mod 2**nbits) that is greater than or equal to
379396
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
380397
///
381398
/// Examples:
@@ -385,18 +402,46 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
385402
/// alignTo(~0LL, 8) = 0
386403
/// alignTo(321, 255) = 510
387404
/// \endcode
388-
inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
405+
template <typename U, typename V, typename T = common_uint<U, V>>
406+
constexpr T alignTo(U Value, V Align) {
407+
assert(Align != 0u && "Align can't be 0.");
408+
// If Value is negative, wrap will occur in the cast.
409+
if (Value > 0)
410+
assert(static_cast<T>(Value) <=
411+
std::numeric_limits<T>::max() - (Align - 1) &&
412+
"alignTo would overflow");
413+
return (Value + Align - 1) / Align * Align;
414+
}
415+
416+
// Fallback when arguments aren't integral.
417+
constexpr inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
389418
assert(Align != 0u && "Align can't be 0.");
390419
return (Value + Align - 1) / Align * Align;
391420
}
392421

393-
inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
422+
template <typename U, typename V, typename T = common_uint<U, V>>
423+
constexpr T alignToPowerOf2(U Value, V Align) {
394424
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
395425
"Align must be a power of 2");
426+
// If Value is negative, wrap will occur in the cast.
427+
if (Value > 0)
428+
assert(static_cast<T>(Value) <=
429+
std::numeric_limits<T>::max() - (Align - 1) &&
430+
"alignToPowerOf2 would overflow");
396431
// Replace unary minus to avoid compilation error on Windows:
397432
// "unary minus operator applied to unsigned type, result still unsigned"
398-
uint64_t negAlign = (~Align) + 1;
399-
return (Value + Align - 1) & negAlign;
433+
T NegAlign = (~Align) + 1;
434+
return (Value + Align - 1) & NegAlign;
435+
}
436+
437+
// Fallback when arguments aren't integral.
438+
constexpr inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
439+
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
440+
"Align must be a power of 2");
441+
// Replace unary minus to avoid compilation error on Windows:
442+
// "unary minus operator applied to unsigned type, result still unsigned"
443+
uint64_t NegAlign = (~Align) + 1;
444+
return (Value + Align - 1) & NegAlign;
400445
}
401446

402447
/// If non-zero \p Skew is specified, the return value will be a minimal integer
@@ -411,64 +456,86 @@ inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
411456
/// alignTo(~0LL, 8, 3) = 3
412457
/// alignTo(321, 255, 42) = 552
413458
/// \endcode
414-
inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
459+
template <typename U, typename V, typename W,
460+
typename T = common_uint<common_uint<U, V>, W>>
461+
constexpr T alignTo(U Value, V Align, W Skew) {
415462
assert(Align != 0u && "Align can't be 0.");
416463
Skew %= Align;
417464
return alignTo(Value - Skew, Align) + Skew;
418465
}
419466

420467
/// Returns the next integer (mod 2**64) that is greater than or equal to
421468
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
422-
template <uint64_t Align> constexpr inline uint64_t alignTo(uint64_t Value) {
469+
template <uint64_t Align> constexpr uint64_t alignTo(uint64_t Value) {
423470
static_assert(Align != 0u, "Align must be non-zero");
424471
return (Value + Align - 1) / Align * Align;
425472
}
426473

427-
/// Returns the integer ceil(Numerator / Denominator). Unsigned integer version.
428-
inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
474+
/// Returns the integer ceil(Numerator / Denominator). Unsigned version.
475+
template <typename U, typename V, typename T = common_uint<U, V>>
476+
constexpr T divideCeil(U Numerator, V Denominator) {
429477
return alignTo(Numerator, Denominator) / Denominator;
430478
}
431479

432-
/// Returns the integer ceil(Numerator / Denominator). Signed integer version.
433-
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
480+
// Fallback when arguments aren't integral.
481+
constexpr inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
482+
return alignTo(Numerator, Denominator) / Denominator;
483+
}
484+
485+
/// Returns the integer ceil(Numerator / Denominator). Signed version.
486+
/// Guaranteed to never overflow.
487+
template <typename U, typename V, typename T = common_sint<U, V>>
488+
constexpr T divideCeilSigned(U Numerator, V Denominator) {
434489
assert(Denominator && "Division by zero");
435490
if (!Numerator)
436491
return 0;
437492
// C's integer division rounds towards 0.
438-
int64_t X = (Denominator > 0) ? -1 : 1;
493+
T X = (Denominator > 0) ? -1 : 1;
439494
bool SameSign = (Numerator > 0) == (Denominator > 0);
440495
return SameSign ? ((Numerator + X) / Denominator) + 1
441496
: Numerator / Denominator;
442497
}
443498

444-
/// Returns the integer floor(Numerator / Denominator). Signed integer version.
445-
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
499+
/// Returns the integer floor(Numerator / Denominator). Signed version.
500+
/// Guaranteed to never overflow.
501+
template <typename U, typename V, typename T = common_sint<U, V>>
502+
constexpr T divideFloorSigned(U Numerator, V Denominator) {
446503
assert(Denominator && "Division by zero");
447504
if (!Numerator)
448505
return 0;
449506
// C's integer division rounds towards 0.
450-
int64_t X = (Denominator > 0) ? -1 : 1;
507+
T X = (Denominator > 0) ? -1 : 1;
451508
bool SameSign = (Numerator > 0) == (Denominator > 0);
452509
return SameSign ? Numerator / Denominator
453510
: -((-Numerator + X) / Denominator) - 1;
454511
}
455512

456513
/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
457-
/// always non-negative.
458-
inline int64_t mod(int64_t Numerator, int64_t Denominator) {
514+
/// always non-negative. Signed version. Guaranteed to never overflow.
515+
template <typename U, typename V, typename T = common_sint<U, V>>
516+
constexpr T mod(U Numerator, V Denominator) {
459517
assert(Denominator >= 1 && "Mod by non-positive number");
460-
int64_t Mod = Numerator % Denominator;
518+
T Mod = Numerator % Denominator;
461519
return Mod < 0 ? Mod + Denominator : Mod;
462520
}
463521

464522
/// Returns the integer nearest(Numerator / Denominator).
465-
inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
523+
template <typename U, typename V, typename T = common_uint<U, V>>
524+
constexpr T divideNearest(U Numerator, V Denominator) {
525+
// If Value is negative, wrap will occur in the cast.
526+
if (Numerator > 0)
527+
assert(static_cast<T>(Numerator) <=
528+
std::numeric_limits<T>::max() - (Denominator / 2) &&
529+
"divideNearest would overflow");
466530
return (Numerator + (Denominator / 2)) / Denominator;
467531
}
468532

469-
/// Returns the largest uint64_t less than or equal to \p Value and is
470-
/// \p Skew mod \p Align. \p Align must be non-zero
471-
inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
533+
/// Returns the largest unsigned integer less than or equal to \p Value and is
534+
/// \p Skew mod \p Align. \p Align must be non-zero. Guaranteed to never
535+
/// overflow.
536+
template <typename U, typename V, typename W = uint8_t,
537+
typename T = common_uint<common_uint<U, V>, W>>
538+
constexpr T alignDown(U Value, V Align, W Skew = 0) {
472539
assert(Align != 0u && "Align can't be 0.");
473540
Skew %= Align;
474541
return (Value - Skew) / Align * Align + Skew;
@@ -512,8 +579,8 @@ inline int64_t SignExtend64(uint64_t X, unsigned B) {
512579

513580
/// Subtract two unsigned integers, X and Y, of type T and return the absolute
514581
/// value of the result.
515-
template <typename T>
516-
std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
582+
template <typename U, typename V, typename T = common_uint<U, V>>
583+
constexpr T AbsoluteDifference(U X, V Y) {
517584
return X > Y ? (X - Y) : (Y - X);
518585
}
519586

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,9 @@ unsigned getMaxFlatWorkGroupSize(const MCSubtargetInfo *STI) {
986986

987987
unsigned getWavesPerWorkGroup(const MCSubtargetInfo *STI,
988988
unsigned FlatWorkGroupSize) {
989-
return divideCeil(FlatWorkGroupSize, getWavefrontSize(STI));
989+
// divideCeil will overflow, unless FlatWorkGroupSize is cast.
990+
return divideCeil(static_cast<uint64_t>(FlatWorkGroupSize),
991+
getWavefrontSize(STI));
990992
}
991993

992994
unsigned getSGPRAllocGranule(const MCSubtargetInfo *STI) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4803,7 +4803,8 @@ bool LoopVectorizationPlanner::isMoreProfitable(
48034803
// different VFs we can use this to compare the total loop-body cost
48044804
// expected after vectorization.
48054805
if (CM.foldTailByMasking())
4806-
return VectorCost * divideCeil(MaxTripCount, VF);
4806+
// divideCeil will overflow, unless MaxTripCount is cast.
4807+
return VectorCost * divideCeil(static_cast<uint64_t>(MaxTripCount), VF);
48074808
return VectorCost * (MaxTripCount / VF) + ScalarCost * (MaxTripCount % VF);
48084809
};
48094810

llvm/unittests/Support/MathExtrasTest.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,10 @@ TEST(MathExtras, AlignTo) {
189189
EXPECT_EQ(8u, alignTo(5, 8));
190190
EXPECT_EQ(24u, alignTo(17, 8));
191191
EXPECT_EQ(0u, alignTo(~0LL, 8));
192-
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
193-
alignTo(std::numeric_limits<uint32_t>::max(), 2));
192+
#ifndef NDEBUG
193+
EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 2),
194+
"alignTo would overflow");
195+
#endif
194196

195197
EXPECT_EQ(7u, alignTo(5, 8, 7));
196198
EXPECT_EQ(17u, alignTo(17, 8, 1));
@@ -204,8 +206,10 @@ TEST(MathExtras, AlignToPowerOf2) {
204206
EXPECT_EQ(8u, alignToPowerOf2(5, 8));
205207
EXPECT_EQ(24u, alignToPowerOf2(17, 8));
206208
EXPECT_EQ(0u, alignToPowerOf2(~0LL, 8));
207-
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
208-
alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2));
209+
#ifndef NDEBUG
210+
EXPECT_DEATH(alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2),
211+
"alignToPowerOf2 would overflow");
212+
#endif
209213
}
210214

211215
TEST(MathExtras, AlignDown) {
@@ -459,15 +463,20 @@ TEST(MathExtras, DivideNearest) {
459463
EXPECT_EQ(divideNearest(14, 3), 5u);
460464
EXPECT_EQ(divideNearest(15, 3), 5u);
461465
EXPECT_EQ(divideNearest(0, 3), 0u);
462-
EXPECT_EQ(divideNearest(std::numeric_limits<uint32_t>::max(), 2),
463-
2147483648u);
466+
#ifndef NDEBUG
467+
EXPECT_DEATH(divideNearest(std::numeric_limits<uint32_t>::max(), 2),
468+
"divideNearest would overflow");
469+
#endif
464470
}
465471

466472
TEST(MathExtras, DivideCeil) {
467473
EXPECT_EQ(divideCeil(14, 3), 5u);
468474
EXPECT_EQ(divideCeil(15, 3), 5u);
469475
EXPECT_EQ(divideCeil(0, 3), 0u);
470-
EXPECT_EQ(divideCeil(std::numeric_limits<uint32_t>::max(), 2), 2147483648u);
476+
#ifndef NDEBUG
477+
EXPECT_DEATH(divideCeil(std::numeric_limits<uint32_t>::max(), 2),
478+
"alignTo would overflow");
479+
#endif
471480

472481
EXPECT_EQ(divideCeilSigned(14, 3), 5);
473482
EXPECT_EQ(divideCeilSigned(15, 3), 5);

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
114114
return ShapedType::kDynamic;
115115

116116
assert(dimSize % shardCount == 0);
117-
return llvm::divideCeilSigned(dimSize, shardCount);
117+
return dimSize / shardCount;
118118
}
119119

120120
// Get the size of an unsharded dimension.

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ void UnrankedMemRefDescriptor::computeSizes(
365365
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
366366
Value indexSize = createIndexAttrConstant(
367367
builder, loc, indexType,
368-
llvm::divideCeilSigned(typeConverter.getIndexTypeBitwidth(), 8));
368+
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
369369

370370
sizes.reserve(sizes.size() + values.size());
371371
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
@@ -378,8 +378,7 @@ void UnrankedMemRefDescriptor::computeSizes(
378378
// to data layout) into the unranked descriptor.
379379
Value pointerSize = createIndexAttrConstant(
380380
builder, loc, indexType,
381-
llvm::divideCeilSigned(typeConverter.getPointerBitwidth(addressSpace),
382-
8));
381+
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
383382
Value doublePointerSize =
384383
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
385384

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ struct MemorySpaceCastOpLowering
971971
resultUnderlyingDesc, resultElemPtrType);
972972

973973
int64_t bytesToSkip =
974-
2 * llvm::divideCeilSigned(
974+
2 * llvm::divideCeil(
975975
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976976
Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
977977
loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));

0 commit comments

Comments
 (0)