Skip to content

Commit 5116464

Browse files
committed
Allow specifying a difference type for cuda::counting_iterator
1 parent d5b1958 commit 5116464

File tree

15 files changed

+870
-152
lines changed

15 files changed

+870
-152
lines changed

libcudacxx/include/cuda/__fwd/iterator.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
#endif // no system header
2323

2424
#include <cuda/__fwd/random.h>
25+
#include <cuda/std/__concepts/arithmetic.h>
2526
#include <cuda/std/__concepts/copyable.h>
27+
#include <cuda/std/__cstddef/types.h>
2628
#include <cuda/std/__iterator/concepts.h>
2729
#include <cuda/std/__iterator/incrementable_traits.h>
2830
#include <cuda/std/__type_traits/enable_if.h>
31+
#include <cuda/std/__type_traits/type_identity.h>
2932
#include <cuda/std/cstdint>
3033

3134
#include <cuda/std/__cccl/prologue.h>
@@ -35,13 +38,46 @@ _CCCL_BEGIN_NAMESPACE_CUDA
3538
template <class _Tp, class _Index = ::cuda::std::ptrdiff_t>
3639
class constant_iterator;
3740

41+
template <class _Tp>
42+
[[nodiscard]] _CCCL_API _CCCL_CONSTEVAL auto __get_wider_signed() noexcept
43+
{
44+
if constexpr (sizeof(_Tp) < sizeof(int))
45+
{
46+
return ::cuda::std::type_identity<int>{};
47+
}
48+
else if constexpr (sizeof(_Tp) < sizeof(long))
49+
{
50+
return ::cuda::std::type_identity<long>{};
51+
}
52+
#if _CCCL_HAS_INT128()
53+
else if constexpr (sizeof(_Tp) < sizeof(long long))
54+
{
55+
return ::cuda::std::type_identity<::cuda::std::ptrdiff_t>{};
56+
}
57+
else // if constexpr (sizeof(_Start) < sizeof(__int128_t))
58+
{
59+
return ::cuda::std::type_identity<__int128_t>{};
60+
}
61+
#else // ^^^ _CCCL_HAS_INT128() ^^^ / vvv !_CCCL_HAS_INT128() vvv
62+
else // if constexpr (sizeof(_Start) < sizeof(long long))
63+
{
64+
return ::cuda::std::type_identity<long long>{};
65+
}
66+
#endif // _CCCL_HAS_INT128()
67+
}
68+
69+
template <class _IntT>
70+
using _IotaDiffT = typename decltype(::cuda::__get_wider_signed<_IntT>())::type;
71+
3872
#if _CCCL_HAS_CONCEPTS()
39-
template <::cuda::std::weakly_incrementable _Start>
73+
template <::cuda::std::weakly_incrementable _Start, ::cuda::std::signed_integral _DiffT = _IotaDiffT<_Start>>
4074
requires ::cuda::std::copyable<_Start>
4175
#else // ^^^ _CCCL_HAS_CONCEPTS() ^^^ / vvv !_CCCL_HAS_CONCEPTS() vvv
4276
template <class _Start,
77+
class _DiffT = _IotaDiffT<_Start>,
4378
::cuda::std::enable_if_t<::cuda::std::weakly_incrementable<_Start>, int> = 0,
44-
::cuda::std::enable_if_t<::cuda::std::copyable<_Start>, int> = 0>
79+
::cuda::std::enable_if_t<::cuda::std::copyable<_Start>, int> = 0,
80+
::cuda::std::enable_if_t<::cuda::std::signed_integral<_DiffT>, int> = 0>
4581
#endif // ^^^ !_CCCL_HAS_CONCEPTS() ^^^
4682
class counting_iterator;
4783

libcudacxx/include/cuda/__iterator/counting_iterator.h

Lines changed: 103 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -66,49 +66,6 @@ _CCCL_BEGIN_NAMESPACE_CUDA
6666

6767
//! @cond
6868

