@@ -94,8 +94,25 @@ static_assert(sizeof(__rtti_base) == sizeof(uint64_t) + sizeof(void*));
9494// Used to map an interface typeid to a pointer to the vtable for that interface.
9595struct __base_info
9696{
97+ using __cast_fn_t = auto (__rtti const *) noexcept -> __base_vptr;
98+
9799 ::cuda::std::__type_info_ptr __typeid_;
98- __base_vptr __vptr_;
100+ union
101+ {
102+ __cast_fn_t * __cast_fn_; // used when __basic_any_version >= 1,
103+ __base_vptr __vptr_v0_; // used when __basic_any_version == 0
104+ };
105+
106+ [[nodiscard]] _CCCL_API auto __get_vptr (__rtti const * __rtti_ptr, uint8_t __version) const noexcept -> __base_vptr
107+ {
108+ return __version >= 1 ? __cast_fn_ (__rtti_ptr) : __vptr_v0_;
109+ }
110+
111+ template <class _VTable , class _Interface >
112+ [[nodiscard]] _CCCL_API static auto __cast_fn_impl (__rtti const * __rtti_ptr) noexcept -> __base_vptr
113+ {
114+ return {static_cast <__vptr_for<_Interface>>(static_cast <_VTable const *>(__rtti_ptr))};
115+ }
99116};
100117
101118inline constexpr size_t __half_size_t_bits = sizeof (size_t ) * CHAR_BIT / 2 ;
@@ -164,35 +181,34 @@ struct __rtti : __rtti_base
164181 {
165182 if (&__id == __base_vptr_map_[__i].__typeid_ )
166183 {
167- return static_cast <__vptr_for<_Interface>>(__base_vptr_map_[__i].__vptr_ );
184+ return static_cast <__vptr_for<_Interface>>(__base_vptr_map_[__i].__get_vptr ( this , __version_) );
168185 }
169186 }
170187
171188 for (size_t __i = 0 ; __i < __nbr_interfaces_; ++__i)
172189 {
173190 if (__id == *__base_vptr_map_[__i].__typeid_ )
174191 {
175- return static_cast <__vptr_for<_Interface>>(__base_vptr_map_[__i].__vptr_ );
192+ return static_cast <__vptr_for<_Interface>>(__base_vptr_map_[__i].__get_vptr ( this , __version_) );
176193 }
177194 }
178195
179196 return nullptr ;
180197 }
181198
182- void (*__dtor_)(void *, bool ) noexcept ;
183- __object_metadata const * __object_info_;
199+ void (*__dtor_)(void *, bool ) noexcept = nullptr ;
200+ __object_metadata const * __object_info_ = nullptr ;
184201 ::cuda::std::__type_info_ptr __interface_typeid_ = nullptr ;
185- __base_info const * __base_vptr_map_;
202+ __base_info const * __base_vptr_map_ = nullptr ;
186203};
187204
188205template <size_t _NbrInterfaces>
189206struct __rtti_ex : __rtti
190207{
191- template <class _Tp , class _Super , class ... _Interfaces, class _VPtr >
192- _CCCL_API constexpr __rtti_ex (__tag<_Tp, _Super> __type, __tag<_Interfaces...> __ibases, _VPtr ) noexcept
208+ template <class _Tp , class _Super , class ... _Interfaces, class _VTable >
209+ _CCCL_API constexpr __rtti_ex (__tag<_Tp, _Super> __type, __tag<_Interfaces...> __ibases, _VTable const * ) noexcept
193210 : __rtti{__type, __ibases, __base_vptr_array}
194- , __base_vptr_array{
195- {&_CCCL_TYPEID (_Interfaces), static_cast <__vptr_for<_Interfaces>>(static_cast <_VPtr>(this ))}...}
211+ , __base_vptr_array{{&_CCCL_TYPEID (_Interfaces), {&__base_info::__cast_fn_impl<_VTable, _Interfaces>}}...}
196212 {}
197213
198214 __base_info __base_vptr_array[_NbrInterfaces];
@@ -225,8 +241,8 @@ template <class _SrcInterface, class _DstInterface>
225241 else
226242 {
227243 // ! Slow down-casts and cross-casts:
228- __rtti const * rtti = __src_vptr->__query_interface (__iunknown ());
229- return rtti ->__query_interface (_DstInterface ());
244+ __rtti const * __rtti_ptr = __src_vptr->__query_interface (__iunknown ());
245+ return __rtti_ptr ->__query_interface (_DstInterface ());
230246 }
231247}
232248
0 commit comments