Auto merge of #130223 - LaihoE:faster_str_replace, r=thomcc

optimize str.replace

Adds a fast path for str.replace for the ascii to ascii case. This allows for autovectorizing the code. Also should this instead be done with specialization? This way we could remove one branch. I think it is the kind of branch that is easy to predict though.

Benchmark for the fast path (replace all "a" with "b" in the rust wikipedia article, using criterion) :
| N        | Speedup | Time New (ns) | Time Old (ns) |
|----------|---------|---------------|---------------|
| 2        | 2.03    | 13.567        | 27.576        |
| 8        | 1.73    | 17.478        | 30.259        |
| 11       | 2.46    | 18.296        | 45.055        |
| 16       | 2.71    | 17.181        | 46.526        |
| 37       | 4.43    | 18.526        | 81.997        |
| 64       | 8.54    | 18.670        | 159.470       |
| 200      | 9.82    | 29.634        | 291.010       |
| 2000     | 24.34   | 81.114        | 1974.300      |
| 20000    | 30.61   | 598.520       | 18318.000     |
| 1000000  | 29.31   | 33458.000     | 980540.000    |
This commit is contained in:
bors 2024-10-17 16:20:02 +00:00
commit 86bd45979a
3 changed files with 63 additions and 2 deletions

View file

@ -20,7 +20,7 @@ pub use core::str::SplitInclusive;
pub use core::str::SplitWhitespace; pub use core::str::SplitWhitespace;
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub use core::str::pattern; pub use core::str::pattern;
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher}; use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher, Utf8Pattern};
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub use core::str::{Bytes, CharIndices, Chars, from_utf8, from_utf8_mut}; pub use core::str::{Bytes, CharIndices, Chars, from_utf8, from_utf8_mut};
#[stable(feature = "str_escape", since = "1.34.0")] #[stable(feature = "str_escape", since = "1.34.0")]
@ -269,6 +269,18 @@ impl str {
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
#[inline] #[inline]
pub fn replace<P: Pattern>(&self, from: P, to: &str) -> String { pub fn replace<P: Pattern>(&self, from: P, to: &str) -> String {
// Fast path for ASCII to ASCII case.
if let Some(from_byte) = match from.as_utf8_pattern() {
Some(Utf8Pattern::StringPattern([from_byte])) => Some(*from_byte),
Some(Utf8Pattern::CharPattern(c)) => c.as_ascii().map(|ascii_char| ascii_char.to_u8()),
_ => None,
} {
if let [to_byte] = to.as_bytes() {
return unsafe { replace_ascii(self.as_bytes(), from_byte, *to_byte) };
}
}
let mut result = String::new(); let mut result = String::new();
let mut last_end = 0; let mut last_end = 0;
for (start, part) in self.match_indices(from) { for (start, part) in self.match_indices(from) {
@ -686,3 +698,14 @@ pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
(ascii_string, rest) (ascii_string, rest)
} }
} }
#[inline]
#[cfg(not(test))]
#[cfg(not(no_global_oom_handling))]
#[allow(dead_code)]
/// Faster implementation of string replacement for ASCII to ASCII cases.
/// Should produce fast vectorized code.
unsafe fn replace_ascii(utf8_bytes: &[u8], from: u8, to: u8) -> String {
let result: Vec<u8> = utf8_bytes.iter().map(|b| if *b == from { to } else { *b }).collect();
// SAFETY: We replaced ascii with ascii on valid utf8 strings.
unsafe { String::from_utf8_unchecked(result) }
}

View file

@ -53,7 +53,7 @@ use core::ops::AddAssign;
#[cfg(not(no_global_oom_handling))] #[cfg(not(no_global_oom_handling))]
use core::ops::Bound::{Excluded, Included, Unbounded}; use core::ops::Bound::{Excluded, Included, Unbounded};
use core::ops::{self, Range, RangeBounds}; use core::ops::{self, Range, RangeBounds};
use core::str::pattern::Pattern; use core::str::pattern::{Pattern, Utf8Pattern};
use core::{fmt, hash, ptr, slice}; use core::{fmt, hash, ptr, slice};
#[cfg(not(no_global_oom_handling))] #[cfg(not(no_global_oom_handling))]
@ -2436,6 +2436,11 @@ impl<'b> Pattern for &'b String {
{ {
self[..].strip_suffix_of(haystack) self[..].strip_suffix_of(haystack)
} }
#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::StringPattern(self.as_bytes()))
}
} }
macro_rules! impl_eq { macro_rules! impl_eq {

View file

@ -160,6 +160,19 @@ pub trait Pattern: Sized {
None None
} }
} }
/// Returns the pattern as utf-8 bytes if possible.
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>>;
}
/// Result of calling [`Pattern::as_utf8_pattern()`].
/// Can be used for inspecting the contents of a [`Pattern`] in cases
/// where the underlying representation can be represented as UTF-8.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum Utf8Pattern<'a> {
/// Type returned by String and str types.
StringPattern(&'a [u8]),
/// Type returned by char types.
CharPattern(char),
} }
// Searcher // Searcher
@ -599,6 +612,11 @@ impl Pattern for char {
{ {
self.encode_utf8(&mut [0u8; 4]).strip_suffix_of(haystack) self.encode_utf8(&mut [0u8; 4]).strip_suffix_of(haystack)
} }
#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::CharPattern(*self))
}
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -657,6 +675,11 @@ impl<C: MultiCharEq> Pattern for MultiCharEqPattern<C> {
fn into_searcher(self, haystack: &str) -> MultiCharEqSearcher<'_, C> { fn into_searcher(self, haystack: &str) -> MultiCharEqSearcher<'_, C> {
MultiCharEqSearcher { haystack, char_eq: self.0, char_indices: haystack.char_indices() } MultiCharEqSearcher { haystack, char_eq: self.0, char_indices: haystack.char_indices() }
} }
#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
None
}
} }
unsafe impl<'a, C: MultiCharEq> Searcher<'a> for MultiCharEqSearcher<'a, C> { unsafe impl<'a, C: MultiCharEq> Searcher<'a> for MultiCharEqSearcher<'a, C> {
@ -747,6 +770,11 @@ macro_rules! pattern_methods {
{ {
($pmap)(self).strip_suffix_of(haystack) ($pmap)(self).strip_suffix_of(haystack)
} }
#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
None
}
}; };
} }
@ -1022,6 +1050,11 @@ impl<'b> Pattern for &'b str {
None None
} }
} }
#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::StringPattern(self.as_bytes()))
}
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////