Skip to content

Commit 84ae8cb

Browse files
authored
[libc++] std::ranges::advance: avoid unneeded bounds checks when advancing iterator (#84126)
Currently, the bounds check in `std::ranges::advance(it, n, s)` is done _before_ `n` is checked. This results in one extra, unneeded bounds check. Thus, `std::ranges::advance(it, 1, s)` currently is _not_ simply equivalent to: ```c++ if (it != s) { ++it; } ``` This difference in behavior matters when the check involves some "expensive" logic. For example, the `==` operator of `std::istreambuf_iterator` may actually have to read the underlying `streambuf`. Swapping around the checks in the `while` results in the expected behavior.
1 parent 12c7371 commit 84ae8cb

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

libcxx/include/__iterator/advance.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,14 @@ struct __fn {
170170
} else {
171171
// Otherwise, if `n` is non-negative, while `bool(i != bound_sentinel)` is true, increments `i` but at
172172
// most `n` times.
173-
while (__i != __bound_sentinel && __n > 0) {
173+
while (__n > 0 && __i != __bound_sentinel) {
174174
++__i;
175175
--__n;
176176
}
177177

178178
// Otherwise, while `bool(i != bound_sentinel)` is true, decrements `i` but at most `-n` times.
179179
if constexpr (bidirectional_iterator<_Ip> && same_as<_Ip, _Sp>) {
180-
while (__i != __bound_sentinel && __n < 0) {
180+
while (__n < 0 && __i != __bound_sentinel) {
181181
--__i;
182182
++__n;
183183
}

libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
#include "../types.h"
2222

2323
template <bool Count, typename It>
24-
constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
24+
constexpr void
25+
check_forward(int* first, int* last, std::iter_difference_t<It> n, int* expected, int expected_equals_count = -1) {
2526
using Difference = std::iter_difference_t<It>;
2627
Difference const M = (expected - first); // expected travel distance
28+
// `expected_equals_count` is only relevant when `Count` is true.
29+
assert(Count || expected_equals_count == -1);
2730

2831
{
2932
It it(first);
@@ -42,6 +45,7 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
4245
// regardless of the iterator category.
4346
assert(it.stride_count() == M);
4447
assert(it.stride_displacement() == M);
48+
assert(it.equals_count() == expected_equals_count);
4549
}
4650
}
4751

@@ -74,9 +78,20 @@ constexpr void check_forward_sized_sentinel(int* first, int* last, std::iter_dif
7478
}
7579
}
7680

77-
template <typename It>
78-
constexpr void check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
79-
static_assert(std::random_access_iterator<It>, "This test doesn't support non random access iterators");
81+
struct Expected {
82+
int stride_count;
83+
int stride_displacement;
84+
int equals_count;
85+
};
86+
87+
template <bool Count, typename It>
88+
constexpr void
89+
check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected, Expected expected_counts) {
90+
// Check preconditions for `advance` when called with negative `n`
91+
// (see [range.iter.op.advance]). In addition, allow `n == 0`.
92+
assert(n <= 0);
93+
static_assert(std::bidirectional_iterator<It>);
94+
8095
using Difference = std::iter_difference_t<It>;
8196
Difference const M = (expected - last); // expected travel distance (which is negative)
8297

@@ -92,9 +107,14 @@ constexpr void check_backward(int* first, int* last, std::iter_difference_t<It>
92107
{
93108
auto it = stride_counting_iterator(It(last));
94109
auto sent = stride_counting_iterator(It(first));
110+
static_assert(std::bidirectional_iterator<stride_counting_iterator<It>>);
111+
static_assert(Count == !std::sized_sentinel_for<It, It>);
112+
95113
(void)std::ranges::advance(it, n, sent);
96-
assert(it.stride_count() <= 1);
97-
assert(it.stride_displacement() <= 1);
114+
115+
assert(it.stride_count() == expected_counts.stride_count);
116+
assert(it.stride_displacement() == expected_counts.stride_displacement);
117+
assert(it.equals_count() == expected_counts.equals_count);
98118
}
99119
}
100120

@@ -171,13 +191,17 @@ constexpr bool test() {
171191

172192
{
173193
int* expected = n > size ? range + size : range + n;
194+
int equals_count = n > size ? size + 1 : n;
195+
196+
// clang-format off
174197
check_forward<false, cpp17_input_iterator<int*>>( range, range+size, n, expected);
175198
check_forward<false, cpp20_input_iterator<int*>>( range, range+size, n, expected);
176-
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected);
177-
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected);
178-
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected);
179-
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected);
180-
check_forward<true, int*>( range, range+size, n, expected);
199+
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected, equals_count);
200+
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected, equals_count);
201+
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected, equals_count);
202+
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected, equals_count);
203+
check_forward<true, int*>( range, range+size, n, expected, equals_count);
204+
// clang-format on
181205

