Skip to content

Commit 4d3fe16

Browse files
committed
nix flake
1 parent 9089bd4 commit 4d3fe16

3 files changed

Lines changed: 138 additions & 25 deletions

File tree

crates/cust/src/memory/unified.rs

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,29 @@ use crate::memory::UnifiedPointer;
1919
use crate::memory::malloc::{cuda_free_unified, cuda_malloc_unified};
2020
use crate::prelude::Stream;
2121

22+
#[cfg(any(cuMemPrefetchAsync_v2, cuMemAdvise_v2))]
23+
unsafe fn cu_mem_location(
24+
type_: driver_sys::CUmemLocationType,
25+
id: std::os::raw::c_int,
26+
) -> driver_sys::CUmemLocation {
27+
let mut location = std::mem::MaybeUninit::<driver_sys::CUmemLocation>::zeroed();
28+
let location_ptr = location.as_mut_ptr();
29+
30+
// Support both older bindgen output (`{ type_, id }`) and the newer
31+
// anonymous-union layout emitted from CUDA 13.2 headers.
32+
unsafe {
33+
(*location_ptr).type_ = type_;
34+
std::ptr::write(
35+
(location_ptr.cast::<u8>())
36+
.add(std::mem::size_of::<driver_sys::CUmemLocationType>())
37+
.cast::<std::os::raw::c_int>(),
38+
id,
39+
);
40+
41+
location.assume_init()
42+
}
43+
}
44+
2245
/// A pointer type for heap-allocation in CUDA unified memory.
2346
///
2447
/// See the [`module-level documentation`](../memory/index.html) for more information on unified
@@ -640,17 +663,13 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
640663
let mem_size = std::mem::size_of_val(slice);
641664

642665
unsafe {
643-
let id = -1; // -1 is CU_DEVICE_CPU
644666
driver_sys::cuMemPrefetchAsync(
645667
slice.as_ptr() as driver_sys::CUdeviceptr,
646668
mem_size,
647669
#[cfg(cuMemPrefetchAsync_v2)]
648-
driver_sys::CUmemLocation {
649-
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
650-
id,
651-
},
670+
cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST, 0),
652671
#[cfg(not(cuMemPrefetchAsync_v2))]
653-
id,
672+
-1, // -1 is CU_DEVICE_CPU
654673
#[cfg(cuMemPrefetchAsync_v2)]
655674
0, // flags for future use, must be 0 as of CUDA 13.0
656675
stream.as_inner(),
@@ -691,10 +710,7 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
691710
slice.as_ptr() as driver_sys::CUdeviceptr,
692711
mem_size,
693712
#[cfg(cuMemPrefetchAsync_v2)]
694-
driver_sys::CUmemLocation {
695-
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
696-
id,
697-
},
713+
cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, id),
698714
#[cfg(not(cuMemPrefetchAsync_v2))]
699715
id,
700716
#[cfg(cuMemPrefetchAsync_v2)]
@@ -727,18 +743,14 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
727743
};
728744

729745
unsafe {
730-
let id = 0;
731746
driver_sys::cuMemAdvise(
732747
slice.as_ptr() as driver_sys::CUdeviceptr,
733748
mem_size,
734749
advice,
735750
#[cfg(cuMemAdvise_v2)]
736-
driver_sys::CUmemLocation {
737-
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
738-
id,
739-
},
751+
cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST, 0),
740752
#[cfg(not(cuMemAdvise_v2))]
741-
id,
753+
0,
742754
)
743755
.to_result()?;
744756
}
@@ -775,9 +787,11 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
775787
mem_size,
776788
driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_PREFERRED_LOCATION,
777789
#[cfg(cuMemAdvise_v2)]
778-
driver_sys::CUmemLocation {
779-
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
780-
id,
790+
match preferred_location {
791+
Some(_) => {
792+
cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, id)
793+
}
794+
None => cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST, 0),
781795
},
782796
#[cfg(not(cuMemAdvise_v2))]
783797
id,
@@ -793,18 +807,14 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
793807
let mem_size = std::mem::size_of_val(slice);
794808

795809
unsafe {
796-
let id = 0;
797810
driver_sys::cuMemAdvise(
798811
slice.as_ptr() as driver_sys::CUdeviceptr,
799812
mem_size,
800813
driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
801814
#[cfg(cuMemAdvise_v2)]
802-
driver_sys::CUmemLocation {
803-
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
804-
id,
805-
},
815+
cu_mem_location(driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST, 0),
806816
#[cfg(not(cuMemAdvise_v2))]
807-
id,
817+
0,
808818
)
809819
.to_result()?;
810820
}

flake.lock

Lines changed: 48 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
inputs = {
3+
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
4+
rust-overlay.url = "github:oxalica/rust-overlay";
5+
rust-overlay.inputs.nixpkgs.follows = "nixpkgs";
6+
};
7+
8+
outputs = { nixpkgs, rust-overlay, ... }:
9+
let
10+
system = "x86_64-linux";
11+
pkgs = import nixpkgs {
12+
inherit system;
13+
overlays = [ rust-overlay.overlays.default ];
14+
};
15+
lib = pkgs.lib;
16+
cudaRoot = "/usr/local/cuda-13.2";
17+
llvm19 = pkgs.llvmPackages_19;
18+
toolchain = pkgs.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml;
19+
in
20+
{
21+
devShells.${system}.default = pkgs.mkShell {
22+
packages = [
23+
toolchain
24+
pkgs.pkg-config
25+
pkgs.openssl
26+
pkgs.cmake
27+
pkgs.ninja
28+
pkgs.ncurses
29+
llvm19.clang
30+
llvm19.libclang
31+
(lib.getDev llvm19.llvm)
32+
pkgs.stdenv.cc.cc.lib
33+
];
34+
35+
CUDA_HOME = cudaRoot;
36+
CUDA_ROOT = cudaRoot;
37+
CUDA_PATH = cudaRoot;
38+
CUDA_TOOLKIT_ROOT_DIR = cudaRoot;
39+
CUDA_LIBRARY_PATH =
40+
"${cudaRoot}/targets/x86_64-linux/lib:${cudaRoot}/lib64:${cudaRoot}/lib64/stubs";
41+
42+
LLVM_CONFIG_19 = "${lib.getDev llvm19.llvm}/bin/llvm-config";
43+
LIBCLANG_PATH = "${lib.getLib llvm19.libclang}/lib";
44+
45+
shellHook = ''
46+
export PATH="${cudaRoot}/bin:${cudaRoot}/nvvm/bin:$PATH"
47+
export LD_LIBRARY_PATH="${cudaRoot}/nvvm/lib64:${cudaRoot}/lib64:${pkgs.ncurses.out}/lib:${pkgs.stdenv.cc.cc.lib}/lib''${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
48+
49+
echo "rust-cuda llvm19 shell"
50+
echo " CUDA_HOME=$CUDA_HOME"
51+
echo " LLVM_CONFIG_19=$LLVM_CONFIG_19"
52+
'';
53+
};
54+
};
55+
}

0 commit comments

Comments
 (0)