Skip to content

Commit b675d2c

Browse files
Replace detail::merge::dispatch by CUB's public API
This includes using the new tuning API internally where possible. Fixes: #7955
1 parent 2f55b2e commit b675d2c

File tree

10 files changed

+214
-252
lines changed

10 files changed

+214
-252
lines changed

cub/benchmarks/bench/merge/keys.cu

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,10 @@
1818
// %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1
1919
// %RANGE% TUNE_THREADS_PER_BLOCK_POW2 tpb 6:10:1
2020

21-
#if !TUNE_BASE
22-
struct bench_policy_selector
23-
{
24-
_CCCL_API constexpr auto operator()(::cuda::arch_id /*arch*/) const -> cub::detail::merge::merge_policy
25-
{
26-
return cub::detail::merge::merge_policy{
27-
TUNE_THREADS_PER_BLOCK,
28-
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
29-
TUNE_LOAD_MODIFIER,
30-
TUNE_STORE_ALGORITHM,
31-
TUNE_USE_BL2SH};
32-
}
33-
};
34-
#endif // !TUNE_BASE
35-
36-
template <typename KeyT, typename OffsetT>
37-
void keys(nvbench::state& state, nvbench::type_list<KeyT, OffsetT>)
21+
template <typename KeyT>
22+
void keys(nvbench::state& state, nvbench::type_list<KeyT>)
3823
{
24+
using offset_t = int64_t;
3925
using compare_op_t = less_t;
4026

4127
// Retrieve axis parameters
@@ -46,7 +32,7 @@ void keys(nvbench::state& state, nvbench::type_list<KeyT, OffsetT>)
4632
const auto num_items_rhs = elements - num_items_lhs;
4733
auto [keys_lhs, keys_rhs] = generate_lhs_rhs<KeyT>(num_items_lhs, num_items_rhs, entropy);
4834

49-
thrust::device_vector<KeyT> keys_out(elements);
35+
thrust::device_vector<KeyT> keys_out(elements, thrust::no_init);
5036
KeyT* d_keys_lhs = thrust::raw_pointer_cast(keys_lhs.data());
5137
KeyT* d_keys_rhs = thrust::raw_pointer_cast(keys_rhs.data());
5238
KeyT* d_keys_out = thrust::raw_pointer_cast(keys_out.data());
@@ -56,51 +42,26 @@ void keys(nvbench::state& state, nvbench::type_list<KeyT, OffsetT>)
5642
state.add_global_memory_reads<KeyT>(elements);
5743
state.add_global_memory_writes<KeyT>(elements);
5844

59-
auto value_nullptr = static_cast<cub::NullType*>(nullptr);
60-
61-
// Allocate temporary storage:
62-
std::size_t temp_size{};
63-
cub::detail::merge::dispatch(
64-
nullptr,
65-
temp_size,
66-
d_keys_lhs,
67-
value_nullptr,
68-
static_cast<OffsetT>(num_items_lhs),
69-
d_keys_rhs,
70-
value_nullptr,
71-
static_cast<OffsetT>(num_items_rhs),
72-
d_keys_out,
73-
value_nullptr,
74-
compare_op_t{},
75-
cudaStream_t{}
45+
caching_allocator_t alloc;
46+
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
47+
auto env = cub_bench_env(
48+
alloc,
49+
launch
7650
#if !TUNE_BASE
77-
,
78-
bench_policy_selector{}
51+
,
52+
cuda::execution::__tune(policy_selector<key_t, value_t, offset_t>{})
7953
#endif // !TUNE_BASE
80-
);
81-
82-
thrust::device_vector<nvbench::uint8_t> temp(temp_size);
83-
auto* temp_storage = thrust::raw_pointer_cast(temp.data());
84-
85-
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
86-
cub::detail::merge::dispatch(
87-
temp_storage,
88-
temp_size,
54+
);
55+
_CCCL_TRY_CUDA_API(
56+
cub::DeviceMerge::MergeKeys,
57+
"MergePairs failed",
8958
d_keys_lhs,
90-
value_nullptr,
91-
static_cast<OffsetT>(num_items_lhs),
59+
static_cast<offset_t>(num_items_lhs),
9260
d_keys_rhs,
93-
value_nullptr,
94-
static_cast<OffsetT>(num_items_rhs),
61+
static_cast<offset_t>(num_items_rhs),
9562
d_keys_out,
96-
value_nullptr,
9763
compare_op_t{},
98-
launch.get_stream()
99-
#if !TUNE_BASE
100-
,
101-
bench_policy_selector{}
102-
#endif // !TUNE_BASE
103-
);
64+
env);
10465
});
10566
}
10667

