Skip to content

Commit bef24a4

Browse files
authored
Rollup merge of rust-lang#123778 - jhorstmann:optimize-upper-lower-auto-vectorization, r=the8472
Improve autovectorization of to_lowercase / to_uppercase functions Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on a microbenchmark with a small ascii-only input decreases from ~55ns to ~18ns per iteration. The new implementation also reduces the amount of unsafe code and encapsulates all unsafe inside the helper function. Fixes rust-lang#123712
2 parents bf750f5 + 623b9c8 commit bef24a4

File tree

4 files changed

+124
-52
lines changed

4 files changed

+124
-52
lines changed

library/alloc/benches/str.rs

+2
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count());
347347

348348
make_test!(split_space_str, s, s.split(" ").count());
349349
make_test!(split_ad_str, s, s.split("ad").count());
350+
351+
make_test!(to_lowercase, s, s.to_lowercase());

library/alloc/src/str.rs

+73-52
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use core::borrow::{Borrow, BorrowMut};
1111
use core::iter::FusedIterator;
1212
use core::mem;
13+
use core::mem::MaybeUninit;
1314
use core::ptr;
1415
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
1516
use core::unicode::conversions;
@@ -367,14 +368,9 @@ impl str {
367368
without modifying the original"]
368369
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
369370
pub fn to_lowercase(&self) -> String {
370-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
371+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
371372

372-
// Safety: we know this is a valid char boundary since
373-
// out.len() is only progressed if ascii bytes are found
374-
let rest = unsafe { self.get_unchecked(out.len()..) };
375-
376-
// Safety: We have written only valid ASCII to our vec
377-
let mut s = unsafe { String::from_utf8_unchecked(out) };
373+
let prefix_len = s.len();
378374

379375
for (i, c) in rest.char_indices() {
380376
if c == 'Σ' {
@@ -383,8 +379,7 @@ impl str {
383379
// in `SpecialCasing.txt`,
384380
// so hard-code it rather than have a generic "condition" mechanism.
385381
// See https://github.com/rust-lang/rust/issues/26035
386-
let out_len = self.len() - rest.len();
387-
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
382+
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
388383
s.push(sigma_lowercase);
389384
} else {
390385
match conversions::to_lower(c) {
@@ -460,14 +455,7 @@ impl str {
460455
without modifying the original"]
461456
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
462457
pub fn to_uppercase(&self) -> String {
463-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
464-
465-
// Safety: we know this is a valid char boundary since
466-
// out.len() is only progressed if ascii bytes are found
467-
let rest = unsafe { self.get_unchecked(out.len()..) };
468-
469-
// Safety: We have written only valid ASCII to our vec
470-
let mut s = unsafe { String::from_utf8_unchecked(out) };
458+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
471459

472460
for c in rest.chars() {
473461
match conversions::to_upper(c) {
@@ -616,50 +604,83 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
616604
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
617605
}
618606

619-
/// Converts the bytes while the bytes are still ascii.
607+
/// Converts leading ascii bytes in `s` by calling the `convert` function.
608+
///
620609
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
621-
/// Returns a vec with the converted bytes.
610+
///
611+
/// Returns a tuple of the converted prefix and the remainder starting from
612+
/// the first non-ascii character.
622613
#[inline]
623614
#[cfg(not(test))]
624615
#[cfg(not(no_global_oom_handling))]
625-
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
626-
let mut out = Vec::with_capacity(b.len());
616+
fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
617+
// Process the input in chunks of 16 bytes to enable auto-vectorization.
618+
// Previously the chunk size depended on the size of `usize`,
619+
// but on 32-bit platforms with sse or neon is also the better choice.
620+
// The only downside on other platforms would be a bit more loop-unrolling.
621+
const N: usize = 16;
622+
623+
let mut slice = s.as_bytes();
624+
let mut out = Vec::with_capacity(slice.len());
625+
let mut out_slice = out.spare_capacity_mut();
626+
627+
let mut ascii_prefix_len = 0_usize;
628+
let mut is_ascii = [false; N];
629+
630+
while slice.len() >= N {
631+
// Safety: checked in loop condition
632+
let chunk = unsafe { slice.get_unchecked(..N) };
633+
// Safety: out_slice has at least same length as input slice and gets sliced with the same offsets
634+
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
635+
636+
for j in 0..N {
637+
is_ascii[j] = chunk[j] <= 127;
638+
}
627639

628-
const USIZE_SIZE: usize = mem::size_of::<usize>();
629-
const MAGIC_UNROLL: usize = 2;
630-
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
631-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
640+
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
641+
// size gives the best result, specifically a pmovmsk instruction on x86.
642+
// There is a codegen test in `issue-123712-str-to-lower-autovectorization.rs` which should
643+
// be updated when this method is changed.
644+
// See also https://github.com/llvm/llvm-project/issues/96395
645+
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
646+
break;
647+
}
632648

633-
let mut i = 0;
634-
unsafe {
635-
while i + N <= b.len() {
636-
// Safety: we have checks the sizes `b` and `out` to know that our
637-
let in_chunk = b.get_unchecked(i..i + N);
638-
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
639-
640-
let mut bits = 0;
641-
for j in 0..MAGIC_UNROLL {
642-
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
643-
// safety: in_chunk is valid bytes in the range
644-
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
645-
}
646-
// if our chunks aren't ascii, then return only the prior bytes as init
647-
if bits & NONASCII_MASK != 0 {
648-
break;
649-
}
649+
for j in 0..N {
650+
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
651+
}
650652

651-
// perform the case conversions on N bytes (gets heavily autovec'd)
652-
for j in 0..N {
653-
// safety: in_chunk and out_chunk is valid bytes in the range
654-
let out = out_chunk.get_unchecked_mut(j);
655-
out.write(convert(in_chunk.get_unchecked(j)));
656-
}
653+
ascii_prefix_len += N;
654+
slice = unsafe { slice.get_unchecked(N..) };
655+
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
656+
}
657657

658-
// mark these bytes as initialised
659-
i += N;
658+
// handle the remainder as individual bytes
659+
while slice.len() > 0 {
660+
let byte = slice[0];
661+
if byte > 127 {
662+
break;
660663
}
661-
out.set_len(i);
664+
// Safety: out_slice has same length as input slice and gets sliced with the same offsets
665+
unsafe {
666+
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
667+
}
668+
ascii_prefix_len += 1;
669+
slice = unsafe { slice.get_unchecked(1..) };
670+
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
662671
}
663672

664-
out
673+
unsafe {
674+
// SAFETY: ascii_prefix_len bytes have been initialized above
675+
out.set_len(ascii_prefix_len);
676+
677+
// SAFETY: We have written only valid ascii to the output vec
678+
let ascii_string = String::from_utf8_unchecked(out);
679+
680+
// SAFETY: we know this is a valid char boundary
681+
// since we only skipped over leading ascii bytes
682+
let rest = core::str::from_utf8_unchecked(slice);
683+
684+
(ascii_string, rest)
685+
}
665686
}

library/alloc/tests/str.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,10 @@ fn to_lowercase() {
18491849
assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");
18501850

18511851
// https://github.com/rust-lang/rust/issues/124714
1852+
// input lengths around the boundary of the chunk size used by the ascii prefix optimization
1853+
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
18521854
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
1855+
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");
18531856

18541857
// a really long string that has it's lowercase form
18551858
// even longer. this tests that implementations don't assume
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//@ compile-flags: -Copt-level=3
2+
#![crate_type = "lib"]
3+
4+
/// Ensure that the ascii-prefix loop for `str::to_lowercase` and `str::to_uppercase` uses vector
5+
/// instructions. Since these methods do not get inlined, the relevant code is duplicated here and
6+
/// should be updated when the implementation changes.
7+
// CHECK-LABEL: @lower_while_ascii
8+
// CHECK: [[A:%[0-9]]] = load <16 x i8>
9+
// CHECK-NEXT: [[B:%[0-9]]] = icmp slt <16 x i8> [[A]], zeroinitializer
10+
// CHECK-NEXT: [[C:%[0-9]]] = bitcast <16 x i1> [[B]] to i16
11+
#[no_mangle]
12+
pub fn lower_while_ascii(mut input: &[u8], mut output: &mut [u8]) -> usize {
13+
// Process the input in chunks to enable auto-vectorization.
14+
const N: usize = 16;
15+
16+
output = &mut output[..input.len()];
17+
18+
let mut ascii_prefix_len = 0_usize;
19+
let mut is_ascii = [false; N];
20+
21+
while input.len() >= N {
22+
let chunk = unsafe { input.get_unchecked(..N) };
23+
let out_chunk = unsafe { output.get_unchecked_mut(..N) };
24+
25+
for j in 0..N {
26+
is_ascii[j] = chunk[j] <= 127;
27+
}
28+
29+
// auto-vectorization for this check is a bit fragile,
30+
// sum and comparing against the chunk size gives the best result,
31+
// specifically a pmovmsk instruction on x86.
32+
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
33+
break;
34+
}
35+
36+
for j in 0..N {
37+
out_chunk[j] = chunk[j].to_ascii_lowercase();
38+
}
39+
40+
ascii_prefix_len += N;
41+
input = unsafe { input.get_unchecked(N..) };
42+
output = unsafe { output.get_unchecked_mut(N..) };
43+
}
44+
45+
ascii_prefix_len
46+
}

0 commit comments

Comments
 (0)