Skip to content

Commit d36525b

Browse files
committed
add async copy opcodes
1 parent 946c91f commit d36525b

2 files changed

Lines changed: 126 additions & 0 deletions

File tree

crates/cuda_std/src/async_copy.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use crate::gpu_only;
2+
/// Asynchronously copies one `T` from global to shared memory, caching the source in L1/L2.
3+
///
4+
/// The copy is issued but not guaranteed to be complete when this function returns.
5+
/// Use [`cp_async_commit_group`] and [`cp_async_wait_group`] to synchronize.
6+
///
7+
/// # Size
8+
///
9+
/// `T` must be exactly 4, 8, or 16 bytes. Any other size is a compile-time error.
10+
///
11+
/// # Safety
12+
///
13+
/// - `dst` must point into shared memory.
14+
/// - `src` must point into global memory.
15+
/// - Both pointers must be aligned to `size_of::<T>()`.
16+
/// - The pointed-to memory must be valid for the duration of the copy (until the
17+
/// corresponding [`cp_async_wait_group`] or [`cp_async_wait_all`] returns).
18+
#[gpu_only]
19+
pub unsafe fn cp_async_ca<T>(dst: *mut T, src: *const T) {
20+
const {
21+
let size = core::mem::size_of::<T>();
22+
assert!(
23+
size == 4 || size == 8 || size == 16,
24+
"cp_async requires T to be exactly 4, 8, or 16 bytes"
25+
);
26+
}
27+
// cp.async requires dst to be a 32-bit shared memory address.
28+
// Generic pointers must be explicitly converted: cvta.to.shared (64-bit)
29+
// then truncated to 32-bit, since shared memory is 32-bit addressable.
30+
unsafe {
31+
match core::mem::size_of::<T>() {
32+
4 => core::arch::asm!(
33+
"cvta.to.shared.u64 {tmp}, {dst};",
34+
"cvt.u32.u64 {smem}, {tmp};",
35+
"cp.async.ca.shared.global [{smem}], [{src}], 4;",
36+
dst = in(reg64) dst,
37+
src = in(reg64) src,
38+
tmp = out(reg64) _,
39+
smem = out(reg32) _,
40+
),
41+
8 => core::arch::asm!(
42+
"cvta.to.shared.u64 {tmp}, {dst};",
43+
"cvt.u32.u64 {smem}, {tmp};",
44+
"cp.async.ca.shared.global [{smem}], [{src}], 8;",
45+
dst = in(reg64) dst,
46+
src = in(reg64) src,
47+
tmp = out(reg64) _,
48+
smem = out(reg32) _,
49+
),
50+
16 => core::arch::asm!(
51+
"cvta.to.shared.u64 {tmp}, {dst};",
52+
"cvt.u32.u64 {smem}, {tmp};",
53+
"cp.async.ca.shared.global [{smem}], [{src}], 16;",
54+
dst = in(reg64) dst,
55+
src = in(reg64) src,
56+
tmp = out(reg64) _,
57+
smem = out(reg32) _,
58+
),
59+
_ => unreachable!(),
60+
}
61+
}
62+
}
63+
64+
/// Asynchronously copies one `T` from global to shared memory, caching only in L2
65+
/// (bypasses L1). Only valid for 16-byte types.
66+
///
67+
/// Prefer this over [`cp_async_ca`] for streaming access patterns where the data
68+
/// will not be reused, to avoid polluting L1.
69+
///
70+
/// # Safety
71+
///
72+
/// - `dst` must point into shared memory.
73+
/// - `src` must point into global memory.
74+
/// - Both pointers must be 16-byte aligned.
75+
/// - `T` must be exactly 16 bytes — enforced at compile time.
76+
#[gpu_only]
77+
pub unsafe fn cp_async_cg<T>(dst: *mut T, src: *const T) {
78+
const {
79+
assert!(
80+
core::mem::size_of::<T>() == 16,
81+
"cp_async_cg requires T to be exactly 16 bytes (.cg cache operator only supports 16-byte copies)"
82+
);
83+
}
84+
unsafe {
85+
core::arch::asm!(
86+
"cvta.to.shared.u64 {tmp}, {dst};",
87+
"cvt.u32.u64 {smem}, {tmp};",
88+
"cp.async.cg.shared.global [{smem}], [{src}], 16;",
89+
dst = in(reg64) dst,
90+
src = in(reg64) src,
91+
tmp = out(reg64) _,
92+
smem = out(reg32) _,
93+
)
94+
}
95+
}
96+
97+
/// Seals all `cp.async` operations issued since the last `cp_async_commit_group` (or
98+
/// program start) into a named group. Groups are completed in FIFO order.
99+
///
100+
/// Must be called before [`cp_async_wait_group`] to define group boundaries.
101+
#[gpu_only]
102+
pub fn cp_async_commit_group() {
103+
unsafe { core::arch::asm!("cp.async.commit_group;") }
104+
}
105+
106+
/// Waits until there are at most `N` committed `cp.async` groups still in flight.
107+
///
108+
/// - `N = 0`: waits for all groups — equivalent to [`cp_async_wait_all`].
109+
/// - `N = 1`: waits for all but the most recently committed group, allowing one
110+
/// prefetch to remain in flight while computing.
111+
///
112+
/// Must be paired with [`cp_async_commit_group`] to define which copies belong to
113+
/// each group.
114+
#[gpu_only]
115+
pub fn cp_async_wait_group<const N: u32>() {
116+
unsafe { core::arch::asm!("cp.async.wait_group {0};", const N) }
117+
}
118+
119+
/// Waits for all outstanding `cp.async` copies to complete.
120+
///
121+
/// Equivalent to `cp_async_wait_group::<0>()`.
122+
#[gpu_only]
123+
pub fn cp_async_wait_all() {
124+
unsafe { core::arch::asm!("cp.async.wait_all;") }
125+
}

crates/cuda_std/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub mod mem;
4242
pub mod misc;
4343
// WIP
4444
// pub mod rt;
45+
pub mod async_copy;
4546
pub mod atomic;
4647
pub mod ptr;
4748
pub mod shared;

0 commit comments

Comments
 (0)