Skip to content

Commit ba439ee

Browse files
committed
Fix ranges::equal for vector<bool> with small storage types
1 parent e30a5d6 commit ba439ee

File tree

6 files changed

+235
-40
lines changed

6 files changed

+235
-40
lines changed

libcxx/include/__algorithm/equal.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,27 @@ __equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1,
5454
unsigned __clz_f = __bits_per_word - __first1.__ctz_;
5555
difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
5656
__n -= __dn;
57-
__storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz_f - __dn));
57+
__storage_type __m = std::__middle_mask<__storage_type>(__clz_f - __dn, __first1.__ctz_);
5858
__storage_type __b = *__first1.__seg_ & __m;
5959
unsigned __clz_r = __bits_per_word - __first2.__ctz_;
6060
__storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
61-
__m = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __ddn));
61+
__m = std::__middle_mask<__storage_type>(__clz_r - __ddn, __first2.__ctz_);
6262
if (__first2.__ctz_ > __first1.__ctz_) {
63-
if ((*__first2.__seg_ & __m) != (__b << (__first2.__ctz_ - __first1.__ctz_)))
63+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
64+
static_cast<__storage_type>(__b << (__first2.__ctz_ - __first1.__ctz_)))
6465
return false;
6566
} else {
66-
if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ - __first2.__ctz_)))
67+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
68+
static_cast<__storage_type>(__b >> (__first1.__ctz_ - __first2.__ctz_)))
6769
return false;
6870
}
6971
__first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
7072
__first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
7173
__dn -= __ddn;
7274
if (__dn > 0) {
73-
__m = ~__storage_type(0) >> (__bits_per_word - __dn);
74-
if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ + __ddn)))
75+
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
76+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
77+
static_cast<__storage_type>(__b >> (__first1.__ctz_ + __ddn)))
7578
return false;
7679
__first2.__ctz_ = static_cast<unsigned>(__dn);
7780
}
@@ -81,29 +84,30 @@ __equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1,
8184
// __first1.__ctz_ == 0;
8285
// do middle words
8386
unsigned __clz_r = __bits_per_word - __first2.__ctz_;
84-
__storage_type __m = ~__storage_type(0) << __first2.__ctz_;
87+
__storage_type __m = std::__leading_mask<__storage_type>(__first2.__ctz_);
8588
for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
8689
__storage_type __b = *__first1.__seg_;
87-
if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
90+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
8891
return false;
8992
++__first2.__seg_;
90-
if ((*__first2.__seg_ & ~__m) != (__b >> __clz_r))
93+
if (static_cast<__storage_type>(*__first2.__seg_ & static_cast<__storage_type>(~__m)) !=
94+
static_cast<__storage_type>(__b >> __clz_r))
9195
return false;
9296
}
9397
// do last word
9498
if (__n > 0) {
95-
__m = ~__storage_type(0) >> (__bits_per_word - __n);
99+
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
96100
__storage_type __b = *__first1.__seg_ & __m;
97101
__storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
98-
__m = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __dn));
99-
if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
102+
__m = std::__middle_mask<__storage_type>(__clz_r - __dn, __first2.__ctz_);
103+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
100104
return false;
101105
__first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
102106
__first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
103107
__n -= __dn;
104108
if (__n > 0) {
105-
__m = ~__storage_type(0) >> (__bits_per_word - __n);
106-
if ((*__first2.__seg_ & __m) != (__b >> __dn))
109+
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
110+
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b >> __dn))
107111
return false;
108112
}
109113
}
@@ -128,7 +132,7 @@ __equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1,
128132
unsigned __clz = __bits_per_word - __first1.__ctz_;
129133
difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
130134
__n -= __dn;
131-
__storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz - __dn));
135+
__storage_type __m = std::__middle_mask<__storage_type>(__clz - __dn, __first1.__ctz_);
132136
if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
133137
return false;
134138
++__first2.__seg_;
@@ -144,7 +148,7 @@ __equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1,
144148
return false;
145149
// do last word
146150
if (__n > 0) {
147-
__storage_type __m = ~__storage_type(0) >> (__bits_per_word - __n);
151+
__storage_type __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
148152
if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
149153
return false;
150154
}

libcxx/include/__bit_reference

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,20 @@ _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __trailing_mask
8282
return static_cast<_StorageType>(~static_cast<_StorageType>(0)) >> __clz;
8383
}
8484