@@ -110,8 +71,8 @@ using key_types = nvbench::type_list<TUNE_KeyT>;
11071
using key_types = fundamental_types;
11172
#endif // TUNE_KeyT
11273

113-
NVBENCH_BENCH_TYPES(keys, NVBENCH_TYPE_AXES(key_types, offset_types))
74+
NVBENCH_BENCH_TYPES(keys, NVBENCH_TYPE_AXES(key_types))
11475
.set_name("base")
115-
.set_type_axes_names({"KeyT{ct}", "OffsetT{ct}"})
76+
.set_type_axes_names({"KeyT{ct}"})
11677
.add_int64_power_of_two_axis("Elements{io}", nvbench::range(16, 28, 4))
11778
.add_string_axis("Entropy", {"1.000", "0.201"});

cub/benchmarks/bench/merge/merge_common.cuh

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@
1313
#include <nvbench_helper.cuh>
1414

1515
#if !TUNE_BASE
16-
# define TUNE_THREADS_PER_BLOCK (1 << TUNE_THREADS_PER_BLOCK_POW2)
17-
# if TUNE_TRANSPOSE == 0
18-
# define TUNE_STORE_ALGORITHM cub::BLOCK_STORE_DIRECT
19-
# else // TUNE_TRANSPOSE == 1
20-
# define TUNE_STORE_ALGORITHM cub::BLOCK_STORE_WARP_TRANSPOSE
21-
# endif // TUNE_TRANSPOSE
22-
2316
# if TUNE_LOAD == 0
2417
# define TUNE_LOAD_MODIFIER cub::LOAD_DEFAULT
2518
# define TUNE_USE_BL2SH false
@@ -35,19 +28,17 @@
3528
# endif // TUNE_LOAD
3629

3730
template <typename KeyT>
38-
struct policy_hub_t
31+
struct bench_policy_selector
3932
{
40-
struct policy_t : cub::ChainedPolicy<500, policy_t, policy_t>
33+
_CCCL_API constexpr auto operator()(::cuda::arch_id /*arch*/) const -> cub::detail::merge::merge_policy
4134
{
42-
using merge_policy =
43-
cub::agent_policy_t<TUNE_THREADS_PER_BLOCK,
44-
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
45-
TUNE_LOAD_MODIFIER,
46-
TUNE_STORE_ALGORITHM,
47-
TUNE_USE_BL2SH>;
48-
};
49-
50-
using MaxPolicy = policy_t;
35+
return cub::detail::merge::merge_policy{
36+
(1 << TUNE_THREADS_PER_BLOCK_POW2),
37+
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
38+
TUNE_LOAD_MODIFIER,
39+
TUNE_TRANSPOSE == 0 ? cub::BLOCK_STORE_DIRECT : cub::BLOCK_STORE_WARP_TRANSPOSE,
40+
TUNE_USE_BL2SH};
41+
}
5142
};
5243
#endif // TUNE_BASE
5344

cub/benchmarks/bench/merge/pairs.cu

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,10 @@
1818
// %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1
1919
// %RANGE% TUNE_THREADS_PER_BLOCK_POW2 tpb 6:10:1
2020

