Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit f3c28da

Browse files
authored
Merge pull request #1721 from senior-zero/fix-main/github/scan_intermediate_type
Fixing scan accumulator types for NVIDIA/cub#511
2 parents f2ba086 + 20ba21c commit f3c28da

File tree

4 files changed

+27
-12
lines changed

4 files changed

+27
-12
lines changed

thrust/system/cuda/detail/async/exclusive_scan.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
7979
OutputIt,
8080
BinaryOp,
8181
InputValueT,
82-
thrust::detail::int32_t>;
82+
thrust::detail::int32_t,
83+
InitialValueType>;
8384
using Dispatch64 = cub::DispatchScan<ForwardIt,
8485
OutputIt,
8586
BinaryOp,
8687
InputValueT,
87-
thrust::detail::int64_t>;
88+
thrust::detail::int64_t,
89+
InitialValueType>;
8890

8991
InputValueT init_value(init);
9092

thrust/system/cuda/detail/async/inclusive_scan.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,19 @@ async_inclusive_scan_n(execution_policy<DerivedPolicy>& policy,
7272
OutputIt out,
7373
BinaryOp op)
7474
{
75+
using AccumT = typename thrust::iterator_traits<ForwardIt>::value_type;
7576
using Dispatch32 = cub::DispatchScan<ForwardIt,
7677
OutputIt,
7778
BinaryOp,
7879
cub::NullType,
79-
thrust::detail::int32_t>;
80+
thrust::detail::int32_t,
81+
AccumT>;
8082
using Dispatch64 = cub::DispatchScan<ForwardIt,
8183
OutputIt,
8284
BinaryOp,
8385
cub::NullType,
84-
thrust::detail::int64_t>;
86+
thrust::detail::int64_t,
87+
AccumT>;
8588

8689
auto const device_alloc = get_async_device_allocator(policy);
8790
unique_eager_event ev;

thrust/system/cuda/detail/scan.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,19 @@ OutputIt inclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
6060
OutputIt result,
6161
ScanOp scan_op)
6262
{
63+
using AccumT = typename thrust::iterator_traits<InputIt>::value_type;
6364
using Dispatch32 = cub::DispatchScan<InputIt,
6465
OutputIt,
6566
ScanOp,
6667
cub::NullType,
67-
thrust::detail::int32_t>;
68+
thrust::detail::int32_t,
69+
AccumT>;
6870
using Dispatch64 = cub::DispatchScan<InputIt,
6971
OutputIt,
7072
ScanOp,
7173
cub::NullType,
72-
thrust::detail::int64_t>;
74+
thrust::detail::int64_t,
75+
AccumT>;
7376

7477
cudaStream_t stream = thrust::cuda_cub::stream(policy);
7578
cudaError_t status;
@@ -141,12 +144,14 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
141144
OutputIt,
142145
ScanOp,
143146
InputValueT,
144-
thrust::detail::int32_t>;
147+
thrust::detail::int32_t,
148+
InitValueT>;
145149
using Dispatch64 = cub::DispatchScan<InputIt,
146150
OutputIt,
147151
ScanOp,
148152
InputValueT,
149-
thrust::detail::int64_t>;
153+
thrust::detail::int64_t,
154+
InitValueT>;
150155

151156
cudaStream_t stream = thrust::cuda_cub::stream(policy);
152157
cudaError_t status;

thrust/system/cuda/detail/scan_by_key.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ ValuesOutIt inclusive_scan_by_key_n(
8787
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
8888
using ValuesOutUnwrapIt =
8989
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
90+
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;
9091

9192
auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
9293
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
@@ -98,14 +99,16 @@ ValuesOutIt inclusive_scan_by_key_n(
9899
EqualityOpT,
99100
ScanOpT,
100101
cub::NullType,
101-
thrust::detail::int32_t>;
102+
thrust::detail::int32_t,
103+
AccumT>;
102104
using Dispatch64 = cub::DispatchScanByKey<KeysInUnwrapIt,
103105
ValuesInUnwrapIt,
104106
ValuesOutUnwrapIt,
105107
EqualityOpT,
106108
ScanOpT,
107109
cub::NullType,
108-
thrust::detail::int64_t>;
110+
thrust::detail::int64_t,
111+
AccumT>;
109112

110113
cudaStream_t stream = thrust::cuda_cub::stream(policy);
111114
cudaError_t status{};
@@ -209,14 +212,16 @@ ValuesOutIt exclusive_scan_by_key_n(
209212
EqualityOpT,
210213
ScanOpT,
211214
InitValueT,
212-
thrust::detail::int32_t>;
215+
thrust::detail::int32_t,
216+
InitValueT>;
213217
using Dispatch64 = cub::DispatchScanByKey<KeysInUnwrapIt,
214218
ValuesInUnwrapIt,
215219
ValuesOutUnwrapIt,
216220
EqualityOpT,
217221
ScanOpT,
218222
InitValueT,
219-
thrust::detail::int64_t>;
223+
thrust::detail::int64_t,
224+
InitValueT>;
220225

221226
cudaStream_t stream = thrust::cuda_cub::stream(policy);
222227
cudaError_t status{};

0 commit comments

Comments
 (0)