Skip to content

Commit 07f67d7

Browse files
vchuravymaleadt
andauthored
Support disabling implicit synchronization (#2662)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 430b7d6 commit 07f67d7

3 files changed

Lines changed: 40 additions & 4 deletions

File tree

src/array.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,35 @@ function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArr
472472
end
473473

474474

475-
## memory copying
475+
## synchronization
476476

477477
synchronize(x::CuArray) = synchronize(x.data[])
478478

479+
"""
480+
enable_synchronization!(arr::CuArray, enable::Bool)
481+
482+
By default `CuArray`s are implicitly synchronized when they are accessed on different CUDA
483+
devices or streams. This may be unwanted when e.g. using disjoint slices of memory across
484+
different tasks. This function allows to enable or disable this behavior.
485+
486+
!!! warning
487+
488+
Disabling implicit synchronization affects _all_ `CuArray`s that are referring to the
489+
same underlying memory. Unsafe use of this API _will_ result in data corruption.
490+
491+
This API is only provided as an escape hatch, and should not be used without careful
492+
consideration. If automatic synchronization is generally problematic for your use case,
493+
it is recommended to figure out a better model instead and file an issue or pull request.
494+
For more details see [this discussion](https://github.com/JuliaGPU/CUDA.jl/issues/2617).
495+
"""
496+
function enable_synchronization!(arr::CuArray, enable::Bool=true)
497+
arr.data[].synchronizing = enable
498+
return arr
499+
end
500+
501+
502+
## memory copying
503+
479504
if VERSION >= v"1.11.0-DEV.753"
480505
function typetagdata(a::Array, i=1)
481506
ptr_or_offset = Int(a.ref.ptr_or_offset)

src/memory.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,16 +503,20 @@ mutable struct Managed{M}
503503
# which stream is currently using the memory.
504504
stream::CuStream
505505

506+
# whether accessing this memory can cause implicit synchronization
507+
synchronizing::Bool
508+
506509
# whether there are outstanding operations that haven't been synchronized
507510
dirty::Bool
508511

509512
# whether the memory has been captured in a way that would make the dirty bit unreliable
510513
captured::Bool
511514

512-
function Managed(mem::AbstractMemory; stream=CUDA.stream(), dirty=true, captured=false)
515+
function Managed(mem::AbstractMemory; stream = CUDA.stream(), synchronizing = true,
516+
dirty = true, captured = false)
513517
# NOTE: memory starts as dirty, because stream-ordered allocations are only
514518
# guaranteed to be physically allocated at a synchronization event.
515-
new{typeof(mem)}(mem, stream, dirty, captured)
519+
new{typeof(mem)}(mem, stream, synchronizing, dirty, captured)
516520
end
517521
end
518522

@@ -524,7 +528,7 @@ function synchronize(managed::Managed)
524528
managed.dirty = false
525529
end
526530
function maybe_synchronize(managed::Managed)
527-
if managed.dirty || managed.captured
531+
if managed.synchronizing && (managed.dirty || managed.captured)
528532
synchronize(managed)
529533
end
530534
end

test/base/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ using ChainRulesCore: add!!, is_inplaceable_destination
5151
end
5252
end
5353

54+
@testset "synchronization" begin
55+
a = CUDA.zeros(2, 2)
56+
synchronize(a)
57+
CUDA.enable_synchronization!(a, false)
58+
CUDA.enable_synchronization!(a)
59+
end
60+
5461
@testset "unsafe_wrap" begin
5562
# managed memory -> CuArray
5663
for a in [cu([1]; device=true), cu([1]; unified=true)]

0 commit comments

Comments
 (0)