Skip to content

Commit 48b7bac

Browse files
mrowqaalexcrichton
authored andcommitted
x86: implemented _mm{,256}_maskstore_epi{32,64} (#155)
* x86: implemented maskloads for avx2 * x86: added docs and tests for avx2 maskloads * x86: refactor - changed `a` to `mem_addr` in avx2 mask loads for consistency * x86: implemented _mm{,256}_maskstore_epi{32,64}
1 parent 87fec0f commit 48b7bac

File tree

1 file changed

+175
-8
lines changed

1 file changed

+175
-8
lines changed

src/x86/avx2.rs

+175-8
Original file line numberDiff line numberDiff line change
@@ -683,14 +683,85 @@ pub unsafe fn _mm256_maddubs_epi16(a: u8x32, b: u8x32) -> i16x16 {
683683
pmaddubsw(a, b)
684684
}
685685

686-
// TODO _mm_maskload_epi32 (int const* mem_addr, __m128i mask)
687-
// TODO _mm256_maskload_epi32 (int const* mem_addr, __m256i mask)
688-
// TODO _mm_maskload_epi64 (__int64 const* mem_addr, __m128i mask)
689-
// TODO _mm256_maskload_epi64 (__int64 const* mem_addr, __m256i mask)
690-
// TODO _mm_maskstore_epi32 (int* mem_addr, __m128i mask, __m128i a)
691-
// TODO _mm256_maskstore_epi32 (int* mem_addr, __m256i mask, __m256i a)
692-
// TODO _mm_maskstore_epi64 (__int64* mem_addr, __m128i mask, __m128i a)
693-
// TODO _mm256_maskstore_epi64 (__int64* mem_addr, __m256i mask, __m256i a)
686+
/// Load packed 32-bit integers from memory pointed by `mem_addr` using `mask`
687+
/// (elements are zeroed out when the highest bit is not set in the
688+
/// corresponding element).
689+
#[inline(always)]
690+
#[target_feature = "+avx2"]
691+
#[cfg_attr(test, assert_instr(vpmaskmovd))]
692+
pub unsafe fn _mm_maskload_epi32(mem_addr: *const i32, mask: i32x4) -> i32x4 {
693+
maskloadd(mem_addr as *const i8, mask)
694+
}
695+
696+
/// Load packed 32-bit integers from memory pointed by `mem_addr` using `mask`
697+
/// (elements are zeroed out when the highest bit is not set in the
698+
/// corresponding element).
699+
#[inline(always)]
700+
#[target_feature = "+avx2"]
701+
#[cfg_attr(test, assert_instr(vpmaskmovd))]
702+
pub unsafe fn _mm256_maskload_epi32(mem_addr: *const i32, mask: i32x8) -> i32x8 {
703+
maskloadd256(mem_addr as *const i8, mask)
704+
}
705+
706+
/// Load packed 64-bit integers from memory pointed by `mem_addr` using `mask`
707+
/// (elements are zeroed out when the highest bit is not set in the
708+
/// corresponding element).
709+
#[inline(always)]
710+
#[target_feature = "+avx2"]
711+
#[cfg_attr(test, assert_instr(vpmaskmovq))]
712+
pub unsafe fn _mm_maskload_epi64(mem_addr: *const i64, mask: i64x2) -> i64x2 {
713+
maskloadq(mem_addr as *const i8, mask)
714+
}
715+
716+
/// Load packed 64-bit integers from memory pointed by `mem_addr` using `mask`
717+
/// (elements are zeroed out when the highest bit is not set in the
718+
/// corresponding element).
719+
#[inline(always)]
720+
#[target_feature = "+avx2"]
721+
#[cfg_attr(test, assert_instr(vpmaskmovq))]
722+
pub unsafe fn _mm256_maskload_epi64(mem_addr: *const i64, mask: i64x4) -> i64x4 {
723+
maskloadq256(mem_addr as *const i8, mask)
724+
}
725+
726+
/// Store packed 32-bit integers from `a` into memory pointed by `mem_addr`
727+
/// using `mask` (elements are not stored when the highest bit is not set
728+
/// in the corresponding element).
729+
#[inline(always)]
730+
#[target_feature = "+avx2"]
731+
#[cfg_attr(test, assert_instr(vpmaskmovd))]
732+
pub unsafe fn _mm_maskstore_epi32(mem_addr: *mut i32, mask: i32x4, a: i32x4) {
733+
maskstored(mem_addr as *mut i8, mask, a)
734+
}
735+
736+
/// Store packed 32-bit integers from `a` into memory pointed by `mem_addr`
737+
/// using `mask` (elements are not stored when the highest bit is not set
738+
/// in the corresponding element).
739+
#[inline(always)]
740+
#[target_feature = "+avx2"]
741+
#[cfg_attr(test, assert_instr(vpmaskmovd))]
742+
pub unsafe fn _mm256_maskstore_epi32(mem_addr: *mut i32, mask: i32x8, a: i32x8) {
743+
maskstored256(mem_addr as *mut i8, mask, a)
744+
}
745+
746+
/// Store packed 64-bit integers from `a` into memory pointed by `mem_addr`
747+
/// using `mask` (elements are not stored when the highest bit is not set
748+
/// in the corresponding element).
749+
#[inline(always)]
750+
#[target_feature = "+avx2"]
751+
#[cfg_attr(test, assert_instr(vpmaskmovq))]
752+
pub unsafe fn _mm_maskstore_epi64(mem_addr: *mut i64, mask: i64x2, a: i64x2) {
753+
maskstoreq(mem_addr as *mut i8, mask, a)
754+
}
755+
756+
/// Store packed 64-bit integers from `a` into memory pointed by `mem_addr`
757+
/// using `mask` (elements are not stored when the highest bit is not set
758+
/// in the corresponding element).
759+
#[inline(always)]
760+
#[target_feature = "+avx2"]
761+
#[cfg_attr(test, assert_instr(vpmaskmovq))]
762+
pub unsafe fn _mm256_maskstore_epi64(mem_addr: *mut i64, mask: i64x4, a: i64x4) {
763+
maskstoreq256(mem_addr as *mut i8, mask, a)
764+
}
694765

695766
/// Compare packed 16-bit integers in `a` and `b`, and return the packed
696767
/// maximum values.
@@ -1852,6 +1923,22 @@ extern "C" {
18521923
fn pmaddwd(a: i16x16, b: i16x16) -> i32x8;
18531924
#[link_name = "llvm.x86.avx2.pmadd.ub.sw"]
18541925
fn pmaddubsw(a: u8x32, b: u8x32) -> i16x16;
1926+
#[link_name = "llvm.x86.avx2.maskload.d"]
1927+
fn maskloadd(mem_addr: *const i8, mask: i32x4) -> i32x4;
1928+
#[link_name = "llvm.x86.avx2.maskload.d.256"]
1929+
fn maskloadd256(mem_addr: *const i8, mask: i32x8) -> i32x8;
1930+
#[link_name = "llvm.x86.avx2.maskload.q"]
1931+
fn maskloadq(mem_addr: *const i8, mask: i64x2) -> i64x2;
1932+
#[link_name = "llvm.x86.avx2.maskload.q.256"]
1933+
fn maskloadq256(mem_addr: *const i8, mask: i64x4) -> i64x4;
1934+
#[link_name = "llvm.x86.avx2.maskstore.d"]
1935+
fn maskstored(mem_addr: *mut i8, mask: i32x4, a: i32x4);
1936+
#[link_name = "llvm.x86.avx2.maskstore.d.256"]
1937+
fn maskstored256(mem_addr: *mut i8, mask: i32x8, a: i32x8);
1938+
#[link_name = "llvm.x86.avx2.maskstore.q"]
1939+
fn maskstoreq(mem_addr: *mut i8, mask: i64x2, a: i64x2);
1940+
#[link_name = "llvm.x86.avx2.maskstore.q.256"]
1941+
fn maskstoreq256(mem_addr: *mut i8, mask: i64x4, a: i64x4);
18551942
#[link_name = "llvm.x86.avx2.pmaxs.w"]
18561943
fn pmaxsw(a: i16x16, b: i16x16) -> i16x16;
18571944
#[link_name = "llvm.x86.avx2.pmaxs.d"]
@@ -2546,6 +2633,86 @@ mod tests {
25462633
assert_eq!(r, e);
25472634
}
25482635

2636+
#[simd_test = "avx2"]
2637+
unsafe fn _mm_maskload_epi32() {
2638+
let nums = [1, 2, 3, 4];
2639+
let a = &nums as *const i32;
2640+
let mask = i32x4::new(-1, 0, 0, -1);
2641+
let r = avx2::_mm_maskload_epi32(a, mask);
2642+
let e = i32x4::new(1, 0, 0, 4);
2643+
assert_eq!(r, e);
2644+
}
2645+
2646+
#[simd_test = "avx2"]
2647+
unsafe fn _mm256_maskload_epi32() {
2648+
let nums = [1, 2, 3, 4, 5, 6, 7, 8];
2649+
let a = &nums as *const i32;
2650+
let mask = i32x8::new(-1, 0, 0, -1, 0, -1, -1, 0);
2651+
let r = avx2::_mm256_maskload_epi32(a, mask);
2652+
let e = i32x8::new(1, 0, 0, 4, 0, 6, 7, 0);
2653+
assert_eq!(r, e);
2654+
}
2655+
2656+
#[simd_test = "avx2"]
2657+
unsafe fn _mm_maskload_epi64() {
2658+
let nums = [1_i64, 2_i64];
2659+
let a = &nums as *const i64;
2660+
let mask = i64x2::new(0, -1);
2661+
let r = avx2::_mm_maskload_epi64(a, mask);
2662+
let e = i64x2::new(0, 2);
2663+
assert_eq!(r, e);
2664+
}
2665+
2666+
#[simd_test = "avx2"]
2667+
unsafe fn _mm256_maskload_epi64() {
2668+
let nums = [1_i64, 2_i64, 3_i64, 4_i64];
2669+
let a = &nums as *const i64;
2670+
let mask = i64x4::new(0, -1, -1, 0);
2671+
let r = avx2::_mm256_maskload_epi64(a, mask);
2672+
let e = i64x4::new(0, 2, 3, 0);
2673+
assert_eq!(r, e);
2674+
}
2675+
2676+
#[simd_test = "avx2"]
2677+
unsafe fn _mm_maskstore_epi32() {
2678+
let a = i32x4::new(1, 2, 3, 4);
2679+
let mut arr = [-1, -1, -1, -1];
2680+
let mask = i32x4::new(-1, 0, 0, -1);
2681+
avx2::_mm_maskstore_epi32(arr.as_mut_ptr(), mask, a);
2682+
let e = [1, -1, -1, 4];
2683+
assert_eq!(arr, e);
2684+
}
2685+
2686+
#[simd_test = "avx2"]
2687+
unsafe fn _mm256_maskstore_epi32() {
2688+
let a = i32x8::new(1, 0x6d726f, 3, 42, 0x777161, 6, 7, 8);
2689+
let mut arr = [-1, -1, -1, 0x776173, -1, 0x68657265, -1, -1];
2690+
let mask = i32x8::new(-1, 0, 0, -1, 0, -1, -1, 0);
2691+
avx2::_mm256_maskstore_epi32(arr.as_mut_ptr(), mask, a);
2692+
let e = [1, -1, -1, 42, -1, 6, 7, -1];
2693+
assert_eq!(arr, e);
2694+
}
2695+
2696+
#[simd_test = "avx2"]
2697+
unsafe fn _mm_maskstore_epi64() {
2698+
let a = i64x2::new(1_i64, 2_i64);
2699+
let mut arr = [-1_i64, -1_i64];
2700+
let mask = i64x2::new(0, -1);
2701+
avx2::_mm_maskstore_epi64(arr.as_mut_ptr(), mask, a);
2702+
let e = [-1, 2];
2703+
assert_eq!(arr, e);
2704+
}
2705+
2706+
#[simd_test = "avx2"]
2707+
unsafe fn _mm256_maskstore_epi64() {
2708+
let a = i64x4::new(1_i64, 2_i64, 3_i64, 4_i64);
2709+
let mut arr = [-1_i64, -1_i64, -1_i64, -1_i64];
2710+
let mask = i64x4::new(0, -1, -1, 0);
2711+
avx2::_mm256_maskstore_epi64(arr.as_mut_ptr(), mask, a);
2712+
let e = [-1, 2, 3, -1];
2713+
assert_eq!(arr, e);
2714+
}
2715+
25492716
#[simd_test = "avx2"]
25502717
unsafe fn _mm256_max_epi16() {
25512718
let a = i16x16::splat(2);

0 commit comments

Comments
 (0)