Skip to content

x86: implemented _mm{,256}_maskload_epi{32,64} #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions src/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,46 @@ pub unsafe fn _mm256_maddubs_epi16(a: u8x32, b: u8x32) -> i16x16 {
pmaddubsw(a, b)
}

// TODO _mm_maskload_epi32 (int const* mem_addr, __m128i mask)
// TODO _mm256_maskload_epi32 (int const* mem_addr, __m256i mask)
// TODO _mm_maskload_epi64 (__int64 const* mem_addr, __m128i mask)
// TODO _mm256_maskload_epi64 (__int64 const* mem_addr, __m256i mask)
/// Load packed 32-bit integers from memory pointed by `mem_addr` using `mask`
/// (elements are zeroed out when the highest bit is not set in the
/// corresponding element).
#[inline(always)]
#[target_feature = "+avx2"]
#[cfg_attr(test, assert_instr(vpmaskmovd))]
pub unsafe fn _mm_maskload_epi32(mem_addr: *const i32, mask: i32x4) -> i32x4 {
maskloadd(mem_addr as *const i8, mask)
}

/// Load packed 32-bit integers from memory pointed by `mem_addr` using `mask`
/// (elements are zeroed out when the highest bit is not set in the
/// corresponding element).
#[inline(always)]
#[target_feature = "+avx2"]
#[cfg_attr(test, assert_instr(vpmaskmovd))]
pub unsafe fn _mm256_maskload_epi32(mem_addr: *const i32, mask: i32x8) -> i32x8 {
maskloadd256(mem_addr as *const i8, mask)
}

/// Load packed 64-bit integers from memory pointed by `mem_addr` using `mask`
/// (elements are zeroed out when the highest bit is not set in the
/// corresponding element).
#[inline(always)]
#[target_feature = "+avx2"]
#[cfg_attr(test, assert_instr(vpmaskmovq))]
pub unsafe fn _mm_maskload_epi64(mem_addr: *const i64, mask: i64x2) -> i64x2 {
maskloadq(mem_addr as *const i8, mask)
}

/// Load packed 64-bit integers from memory pointed by `mem_addr` using `mask`
/// (elements are zeroed out when the highest bit is not set in the
/// corresponding element).
#[inline(always)]
#[target_feature = "+avx2"]
#[cfg_attr(test, assert_instr(vpmaskmovq))]
pub unsafe fn _mm256_maskload_epi64(mem_addr: *const i64, mask: i64x4) -> i64x4 {
maskloadq256(mem_addr as *const i8, mask)
}

// TODO _mm_maskstore_epi32 (int* mem_addr, __m128i mask, __m128i a)
// TODO _mm256_maskstore_epi32 (int* mem_addr, __m256i mask, __m256i a)
// TODO _mm_maskstore_epi64 (__int64* mem_addr, __m128i mask, __m128i a)
Expand Down Expand Up @@ -1761,6 +1797,14 @@ extern "C" {
fn pmaddwd(a: i16x16, b: i16x16) -> i32x8;
#[link_name = "llvm.x86.avx2.pmadd.ub.sw"]
fn pmaddubsw(a: u8x32, b: u8x32) -> i16x16;
#[link_name = "llvm.x86.avx2.maskload.d"]
fn maskloadd(mem_addr: *const i8, mask: i32x4) -> i32x4;
#[link_name = "llvm.x86.avx2.maskload.d.256"]
fn maskloadd256(mem_addr: *const i8, mask: i32x8) -> i32x8;
#[link_name = "llvm.x86.avx2.maskload.q"]
fn maskloadq(mem_addr: *const i8, mask: i64x2) -> i64x2;
#[link_name = "llvm.x86.avx2.maskload.q.256"]
fn maskloadq256(mem_addr: *const i8, mask: i64x4) -> i64x4;
#[link_name = "llvm.x86.avx2.pmaxs.w"]
fn pmaxsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.pmaxs.d"]
Expand Down Expand Up @@ -2455,6 +2499,46 @@ mod tests {
assert_eq!(r, e);
}

#[simd_test = "avx2"]
unsafe fn _mm_maskload_epi32() {
let nums = [1, 2, 3, 4];
let a = &nums as *const i32;
let mask = i32x4::new(-1, 0, 0, -1);
let r = avx2::_mm_maskload_epi32(a, mask);
let e = i32x4::new(1, 0, 0, 4);
assert_eq!(r, e);
}

#[simd_test = "avx2"]
unsafe fn _mm256_maskload_epi32() {
let nums = [1, 2, 3, 4, 5, 6, 7, 8];
let a = &nums as *const i32;
let mask = i32x8::new(-1, 0, 0, -1, 0, -1, -1, 0);
let r = avx2::_mm256_maskload_epi32(a, mask);
let e = i32x8::new(1, 0, 0, 4, 0, 6, 7, 0);
assert_eq!(r, e);
}

#[simd_test = "avx2"]
unsafe fn _mm_maskload_epi64() {
let nums = [1_i64, 2_i64];
let a = &nums as *const i64;
let mask = i64x2::new(0, -1);
let r = avx2::_mm_maskload_epi64(a, mask);
let e = i64x2::new(0, 2);
assert_eq!(r, e);
}

#[simd_test = "avx2"]
unsafe fn _mm256_maskload_epi64() {
let nums = [1_i64, 2_i64, 3_i64, 4_i64];
let a = &nums as *const i64;
let mask = i64x4::new(0, -1, -1, 0);
let r = avx2::_mm256_maskload_epi64(a, mask);
let e = i64x4::new(0, 2, 3, 0);
assert_eq!(r, e);
}

#[simd_test = "avx2"]
unsafe fn _mm256_max_epi16() {
let a = i16x16::splat(2);
Expand Down