Skip to content

Stop peeling the last iteration of the loop in Vec::resize_with #104818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions library/alloc/src/vec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2874,13 +2874,12 @@ impl<T, A: Allocator> Vec<T, A> {
);
self.reserve(additional);
unsafe {
let mut ptr = self.as_mut_ptr().add(self.len());
let ptr = self.as_mut_ptr();
let mut local_len = SetLenOnDrop::new(&mut self.len);
iterator.for_each(move |element| {
ptr::write(ptr, element);
ptr = ptr.add(1);
// Since the loop executes user code which can panic we have to bump the pointer
// after each step.
ptr::write(ptr.add(local_len.current_len()), element);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pointless, but seems like it reduces register pressure since we have to maintain the local_len for panic safety anyway.

Before:

.LBB4_3:
	mov	qword ptr [rbp - 16], rax
	mov	qword ptr [rbp - 24], rcx    ; Note Spill
	call	make_thing
	mov	rcx, qword ptr [rbp - 24]
	mov	dword ptr [rcx], eax
	add	rcx, 4
	mov	rax, qword ptr [rbp - 16]
	dec	rax
	cmp	rsi, rax
	jne	.LBB4_3

After:

.LBB6_3:
	mov	qword ptr [rbp - 16], rax
	call	make_thing
	mov	rdx, qword ptr [rbp - 16]
	lea	rcx, [rdx + 1]
	mov	dword ptr [rbx + 4*rdx], eax    ; Note addressing mode
	mov	rax, rcx
	dec	rsi
	jne	.LBB6_3

(And LLVM understands that kind of loop very well too, since it's a for i in A..B { v[i] = foo(); } loop that's common all over the place.)

// Since the loop executes user code which can panic we have to update
// the length every step to correctly drop what we've written.
// NB can't overflow since we would have had to alloc the address space
local_len.increment_len(1);
});
Expand Down
5 changes: 5 additions & 0 deletions library/alloc/src/vec/set_len_on_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> {
pub(super) fn increment_len(&mut self, increment: usize) {
self.local_len += increment;
}

#[inline]
pub(super) fn current_len(&self) -> usize {
self.local_len
}
}

impl Drop for SetLenOnDrop<'_> {
Expand Down
21 changes: 20 additions & 1 deletion library/core/src/iter/adapters/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ where
#[inline]
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
Expand All @@ -100,6 +99,26 @@ where

impl_fold_via_try_fold! { fold -> try_fold }

#[inline]
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
// The default implementation would use a unit accumulator, so we can
// avoid a stateful closure by folding over the remaining number
// of items we wish to return instead.
fn check<'a, Item>(
mut action: impl FnMut(Item) + 'a,
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
move |more, x| {
action(x);
more.checked_sub(1)
}
}
Comment on lines +107 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A pity that the duplication-by-unused-generics issues still haven't been fixed 😮‍💨


let remaining = self.n;
if remaining > 0 {
self.iter.try_fold(remaining - 1, check(f));
}
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
Expand Down
17 changes: 17 additions & 0 deletions library/core/src/iter/sources/repeat_with.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::iter::{FusedIterator, TrustedLen};
use crate::ops::Try;

/// Creates a new iterator that repeats elements of type `A` endlessly by
/// applying the provided closure, the repeater, `F: FnMut() -> A`.
Expand Down Expand Up @@ -89,6 +90,22 @@ impl<A, F: FnMut() -> A> Iterator for RepeatWith<F> {
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::MAX, None)
}

#[inline]
fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
where
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
// This override isn't strictly needed, but avoids the need to optimize
// away the `next`-always-returns-`Some` and emphasizes that the `?`
// is the only way to exit the loop.

loop {
let item = (self.repeater)();
init = fold(init, item)?;
}
}
}

#[stable(feature = "iterator_repeat_with", since = "1.28.0")]
Expand Down
20 changes: 20 additions & 0 deletions library/core/tests/iter/adapters/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,23 @@ fn test_take_try_folds() {
assert_eq!(iter.try_for_each(Err), Err(2));
assert_eq!(iter.try_for_each(Err), Ok(()));
}

#[test]
fn test_byref_take_consumed_items() {
let mut inner = 10..90;

let mut count = 0;
inner.by_ref().take(0).for_each(|_| count += 1);
assert_eq!(count, 0);
assert_eq!(inner, 10..90);

let mut count = 0;
inner.by_ref().take(10).for_each(|_| count += 1);
assert_eq!(count, 10);
assert_eq!(inner, 20..90);

let mut count = 0;
inner.by_ref().take(100).for_each(|_| count += 1);
assert_eq!(count, 70);
assert_eq!(inner, 90..90);
}
7 changes: 7 additions & 0 deletions src/test/codegen/repeat-trusted-len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec<u8> {
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false)
iter::repeat(42).take(100000).collect()
}

// CHECK-LABEL: @repeat_with_take_collect
#[no_mangle]
pub fn repeat_with_take_collect() -> Vec<u8> {
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false)
iter::repeat_with(|| 13).take(12345).collect()
}