21-
#if !TUNE_BASE
22-
struct bench_policy_selector
23-
{
24-
_CCCL_API constexpr auto operator()(::cuda::arch_id /*arch*/) const -> cub::detail::merge::merge_policy
25-
{
26-
return cub::detail::merge::merge_policy{
27-
TUNE_THREADS_PER_BLOCK,
28-
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
29-
TUNE_LOAD_MODIFIER,
30-
TUNE_STORE_ALGORITHM,
31-
TUNE_USE_BL2SH};
32-
}
33-
};
34-
#endif // !TUNE_BASE
35-
36-
template <typename KeyT, typename ValueT, typename OffsetT>
37-
void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT, OffsetT>)
21+
template <typename KeyT, typename ValueT>
22+
void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT>)
3823
{
24+
using offset_t = int64_t;
3925
using compare_op_t = less_t;
4026

4127
// Retrieve axis parameters
@@ -45,10 +31,10 @@ void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT, OffsetT>)
4531
const auto num_items_lhs = elements / 2;
4632
const auto num_items_rhs = elements - num_items_lhs;
4733

48-
thrust::device_vector<KeyT> keys_out(elements);
49-
thrust::device_vector<ValueT> values_lhs(num_items_lhs);
50-
thrust::device_vector<ValueT> values_rhs(num_items_rhs);
51-
thrust::device_vector<ValueT> values_out(elements);
34+
thrust::device_vector<KeyT> keys_out(elements, thrust::no_init);
35+
thrust::device_vector<ValueT> values_lhs(num_items_lhs, thrust::no_init);
36+
thrust::device_vector<ValueT> values_rhs(num_items_rhs, thrust::no_init);
37+
thrust::device_vector<ValueT> values_out(elements, thrust::no_init);
5238

5339
auto [keys_lhs, keys_rhs] = generate_lhs_rhs<KeyT>(num_items_lhs, num_items_rhs, entropy);
5440

@@ -66,49 +52,29 @@ void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT, OffsetT>)
6652
state.add_global_memory_writes<KeyT>(elements);
6753
state.add_global_memory_writes<ValueT>(elements);
6854

69-
// Allocate temporary storage:
70-
std::size_t temp_size{};
71-
cub::detail::merge::dispatch(
72-
nullptr,
73-
temp_size,
74-
d_keys_lhs,
75-
d_values_lhs,
76-
static_cast<OffsetT>(num_items_lhs),
77-
d_keys_rhs,
78-
d_values_rhs,
79-
static_cast<OffsetT>(num_items_rhs),
80-
d_keys_out,
81-
d_values_out,
82-
compare_op_t{},
83-
cudaStream_t{}
55+
caching_allocator_t alloc;
56+
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
57+
auto env = cub_bench_env(
58+
alloc,
59+
launch
8460
#if !TUNE_BASE
85-
,
86-
bench_policy_selector{}
61+
,
62+
cuda::execution::__tune(policy_selector<key_t, value_t, offset_t>{})
8763
#endif // !TUNE_BASE
88-
);
89-
90-
thrust::device_vector<nvbench::uint8_t> temp(temp_size);
91-
auto* temp_storage = thrust::raw_pointer_cast(temp.data());
92-
93-
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
94-
cub::detail::merge::dispatch(
95-
temp_storage,
96-
temp_size,
64+
);
65+
_CCCL_TRY_CUDA_API(
66+
cub::DeviceMerge::MergePairs,
67+
"MergePairs failed",
9768
d_keys_lhs,
9869
d_values_lhs,
99-
static_cast<OffsetT>(num_items_lhs),
70+
static_cast<offset_t>(num_items_lhs),
10071
d_keys_rhs,
10172
d_values_rhs,
102-
static_cast<OffsetT>(num_items_rhs),
73+
static_cast<offset_t>(num_items_rhs),
10374
d_keys_out,
10475
d_values_out,
10576
compare_op_t{},
106-
launch.get_stream()
107-
#if !TUNE_BASE
108-
,
109-
bench_policy_selector{}
110-
#endif // !TUNE_BASE
111-
);
77+
env);
11278
});
11379
}
11480

