@@ -19,6 +19,29 @@ use crate::memory::UnifiedPointer;
1919use crate :: memory:: malloc:: { cuda_free_unified, cuda_malloc_unified} ;
2020use crate :: prelude:: Stream ;
2121
22+ #[ cfg( any( cuMemPrefetchAsync_v2, cuMemAdvise_v2) ) ]
23+ unsafe fn cu_mem_location (
24+ type_ : driver_sys:: CUmemLocationType ,
25+ id : std:: os:: raw:: c_int ,
26+ ) -> driver_sys:: CUmemLocation {
27+ let mut location = std:: mem:: MaybeUninit :: < driver_sys:: CUmemLocation > :: zeroed ( ) ;
28+ let location_ptr = location. as_mut_ptr ( ) ;
29+
30+ // Support both older bindgen output (`{ type_, id }`) and the newer
31+ // anonymous-union layout emitted from CUDA 13.2 headers.
32+ unsafe {
33+ ( * location_ptr) . type_ = type_;
34+ std:: ptr:: write (
35+ ( location_ptr. cast :: < u8 > ( ) )
36+ . add ( std:: mem:: size_of :: < driver_sys:: CUmemLocationType > ( ) )
37+ . cast :: < std:: os:: raw:: c_int > ( ) ,
38+ id,
39+ ) ;
40+
41+ location. assume_init ( )
42+ }
43+ }
44+
2245/// A pointer type for heap-allocation in CUDA unified memory.
2346///
2447/// See the [`module-level documentation`](../memory/index.html) for more information on unified
@@ -640,17 +663,13 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
640663 let mem_size = std:: mem:: size_of_val ( slice) ;
641664
642665 unsafe {
643- let id = -1 ; // -1 is CU_DEVICE_CPU
644666 driver_sys:: cuMemPrefetchAsync (
645667 slice. as_ptr ( ) as driver_sys:: CUdeviceptr ,
646668 mem_size,
647669 #[ cfg( cuMemPrefetchAsync_v2) ]
648- driver_sys:: CUmemLocation {
649- type_ : driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE ,
650- id,
651- } ,
670+ cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_HOST , 0 ) ,
652671 #[ cfg( not( cuMemPrefetchAsync_v2) ) ]
653- id ,
672+ - 1 , // -1 is CU_DEVICE_CPU
654673 #[ cfg( cuMemPrefetchAsync_v2) ]
655674 0 , // flags for future use, must be 0 as of CUDA 13.0
656675 stream. as_inner ( ) ,
@@ -691,10 +710,7 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
691710 slice. as_ptr ( ) as driver_sys:: CUdeviceptr ,
692711 mem_size,
693712 #[ cfg( cuMemPrefetchAsync_v2) ]
694- driver_sys:: CUmemLocation {
695- type_ : driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE ,
696- id,
697- } ,
713+ cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE , id) ,
698714 #[ cfg( not( cuMemPrefetchAsync_v2) ) ]
699715 id,
700716 #[ cfg( cuMemPrefetchAsync_v2) ]
@@ -727,18 +743,14 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
727743 } ;
728744
729745 unsafe {
730- let id = 0 ;
731746 driver_sys:: cuMemAdvise (
732747 slice. as_ptr ( ) as driver_sys:: CUdeviceptr ,
733748 mem_size,
734749 advice,
735750 #[ cfg( cuMemAdvise_v2) ]
736- driver_sys:: CUmemLocation {
737- type_ : driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE ,
738- id,
739- } ,
751+ cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_HOST , 0 ) ,
740752 #[ cfg( not( cuMemAdvise_v2) ) ]
741- id ,
753+ 0 ,
742754 )
743755 . to_result ( ) ?;
744756 }
@@ -775,9 +787,11 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
775787 mem_size,
776788 driver_sys:: CUmem_advise :: CU_MEM_ADVISE_SET_PREFERRED_LOCATION ,
777789 #[ cfg( cuMemAdvise_v2) ]
778- driver_sys:: CUmemLocation {
779- type_ : driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE ,
780- id,
790+ match preferred_location {
791+ Some ( _) => {
792+ cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE , id)
793+ }
794+ None => cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_HOST , 0 ) ,
781795 } ,
782796 #[ cfg( not( cuMemAdvise_v2) ) ]
783797 id,
@@ -793,18 +807,14 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
793807 let mem_size = std:: mem:: size_of_val ( slice) ;
794808
795809 unsafe {
796- let id = 0 ;
797810 driver_sys:: cuMemAdvise (
798811 slice. as_ptr ( ) as driver_sys:: CUdeviceptr ,
799812 mem_size,
800813 driver_sys:: CUmem_advise :: CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION ,
801814 #[ cfg( cuMemAdvise_v2) ]
802- driver_sys:: CUmemLocation {
803- type_ : driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_DEVICE ,
804- id,
805- } ,
815+ cu_mem_location ( driver_sys:: CUmemLocationType :: CU_MEM_LOCATION_TYPE_HOST , 0 ) ,
806816 #[ cfg( not( cuMemAdvise_v2) ) ]
807- id ,
817+ 0 ,
808818 )
809819 . to_result ( ) ?;
810820 }
0 commit comments