Skip to content

[libc++] Optimize std::minmax_element #135495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

wsehjk
Copy link
Contributor

@wsehjk wsehjk commented Apr 12, 2025

This pr is to close #112397.
This method is inspired by find and locate. I slice the input into fix-sized block and update the _max_element, __max_block_start and __max_block_end variable. In the end, the code iterates the __max_block_start and __max_block_end block to loacate the _max_element. So is to find the min_element pos.

However. The bencmark result is not promising as exected. This may be because I'm testing on Macos, which only supports __sse__

name                                                   old          new    speedup
BM_std_minmax_element<char>/1                      0.382 ns       1.77 ns     0.22x
BM_std_minmax_element<char>/2                      0.764 ns       1.76 ns     0.43x
BM_std_minmax_element<char>/3                       1.41 ns       2.35 ns     0.60x
BM_std_minmax_element<char>/4                       1.53 ns       2.52 ns     0.61x
BM_std_minmax_element<char>/64                      24.4 ns       7.83 ns     3.12x
BM_std_minmax_element<char>/512                      199 ns       21.3 ns     9.34x
BM_std_minmax_element<char>/1024                     394 ns       37.6 ns    10.48x
BM_std_minmax_element<char>/4000                    1529 ns        132 ns    11.58x
BM_std_minmax_element<char>/4096                    1567 ns        135 ns    11.61x
BM_std_minmax_element<char>/5500                    2103 ns        180 ns    11.68x
BM_std_minmax_element<char>/64000                  24458 ns       2033 ns    12.03x
BM_std_minmax_element<char>/65536                  25048 ns       2082 ns    12.03x
BM_std_minmax_element<char>/70000                  26744 ns       2224 ns    12.03x
BM_std_minmax_element<short>/1                     0.381 ns       1.76 ns     0.22x
BM_std_minmax_element<short>/2                     0.762 ns       1.76 ns     0.43x
BM_std_minmax_element<short>/3                      1.53 ns       2.52 ns     0.61x
BM_std_minmax_element<short>/4                      1.37 ns       2.52 ns     0.54x
BM_std_minmax_element<short>/64                     20.9 ns       7.07 ns     2.96x
BM_std_minmax_element<short>/512                     194 ns       35.3 ns     5.50x
BM_std_minmax_element<short>/1024                    393 ns       67.7 ns     5.81x
BM_std_minmax_element<short>/4000                   1546 ns        263 ns     5.88x
BM_std_minmax_element<short>/4096                   1584 ns        267 ns     5.93x
BM_std_minmax_element<short>/5500                   2170 ns        355 ns     6.11x
BM_std_minmax_element<short>/64000                 24814 ns       4106 ns     6.04x
BM_std_minmax_element<short>/65536                 25310 ns       4325 ns     5.85x
BM_std_minmax_element<short>/70000                 27017 ns       4663 ns     5.79x
BM_std_minmax_element<int>/1                       0.380 ns       1.76 ns     0.22x
BM_std_minmax_element<int>/2                       0.759 ns       1.76 ns     0.43x
BM_std_minmax_element<int>/3                        1.52 ns       2.26 ns     0.67x
BM_std_minmax_element<int>/4                        1.52 ns       3.02 ns     0.50x
BM_std_minmax_element<int>/64                       24.3 ns       10.1 ns     2.41x
BM_std_minmax_element<int>/512                       197 ns       67.2 ns     2.93x
BM_std_minmax_element<int>/1024                      391 ns        132 ns     2.96x
BM_std_minmax_element<int>/4000                     1517 ns        517 ns     2.93x
BM_std_minmax_element<int>/4096                     1556 ns        529 ns     2.94x
BM_std_minmax_element<int>/5500                     2141 ns        708 ns     3.02x
BM_std_minmax_element<int>/64000                   24771 ns       8700 ns     2.85x
BM_std_minmax_element<int>/65536                   25702 ns       9042 ns     2.84x
BM_std_minmax_element<int>/70000                   28174 ns       9715 ns     2.90x
BM_std_minmax_element<long long>/1                 0.392 ns       1.77 ns     0.22x
BM_std_minmax_element<long long>/2                 0.761 ns       2.54 ns     0.30x
BM_std_minmax_element<long long>/3                  1.35 ns       2.38 ns     0.57x
BM_std_minmax_element<long long>/4                  1.47 ns       3.18 ns     0.46x
BM_std_minmax_element<long long>/64                 20.9 ns       22.9 ns     0.91x
BM_std_minmax_element<long long>/512                 167 ns        178 ns     0.94x
BM_std_minmax_element<long long>/1024                332 ns        362 ns     0.92x
BM_std_minmax_element<long long>/4000               1307 ns       1391 ns     0.94x
BM_std_minmax_element<long long>/4096               1299 ns       1417 ns     0.92x
BM_std_minmax_element<long long>/5500               1740 ns       1907 ns     0.91x
BM_std_minmax_element<long long>/64000             20228 ns      22897 ns     0.88x
BM_std_minmax_element<long long>/65536             20709 ns      22715 ns     0.91x
BM_std_minmax_element<long long>/70000             22114 ns      24215 ns     0.91x