182206
check_forward_sized_sentinel<cpp17_input_iterator<int*>>( range, range+size, n, expected);
183207
check_forward_sized_sentinel<cpp20_input_iterator<int*>>( range, range+size, n, expected);
@@ -188,14 +212,32 @@ constexpr bool test() {
188212
check_forward_sized_sentinel<int*>( range, range+size, n, expected);
189213
}
190214

215+
// Input and forward iterators are not tested as the backwards case does
216+
// not apply for them.
191217
{
192-
// Note that we can only test ranges::advance with a negative n for iterators that
193-
// are sized sentinels for themselves, because ranges::advance is UB otherwise.
194-
// In particular, that excludes bidirectional_iterators since those are not sized sentinels.
195218
int* expected = n > size ? range : range + size - n;
196-
check_backward<random_access_iterator<int*>>(range, range+size, -n, expected);
197-
check_backward<contiguous_iterator<int*>>( range, range+size, -n, expected);
198-
check_backward<int*>( range, range+size, -n, expected);
219+
{
220+
Expected expected_counts = {
221+
.stride_count = static_cast<int>(range + size - expected),
222+
.stride_displacement = -expected_counts.stride_count,
223+
.equals_count = n > size ? size + 1 : n,
224+
};
225+
226+
check_backward<true, bidirectional_iterator<int*>>(range, range + size, -n, expected, expected_counts);
227+
}
228+
{
229+
Expected expected_counts = {
230+
// If `n >= size`, the algorithm can just do `it = std::move(sent);`
231+
// instead of doing iterator arithmetic.
232+
.stride_count = (n >= size) ? 0 : 1,
233+
.stride_displacement = (n >= size) ? 0 : 1,
234+
.equals_count = 0,
235+
};
236+
237+
check_backward<false, random_access_iterator<int*>>(range, range + size, -n, expected, expected_counts);
238+
check_backward<false, contiguous_iterator<int*>>(range, range + size, -n, expected, expected_counts);
239+
check_backward<false, int*>(range, range + size, -n, expected, expected_counts);
240+
}
199241
}
200242
}
201243
}

libcxx/test/support/test_iterators.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,11 +725,14 @@ struct common_input_iterator {
725725
# endif // TEST_STD_VER >= 20
726726

727727
// Iterator adaptor that counts the number of times the iterator has had a successor/predecessor
728-
// operation called. Has two recorders:
728+
// operation or an equality comparison operation called. Has three recorders:
729729
// * `stride_count`, which records the total number of calls to an op++, op--, op+=, or op-=.
730730
// * `stride_displacement`, which records the displacement of the calls. This means that both
731731
// op++/op+= will increase the displacement counter by 1, and op--/op-= will decrease the
732732
// displacement counter by 1.
733+
// * `equals_count`, which records the total number of calls to an op== or op!=. If compared
734+
// against a sentinel object, that sentinel object must call the `record_equality_comparison`
735+
// function so that the comparison is counted correctly.
733736
template <class It>
734737
class stride_counting_iterator {
735738
public:
@@ -754,6 +757,8 @@ class stride_counting_iterator {
754757

755758
constexpr difference_type stride_displacement() const { return stride_displacement_; }
756759

760+
constexpr difference_type equals_count() const { return equals_count_; }
761+
757762
constexpr decltype(auto) operator*() const { return *It(base_); }
758763

759764
constexpr decltype(auto) operator[](difference_type n) const { return It(base_)[n]; }
@@ -838,10 +843,13 @@ class stride_counting_iterator {
838843
return base(x) - base(y);
839844
}
840845

846+
constexpr void record_equality_comparison() const { ++equals_count_; }
847+
841848
constexpr bool operator==(stride_counting_iterator const& other) const
842849
requires std::sentinel_for<It, It>
843850
{
844-
return It(base_) == It(other.base_);
851+
record_equality_comparison();
852+
return It(base_) == It(other.base_);
845853
}
846854

847855
friend constexpr bool operator<(stride_counting_iterator const& x, stride_counting_iterator const& y)
@@ -875,6 +883,7 @@ class stride_counting_iterator {
875883
decltype(base(std::declval<It>())) base_;
876884
difference_type stride_count_ = 0;
877885
difference_type stride_displacement_ = 0;
886+
mutable difference_type equals_count_ = 0;
878887
};
879888
template <class It>
880889
stride_counting_iterator(It) -> stride_counting_iterator<It>;
@@ -887,7 +896,14 @@ class sentinel_wrapper {
887896
public:
888897
explicit sentinel_wrapper() = default;
889898
constexpr explicit sentinel_wrapper(const It& it) : base_(base(it)) {}
890-
constexpr bool operator==(const It& other) const { return base_ == base(other); }
899+
constexpr bool operator==(const It& other) const {
900+
// If supported, record statistics about the equality operator call
901+
// inside `other`.
902+
if constexpr (requires { other.record_equality_comparison(); }) {
903+
other.record_equality_comparison();
904+
}
905+
return base_ == base(other);
906+
}
891907
friend constexpr It base(const sentinel_wrapper& s) { return It(s.base_); }
892908
private:
893909
decltype(base(std::declval<It>())) base_;

0 commit comments

Comments
 (0)