-
Notifications
You must be signed in to change notification settings - Fork 88
Implement fallback to smaller vector size for swizzle_dyn #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -15,7 +15,7 @@ where | |||
/// A planned compiler improvement will enable using `#[target_feature]` instead. | ||||
#[inline] | ||||
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self { | ||||
#![allow(unused_imports, unused_unsafe)] | ||||
#![allow(unused_imports, unused_unsafe, unreachable_patterns)] | ||||
#[cfg(all( | ||||
any(target_arch = "aarch64", target_arch = "arm64ec"), | ||||
target_endian = "little" | ||||
|
@@ -57,8 +57,6 @@ where | |||
target_endian = "little" | ||||
))] | ||||
16 => transize(vqtbl1q_u8, self, idxs), | ||||
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))] | ||||
32 => transize(avx2_pshufb, self, idxs), | ||||
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))] | ||||
32 => { | ||||
// Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit | ||||
|
@@ -71,6 +69,8 @@ where | |||
}; | ||||
transize(swizzler, self, idxs) | ||||
} | ||||
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))] | ||||
32 => transize(avx2_pshufb, self, idxs), | ||||
// Notable absence: avx512bw pshufb shuffle | ||||
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))] | ||||
64 => { | ||||
|
@@ -84,20 +84,147 @@ where | |||
}; | ||||
transize(swizzler, self, idxs) | ||||
} | ||||
_ => { | ||||
let mut array = [0; N]; | ||||
for (i, k) in idxs.to_array().into_iter().enumerate() { | ||||
if (k as usize) < N { | ||||
array[i] = self[k as usize]; | ||||
}; | ||||
} | ||||
array.into() | ||||
} | ||||
#[cfg(any( | ||||
all( | ||||
any( | ||||
target_arch = "aarch64", | ||||
target_arch = "arm64ec", | ||||
all(target_arch = "arm", target_feature = "v7") | ||||
), | ||||
target_feature = "neon", | ||||
target_endian = "little" | ||||
Comment on lines
+92
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. conflicts hereabouts |
||||
), | ||||
target_feature = "ssse3", | ||||
target_feature = "simd128" | ||||
))] | ||||
_ => dispatch_compat(self, idxs), | ||||
_ => swizzle_dyn_scalar(self, idxs), | ||||
Comment on lines
+100
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not like this new structure in the match. Some allowances are required in this code without making it totally illegible, but this requires allowing a lint that should not be allowed. It can be avoided by making the first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did consider that approach, however unless I'm missing something, it would have to effectively duplicate all of the previous if-cfgs in a negated form which seems to be a lot uglier than just having a silently unreachable pattern in some build configurations. I don't have an issue with changing it if you prefer it like that. |
||||
} | ||||
} | ||||
} | ||||
} | ||||
|
||||
#[inline(always)] | ||||
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N> | ||||
where | ||||
LaneCount<N>: SupportedLaneCount, | ||||
{ | ||||
let mut array = [0; N]; | ||||
for (i, k) in idxs.to_array().into_iter().enumerate() { | ||||
if (k as usize) < N { | ||||
array[i] = bytes[k as usize]; | ||||
}; | ||||
} | ||||
array.into() | ||||
} | ||||
|
||||
/// Dispatch to swizzle_dyn_compat and swizzle_dyn_zext according to N. | ||||
/// Should only be called if there exists some power-of-two size for which | ||||
/// the target architecture has a vectorized swizzle_dyn (e.g. pshufb, vqtbl). | ||||
#[inline(always)] | ||||
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N> | ||||
where | ||||
LaneCount<N>: SupportedLaneCount, | ||||
{ | ||||
#![allow( | ||||
dead_code, | ||||
unused_unsafe, | ||||
unreachable_patterns, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should not need this
Suggested change
|
||||
non_contiguous_range_endpoints | ||||
)] | ||||
|
||||
// SAFETY: only unsafe usage is transize, see comment on transize | ||||
unsafe { | ||||
match N { | ||||
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs), | ||||
// only arm actually has 8-byte swizzle_dyn | ||||
#[cfg(all( | ||||
any( | ||||
target_arch = "aarch64", | ||||
target_arch = "arm64ec", | ||||
all(target_arch = "arm", target_feature = "v7") | ||||
), | ||||
target_feature = "neon", | ||||
target_endian = "little" | ||||
))] | ||||
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs), | ||||
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs), | ||||
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs), | ||||
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs), | ||||
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs), | ||||
_ => swizzle_dyn_scalar(bytes, idxs), | ||||
} | ||||
} | ||||
} | ||||
|
||||
/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT. | ||||
#[inline(always)] | ||||
#[allow(unused)] | ||||
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>( | ||||
bytes: Simd<u8, N>, | ||||
idxs: Simd<u8, N>, | ||||
) -> Simd<u8, N> | ||||
where | ||||
LaneCount<N>: SupportedLaneCount, | ||||
LaneCount<N_EXT>: SupportedLaneCount, | ||||
{ | ||||
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!"); | ||||
assert!(N < N_EXT, "N_EXT should be larger than N"); | ||||
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0) | ||||
} | ||||
|
||||
/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes. | ||||
/// | ||||
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc). | ||||
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb | ||||
/// | ||||
/// If there is no vectorized implementation for a lower size, | ||||
/// this runs in N*logN time and will be slower than the scalar implementation. | ||||
#[inline(always)] | ||||
#[allow(unused)] | ||||
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>( | ||||
bytes: Simd<u8, N>, | ||||
idxs: Simd<u8, N>, | ||||
) -> Simd<u8, N> | ||||
where | ||||
LaneCount<N>: SupportedLaneCount, | ||||
LaneCount<HALF_N>: SupportedLaneCount, | ||||
{ | ||||
use crate::simd::cmp::SimdPartialOrd; | ||||
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N"); | ||||
assert!(N < u8::MAX as usize, "doesn't work for N >= 256"); | ||||
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two"); | ||||
|
||||
let mid = Simd::splat(HALF_N as u8); | ||||
|
||||
// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15, | ||||
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices | ||||
// since N is a power of two, any zeroing indices will remain zeroing | ||||
let idxs_trunc = idxs & !mid; | ||||
|
||||
let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]); | ||||
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]); | ||||
|
||||
let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]); | ||||
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]); | ||||
|
||||
macro_rules! half_swizzle { | ||||
($bytes:ident) => {{ | ||||
let lo = Simd::swizzle_dyn($bytes, idx_lo); | ||||
let hi = Simd::swizzle_dyn($bytes, idx_hi); | ||||
|
||||
let mut res = [0; N]; | ||||
res[..HALF_N].copy_from_slice(&lo[..]); | ||||
res[HALF_N..].copy_from_slice(&hi[..]); | ||||
Simd::from_array(res) | ||||
}}; | ||||
} | ||||
|
||||
let result_lo = half_swizzle!(bytes_lo); | ||||
let result_hi = half_swizzle!(bytes_hi); | ||||
idxs.simd_lt(mid).select(result_lo, result_hi) | ||||
} | ||||
|
||||
/// "vpshufb like it was meant to be" on AVX2 | ||||
/// | ||||
/// # Safety | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh...? 🤔