@wsehjk wsehjk requested a review from a team as a code owner April 12, 2025 14:47
@llvmbot llvmbot added the libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. label Apr 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2025

@llvm/pr-subscribers-libcxx

Author: Leslie (wsehjk)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/135495.diff

1 Files Affected:

  • (modified) libcxx/include/__algorithm/minmax_element.h (+49-3)
diff --git a/libcxx/include/__algorithm/minmax_element.h b/libcxx/include/__algorithm/minmax_element.h
index dc0c3a818cd57..9f6ca60267e42 100644
--- a/libcxx/include/__algorithm/minmax_element.h
+++ b/libcxx/include/__algorithm/minmax_element.h
@@ -15,6 +15,7 @@
 #include <__iterator/iterator_traits.h>
 #include <__type_traits/invoke.h>
 #include <__type_traits/is_callable.h>
+#include <__type_traits/is_integral.h>
 #include <__utility/pair.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -38,9 +39,10 @@ class _MinmaxElementLessFunc {
   }
 };
 
-template <class _Iter, class _Sent, class _Proj, class _Comp>
-_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
-__minmax_element_impl(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj) {
+template<class _Iter, class _Sent, class _Proj, class _Comp>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter> 
+__minmax_element_loop(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj) {
+  __builtin_printf("Debug: __minmax_element_impl called, %d\n", __LINE__);  // 不需要 iostream
   auto __less = _MinmaxElementLessFunc<_Comp, _Proj>(__comp, __proj);
 
   pair<_Iter, _Iter> __result(__first, __first);
@@ -78,6 +80,50 @@ __minmax_element_impl(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj)
   return __result;
 }
 
+
+// template<class _Tp>
+// typename std::iterator_traits<_Iter>::value_type
+// __minmax_element_vectorized(_Tp __first, _Tp __last) {
+
+// }
+
+
+template <class _Iter, class _Proj, class _Comp,
+          __enable_if_t<is_integral_v<typename std::iterator_traits<_Iter>::value_type>
+          && __is_identity<_Proj>::value && __desugars_to_v<__less_tag, _Comp, _Iter, _Iter>,
+          int> = 0
+          >
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
+__minmax_element_impl(_Iter __first, _Iter __last, _Comp& __comp, _Proj& __proj) {
+  if (__libcpp_is_constant_evaluated()) {
+    return __minmax_element_loop(__first, __last, __comp, __proj);
+  } else {
+
+  }
+}
+
+template <class _Iter, class _Proj, class _Comp,
+          __enable_if_t<!is_integral_v<typename std::iterator_traits<_Iter>::value_type>
+          && __can_map_to_integer_v<typename std::iterator_traits<_Iter>::value_type> 
+          && __libcpp_is_trivially_equality_comparable<typename std::iterator_traits<_Iter>::value_type, typename std::iterator_traits<_Iter>::value_type>::value
+          && __is_identity<_Proj>::value && __desugars_to_v<__less_tag, _Comp, _Iter, _Iter>,
+          int> = 0
+          >
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
+__minmax_element_impl(_Iter __first, _Iter __last, _Comp& __comp, _Proj& __proj) {
+  if (__libcpp_is_constant_evaluated()) {
+    return __minmax_element_loop(__first, __last, __comp, __proj);
+  } else {
+
+  }
+}
+
+template <class _Iter, class _Sent, class _Proj, class _Comp>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
+__minmax_element_impl(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj) {
+  return std::__minmax_element_loop(__first, __last, __comp, __proj);
+}
+
 template <class _ForwardIterator, class _Compare>
 [[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_ForwardIterator, _ForwardIterator>
 minmax_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {

@wsehjk wsehjk marked this pull request as draft April 12, 2025 14:48
Copy link

github-actions bot commented Apr 12, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions h -- libcxx/include/__algorithm/minmax_element.h
View the diff from clang-format here.
diff --git a/libcxx/include/__algorithm/minmax_element.h b/libcxx/include/__algorithm/minmax_element.h
index 67287b413..8894e0003 100644
--- a/libcxx/include/__algorithm/minmax_element.h
+++ b/libcxx/include/__algorithm/minmax_element.h
@@ -41,8 +41,8 @@ public:
   }
 };
 
-template<class _Iter, class _Sent, class _Proj, class _Comp>
-_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter> 
+template <class _Iter, class _Sent, class _Proj, class _Comp>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
 __minmax_element_loop(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj) {
   auto __less = _MinmaxElementLessFunc<_Comp, _Proj>(__comp, __proj);
 
@@ -82,8 +82,8 @@ __minmax_element_loop(_Iter __first, _Sent __last, _Comp& __comp, _Proj& __proj)
 }
 
 #if _LIBCPP_VECTORIZE_ALGORITHMS
-template<class _Iter>
-_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter> 
+template <class _Iter>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
 __minmax_element_vectorized(_Iter __first, _Iter __last) {
   using __value_type              = __iter_value_type<_Iter>;
   constexpr size_t __unroll_count = 4;
@@ -100,63 +100,63 @@ __minmax_element_vectorized(_Iter __first, _Iter __last) {
   __value_type __max_element = *__first;
 
   _Iter __min_block_start = __first;
-  _Iter __min_block_end = __last + 1;
+  _Iter __min_block_end   = __last + 1;
   _Iter __max_block_start = __first;
-  _Iter __max_block_end = __last + 1;
-  
-  while(static_cast<size_t>(__last - __first) >= __unroll_count * __vec_size) [[__likely__]]{
+  _Iter __max_block_end   = __last + 1;
+
+  while (static_cast<size_t>(__last - __first) >= __unroll_count * __vec_size) [[__likely__]] {
     __vec_type __vec[__unroll_count];
-    for(size_t __i = 0; __i < __unroll_count; ++__i) {
+    for (size_t __i = 0; __i < __unroll_count; ++__i) {
       __vec[__i] = std::__load_vector<__vec_type>(__first + __i * __vec_size);
       // block min
       auto __block_min_element = __builtin_reduce_min(__vec[__i]);
       if (__block_min_element < __min_element) {
-        __min_element = __block_min_element;
+        __min_element     = __block_min_element;
         __min_block_start = __first + __i * __vec_size;
-        __min_block_end = __first + (__i + 1) * __vec_size;
+        __min_block_end   = __first + (__i + 1) * __vec_size;
       }
       // block max
       auto __block_max_element = __builtin_reduce_max(__vec[__i]);
       if (__block_max_element >= __max_element) {
-        __max_element = __block_max_element;
+        __max_element     = __block_max_element;
         __max_block_start = __first + __i * __vec_size;
-        __max_block_end = __first + (__i + 1) * __vec_size;
+        __max_block_end   = __first + (__i + 1) * __vec_size;
       }
     }
     __first += __unroll_count * __vec_size;
   }
 
-  // remaining vectors 
-  while(static_cast<size_t>(__last - __first) >=  __vec_size) {
-      __vec_type __vec = std::__load_vector<__vec_type>(__first);
-      auto __block_min_element = __builtin_reduce_min(__vec);
-      if (__block_min_element < __min_element) {
-        __min_element = __block_min_element;
-        __min_block_start = __first;
-        __min_block_end = __first + __vec_size;
-      }
-      // max
-      auto __block_max_element = __builtin_reduce_max(__vec);
-      if (__block_max_element >= __max_element) {
-        __max_element = __block_max_element;
-        __max_block_start = __first;
-        __max_block_end = __first + __vec_size;
-      }
-      __first += __vec_size;
+  // remaining vectors
+  while (static_cast<size_t>(__last - __first) >= __vec_size) {
+    __vec_type __vec         = std::__load_vector<__vec_type>(__first);
+    auto __block_min_element = __builtin_reduce_min(__vec);
+    if (__block_min_element < __min_element) {
+      __min_element     = __block_min_element;
+      __min_block_start = __first;
+      __min_block_end   = __first + __vec_size;
+    }
+    // max
+    auto __block_max_element = __builtin_reduce_max(__vec);
+    if (__block_max_element >= __max_element) {
+      __max_element     = __block_max_element;
+      __max_block_start = __first;
+      __max_block_end   = __first + __vec_size;
+    }
+    __first += __vec_size;
   }
 
   if (__last > __first) {
-    auto __epilogue = std::__minmax_element_loop(__first, __last, __comp, __proj);
+    auto __epilogue                     = std::__minmax_element_loop(__first, __last, __comp, __proj);
     __value_type __epilogue_min_element = *__epilogue.first;
     __value_type __epilogue_max_element = *__epilogue.second;
     if (__epilogue_min_element < __min_element && __epilogue_max_element >= __max_element) {
       return __epilogue;
     } else if (__epilogue_min_element < __min_element) {
-      __min_element = __epilogue_min_element;
+      __min_element     = __epilogue_min_element;
       __min_block_start = __epilogue.first;
       __min_block_end   = __epilogue.first; // this is global min_element
     } else if (__epilogue_max_element >= __max_element) {
-      __max_element = __epilogue_max_element;
+      __max_element     = __epilogue_max_element;
       __max_block_start = __epilogue.second;
       __max_block_end   = __epilogue.second; // this is global max_element
     }
@@ -179,14 +179,13 @@ __minmax_element_vectorized(_Iter __first, _Iter __last) {
   return {__min_block_start, __max_block_start};
 }
 
-template <class _Iter, class _Proj, class _Comp,
-          __enable_if_t
-          <is_integral_v<__iter_value_type<_Iter>>
-          && is_same_v<__iterator_category_type<_Iter>, random_access_iterator_tag>
-          && __is_identity<_Proj>::value
-          && __desugars_to_v<__less_tag, _Comp, _Iter, _Iter>,
-          int> = 0
-          >
+template <class _Iter,
+          class _Proj,
+          class _Comp,
+          __enable_if_t<is_integral_v<__iter_value_type<_Iter>> &&
+                            is_same_v<__iterator_category_type<_Iter>, random_access_iterator_tag> &&
+                            __is_identity<_Proj>::value && __desugars_to_v<__less_tag, _Comp, _Iter, _Iter>,
+                        int> = 0 >
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
 __minmax_element_impl(_Iter __first, _Iter __last, _Comp& __comp, _Proj& __proj) {
   if (__libcpp_is_constant_evaluated()) {
@@ -200,9 +199,9 @@ __minmax_element_impl(_Iter __first, _Iter __last, _Comp& __comp, _Proj& __proj)
 //           __enable_if_t
 //           <!is_integral_v<__iter_value_type<_Iter>>
 //           && is_same_v<__iterator_category_type<_Iter>, random_access_iterator_tag>
-//           && __can_map_to_integer_v<__iter_value_type<_Iter>> 
+//           && __can_map_to_integer_v<__iter_value_type<_Iter>>
 //           && __libcpp_is_trivially_equality_comparable<__iter_value_type<_Iter>, __iter_value_type<_Iter>>::value
-//           && __is_identity<_Proj>::value 
+//           && __is_identity<_Proj>::value
 //           && __desugars_to_v<__less_tag, _Comp, _Iter, _Iter>,
 //           int> = 0
 //         >

@wsehjk wsehjk marked this pull request as ready for review April 13, 2025 17:08
@wsehjk
Copy link
Contributor Author

wsehjk commented Apr 19, 2025

Hi @hiraditya @philnik777, could you please review my code? Thanks

@hiraditya
Copy link
Collaborator

hiraditya commented Apr 20, 2025

seems like you are getting speedup on large sizes (>64 etc). Maybe use the default algorithm for small sizes and switch to new implementations otherwise?

@wsehjk
Copy link
Contributor Author

wsehjk commented Apr 23, 2025

seems like you are getting speedup on large sizes (>64 etc). Maybe use the default algorithm for small sizes and switch to new implementations otherwise?

No, the speedups for char, short, int and long loog are about 10x, 6x, 3x and 0.91x respectively. The vector size for each integer is not fixed, but determined by platform. Plz check simd_utils.h. I'm tesing on MaxOs, which only supports sse. For x86 platform, the speedup could be better. But I haven't got time to test it.

@philnik777
Copy link
Contributor

I think the main problem is that you're currently reducing in every single iteration. If we search for the minimum and maximum element by line instead, I think the performance would be significantly better, since we'd be able to reduce only once in the end instead.

@wsehjk
Copy link
Contributor Author

wsehjk commented Apr 27, 2025

I think the main problem is that you're currently reducing in every single iteration. If we search for the minimum and maximum element by line instead, I think the performance would be significantly better, since we'd be able to reduce only once in the end instead.

Hi, I don't quite get it. I'm reducing in every block to get the block_min_element and its block position, but how could you search the minimum by line and reduce only once? Can you elaborate a little bit?

@ldionne ldionne changed the title optimize minmax_element [libc++] Optimize std::minmax_element May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorize minmax_element.
5 participants