Skip to content

Commit 6813780

Browse files
committed
fix complex and removed user specialization
1 parent 4b8e70d commit 6813780

File tree

2 files changed

+14
-46
lines changed

2 files changed

+14
-46
lines changed

libcudacxx/include/cuda/__type_traits/is_trivially_copyable.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ _CCCL_BEGIN_NAMESPACE_CUDA
4242
template <typename _Tp, typename = void>
4343
inline constexpr bool __is_aggregate_trivially_copyable_v = false;
4444

45-
//! Users are allowed to specialize this variable template for their own types
4645
template <typename _Tp>
4746
inline constexpr bool __is_trivially_copyable_v =
4847
::cuda::std::is_trivially_copyable_v<_Tp> || ::cuda::std::__is_extended_floating_point_v<_Tp>
@@ -67,11 +66,17 @@ inline constexpr bool __is_trivially_copyable_v<::cuda::std::pair<_T1, _T2>> =
6766
template <typename... _Ts>
6867
inline constexpr bool __is_trivially_copyable_v<::cuda::std::tuple<_Ts...>> = (__is_trivially_copyable_v<_Ts> && ...);
6968

70-
template <typename _Tp>
71-
inline constexpr bool __is_trivially_copyable_v<::cuda::std::complex<_Tp>> = __is_trivially_copyable_v<_Tp>;
69+
template <>
70+
inline constexpr bool __is_trivially_copyable_v<::cuda::std::complex<::__half>> = true;
7271

73-
template <typename _Tp>
74-
inline constexpr bool __is_trivially_copyable_v<::cuda::complex<_Tp>> = __is_trivially_copyable_v<_Tp>;
72+
template <>
73+
inline constexpr bool __is_trivially_copyable_v<::cuda::std::complex<::__nv_bfloat16>> = true;
74+
75+
template <>
76+
inline constexpr bool __is_trivially_copyable_v<::cuda::complex<::__half>> = true;
77+
78+
template <>
79+
inline constexpr bool __is_trivially_copyable_v<::cuda::complex<::__nv_bfloat16>> = true;
7580

7681
// if all the previous conditions fail, check if the type is an aggregate and all its members are trivially copyable
7782
template <typename _Tp>
@@ -88,7 +93,6 @@ inline constexpr bool
8893
template <typename _Tp>
8994
inline constexpr bool is_trivially_copyable_v = __is_trivially_copyable_v<::cuda::std::remove_const_t<_Tp>>;
9095

91-
// defined as alias so users cannot specialize it (they should specialize the variable template instead)
9296
template <typename _Tp>
9397
using is_trivially_copyable = ::cuda::std::bool_constant<is_trivially_copyable_v<_Tp>>;
9498

libcudacxx/test/libcudacxx/cuda/type_traits/is_trivially_copyable.aggr.pass.cpp

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,40 +47,6 @@ struct ArrayMember
4747
int values[2];
4848
};
4949

50-
//----------------------------------------------------------------------------------------------------------------------
51-
// custom type
52-
53-
struct CustomNonTrivialType
54-
{
55-
int x;
56-
57-
CustomNonTrivialType() = default;
58-
__host__ __device__ CustomNonTrivialType(const CustomNonTrivialType&) {}
59-
};
60-
61-
template <>
62-
constexpr bool cuda::is_trivially_copyable_v<CustomNonTrivialType> = true;
63-
64-
struct SingleMemberCustom
65-
{
66-
CustomNonTrivialType x;
67-
};
68-
69-
struct DerivedStructCustom : SingleMemberCustom
70-
{
71-
CustomNonTrivialType y;
72-
};
73-
74-
struct NestedStructCustom
75-
{
76-
CustomNonTrivialType z;
77-
};
78-
79-
struct ArrayMemberCustom
80-
{
81-
CustomNonTrivialType values[2];
82-
};
83-
8450
//----------------------------------------------------------------------------------------------------------------------
8551
// non trivially copyable type
8652

@@ -89,9 +55,9 @@ struct NonTriviallyCopyable
8955
__host__ __device__ NonTriviallyCopyable(const NonTriviallyCopyable&) {};
9056
};
9157

92-
struct RelaxedWithNonRelaxedMember
58+
struct AggregateWithNonTriviallyCopyableMember
9359
{
94-
CustomNonTrivialType x;
60+
int x;
9561
NonTriviallyCopyable y;
9662
};
9763

@@ -101,10 +67,8 @@ __host__ __device__ void test()
10167
test_is_trivially_copyable<DerivedStruct>();
10268
test_is_trivially_copyable<NesterStruct>();
10369
test_is_trivially_copyable<ArrayMember>();
104-
test_is_trivially_copyable<ArrayMember>();
105-
test_is_trivially_copyable<DerivedStructCustom>();
106-
test_is_trivially_copyable<NestedStructCustom>();
107-
test_is_not_trivially_copyable<RelaxedWithNonRelaxedMember>();
70+
test_is_not_trivially_copyable<NonTriviallyCopyable>();
71+
test_is_not_trivially_copyable<AggregateWithNonTriviallyCopyableMember>();
10872
}
10973

11074
int main(int, char**)

0 commit comments

Comments
 (0)