Skip to content

[libc++] std::ranges::advance: avoid unneeded bounds checks when advancing iterator #84126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
#include "../types.h"

template <bool Count, typename It>
constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
constexpr void check_forward(
int* first, int* last, std::iter_difference_t<It> n, int* expected, std::ptrdiff_t expected_equals_count) {
using Difference = std::iter_difference_t<It>;
Difference const M = (expected - first); // expected travel distance
// `expected_equals_count` is only relevant when `Count` is true.
assert(Count || expected_equals_count == -1);

{
It it(first);
Expand All @@ -42,18 +45,7 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
// regardless of the iterator category.
assert(it.stride_count() == M);
assert(it.stride_displacement() == M);
if (n == 0) {
assert(it.equals_count() == 0);
} else {
if (n > M) {
// We "hit" the bound, so there is one extra equality check.
assert(it.equals_count() == M + 1);
} else {
assert(it.equals_count() == M);
}
// In any case, there must not be more than `n` bounds checks.
assert(it.equals_count() <= n);
}
assert(it.equals_count() == expected_equals_count);
}
}

Expand Down Expand Up @@ -86,10 +78,17 @@ constexpr void check_forward_sized_sentinel(int* first, int* last, std::iter_dif
}
}

template <typename It>
constexpr void check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
// Check preconditions for `advance` when called with negative `n`:
// <https://eel.is/c++draft/iterators#range.iter.op.advance-5>
template <bool Count, typename It>
constexpr void check_backward(
int* first,
int* last,
std::iter_difference_t<It> n,
int* expected,
std::ptrdiff_t expected_stride_count,
std::ptrdiff_t expected_stride_displacement,
std::ptrdiff_t expected_equals_count) {
// Check preconditions for `advance` when called with negative `n`
// (see [range.iter.op.advance]).
assert(n < 0);
static_assert(std::bidirectional_iterator<It>);

Expand All @@ -109,33 +108,13 @@ constexpr void check_backward(int* first, int* last, std::iter_difference_t<It>
auto it = stride_counting_iterator(It(last));
auto sent = stride_counting_iterator(It(first));
static_assert(std::bidirectional_iterator<stride_counting_iterator<It>>);
static_assert(Count == !std::sized_sentinel_for<It, It>);

(void)std::ranges::advance(it, n, sent);

if constexpr (std::sized_sentinel_for<It, It>) {
if (expected == first) {
// In this case, the algorithm can just do `it = std::move(sent);`
// instead of doing iterator arithmetic:
// <https://eel.is/c++draft/iterators#range.iter.op.advance-4.1>
assert(it.stride_count() == 0);
assert(it.stride_displacement() == 0);
} else {
assert(it.stride_count() == 1);
assert(it.stride_displacement() == 1);
}
assert(it.equals_count() == 0);
} else {
assert(it.stride_count() == -M);
assert(it.stride_displacement() == M);
if (-n > -M) {
// We "hit" the bound, so there is one extra equality check.
assert(it.equals_count() == -M + 1);
} else {
assert(it.equals_count() == -M);
}
// In any case, there must not be more than `-n` bounds checks.
assert(it.equals_count() <= -n);
}
assert(it.stride_count() == expected_stride_count);
assert(it.stride_displacement() == expected_stride_displacement);
assert(it.equals_count() == expected_equals_count);
}
}

Expand Down Expand Up @@ -212,13 +191,17 @@ constexpr bool test() {

{
int* expected = n > size ? range + size : range + n;
check_forward<false, cpp17_input_iterator<int*>>( range, range+size, n, expected);
check_forward<false, cpp20_input_iterator<int*>>( range, range+size, n, expected);
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected);
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected);
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected);
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected);
check_forward<true, int*>( range, range+size, n, expected);
std::ptrdiff_t equals_count = n > size ? size + 1 : n;

// clang-format off
check_forward<false, cpp17_input_iterator<int*>>( range, range+size, n, expected, -1);
check_forward<false, cpp20_input_iterator<int*>>( range, range+size, n, expected, -1);
check_forward<true, forward_iterator<int*>>( range, range+size, n, expected, equals_count);
check_forward<true, bidirectional_iterator<int*>>(range, range+size, n, expected, equals_count);
check_forward<true, random_access_iterator<int*>>(range, range+size, n, expected, equals_count);
check_forward<true, contiguous_iterator<int*>>( range, range+size, n, expected, equals_count);
check_forward<true, int*>( range, range+size, n, expected, equals_count);
// clang-format on

check_forward_sized_sentinel<cpp17_input_iterator<int*>>( range, range+size, n, expected);
check_forward_sized_sentinel<cpp20_input_iterator<int*>>( range, range+size, n, expected);
Expand All @@ -229,15 +212,34 @@ constexpr bool test() {
check_forward_sized_sentinel<int*>( range, range+size, n, expected);
}

// Exclude the `n == 0` case for the backwards checks.
// Exclude the `n == 0` case for the backwards checks (this is tested by
// the forward tests above).
// Input and forward iterators are not tested as the backwards case does
// not apply for them.
if (n > 0) {
int* expected = n > size ? range : range + size - n;
check_backward<bidirectional_iterator<int*>>(range, range + size, -n, expected);
check_backward<random_access_iterator<int*>>(range, range+size, -n, expected);
check_backward<contiguous_iterator<int*>>( range, range+size, -n, expected);
check_backward<int*>( range, range+size, -n, expected);
{
std::ptrdiff_t stride_count = range + size - expected;
std::ptrdiff_t stride_displacement = -stride_count;
std::ptrdiff_t equals_count = n > size ? size + 1 : n;

check_backward<true, bidirectional_iterator<int*>>(
range, range + size, -n, expected, stride_count, stride_displacement, equals_count);
}
{
// If `n >= size`, the algorithm can just do `it = std::move(sent);`
// instead of doing iterator arithmetic.
std::ptrdiff_t stride_count = (n >= size) ? 0 : 1;
std::ptrdiff_t stride_displacement = (n >= size) ? 0 : 1;
std::ptrdiff_t equals_count = 0;

check_backward<false, random_access_iterator<int*>>(
range, range + size, -n, expected, stride_count, stride_displacement, equals_count);
check_backward<false, contiguous_iterator<int*>>(
range, range + size, -n, expected, stride_count, stride_displacement, equals_count);
check_backward<false, int*>(
range, range + size, -n, expected, stride_count, stride_displacement, equals_count);
}
}
}
}
Expand Down