Skip to content

Implement fallback to smaller vector size for swizzle_dyn #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 139 additions & 12 deletions crates/core_simd/src/swizzle_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ where
/// A planned compiler improvement will enable using `#[target_feature]` instead.
#[inline]
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
#![allow(unused_imports, unused_unsafe)]
#![allow(unused_imports, unused_unsafe, unreachable_patterns)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh...? 🤔

#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm64ec"),
target_endian = "little"
Expand Down Expand Up @@ -57,8 +57,6 @@ where
target_endian = "little"
))]
16 => transize(vqtbl1q_u8, self, idxs),
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
32 => transize(avx2_pshufb, self, idxs),
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
32 => {
// Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
Expand All @@ -71,6 +69,8 @@ where
};
transize(swizzler, self, idxs)
}
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
32 => transize(avx2_pshufb, self, idxs),
// Notable absence: avx512bw pshufb shuffle
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
64 => {
Expand All @@ -84,20 +84,147 @@ where
};
transize(swizzler, self, idxs)
}
_ => {
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = self[k as usize];
};
}
array.into()
}
#[cfg(any(
all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
Comment on lines +92 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflicts hereabouts

),
target_feature = "ssse3",
target_feature = "simd128"
))]
_ => dispatch_compat(self, idxs),
_ => swizzle_dyn_scalar(self, idxs),
Comment on lines +100 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like this new structure in the match. Some allowances are required in this code without making it totally illegible, but this requires allowing a lint that should not be allowed. It can be avoided by making the first _ guarded by an if cfg!(..) so that swizzle_dyn_scalar only catches the true base case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider that approach, however unless I'm missing something, it would have to effectively duplicate all of the previous if-cfgs in a negated form which seems to be a lot uglier than just having a silently unreachable pattern in some build configurations. I don't have an issue with changing it if you prefer it like that.

}
}
}
}

#[inline(always)]
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = bytes[k as usize];
};
}
array.into()
}

/// Dispatch to swizzle_dyn_compat and swizzle_dyn_zext according to N.
/// Should only be called if there exists some power-of-two size for which
/// the target architecture has a vectorized swizzle_dyn (e.g. pshufb, vqtbl).
#[inline(always)]
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
#![allow(
dead_code,
unused_unsafe,
unreachable_patterns,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not need this

Suggested change
unreachable_patterns,

non_contiguous_range_endpoints
)]

// SAFETY: only unsafe usage is transize, see comment on transize
unsafe {
match N {
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs),
// only arm actually has 8-byte swizzle_dyn
#[cfg(all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
))]
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs),
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs),
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs),
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs),
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs),
_ => swizzle_dyn_scalar(bytes, idxs),
}
}
}

/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<N_EXT>: SupportedLaneCount,
{
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!");
assert!(N < N_EXT, "N_EXT should be larger than N");
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0)
}

/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes.
///
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc).
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb
///
/// If there is no vectorized implementation for a lower size,
/// this runs in N*logN time and will be slower than the scalar implementation.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<HALF_N>: SupportedLaneCount,
{
use crate::simd::cmp::SimdPartialOrd;
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N");
assert!(N < u8::MAX as usize, "doesn't work for N >= 256");
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two");

let mid = Simd::splat(HALF_N as u8);

// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15,
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices
// since N is a power of two, any zeroing indices will remain zeroing
let idxs_trunc = idxs & !mid;

let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]);
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]);

let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]);
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]);

macro_rules! half_swizzle {
($bytes:ident) => {{
let lo = Simd::swizzle_dyn($bytes, idx_lo);
let hi = Simd::swizzle_dyn($bytes, idx_hi);

let mut res = [0; N];
res[..HALF_N].copy_from_slice(&lo[..]);
res[HALF_N..].copy_from_slice(&hi[..]);
Simd::from_array(res)
}};
}

let result_lo = half_swizzle!(bytes_lo);
let result_hi = half_swizzle!(bytes_hi);
idxs.simd_lt(mid).select(result_lo, result_hi)
}

/// "vpshufb like it was meant to be" on AVX2
///
/// # Safety
Expand Down
Loading