85+
// Creates a mask of type `_StorageType` with a specified number of trailing zeros (__ctz) and sets all remaining
86+
// bits to one.
87+
template <class _StorageType>
88+
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __leading_mask(unsigned __ctz) {
89+
static_assert(is_unsigned<_StorageType>::value, "__leading_mask only works with unsigned types");
90+
return static_cast<_StorageType>(~static_cast<_StorageType>(0)) << __ctz;
91+
}
92+
8593
// Creates a mask of type `_StorageType` with a specified number of leading zeros (__clz), a specified number of
8694
// trailing zeros (__ctz), and sets all bits in between to one.
8795
template <class _StorageType>
8896
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __middle_mask(unsigned __clz, unsigned __ctz) {
8997
static_assert(is_unsigned<_StorageType>::value, "__middle_mask only works with unsigned types");
90-
return (static_cast<_StorageType>(~static_cast<_StorageType>(0)) << __ctz) &
91-
std::__trailing_mask<_StorageType>(__clz);
98+
return std::__leading_mask<_StorageType>(__ctz) & std::__trailing_mask<_StorageType>(__clz);
9299
}
93100

94101
// This function is designed to operate correctly even for smaller integral types like `uint8_t`, `uint16_t`,

libcxx/include/__fwd/bit_reference.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ __fill_masked_range(_StoragePointer __word, unsigned __ctz, unsigned __clz, bool
3333
template <class _StorageType>
3434
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __trailing_mask(unsigned __clz);
3535

36+
template <class _StorageType>
37+
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __leading_mask(unsigned __ctz);
38+
3639
template <class _StorageType>
3740
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _StorageType __middle_mask(unsigned __clz, unsigned __ctz);
3841

libcxx/include/__vector/comparison.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,6 @@ operator==(const vector<_Tp, _Allocator>& __x, const vector<_Tp, _Allocator>& __
2929
return __sz == __y.size() && std::equal(__x.begin(), __x.end(), __y.begin());
3030
}
3131

32-
// FIXME: Remove this `vector<bool>` overload once #126369 is resolved, reverting to the generic `operator==`
33-
// with `std::equal` for better performance.
34-
template <class _Allocator>
35-
_LIBCPP_CONSTEXPR_SINCE_CXX20 inline _LIBCPP_HIDE_FROM_ABI bool
36-
operator==(const vector<bool, _Allocator>& __x, const vector<bool, _Allocator>& __y) {
37-
const typename vector<bool, _Allocator>::size_type __sz = __x.size();
38-
if (__sz != __y.size())
39-
return false;
40-
for (typename vector<bool, _Allocator>::size_type __i = 0; __i < __sz; ++__i)
41-
if (__x[__i] != __y[__i])
42-
return false;
43-
return true;
44-
}
45-
4632
#if _LIBCPP_STD_VER <= 17
4733

4834
template <class _Tp, class _Allocator>

libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
// MSVC warning C4244: 'argument': conversion from 'wchar_t' to 'const _Ty', possible loss of data
2525
// MSVC warning C4389: '==': signed/unsigned mismatch
2626
// ADDITIONAL_COMPILE_FLAGS(cl-style-warnings): /wd4242 /wd4244 /wd4389
27+
// XFAIL: FROZEN-CXX03-HEADERS-FIXME
2728

2829
#include <algorithm>
2930
#include <cassert>
3031
#include <functional>
3132
#include <vector>
3233

34+
#include "sized_allocator.h"
3335
#include "test_iterators.h"
3436
#include "test_macros.h"
3537
#include "type_algorithms.h"
@@ -173,6 +175,90 @@ TEST_CONSTEXPR_CXX20 bool test() {
173175
test_vector_bool<256>();
174176
}
175177

178+
// Make sure std::equal behaves properly with std::vector<bool> iterators with custom size types.
179+
// See issue: https://github.com/llvm/llvm-project/issues/126369.
180+
{
181+
//// Tests for std::equal with aligned bits
182+
183+
{ // Test the first (partial) word for uint8_t
184+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
185+
std::vector<bool, Alloc> in(6, true, Alloc(1));
186+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
187+
assert(std::equal(in.begin() + 4, in.end(), expected.begin() + 4));
188+
}
189+
{ // Test the last word for uint8_t
190+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
191+
std::vector<bool, Alloc> in(12, true, Alloc(1));
192+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
193+
assert(std::equal(in.begin(), in.end(), expected.begin()));
194+
}
195+
{ // Test middle words for uint8_t
196+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
197+
std::vector<bool, Alloc> in(24, true, Alloc(1));
198+
std::vector<bool, Alloc> expected(29, true, Alloc(1));
199+
assert(std::equal(in.begin(), in.end(), expected.begin()));
200+
}
201+
202+
{ // Test the first (partial) word for uint16_t
203+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
204+
std::vector<bool, Alloc> in(12, true, Alloc(1));
205+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
206+
assert(std::equal(in.begin() + 4, in.end(), expected.begin() + 4));
207+
}
208+
{ // Test the last word for uint16_t
209+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
210+
std::vector<bool, Alloc> in(24, true, Alloc(1));
211+
std::vector<bool, Alloc> expected(32, true, Alloc(1));
212+
assert(std::equal(in.begin(), in.end(), expected.begin()));
213+
}
214+
{ // Test middle words for uint16_t
215+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
216+
std::vector<bool, Alloc> in(48, true, Alloc(1));
217+
std::vector<bool, Alloc> expected(55, true, Alloc(1));
218+
assert(std::equal(in.begin(), in.end(), expected.begin()));
219+
}
220+
221+
//// Tests for std::equal with unaligned bits
222+
223+
{ // Test the first (partial) word for uint8_t
224+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
225+
std::vector<bool, Alloc> in(6, true, Alloc(1));
226+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
227+
assert(std::equal(in.begin() + 4, in.end(), expected.begin()));
228+
}
229+
{ // Test the last word for uint8_t
230+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
231+
std::vector<bool, Alloc> in(4, true, Alloc(1));
232+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
233+
assert(std::equal(in.begin(), in.end(), expected.begin() + 3));
234+
}
235+
{ // Test middle words for uint8_t
236+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
237+
std::vector<bool, Alloc> in(16, true, Alloc(1));
238+
std::vector<bool, Alloc> expected(24, true, Alloc(1));
239+
assert(std::equal(in.begin(), in.end(), expected.begin() + 4));
240+
}
241+
242+
{ // Test the first (partial) word for uint16_t
243+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
244+
std::vector<bool, Alloc> in(12, true, Alloc(1));
245+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
246+
assert(std::equal(in.begin() + 4, in.end(), expected.begin()));
247+
}
248+
{ // Test the last word for uint16_t
249+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
250+
std::vector<bool, Alloc> in(12, true, Alloc(1));
251+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
252+
assert(std::equal(in.begin(), in.end(), expected.begin() + 3));
253+
}
254+
{ // Test the middle words for uint16_t
255+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
256+
std::vector<bool, Alloc> in(32, true, Alloc(1));
257+
std::vector<bool, Alloc> expected(64, true, Alloc(1));
258+
assert(std::equal(in.begin(), in.end(), expected.begin() + 4));
259+
}
260+
}
261+
176262
return true;
177263
}
178264

libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/ranges.equal.pass.cpp

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <vector>
3232

3333
#include "almost_satisfies_types.h"
34+
#include "sized_allocator.h"
3435
#include "test_iterators.h"
3536
#include "test_macros.h"
3637

@@ -432,15 +433,123 @@ constexpr bool test() {
432433
assert(projCount == 6);
433434
}
434435
}
436+
}
437+
438+
{ // Test vector<bool>::iterator optimization
439+
test_vector_bool<8>();
440+
test_vector_bool<19>();
441+
test_vector_bool<32>();
442+
test_vector_bool<49>();
443+
test_vector_bool<64>();
444+
test_vector_bool<199>();
445+
test_vector_bool<256>();
446+
}
447+
448+
// Make sure std::equal behaves properly with std::vector<bool> iterators with custom size types.
449+
// See issue: https://github.com/llvm/llvm-project/issues/126369.
450+
{
451+
//// Tests for std::equal with aligned bits
452+
453+
{ // Test the first (partial) word for uint8_t
454+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
455+
std::vector<bool, Alloc> in(6, true, Alloc(1));
456+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
457+
auto a = std::ranges::subrange(in.begin() + 4, in.end());
458+
auto b = std::ranges::subrange(expected.begin() + 4, expected.begin() + 4 + a.size());
459+
assert(std::ranges::equal(a, b));
460+
}
461+
{ // Test the last word for uint8_t
462+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
463+
std::vector<bool, Alloc> in(12, true, Alloc(1));
464+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
465+
auto a = std::ranges::subrange(in.begin(), in.end());
466+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
467+
assert(std::ranges::equal(a, b));
468+
}
469+
{ // Test middle words for uint8_t
470+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
471+
std::vector<bool, Alloc> in(24, true, Alloc(1));
472+
std::vector<bool, Alloc> expected(29, true, Alloc(1));
473+
auto a = std::ranges::subrange(in.begin(), in.end());
474+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
475+
assert(std::ranges::equal(a, b));
476+
}
435477

436-
{ // Test vector<bool>::iterator optimization
437-
test_vector_bool<8>();
438-
test_vector_bool<19>();
439-
test_vector_bool<32>();
440-
test_vector_bool<49>();
441-
test_vector_bool<64>();
442-
test_vector_bool<199>();
443-
test_vector_bool<256>();
478+
{ // Test the first (partial) word for uint16_t
479+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
480+
std::vector<bool, Alloc> in(12, true, Alloc(1));
481+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
482+
auto a = std::ranges::subrange(in.begin() + 4, in.end());
483+
auto b = std::ranges::subrange(expected.begin() + 4, expected.begin() + 4 + a.size());
484+
assert(std::ranges::equal(a, b));
485+
}
486+
{ // Test the last word for uint16_t
487+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
488+
std::vector<bool, Alloc> in(24, true, Alloc(1));
489+
std::vector<bool, Alloc> expected(32, true, Alloc(1));
490+
auto a = std::ranges::subrange(in.begin(), in.end());
491+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
492+
assert(std::ranges::equal(a, b));
493+
}
494+
{ // Test middle words for uint16_t
495+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
496+
std::vector<bool, Alloc> in(48, true, Alloc(1));
497+
std::vector<bool, Alloc> expected(55, true, Alloc(1));
498+
auto a = std::ranges::subrange(in.begin(), in.end());
499+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
500+
assert(std::ranges::equal(a, b));
501+
}
502+
503+
//// Tests for std::equal with unaligned bits
504+
505+
{ // Test the first (partial) word for uint8_t
506+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
507+
std::vector<bool, Alloc> in(6, true, Alloc(1));
508+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
509+
auto a = std::ranges::subrange(in.begin() + 4, in.end());
510+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
511+
assert(std::ranges::equal(a, b));
512+
}
513+
{ // Test the last word for uint8_t
514+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
515+
std::vector<bool, Alloc> in(4, true, Alloc(1));
516+
std::vector<bool, Alloc> expected(8, true, Alloc(1));
517+
auto a = std::ranges::subrange(in.begin(), in.end());
518+
auto b = std::ranges::subrange(expected.begin() + 3, expected.begin() + 3 + a.size());
519+
assert(std::ranges::equal(a, b));
520+
}
521+
{ // Test middle words for uint8_t
522+
using Alloc = sized_allocator<bool, std::uint8_t, std::int8_t>;
523+
std::vector<bool, Alloc> in(16, true, Alloc(1));
524+
std::vector<bool, Alloc> expected(24, true, Alloc(1));
525+
auto a = std::ranges::subrange(in.begin(), in.end());
526+
auto b = std::ranges::subrange(expected.begin() + 4, expected.begin() + 4 + a.size());
527+
assert(std::ranges::equal(a, b));
528+
}
529+
530+
{ // Test the first (partial) word for uint16_t
531+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
532+
std::vector<bool, Alloc> in(12, true, Alloc(1));
533+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
534+
auto a = std::ranges::subrange(in.begin() + 4, in.end());
535+
auto b = std::ranges::subrange(expected.begin(), expected.begin() + a.size());
536+
assert(std::ranges::equal(a, b));
537+
}
538+
{ // Test the last word for uint16_t
539+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
540+
std::vector<bool, Alloc> in(12, true, Alloc(1));
541+
std::vector<bool, Alloc> expected(16, true, Alloc(1));
542+
auto a = std::ranges::subrange(in.begin(), in.end());
543+
auto b = std::ranges::subrange(expected.begin() + 3, expected.begin() + 3 + a.size());
544+
assert(std::ranges::equal(a, b));
545+
}
546+
{ // Test the middle words for uint16_t
547+
using Alloc = sized_allocator<bool, std::uint16_t, std::int16_t>;
548+
std::vector<bool, Alloc> in(32, true, Alloc(1));
549+
std::vector<bool, Alloc> expected(64, true, Alloc(1));
550+
auto a = std::ranges::subrange(in.begin(), in.end());
551+
auto b = std::ranges::subrange(expected.begin() + 4, expected.begin() + 4 + a.size());
552+
assert(std::ranges::equal(a, b));
444553
}
445554
}
446555

0 commit comments

Comments
 (0)