@@ -445,8 +445,9 @@ namespace __reduce_by_key {
445
445
{
446
446
if (segment_flags[ITEM])
447
447
{
448
- storage.raw_exchange [segment_indices[ITEM] -
449
- num_tile_segments_prefix] = scatter_items[ITEM];
448
+ int idx = static_cast <int >(segment_indices[ITEM] -
449
+ num_tile_segments_prefix);
450
+ storage.raw_exchange [idx] = scatter_items[ITEM];
450
451
}
451
452
}
452
453
@@ -786,7 +787,7 @@ namespace __reduce_by_key {
786
787
// so just assign one tile per block
787
788
//
788
789
int tile_idx = blockIdx.x ;
789
- Size tile_offset = tile_idx * ITEMS_PER_TILE;
790
+ Size tile_offset = static_cast < Size >( tile_idx) * ITEMS_PER_TILE;
790
791
Size num_remaining = num_items - tile_offset;
791
792
792
793
if (num_remaining > ITEMS_PER_TILE)
@@ -962,7 +963,8 @@ namespace __reduce_by_key {
962
963
return status;
963
964
}
964
965
965
- template <typename Derived,
966
+ template <typename Size ,
967
+ typename Derived,
966
968
typename KeysInputIt,
967
969
typename ValuesInputIt,
968
970
typename KeysOutputIt,
@@ -971,24 +973,23 @@ namespace __reduce_by_key {
971
973
typename ReductionOp>
972
974
THRUST_RUNTIME_FUNCTION
973
975
pair<KeysOutputIt, ValuesOutputIt>
974
- reduce_by_key (execution_policy<Derived>& policy,
975
- KeysInputIt keys_first,
976
- KeysInputIt keys_last ,
977
- ValuesInputIt values_first,
978
- KeysOutputIt keys_output,
979
- ValuesOutputIt values_output,
980
- EqualityOp equality_op,
981
- ReductionOp reduction_op)
976
+ reduce_by_key_dispatch (execution_policy<Derived>& policy,
977
+ KeysInputIt keys_first,
978
+ Size num_items ,
979
+ ValuesInputIt values_first,
980
+ KeysOutputIt keys_output,
981
+ ValuesOutputIt values_output,
982
+ EqualityOp equality_op,
983
+ ReductionOp reduction_op)
982
984
{
983
- typedef int size_type;
984
-
985
- size_type num_items = static_cast <size_type>(thrust::distance (keys_first, keys_last));
986
985
size_t temp_storage_bytes = 0 ;
987
986
cudaStream_t stream = cuda_cub::stream (policy);
988
987
bool debug_sync = THRUST_DEBUG_SYNC_FLAG;
989
988
990
989
if (num_items == 0 )
990
+ {
991
991
return thrust::make_pair (keys_output, values_output);
992
+ }
992
993
993
994
cudaError_t status;
994
995
status = doit_step (NULL ,
@@ -997,15 +998,15 @@ namespace __reduce_by_key {
997
998
values_first,
998
999
keys_output,
999
1000
values_output,
1000
- reinterpret_cast <size_type *>(NULL ),
1001
+ reinterpret_cast <Size *>(NULL ),
1001
1002
equality_op,
1002
1003
reduction_op,
1003
1004
num_items,
1004
1005
stream,
1005
1006
debug_sync);
1006
1007
cuda_cub::throw_on_error (status, " reduce_by_key failed on 1st step" );
1007
1008
1008
- size_t allocation_sizes[2 ] = {sizeof (size_type ), temp_storage_bytes};
1009
+ size_t allocation_sizes[2 ] = {sizeof (Size ), temp_storage_bytes};
1009
1010
void * allocations[2 ] = {NULL , NULL };
1010
1011
1011
1012
size_t storage_size = 0 ;
@@ -1026,8 +1027,8 @@ namespace __reduce_by_key {
1026
1027
allocation_sizes);
1027
1028
cuda_cub::throw_on_error (status, " reduce failed on 2nd alias_storage" );
1028
1029
1029
- size_type * d_num_runs_out
1030
- = thrust::detail::aligned_reinterpret_cast<size_type *>(allocations[0 ]);
1030
+ Size * d_num_runs_out
1031
+ = thrust::detail::aligned_reinterpret_cast<Size *>(allocations[0 ]);
1031
1032
1032
1033
status = doit_step (allocations[1 ],
1033
1034
temp_storage_bytes,
@@ -1054,6 +1055,49 @@ namespace __reduce_by_key {
1054
1055
);
1055
1056
}
1056
1057
1058
+ template <typename Derived,
1059
+ typename KeysInputIt,
1060
+ typename ValuesInputIt,
1061
+ typename KeysOutputIt,
1062
+ typename ValuesOutputIt,
1063
+ typename EqualityOp,
1064
+ typename ReductionOp>
1065
+ THRUST_RUNTIME_FUNCTION
1066
+ pair<KeysOutputIt, ValuesOutputIt>
1067
+ reduce_by_key (execution_policy<Derived>& policy,
1068
+ KeysInputIt keys_first,
1069
+ KeysInputIt keys_last,
1070
+ ValuesInputIt values_first,
1071
+ KeysOutputIt keys_output,
1072
+ ValuesOutputIt values_output,
1073
+ EqualityOp equality_op,
1074
+ ReductionOp reduction_op)
1075
+ {
1076
+ using size_type = typename iterator_traits<KeysInputIt>::difference_type;
1077
+
1078
+ size_type num_items = thrust::distance (keys_first, keys_last);
1079
+
1080
+ if (num_items == 0 )
1081
+ {
1082
+ return thrust::make_pair (keys_output, values_output);
1083
+ }
1084
+
1085
+ pair<KeysOutputIt, ValuesOutputIt> result{};
1086
+ THRUST_INDEX_TYPE_DISPATCH (result,
1087
+ reduce_by_key_dispatch,
1088
+ num_items,
1089
+ (policy,
1090
+ keys_first,
1091
+ num_items_fixed,
1092
+ values_first,
1093
+ keys_output,
1094
+ values_output,
1095
+ equality_op,
1096
+ reduction_op));
1097
+
1098
+ return result;
1099
+ }
1100
+
1057
1101
} // namespace __reduce_by_key
1058
1102
1059
1103
// -------------------------
0 commit comments