Skip to content

Commit a54102a

Browse files
authored
[llvm] Add support for zero-width integers in MathExtras.h (#87193)
MLIR uses zero-width integers, but also re-uses integer logic from LLVM to avoid duplication. This creates issues when LLVM logic is used in MLIR on integers which can be zero-width. In order to avoid special-casing the bitwidth-related logic in MLIR, this PR adds support for zero-width integers in LLVM's MathExtras (and consequently APInt). While most of the logic in theory works the same way out of the box, because bitshifting right by the entire bitwidth in C++ is undefined behavior instead of being zero, some special cases had to be added. Fortunately, it seems like the performance penalty is small. In x86, this usually yields the addition of a predicated conditional move. I checked that no branch is inserted in Arm too. This happens to fix a crash in `arith.extsi` canonicalization in MLIR. I think a follow-up PR to add tests for i0 in arith would be beneficial.
1 parent 43c26bb commit a54102a

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

llvm/include/llvm/Support/MathExtras.h

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ template <typename T> T maskTrailingOnes(unsigned N) {
6666
static_assert(std::is_unsigned_v<T>, "Invalid type!");
6767
const unsigned Bits = CHAR_BIT * sizeof(T);
6868
assert(N <= Bits && "Invalid bit index");
69-
return N == 0 ? 0 : (T(-1) >> (Bits - N));
69+
if (N == 0)
70+
return 0;
71+
return T(-1) >> (Bits - N);
7072
}
7173

7274
/// Create a bitmask with the N left-most bits set to 1, and all other
@@ -149,6 +151,8 @@ constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) {
149151

150152
/// Checks if an integer fits into the given bit width.
151153
template <unsigned N> constexpr inline bool isInt(int64_t x) {
154+
if constexpr (N == 0)
155+
return 0 == x;
152156
if constexpr (N == 8)
153157
return static_cast<int8_t>(x) == x;
154158
if constexpr (N == 16)
@@ -164,15 +168,15 @@ template <unsigned N> constexpr inline bool isInt(int64_t x) {
164168
/// Checks if a signed integer is an N bit number shifted left by S.
165169
template <unsigned N, unsigned S>
166170
constexpr inline bool isShiftedInt(int64_t x) {
167-
static_assert(
168-
N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number.");
171+
static_assert(S < 64, "isShiftedInt<N, S> with S >= 64 is too much.");
169172
static_assert(N + S <= 64, "isShiftedInt<N, S> with N + S > 64 is too wide.");
170173
return isInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
171174
}
172175

173176
/// Checks if an unsigned integer fits into the given bit width.
174177
template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
175-
static_assert(N > 0, "isUInt<0> doesn't make sense");
178+
if constexpr (N == 0)
179+
return 0 == x;
176180
if constexpr (N == 8)
177181
return static_cast<uint8_t>(x) == x;
178182
if constexpr (N == 16)
@@ -188,39 +192,46 @@ template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
188192
/// Checks if a unsigned integer is an N bit number shifted left by S.
189193
template <unsigned N, unsigned S>
190194
constexpr inline bool isShiftedUInt(uint64_t x) {
191-
static_assert(
192-
N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)");
195+
static_assert(S < 64, "isShiftedUInt<N, S> with S >= 64 is too much.");
193196
static_assert(N + S <= 64,
194197
"isShiftedUInt<N, S> with N + S > 64 is too wide.");
195-
// Per the two static_asserts above, S must be strictly less than 64. So
196-
// 1 << S is not undefined behavior.
198+
// S must be strictly less than 64. So 1 << S is not undefined behavior.
197199
return isUInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
198200
}
199201

200202
/// Gets the maximum value for a N-bit unsigned integer.
201203
inline uint64_t maxUIntN(uint64_t N) {
202-
assert(N > 0 && N <= 64 && "integer width out of range");
204+
assert(N <= 64 && "integer width out of range");
203205

204206
// uint64_t(1) << 64 is undefined behavior, so we can't do
205207
// (uint64_t(1) << N) - 1
206208
// without checking first that N != 64. But this works and doesn't have a
207-
// branch.
209+
// branch for N != 0.
210+
// Unfortunately, shifting a uint64_t right by 64 bit is undefined
211+
// behavior, so the condition on N == 0 is necessary. Fortunately, most
212+
// optimizers do not emit branches for this check.
213+
if (N == 0)
214+
return 0;
208215
return UINT64_MAX >> (64 - N);
209216
}
210217

211218
/// Gets the minimum value for a N-bit signed integer.
212219
inline int64_t minIntN(int64_t N) {
213-
assert(N > 0 && N <= 64 && "integer width out of range");
220+
assert(N <= 64 && "integer width out of range");
214221

222+
if (N == 0)
223+
return 0;
215224
return UINT64_C(1) + ~(UINT64_C(1) << (N - 1));
216225
}
217226

218227
/// Gets the maximum value for a N-bit signed integer.
219228
inline int64_t maxIntN(int64_t N) {
220-
assert(N > 0 && N <= 64 && "integer width out of range");
229+
assert(N <= 64 && "integer width out of range");
221230

222231
// This relies on two's complement wraparound when N == 64, so we convert to
223232
// int64_t only at the very end to avoid UB.
233+
if (N == 0)
234+
return 0;
224235
return (UINT64_C(1) << (N - 1)) - 1;
225236
}
226237

@@ -432,34 +443,38 @@ inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
432443
}
433444

434445
/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
435-
/// Requires 0 < B <= 32.
446+
/// Requires B <= 32.
436447
template <unsigned B> constexpr inline int32_t SignExtend32(uint32_t X) {
437-
static_assert(B > 0, "Bit width can't be 0.");
438448
static_assert(B <= 32, "Bit width out of range.");
449+
if constexpr (B == 0)
450+
return 0;
439451
return int32_t(X << (32 - B)) >> (32 - B);
440452
}
441453

442454
/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
443-
/// Requires 0 < B <= 32.
455+
/// Requires B <= 32.
444456
inline int32_t SignExtend32(uint32_t X, unsigned B) {
445-
assert(B > 0 && "Bit width can't be 0.");
446457
assert(B <= 32 && "Bit width out of range.");
458+
if (B == 0)
459+
return 0;
447460
return int32_t(X << (32 - B)) >> (32 - B);
448461
}
449462

450463
/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
451-
/// Requires 0 < B <= 64.
464+
/// Requires B <= 64.
452465
template <unsigned B> constexpr inline int64_t SignExtend64(uint64_t x) {
453-
static_assert(B > 0, "Bit width can't be 0.");
454466
static_assert(B <= 64, "Bit width out of range.");
467+
if constexpr (B == 0)
468+
return 0;
455469
return int64_t(x << (64 - B)) >> (64 - B);
456470
}
457471

458472
/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
459-
/// Requires 0 < B <= 64.
473+
/// Requires B <= 64.
460474
inline int64_t SignExtend64(uint64_t X, unsigned B) {
461-
assert(B > 0 && "Bit width can't be 0.");
462475
assert(B <= 64 && "Bit width out of range.");
476+
if (B == 0)
477+
return 0;
463478
return int64_t(X << (64 - B)) >> (64 - B);
464479
}
465480

@@ -564,7 +579,6 @@ SaturatingMultiplyAdd(T X, T Y, T A, bool *ResultOverflowed = nullptr) {
564579
/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC.
565580
extern const float huge_valf;
566581

567-
568582
/// Add two signed integers, computing the two's complement truncated result,
569583
/// returning true if overflow occurred.
570584
template <typename T>
@@ -644,6 +658,6 @@ std::enable_if_t<std::is_signed_v<T>, T> MulOverflow(T X, T Y, T &Result) {
644658
return UX > (static_cast<U>(std::numeric_limits<T>::max())) / UY;
645659
}
646660

647-
} // End llvm namespace
661+
} // namespace llvm
648662

649663
#endif

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,6 +2797,9 @@ TEST(APIntTest, sext) {
27972797
EXPECT_EQ(63U, i32_neg1.countl_one());
27982798
EXPECT_EQ(0U, i32_neg1.countr_zero());
27992799
EXPECT_EQ(63U, i32_neg1.popcount());
2800+
2801+
EXPECT_EQ(APInt(32u, 0), APInt(0u, 0).sext(32));
2802+
EXPECT_EQ(APInt(64u, 0), APInt(0u, 0).sext(64));
28002803
}
28012804

28022805
TEST(APIntTest, trunc) {

llvm/unittests/Support/MathExtrasTest.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,34 @@ TEST(MathExtras, onesMask) {
4141
TEST(MathExtras, isIntN) {
4242
EXPECT_TRUE(isIntN(16, 32767));
4343
EXPECT_FALSE(isIntN(16, 32768));
44+
EXPECT_TRUE(isUIntN(0, 0));
45+
EXPECT_FALSE(isUIntN(0, 1));
46+
EXPECT_FALSE(isUIntN(0, -1));
4447
}
4548

4649
TEST(MathExtras, isUIntN) {
4750
EXPECT_TRUE(isUIntN(16, 65535));
4851
EXPECT_FALSE(isUIntN(16, 65536));
4952
EXPECT_TRUE(isUIntN(1, 0));
5053
EXPECT_TRUE(isUIntN(6, 63));
54+
EXPECT_TRUE(isUIntN(0, 0));
55+
EXPECT_FALSE(isUIntN(0, 1));
5156
}
5257

5358
TEST(MathExtras, maxIntN) {
5459
EXPECT_EQ(32767, maxIntN(16));
5560
EXPECT_EQ(2147483647, maxIntN(32));
5661
EXPECT_EQ(std::numeric_limits<int32_t>::max(), maxIntN(32));
5762
EXPECT_EQ(std::numeric_limits<int64_t>::max(), maxIntN(64));
63+
EXPECT_EQ(0, maxIntN(0));
5864
}
5965

6066
TEST(MathExtras, minIntN) {
6167
EXPECT_EQ(-32768LL, minIntN(16));
6268
EXPECT_EQ(-64LL, minIntN(7));
6369
EXPECT_EQ(std::numeric_limits<int32_t>::min(), minIntN(32));
6470
EXPECT_EQ(std::numeric_limits<int64_t>::min(), minIntN(64));
71+
EXPECT_EQ(0, minIntN(0));
6572
}
6673

6774
TEST(MathExtras, maxUIntN) {
@@ -70,6 +77,7 @@ TEST(MathExtras, maxUIntN) {
7077
EXPECT_EQ(0xffffffffffffffffULL, maxUIntN(64));
7178
EXPECT_EQ(1ULL, maxUIntN(1));
7279
EXPECT_EQ(0x0fULL, maxUIntN(4));
80+
EXPECT_EQ(0, maxUIntN(0));
7381
}
7482

7583
TEST(MathExtras, reverseBits) {

0 commit comments

Comments
 (0)