|
| 1 | +use crate::gpu_only; |
| 2 | +/// Asynchronously copies one `T` from global to shared memory, caching the source in L1/L2. |
| 3 | +/// |
| 4 | +/// The copy is issued but not guaranteed to be complete when this function returns. |
| 5 | +/// Use [`cp_async_commit_group`] and [`cp_async_wait_group`] to synchronize. |
| 6 | +/// |
| 7 | +/// # Size |
| 8 | +/// |
| 9 | +/// `T` must be exactly 4, 8, or 16 bytes. Any other size is a compile-time error. |
| 10 | +/// |
| 11 | +/// # Safety |
| 12 | +/// |
| 13 | +/// - `dst` must point into shared memory. |
| 14 | +/// - `src` must point into global memory. |
| 15 | +/// - Both pointers must be aligned to `size_of::<T>()`. |
| 16 | +/// - The pointed-to memory must be valid for the duration of the copy (until the |
| 17 | +/// corresponding [`cp_async_wait_group`] or [`cp_async_wait_all`] returns). |
| 18 | +#[gpu_only] |
| 19 | +pub unsafe fn cp_async_ca<T>(dst: *mut T, src: *const T) { |
| 20 | + const { |
| 21 | + let size = core::mem::size_of::<T>(); |
| 22 | + assert!( |
| 23 | + size == 4 || size == 8 || size == 16, |
| 24 | + "cp_async requires T to be exactly 4, 8, or 16 bytes" |
| 25 | + ); |
| 26 | + } |
| 27 | + // cp.async requires dst to be a 32-bit shared memory address. |
| 28 | + // Generic pointers must be explicitly converted: cvta.to.shared (64-bit) |
| 29 | + // then truncated to 32-bit, since shared memory is 32-bit addressable. |
| 30 | + unsafe { |
| 31 | + match core::mem::size_of::<T>() { |
| 32 | + 4 => core::arch::asm!( |
| 33 | + "cvta.to.shared.u64 {tmp}, {dst};", |
| 34 | + "cvt.u32.u64 {smem}, {tmp};", |
| 35 | + "cp.async.ca.shared.global [{smem}], [{src}], 4;", |
| 36 | + dst = in(reg64) dst, |
| 37 | + src = in(reg64) src, |
| 38 | + tmp = out(reg64) _, |
| 39 | + smem = out(reg32) _, |
| 40 | + ), |
| 41 | + 8 => core::arch::asm!( |
| 42 | + "cvta.to.shared.u64 {tmp}, {dst};", |
| 43 | + "cvt.u32.u64 {smem}, {tmp};", |
| 44 | + "cp.async.ca.shared.global [{smem}], [{src}], 8;", |
| 45 | + dst = in(reg64) dst, |
| 46 | + src = in(reg64) src, |
| 47 | + tmp = out(reg64) _, |
| 48 | + smem = out(reg32) _, |
| 49 | + ), |
| 50 | + 16 => core::arch::asm!( |
| 51 | + "cvta.to.shared.u64 {tmp}, {dst};", |
| 52 | + "cvt.u32.u64 {smem}, {tmp};", |
| 53 | + "cp.async.ca.shared.global [{smem}], [{src}], 16;", |
| 54 | + dst = in(reg64) dst, |
| 55 | + src = in(reg64) src, |
| 56 | + tmp = out(reg64) _, |
| 57 | + smem = out(reg32) _, |
| 58 | + ), |
| 59 | + _ => unreachable!(), |
| 60 | + } |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +/// Asynchronously copies one `T` from global to shared memory, caching only in L2 |
| 65 | +/// (bypasses L1). Only valid for 16-byte types. |
| 66 | +/// |
| 67 | +/// Prefer this over [`cp_async_ca`] for streaming access patterns where the data |
| 68 | +/// will not be reused, to avoid polluting L1. |
| 69 | +/// |
| 70 | +/// # Safety |
| 71 | +/// |
| 72 | +/// - `dst` must point into shared memory. |
| 73 | +/// - `src` must point into global memory. |
| 74 | +/// - Both pointers must be 16-byte aligned. |
| 75 | +/// - `T` must be exactly 16 bytes — enforced at compile time. |
| 76 | +#[gpu_only] |
| 77 | +pub unsafe fn cp_async_cg<T>(dst: *mut T, src: *const T) { |
| 78 | + const { |
| 79 | + assert!( |
| 80 | + core::mem::size_of::<T>() == 16, |
| 81 | + "cp_async_cg requires T to be exactly 16 bytes (.cg cache operator only supports 16-byte copies)" |
| 82 | + ); |
| 83 | + } |
| 84 | + unsafe { |
| 85 | + core::arch::asm!( |
| 86 | + "cvta.to.shared.u64 {tmp}, {dst};", |
| 87 | + "cvt.u32.u64 {smem}, {tmp};", |
| 88 | + "cp.async.cg.shared.global [{smem}], [{src}], 16;", |
| 89 | + dst = in(reg64) dst, |
| 90 | + src = in(reg64) src, |
| 91 | + tmp = out(reg64) _, |
| 92 | + smem = out(reg32) _, |
| 93 | + ) |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +/// Seals all `cp.async` operations issued since the last `cp_async_commit_group` (or |
| 98 | +/// program start) into a named group. Groups are completed in FIFO order. |
| 99 | +/// |
| 100 | +/// Must be called before [`cp_async_wait_group`] to define group boundaries. |
| 101 | +#[gpu_only] |
| 102 | +pub fn cp_async_commit_group() { |
| 103 | + unsafe { core::arch::asm!("cp.async.commit_group;") } |
| 104 | +} |
| 105 | + |
| 106 | +/// Waits until there are at most `N` committed `cp.async` groups still in flight. |
| 107 | +/// |
| 108 | +/// - `N = 0`: waits for all groups — equivalent to [`cp_async_wait_all`]. |
| 109 | +/// - `N = 1`: waits for all but the most recently committed group, allowing one |
| 110 | +/// prefetch to remain in flight while computing. |
| 111 | +/// |
| 112 | +/// Must be paired with [`cp_async_commit_group`] to define which copies belong to |
| 113 | +/// each group. |
| 114 | +#[gpu_only] |
| 115 | +pub fn cp_async_wait_group<const N: u32>() { |
| 116 | + unsafe { core::arch::asm!("cp.async.wait_group {0};", const N) } |
| 117 | +} |
| 118 | + |
| 119 | +/// Waits for all outstanding `cp.async` copies to complete. |
| 120 | +/// |
| 121 | +/// Equivalent to `cp_async_wait_group::<0>()`. |
| 122 | +#[gpu_only] |
| 123 | +pub fn cp_async_wait_all() { |
| 124 | + unsafe { core::arch::asm!("cp.async.wait_all;") } |
| 125 | +} |
0 commit comments