@@ -130,8 +96,8 @@ using value_types = nvbench::type_list<int8_t, int16_t, int32_t, int64_t
13096
>;
13197
#endif // TUNE_ValueT
13298

133-
NVBENCH_BENCH_TYPES(pairs, NVBENCH_TYPE_AXES(key_types, value_types, offset_types))
99+
NVBENCH_BENCH_TYPES(pairs, NVBENCH_TYPE_AXES(key_types, value_types))
134100
.set_name("base")
135-
.set_type_axes_names({"KeyT{ct}", "ValueT{ct}", "OffsetT{ct}"})
101+
.set_type_axes_names({"KeyT{ct}", "ValueT{ct}"})
136102
.add_int64_power_of_two_axis("Elements{io}", nvbench::range(16, 28, 4))
137103
.add_string_axis("Entropy", {"1.000", "0.201"});

cub/cub/detail/env_dispatch.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ CUB_RUNTIME_FUNCTION static cudaError_t dispatch_with_env(EnvT env, AlgorithmCal
7474
return (error != cudaSuccess) ? error : deallocate_error;
7575
}
7676
//! @endcond
77+
78+
template <typename DefaultPolicySelector, typename EnvT, typename AlgorithmCallable>
79+
CUB_RUNTIME_FUNCTION static cudaError_t dispatch_with_env_and_tuning(EnvT env, AlgorithmCallable&& algorithm_callable)
80+
{
81+
return detail::dispatch_with_env(
82+
env, [&]([[maybe_unused]] auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
83+
using policy_t = decltype(DefaultPolicySelector{}(::cuda::arch_id{}));
84+
using policy_selector =
85+
::cuda::std::execution::__query_result_or_t<decltype(tuning_env), policy_t, DefaultPolicySelector>;
86+
return algorithm_callable(policy_selector{}, d_temp_storage, temp_storage_bytes, stream);
87+
});
88+
}
7789
} // namespace detail
7890

7991
CUB_NAMESPACE_END

cub/cub/device/device_merge.cuh

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,10 @@ struct DeviceMerge
192192
{
193193
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceMerge::MergeKeys");
194194

195-
return detail::dispatch_with_env(
196-
env, [&]([[maybe_unused]] auto tuning, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
195+
using default_policy_selector =
196+
detail::merge::policy_selector_from_types<detail::it_value_t<KeyIteratorIn1>, NullType, int64_t>;
197+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
198+
env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
197199
return detail::merge::dispatch(
198200
d_temp_storage,
199201
temp_storage_bytes,
@@ -206,7 +208,8 @@ struct DeviceMerge
206208
keys_out,
207209
static_cast<NullType*>(nullptr),
208210
compare_op,
209-
stream);
211+
stream,
212+
policy_selector);
210213
});
211214
}
212215

@@ -413,9 +416,10 @@ struct DeviceMerge
413416
EnvT env = {})
414417
{
415418
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceMerge::MergePairs");
416-
417-
return detail::dispatch_with_env(
418-
env, [&]([[maybe_unused]] auto tuning, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
419+
using default_policy_selector = detail::merge::
420+
policy_selector_from_types<detail::it_value_t<KeyIteratorIn1>, detail::it_value_t<ValueIteratorIn1>, int64_t>;
421+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
422+
env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
419423
return detail::merge::dispatch(
420424
d_temp_storage,
421425
temp_storage_bytes,
@@ -428,7 +432,8 @@ struct DeviceMerge
428432
keys_out,
429433
values_out,
430434
compare_op,
431-
stream);
435+
stream,
436+
policy_selector);
432437
});
433438
}
434439
};

0 commit comments

Comments
 (0)