69-
template <class _Int>
70-
struct __get_wider_signed
71-
{
72-
_CCCL_API inline static auto __call() noexcept
73-
{
74-
if constexpr (sizeof(_Int) < sizeof(short))
75-
{
76-
return ::cuda::std::type_identity<short>{};
77-
}
78-
else if constexpr (sizeof(_Int) < sizeof(int))
79-
{
80-
return ::cuda::std::type_identity<int>{};
81-
}
82-
else if constexpr (sizeof(_Int) < sizeof(long))
83-
{
84-
return ::cuda::std::type_identity<long>{};
85-
}
86-
#if _CCCL_HAS_INT128()
87-
else if constexpr (sizeof(_Int) < sizeof(long long))
88-
{
89-
return ::cuda::std::type_identity<long long>{};
90-
}
91-
else // if constexpr (sizeof(_Int) < sizeof(__int128_t))
92-
{
93-
return ::cuda::std::type_identity<__int128_t>{};
94-
}
95-
#else // ^^^ _CCCL_HAS_INT128() ^^^ / vvv !_CCCL_HAS_INT128() vvv
96-
else // if constexpr (sizeof(_Int) < sizeof(long long))
97-
{
98-
return ::cuda::std::type_identity<long long>{};
99-
}
100-
#endif // _CCCL_HAS_INT128()
101-
}
102-
103-
using type = typename decltype(__call())::type;
104-
};
105-
106-
template <class _Start>
107-
using _IotaDiffT = typename ::cuda::std::conditional_t<
108-
(!::cuda::std::integral<_Start> || sizeof(::cuda::std::iter_difference_t<_Start>) > sizeof(_Start)),
109-
::cuda::std::type_identity<::cuda::std::iter_difference_t<_Start>>,
110-
__get_wider_signed<_Start>>::type;
111-
11269
template <class _Iter>
11370
_CCCL_CONCEPT __decrementable = _CCCL_REQUIRES_EXPR((_Iter), _Iter __iter)(
11471
requires(::cuda::std::incrementable<_Iter>), _Same_as(_Iter&)(--__iter), _Same_as(_Iter)(__iter--));
@@ -164,12 +121,14 @@ struct __counting_iterator_category<_Tp, ::cuda::std::enable_if_t<::cuda::std::i
164121
//! std::copy(iter, iter + vec.size(), vec.begin());
165122
//! @endcode
166123
#if _CCCL_HAS_CONCEPTS()
167-
template <::cuda::std::weakly_incrementable _Start>
124+
template <::cuda::std::weakly_incrementable _Start, ::cuda::std::signed_integral _DiffT>
168125
requires ::cuda::std::copyable<_Start>
169126
#else // ^^^ _CCCL_HAS_CONCEPTS() ^^^ / vvv !_CCCL_HAS_CONCEPTS() vvv
170127
template <class _Start,
128+
class _DiffT,
171129
::cuda::std::enable_if_t<::cuda::std::weakly_incrementable<_Start>, int>,
172-
::cuda::std::enable_if_t<::cuda::std::copyable<_Start>, int>>
130+
::cuda::std::enable_if_t<::cuda::std::copyable<_Start>, int>,
131+
::cuda::std::enable_if_t<::cuda::std::signed_integral<_DiffT>, int>>
173132
#endif // ^^^ !_CCCL_HAS_CONCEPTS() ^^^
174133
class counting_iterator : public __counting_iterator_category<_Start>
175134
{
@@ -187,13 +146,30 @@ class counting_iterator : public __counting_iterator_category<_Start>
187146
/*Else*/ ::cuda::std::input_iterator_tag>>>;
188147

189148
using value_type = _Start;
190-
using difference_type = _IotaDiffT<_Start>;
149+
using difference_type = _DiffT;
191150

192151
// Those are technically not to spec, but pre-ranges iterator_traits do not work properly with iterators that do not
193152
// define all 5 aliases, see https://en.cppreference.com/w/cpp/iterator/iterator_traits.html
194153
using reference = _Start;
195154
using pointer = void;
196155

156+
// Needed for comparison operators and constructors, because the other side might have a
157+
// different difference type so we cannot reach into their private members. Usually you solve
158+
// this with the power of friendship, but since this class uses concepts or SFINAE, spelling
159+
// out the friendship is a faff.
160+
//
161+
// We also cannot use operator*() here to get the value because that imposes the additional
162+
// burden of requiring _Start to be copy-constructible which is not needed for comparisons.
163+
[[nodiscard]] _CCCL_API constexpr const _Start& __get_value() const noexcept
164+
{
165+
return __value_;
166+
}
167+
168+
[[nodiscard]] _CCCL_API constexpr _Start& __get_value() noexcept
169+
{
170+
return __value_;
171+
}
172+
197173
#if _CCCL_HAS_CONCEPTS()
198174
_CCCL_HIDE_FROM_ABI counting_iterator()
199175
requires ::cuda::std::default_initializable<_Start>
@@ -211,6 +187,43 @@ class counting_iterator : public __counting_iterator_category<_Start>
211187
: __value_(::cuda::std::move(__value))
212188
{}
213189

190+
constexpr counting_iterator(const counting_iterator&) = default;
191+
constexpr counting_iterator(counting_iterator&&) = default;
192+
constexpr counting_iterator& operator=(const counting_iterator&) = default;
193+
constexpr counting_iterator& operator=(counting_iterator&&) = default;
194+
195+
//! @brief Creates a @c counting_iterator from another @c counting_iterator of a different
196+
//! difference type.
197+
//! @param __other The @c counting_iterator to copy from.
198+
_CCCL_TEMPLATE(class _DiffT2)
199+
_CCCL_REQUIRES((!::cuda::std::same_as<_DiffT, _DiffT2>) )
200+
_CCCL_API constexpr explicit counting_iterator(const counting_iterator<_Start, _DiffT2>& __other) noexcept(
201+
::cuda::std::is_nothrow_copy_constructible_v<_Start>)
202+
: __value_(__other.__get_value())
203+
{}
204+
205+
//! @brief Creates a @c counting_iterator from another @c counting_iterator of a different
206+
//! difference type.
207+
//! @param __other The @c counting_iterator to move from.
208+
_CCCL_TEMPLATE(class _DiffT2)
209+
_CCCL_REQUIRES((!::cuda::std::same_as<_DiffT, _DiffT2>) )
210+
_CCCL_API constexpr explicit counting_iterator(counting_iterator<_Start, _DiffT2>&& __other) noexcept(
211+
::cuda::std::is_nothrow_move_constructible_v<_Start>)
212+
: __value_(::cuda::std::move(__other.__get_value()))
213+
{}
214+
215+
//! @brief Assignment between counting iterators of differing difference types is explicitly
216+
//! deleted. If such a conversion is intended, use the copy or move constructors to convert.
217+
_CCCL_TEMPLATE(class _DiffT2)
218+
_CCCL_REQUIRES((!::cuda::std::same_as<_DiffT, _DiffT2>) )
219+
_CCCL_API constexpr counting_iterator& operator=(const counting_iterator<_Start, _DiffT2>&) = delete;
220+
221+
//! @brief Assignment between counting iterators of differing difference types is explicitly
222+
//! deleted. If such a conversion is intended, use the copy or move constructors
223+
_CCCL_TEMPLATE(class _DiffT2)
224+
_CCCL_REQUIRES((!::cuda::std::same_as<_DiffT, _DiffT2>) )
225+
_CCCL_API constexpr counting_iterator& operator=(counting_iterator<_Start, _DiffT2>&&) = delete;
226+
214227
//! @brief Returns the value currently stored in the @c counting_iterator
215228
[[nodiscard]] _CCCL_API constexpr _Start operator*() const
216229
noexcept(::cuda::std::is_nothrow_copy_constructible_v<_Start>)
@@ -394,81 +407,78 @@ class counting_iterator : public __counting_iterator_category<_Start>
394407
}
395408
}
396409

