Skip to content

[libc++] std::ranges::advance: avoid unneeded bounds checks when advancing iterator #84126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 2, 2024
4 changes: 2 additions & 2 deletions libcxx/include/__iterator/advance.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ struct __fn {
} else {
// Otherwise, if `n` is non-negative, while `bool(i != bound_sentinel)` is true, increments `i` but at
// most `n` times.
while (__i != __bound_sentinel && __n > 0) {
while (__n > 0 && __i != __bound_sentinel) {
++__i;
--__n;
}

// Otherwise, while `bool(i != bound_sentinel)` is true, decrements `i` but at most `-n` times.
if constexpr (bidirectional_iterator<_Ip> && same_as<_Ip, _Sp>) {
while (__i != __bound_sentinel && __n < 0) {
while (__n < 0 && __i != __bound_sentinel) {
--__i;
++__n;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
#include "../types.h"

template <bool Count, typename It>
constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
constexpr void
check_forward(int* first, int* last, std::iter_difference_t<It> n, int* expected, int expected_equals_count = -1) {
using Difference = std::iter_difference_t<It>;
Difference const M = (expected - first); // expected travel distance
// `expected_equals_count` is only relevant when `Count` is true.
assert(Count || expected_equals_count == -1);

{
It it(first);
Expand All @@ -42,6 +45,7 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
// regardless of the iterator category.
assert(it.stride_count() == M);
assert(it.stride_displacement() == M);
assert(it.equals_count() == expected_equals_count);
}
}

Expand Down Expand Up @@ -74,9 +78,20 @@ constexpr void check_forward_sized_sentinel(int* first, int* last, std::iter_dif
}
}

template <typename It>
constexpr void check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
static_assert(std::random_access_iterator<It>, "This test doesn't support non random access iterators");
struct Expected {
int stride_count;
int stride_displacement;
int equals_count;
};

template <bool Count, typename It>
constexpr void
check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected, Expected expected_counts) {
// Check preconditions for `advance` when called with negative `n`
// (see [range.iter.op.advance]). In addition, allow `n == 0`.
assert(n <= 0);
static_assert(std::bidirectional_iterator<It>);

using Difference = std::iter_difference_t<It>;
Difference const M = (expected - last); // expected travel distance (which is negative)

Expand All @@ -92,9 +107,14 @@ constexpr void check_backward(int* first, int* last, std::iter_difference_t<It>
{
auto it = stride_counting_iterator(It(last));
auto sent = stride_counting_iterator(It(first));
static_assert(std::bidirectional_iterator<stride_counting_iterator<It>>);
static_assert(Count == !std::sized_sentinel_for<It, It>);

(void)std::ranges::advance(it, n, sent);
assert(it.stride_count() <= 1);
assert(it.stride_displacement() <= 1);

assert(it.stride_count() == expected_counts.stride_count);
assert(it.stride_displacement() == expected_counts.stride_displacement);
assert(it.equals_count() == expected_counts.equals_count);
}
}

Expand Down Expand Up @@ -171,13 +191,17 @@ constexpr bool test() {

{
int* expected = n > size ? range + size : range + n;
int equals_count = n > size ? size + 1 : n;

// clang-format off
check_forward<false, cpp17_input_iterator<int*>>( range, range+size, n, expected);
check_forward<false, cpp20_input_iterator<int*>>( range, range+size, n, expected);
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected);
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected);
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected);
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected);
check_forward<true, int*>( range, range+size, n, expected);
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected, equals_count);
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected, equals_count);
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected, equals_count);
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected, equals_count);
check_forward<true, int*>( range, range+size, n, expected, equals_count);
// clang-format on

check_forward_sized_sentinel<cpp17_input_iterator<int*>>( range, range+size, n, expected);
check_forward_sized_sentinel<cpp20_input_iterator<int*>>( range, range+size, n, expected);
Expand All @@ -188,14 +212,32 @@ constexpr bool test() {
check_forward_sized_sentinel<int*>( range, range+size, n, expected);
}

