Skip to content

Commit 71a1558

Browse files
committed
Encode core::str::CharSearcher::utf8_size as enum
1 parent e927184 commit 71a1558

File tree

1 file changed

+55
-13
lines changed

1 file changed

+55
-13
lines changed

library/core/src/str/pattern.rs

+55-13
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,45 @@ pub trait DoubleEndedSearcher<'a>: ReverseSearcher<'a> {}
348348
// Impl for char
349349
/////////////////////////////////////////////////////////////////////////////
350350

351+
#[derive(Clone, Copy, Debug)]
352+
enum Utf8Size {
353+
// Values are indexes, so `- 1`
354+
One = 0,
355+
Two = 1,
356+
Three = 2,
357+
Four = 3,
358+
}
359+
360+
impl Utf8Size {
361+
fn new(size: usize) -> Option<Self> {
362+
match size {
363+
1 => Some(Self::One),
364+
2 => Some(Self::Two),
365+
3 => Some(Self::Three),
366+
4 => Some(Self::Four),
367+
_ => None,
368+
}
369+
}
370+
371+
// # Safety
372+
//
373+
// `size` must be more than `0` and less than `5`
374+
unsafe fn new_unchecked(size: usize) -> Self {
375+
// SAFETY: Invariant held by caller
376+
unsafe { Self::new(size).unwrap_unchecked() }
377+
}
378+
379+
/// Returns the size value
380+
fn get_raw(self) -> usize {
381+
(self as usize) + 1
382+
}
383+
384+
fn index(self, arr: &[u8; 4]) -> &u8 {
385+
// SAFETY: max value is 3, which indexes to the 4th element.
386+
unsafe { arr.get_unchecked(self as usize) }
387+
}
388+
}
389+
351390
/// Associated type for `<char as Pattern<'a>>::Searcher`.
352391
#[derive(Clone, Debug)]
353392
pub struct CharSearcher<'a> {
@@ -368,9 +407,8 @@ pub struct CharSearcher<'a> {
368407
/// The character being searched for
369408
needle: char,
370409

371-
// safety invariant: `utf8_size` must be less than 5
372410
/// The number of bytes `needle` takes up when encoded in utf8.
373-
utf8_size: usize,
411+
utf8_size: Utf8Size,
374412
/// A utf8 encoded copy of the `needle`
375413
utf8_encoded: [u8; 4],
376414
}
@@ -413,8 +451,7 @@ unsafe impl<'a> Searcher<'a> for CharSearcher<'a> {
413451
// get the haystack after the last character found
414452
let bytes = self.haystack.as_bytes().get(self.finger..self.finger_back)?;
415453
// the last byte of the utf8 encoded needle
416-
// SAFETY: we have an invariant that `utf8_size < 5`
417-
let last_byte = unsafe { *self.utf8_encoded.get_unchecked(self.utf8_size - 1) };
454+
let last_byte = *self.utf8_size.index(&self.utf8_encoded);
418455
if let Some(index) = memchr::memchr(last_byte, bytes) {
419456
// The new finger is the index of the byte we found,
420457
// plus one, since we memchr'd for the last byte of the character.
@@ -434,10 +471,12 @@ unsafe impl<'a> Searcher<'a> for CharSearcher<'a> {
434471
// find something. When we find something the `finger` will be set
435472
// to a UTF8 boundary.
436473
self.finger += index + 1;
437-
if self.finger >= self.utf8_size {
438-
let found_char = self.finger - self.utf8_size;
474+
475+
let utf8_size = self.utf8_size.get_raw();
476+
if self.finger >= utf8_size {
477+
let found_char = self.finger - utf8_size;
439478
if let Some(slice) = self.haystack.as_bytes().get(found_char..self.finger) {
440-
if slice == &self.utf8_encoded[0..self.utf8_size] {
479+
if slice == &self.utf8_encoded[0..utf8_size] {
441480
return Some((found_char, self.finger));
442481
}
443482
}
@@ -481,8 +520,7 @@ unsafe impl<'a> ReverseSearcher<'a> for CharSearcher<'a> {
481520
// get the haystack up to but not including the last character searched
482521
let bytes = haystack.get(self.finger..self.finger_back)?;
483522
// the last byte of the utf8 encoded needle
484-
// SAFETY: we have an invariant that `utf8_size < 5`
485-
let last_byte = unsafe { *self.utf8_encoded.get_unchecked(self.utf8_size - 1) };
523+
let last_byte = *self.utf8_size.index(&self.utf8_encoded);
486524
if let Some(index) = memchr::memrchr(last_byte, bytes) {
487525
// we searched a slice that was offset by self.finger,
488526
// add self.finger to recoup the original index
@@ -493,14 +531,15 @@ unsafe impl<'a> ReverseSearcher<'a> for CharSearcher<'a> {
493531
// char in the paradigm of reverse iteration). For
494532
// multibyte chars we need to skip down by the number of more
495533
// bytes they have than ASCII
496-
let shift = self.utf8_size - 1;
534+
let utf8_size = self.utf8_size.get_raw();
535+
let shift = utf8_size - 1;
497536
if index >= shift {
498537
let found_char = index - shift;
499-
if let Some(slice) = haystack.get(found_char..(found_char + self.utf8_size)) {
500-
if slice == &self.utf8_encoded[0..self.utf8_size] {
538+
if let Some(slice) = haystack.get(found_char..(found_char + utf8_size)) {
539+
if slice == &self.utf8_encoded[0..utf8_size] {
501540
// move finger to before the character found (i.e., at its start index)
502541
self.finger_back = found_char;
503-
return Some((self.finger_back, self.finger_back + self.utf8_size));
542+
return Some((self.finger_back, self.finger_back + utf8_size));
504543
}
505544
}
506545
}
@@ -543,6 +582,9 @@ impl<'a> Pattern<'a> for char {
543582
fn into_searcher(self, haystack: &'a str) -> Self::Searcher {
544583
let mut utf8_encoded = [0; 4];
545584
let utf8_size = self.encode_utf8(&mut utf8_encoded).len();
585+
586+
// SAFETY: utf8_size is above 0 and below 5
587+
let utf8_size = unsafe { Utf8Size::new_unchecked(utf8_size) };
546588
CharSearcher {
547589
haystack,
548590
finger: 0,

0 commit comments

Comments
 (0)