x86_64 SSE2 fast-path for str.contains(&str) and short needles

Based on Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]

The two-way algorithm is Big-O efficient but it needs to preprocess the needle
to find a "criticla factorization" of it. This additional work is significant
for short needles. Additionally it mostly advances needle.len() bytes at a time.

The SIMD-based approach used here on the other hand can advance based on its
vector width, which can exceed the needle length. Except for pathological cases,
but due to being limited to small needles the worst case blowup is also small.

benchmarks taken on a Zen2:

```
16CGU, OLD:
test str::bench_contains_short_short                     ... bench:          27 ns/iter (+/- 1)
test str::bench_contains_short_long                      ... bench:         667 ns/iter (+/- 29)
test str::bench_contains_bad_naive                       ... bench:         131 ns/iter (+/- 2)
test str::bench_contains_bad_simd                        ... bench:         130 ns/iter (+/- 2)
test str::bench_contains_equal                           ... bench:         148 ns/iter (+/- 4)


16CGU, NEW:
test str::bench_contains_short_short                     ... bench:           8 ns/iter (+/- 0)
test str::bench_contains_short_long                      ... bench:         135 ns/iter (+/- 4)
test str::bench_contains_bad_naive                       ... bench:         130 ns/iter (+/- 2)
test str::bench_contains_bad_simd                        ... bench:         292 ns/iter (+/- 1)
test str::bench_contains_equal                           ... bench:           3 ns/iter (+/- 0)


1CGU, OLD:
test str::bench_contains_short_short                     ... bench:          30 ns/iter (+/- 0)
test str::bench_contains_short_long                      ... bench:         713 ns/iter (+/- 17)
test str::bench_contains_bad_naive                       ... bench:         131 ns/iter (+/- 3)
test str::bench_contains_bad_simd                        ... bench:         130 ns/iter (+/- 3)
test str::bench_contains_equal                           ... bench:         148 ns/iter (+/- 6)

1CGU, NEW:
test str::bench_contains_short_short                     ... bench:          10 ns/iter (+/- 0)
test str::bench_contains_short_long                      ... bench:         111 ns/iter (+/- 0)
test str::bench_contains_bad_naive                       ... bench:         135 ns/iter (+/- 3)
test str::bench_contains_bad_simd                        ... bench:         274 ns/iter (+/- 2)
test str::bench_contains_equal                           ... bench:           4 ns/iter (+/- 0)
```


[0] http://0x80.pl/articles/simd-strfind.html#sse-avx2
This commit is contained in:
The 8472 2022-10-30 21:47:04 +01:00
parent 467b299e53
commit 3d4a8482b9

View File

@ -39,6 +39,7 @@
)] )]
use crate::cmp; use crate::cmp;
use crate::cmp::Ordering;
use crate::fmt; use crate::fmt;
use crate::slice::memchr; use crate::slice::memchr;
@ -946,6 +947,27 @@ impl<'a, 'b> Pattern<'a> for &'b str {
haystack.as_bytes().starts_with(self.as_bytes()) haystack.as_bytes().starts_with(self.as_bytes())
} }
/// Checks whether the pattern matches anywhere in the haystack
#[inline]
fn is_contained_in(self, haystack: &'a str) -> bool {
if self.len() == 0 {
return true;
}
match self.len().cmp(&haystack.len()) {
Ordering::Less => {
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
if self.as_bytes().len() <= 8 {
return simd_contains(self, haystack);
}
self.into_searcher(haystack).next_match().is_some()
}
Ordering::Equal => self == haystack,
Ordering::Greater => false,
}
}
/// Removes the pattern from the front of haystack, if it matches. /// Removes the pattern from the front of haystack, if it matches.
#[inline] #[inline]
fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> { fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
@ -1684,3 +1706,83 @@ impl TwoWayStrategy for RejectAndMatch {
SearchStep::Match(a, b) SearchStep::Match(a, b)
} }
} }
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
#[inline]
fn simd_contains(needle: &str, haystack: &str) -> bool {
let needle = needle.as_bytes();
let haystack = haystack.as_bytes();
if needle.len() == 1 {
return haystack.contains(&needle[0]);
}
const CHUNK: usize = 16;
// do a naive search if if the haystack is too small to fit
if haystack.len() < CHUNK + needle.len() - 1 {
return haystack.windows(needle.len()).any(|c| c == needle);
}
use crate::arch::x86_64::{
__m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
};
// SAFETY: no preconditions other than sse2 being available
let first: __m128i = unsafe { _mm_set1_epi8(needle[0] as i8) };
// SAFETY: no preconditions other than sse2 being available
let last: __m128i = unsafe { _mm_set1_epi8(*needle.last().unwrap() as i8) };
let check_mask = #[cold]
|idx, mut mask: u32| -> bool {
while mask != 0 {
let trailing = mask.trailing_zeros();
let offset = idx + trailing as usize + 1;
let sub = &haystack[offset..][..needle.len() - 2];
let trimmed_needle = &needle[1..needle.len() - 1];
if sub == trimmed_needle {
return true;
}
mask &= !(1 << trailing);
}
return false;
};
let test_chunk = |i| -> bool {
// SAFETY: this requires at least CHUNK bytes being readable at offset i
// that is ensured by the loop ranges (see comments below)
let a: __m128i = unsafe { _mm_loadu_si128(haystack.as_ptr().add(i) as *const _) };
let b: __m128i =
// SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i
unsafe { _mm_loadu_si128(haystack.as_ptr().add(i + needle.len() - 1) as *const _) };
// SAFETY: no preconditions other than sse2 being available
let eq_first: __m128i = unsafe { _mm_cmpeq_epi8(first, a) };
// SAFETY: no preconditions other than sse2 being available
let eq_last: __m128i = unsafe { _mm_cmpeq_epi8(last, b) };
// SAFETY: no preconditions other than sse2 being available
let mask: u32 = unsafe { _mm_movemask_epi8(_mm_and_si128(eq_first, eq_last)) } as u32;
if mask != 0 {
return check_mask(i, mask);
}
return false;
};
let mut i = 0;
let mut result = false;
while !result && i + CHUNK + needle.len() <= haystack.len() {
result |= test_chunk(i);
i += CHUNK;
}
// process the tail that didn't fit into CHUNK-sized steps
// this simply repeats the same procedure but as right-aligned chunk instead
// of a left-aligned one. The last byte must be exactly flush with the string end so
// we don't miss a single byte or read out of bounds.
result |= test_chunk(haystack.len() + 1 - needle.len() - CHUNK);
return result;
}