// Input and forward iterators are not tested as the backwards case does
// not apply for them.
{
// Note that we can only test ranges::advance with a negative n for iterators that
// are sized sentinels for themselves, because ranges::advance is UB otherwise.
// In particular, that excludes bidirectional_iterators since those are not sized sentinels.
int* expected = n > size ? range : range + size - n;
check_backward<random_access_iterator<int*>>(range, range+size, -n, expected);
check_backward<contiguous_iterator<int*>>( range, range+size, -n, expected);
check_backward<int*>( range, range+size, -n, expected);
{
Expected expected_counts = {
.stride_count = static_cast<int>(range + size - expected),
.stride_displacement = -expected_counts.stride_count,
.equals_count = n > size ? size + 1 : n,
};

check_backward<true, bidirectional_iterator<int*>>(range, range + size, -n, expected, expected_counts);
}
{
Expected expected_counts = {
// If `n >= size`, the algorithm can just do `it = std::move(sent);`
// instead of doing iterator arithmetic.
.stride_count = (n >= size) ? 0 : 1,
.stride_displacement = (n >= size) ? 0 : 1,
.equals_count = 0,
};

check_backward<false, random_access_iterator<int*>>(range, range + size, -n, expected, expected_counts);
check_backward<false, contiguous_iterator<int*>>(range, range + size, -n, expected, expected_counts);
check_backward<false, int*>(range, range + size, -n, expected, expected_counts);
}
}
}
}
Expand Down
22 changes: 19 additions & 3 deletions libcxx/test/support/test_iterators.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,14 @@ struct common_input_iterator {
# endif // TEST_STD_VER >= 20

// Iterator adaptor that counts the number of times the iterator has had a successor/predecessor
// operation called. Has two recorders:
// operation or an equality comparison operation called. Has three recorders:
// * `stride_count`, which records the total number of calls to an op++, op--, op+=, or op-=.
// * `stride_displacement`, which records the displacement of the calls. This means that both
// op++/op+= will increase the displacement counter by 1, and op--/op-= will decrease the
// displacement counter by 1.
// * `equals_count`, which records the total number of calls to an op== or op!=. If compared
// against a sentinel object, that sentinel object must call the `record_equality_comparison`
// function so that the comparison is counted correctly.
template <class It>
class stride_counting_iterator {
public:
Expand All @@ -754,6 +757,8 @@ class stride_counting_iterator {

constexpr difference_type stride_displacement() const { return stride_displacement_; }

constexpr difference_type equals_count() const { return equals_count_; }

constexpr decltype(auto) operator*() const { return *It(base_); }

constexpr decltype(auto) operator[](difference_type n) const { return It(base_)[n]; }
Expand Down Expand Up @@ -838,10 +843,13 @@ class stride_counting_iterator {
return base(x) - base(y);
}

constexpr void record_equality_comparison() const { ++equals_count_; }

constexpr bool operator==(stride_counting_iterator const& other) const
requires std::sentinel_for<It, It>
{
return It(base_) == It(other.base_);
record_equality_comparison();
return It(base_) == It(other.base_);
}

friend constexpr bool operator<(stride_counting_iterator const& x, stride_counting_iterator const& y)
Expand Down Expand Up @@ -875,6 +883,7 @@ class stride_counting_iterator {
decltype(base(std::declval<It>())) base_;
difference_type stride_count_ = 0;
difference_type stride_displacement_ = 0;
mutable difference_type equals_count_ = 0;
};
template <class It>
stride_counting_iterator(It) -> stride_counting_iterator<It>;
Expand All @@ -887,7 +896,14 @@ class sentinel_wrapper {
public:
explicit sentinel_wrapper() = default;
constexpr explicit sentinel_wrapper(const It& it) : base_(base(it)) {}
constexpr bool operator==(const It& other) const { return base_ == base(other); }
constexpr bool operator==(const It& other) const {
// If supported, record statistics about the equality operator call
// inside `other`.
if constexpr (requires { other.record_equality_comparison(); }) {
other.record_equality_comparison();
}
return base_ == base(other);
}
friend constexpr It base(const sentinel_wrapper& s) { return It(s.base_); }
private:
decltype(base(std::declval<It>())) base_;
Expand Down