397-
//! @brief Compares two @c counting_iterator for equality.
398-
//! @return True if the stored values compare equal
399-
_CCCL_TEMPLATE(class _Start2 = _Start)
400-
_CCCL_REQUIRES(::cuda::std::equality_comparable<_Start2>)
410+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
411+
_CCCL_REQUIRES(::cuda::std::equality_comparable_with<_Start, _Start2>)
401412
[[nodiscard]] _CCCL_API friend constexpr bool
402-
operator==(const counting_iterator& __x, const counting_iterator& __y) noexcept(
403-
noexcept(::cuda::std::declval<const _Start2&>() == ::cuda::std::declval<const _Start2&>()))
413+
operator==(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(
414+
noexcept(::cuda::std::declval<const _Start&>() == ::cuda::std::declval<const _Start2&>()))
404415
{
405-
return __x.__value_ == __y.__value_;
416+
return __x.__value_ == __y.__get_value();
406417
}
407418

408419
#if _CCCL_STD_VER <= 2017
409420
//! @brief Compares two @c counting_iterator for inequality.
410421
//! @return True if the stored values do not compare equal
411-
_CCCL_TEMPLATE(class _Start2 = _Start)
412-
_CCCL_REQUIRES(::cuda::std::equality_comparable<_Start2>)
422+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
423+
_CCCL_REQUIRES(::cuda::std::equality_comparable_with<_Start, _Start2>)
413424
[[nodiscard]] _CCCL_API friend constexpr bool
414-
operator!=(const counting_iterator& __x, const counting_iterator& __y) noexcept(
415-
noexcept(::cuda::std::declval<const _Start2&>() != ::cuda::std::declval<const _Start2&>()))
425+
operator!=(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(
426+
noexcept(::cuda::std::declval<const _Start&>() != ::cuda::std::declval<const _Start2&>()))
416427
{
417-
return __x.__value_ != __y.__value_;
428+
return __x.__value_ != __y.__get_value();
418429
}
419430
#endif // _CCCL_STD_VER <= 2017
420431

421432
#if _LIBCUDACXX_HAS_SPACESHIP_OPERATOR()
422433
//! @brief Three-way compares two @c counting_iterator.
423434
//! @return The three-way comparison of the stored values
435+
template <class _Start2, class _DiffT2>
424436
[[nodiscard]] _CCCL_API friend constexpr auto
425-
operator<=>(const counting_iterator& __x, const counting_iterator& __y) noexcept(
426-
noexcept(::cuda::std::declval<const _Start2&>() <=> ::cuda::std::declval<const _Start2&>()))
427-
requires ::cuda::std::totally_ordered<_Start> && ::cuda::std::three_way_comparable<_Start>
437+
operator<=>(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(
438+
noexcept(::cuda::std::declval<const _Start&>() <=> ::cuda::std::declval<const _Start2&>()))
439+
requires ::cuda::std::totally_ordered_with<_Start, _Start2>
440+
&& ::cuda::std::three_way_comparable_with<_Start, _Start2>
428441
{
429-
return __x.__value_ <=> __y.__value_;
442+
return __x.__value_ <=> __y.__get_value();
430443
}
431444
#else // ^^^ _LIBCUDACXX_HAS_SPACESHIP_OPERATOR() ^^^ / vvv !_LIBCUDACXX_HAS_SPACESHIP_OPERATOR() vvv
432445
//! @brief Compares two @c counting_iterator for less than.
433446
//! @return True if stored values compare less than
434-
_CCCL_TEMPLATE(class _Start2 = _Start)
435-
_CCCL_REQUIRES(::cuda::std::totally_ordered<_Start2>)
447+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
448+
_CCCL_REQUIRES(::cuda::std::totally_ordered_with<_Start, _Start2>)
436449
[[nodiscard]] _CCCL_API friend constexpr bool
437-
operator<(const counting_iterator& __x, const counting_iterator& __y) noexcept(
438-
noexcept(::cuda::std::declval<const _Start2&>() < ::cuda::std::declval<const _Start2&>()))
450+
operator<(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(
451+
noexcept(::cuda::std::declval<const _Start&>() < ::cuda::std::declval<const _Start2&>()))
439452
{
440-
return __x.__value_ < __y.__value_;
453+
return __x.__value_ < __y.__get_value();
441454
}
442455

443456
//! @brief Compares two @c counting_iterator for greater than.
444457
//! @return True if stored values compare greater than
445-
_CCCL_TEMPLATE(class _Start2 = _Start)
446-
_CCCL_REQUIRES(::cuda::std::totally_ordered<_Start2>)
458+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
459+
_CCCL_REQUIRES(::cuda::std::totally_ordered_with<_Start, _Start2>)
447460
[[nodiscard]] _CCCL_API friend constexpr bool
448-
operator>(const counting_iterator& __x, const counting_iterator& __y) noexcept(
449-
noexcept(::cuda::std::declval<const _Start2&>() < ::cuda::std::declval<const _Start2&>()))
461+
operator>(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(noexcept(__y < __x))
450462
{
451463
return __y < __x;
452464
}
453465

454466
//! @brief Compares two @c counting_iterator for less equal.
455467
//! @return True if stored values compare less equal
456-
_CCCL_TEMPLATE(class _Start2 = _Start)
457-
_CCCL_REQUIRES(::cuda::std::totally_ordered<_Start2>)
468+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
469+
_CCCL_REQUIRES(::cuda::std::totally_ordered_with<_Start, _Start2>)
458470
[[nodiscard]] _CCCL_API friend constexpr bool
459-
operator<=(const counting_iterator& __x, const counting_iterator& __y) noexcept(
460-
noexcept(::cuda::std::declval<const _Start2&>() < ::cuda::std::declval<const _Start2&>()))
471+
operator<=(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(noexcept(__y < __x))
461472
{
462473
return !(__y < __x);
463474
}
464475

465476
//! @brief Compares two @c counting_iterator for greater equal.
466477
//! @return True if stored values compare greater equal
467-
_CCCL_TEMPLATE(class _Start2 = _Start)
468-
_CCCL_REQUIRES(::cuda::std::totally_ordered<_Start2>)
478+
_CCCL_TEMPLATE(class _Start2, class _DiffT2)
479+
_CCCL_REQUIRES(::cuda::std::totally_ordered_with<_Start, _Start2>)
469480
[[nodiscard]] _CCCL_API friend constexpr bool
470-
operator>=(const counting_iterator& __x, const counting_iterator& __y) noexcept(
471-
noexcept(::cuda::std::declval<const _Start2&>() < ::cuda::std::declval<const _Start2&>()))
481+
operator>=(const counting_iterator& __x, const counting_iterator<_Start2, _DiffT2>& __y) noexcept(noexcept(__x < __y))
472482
{
473483
return !(__x < __y);
474484
}
@@ -494,36 +504,36 @@ _CCCL_BEGIN_NAMESPACE_STD
494504

495505
//! counting_iterator is a C++20 iterator, so it does not play well with legacy STL features like std::distance
496506
//! To work around that specialize those functions for counting_iterator
497-
template <class _Diff, class _Start>
507+
template <class _Diff, class _Start, class _DiffT2>
498508
_CCCL_HOST_API constexpr void
499-
advance(::cuda::counting_iterator<_Start>& __iter, _Diff __diff) noexcept(::cuda::std::__integer_like<_Start>)
509+
advance(::cuda::counting_iterator<_Start, _DiffT2>& __iter, _Diff __diff) noexcept(::cuda::std::__integer_like<_Start>)
500510
{
501511
::cuda::std::advance(__iter, ::cuda::std::move(__diff));
502512
}
503513

504-
template <class _Start>
505-
[[nodiscard]] _CCCL_HOST_API constexpr ::cuda::_IotaDiffT<_Start>
506-
distance(::cuda::counting_iterator<_Start> __first,
507-
::cuda::counting_iterator<_Start> __last) noexcept(::cuda::std::__integer_like<_Start>)
514+
template <class _Start, class _DiffT>
515+
[[nodiscard]] _CCCL_HOST_API constexpr typename ::cuda::counting_iterator<_Start, _DiffT>::difference_type
516+
distance(::cuda::counting_iterator<_Start, _DiffT> __first,
517+
::cuda::counting_iterator<_Start, _DiffT> __last) noexcept(::cuda::std::__integer_like<_Start>)
508518
{
509519
return ::cuda::std::distance(::cuda::std::move(__first), ::cuda::std::move(__last));
510520
}
511521

512-
template <class _Start>
513-
[[nodiscard]] _CCCL_HOST_API constexpr ::cuda::counting_iterator<_Start>
514-
next(::cuda::counting_iterator<_Start> __iter,
515-
::cuda::_IotaDiffT<_Start> __n = 1) noexcept(::cuda::std::__integer_like<_Start>)
522+
template <class _Start, class _DiffT>
523+
[[nodiscard]] _CCCL_HOST_API constexpr ::cuda::counting_iterator<_Start, _DiffT>
524+
next(::cuda::counting_iterator<_Start, _DiffT> __iter,
525+
::cuda::std::type_identity_t<_DiffT> __n = 1) noexcept(::cuda::std::__integer_like<_Start>)
516526
{
517527
_CCCL_ASSERT(__n >= 0 || ::cuda::__decrementable<_Start>,
518528
"Attempt to std::next(it, n) with negative n on a non-bidirectional iterator");
519529
::cuda::std::advance(__iter, __n);
520530
return __iter;
521531
}
522532

523-
template <class _Start>
524-
[[nodiscard]] _CCCL_HOST_API constexpr ::cuda::counting_iterator<_Start>
525-
prev(::cuda::counting_iterator<_Start> __iter,
526-
::cuda::_IotaDiffT<_Start> __n = 1) noexcept(::cuda::std::__integer_like<_Start>)
533+
template <class _Start, class _DiffT>
534+
[[nodiscard]] _CCCL_HOST_API constexpr ::cuda::counting_iterator<_Start, _DiffT>
535+
prev(::cuda::counting_iterator<_Start, _DiffT> __iter,
536+
::cuda::std::type_identity_t<_DiffT> __n = 1) noexcept(::cuda::std::__integer_like<_Start>)
527537
{
528538
_CCCL_ASSERT(__n >= 0 || ::cuda::__decrementable<_Start>, "Attempt to std::prev(it, +n) on a non-bidi iterator");
529539
::cuda::std::advance(__iter, -__n);

0 commit comments

Comments
 (0)