Skip to content

Commit a8cf98a

Browse files
authored
Split out copyto for texture arrays and add more tests (#2719)
1 parent 0ee73ec commit a8cf98a

2 files changed

Lines changed: 67 additions & 10 deletions

File tree

src/texture.jl

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,29 +104,65 @@ CuTextureArray(A::AbstractArray{T,N}) where {T,N} = CuTextureArray{T,N}(A)
104104

105105
## memory operations
106106

107-
function Base.copyto!(dst::CuTextureArray{T,1}, src::Union{Array{T,1}, CuArray{T,1}}) where {T}
107+
function Base.copyto!(dst::CuTextureArray{T,1}, src::Union{Array{T,1}, CuTextureArray{T,1}}) where {T}
108108
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
109109
Base.unsafe_copyto!(pointer(dst), pointer(src), length(dst))
110110
return dst
111111
end
112112

113-
function Base.copyto!(dst::CuTextureArray{T,2}, src::Union{Array{T,2}, CuArray{T,2}}) where {T}
113+
function Base.copyto!(dst::CuTextureArray{T,1}, src::CuArray{T,1,M}) where {T, M}
114114
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
115-
unsafe_copy2d!(pointer(dst), ArrayMemory,
116-
pointer(src), isa(src, Array) ? HostMemory : DeviceMemory,
117-
size(dst)...)
115+
if M <: Union{HostMemory, DeviceMemory}
116+
Base.unsafe_copyto!(pointer(dst), pointer(src; type=M), length(dst))
117+
else
118+
Base.unsafe_copyto!(pointer(dst), pointer(src), length(dst))
119+
end
118120
return dst
119121
end
120122

121-
function Base.copyto!(dst::CuTextureArray{T,3}, src::Union{Array{T,3}, CuArray{T,3}}) where {T}
123+
function Base.copyto!(dst::CuTextureArray{T,2}, src::Array{T,2}) where {T}
122124
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
123-
unsafe_copy3d!(pointer(dst), ArrayMemory,
124-
pointer(src), isa(src, Array) ? HostMemory : DeviceMemory,
125-
size(dst)...)
125+
unsafe_copy2d!(pointer(dst), ArrayMemory, pointer(src), HostMemory, size(dst)...)
126126
return dst
127127
end
128128

129+
function Base.copyto!(dst::CuTextureArray{T,2}, src::CuArray{T,2,M}) where {T, M}
130+
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
131+
if M <: Union{HostMemory, DeviceMemory}
132+
unsafe_copy2d!(pointer(dst), ArrayMemory, pointer(src; type=M), M, size(dst)...)
133+
else
134+
unsafe_copy2d!(pointer(dst), ArrayMemory, pointer(src), M, size(dst)...)
135+
end
136+
return dst
137+
end
129138

139+
function Base.copyto!(dst::CuTextureArray{T,2}, src::CuTextureArray{T,2}) where {T}
140+
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
141+
unsafe_copy2d!(pointer(dst), ArrayMemory, pointer(src), ArrayMemory, size(dst)...)
142+
return dst
143+
end
144+
145+
function Base.copyto!(dst::CuTextureArray{T,3}, src::Array{T,3}) where {T}
146+
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
147+
unsafe_copy3d!(pointer(dst), ArrayMemory, pointer(src), HostMemory, size(dst)...)
148+
return dst
149+
end
150+
151+
function Base.copyto!(dst::CuTextureArray{T,3}, src::CuArray{T,3,M}) where {T, M}
152+
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
153+
if M <: Union{HostMemory, DeviceMemory}
154+
unsafe_copy3d!(pointer(dst), ArrayMemory, pointer(src; type=M), M, size(dst)...)
155+
else
156+
unsafe_copy3d!(pointer(dst), ArrayMemory, pointer(src), M, size(dst)...)
157+
end
158+
return dst
159+
end
160+
161+
function Base.copyto!(dst::CuTextureArray{T,3}, src::CuTextureArray{T,3}) where {T}
162+
size(dst) == size(src) || throw(DimensionMismatch("source and destination sizes must match"))
163+
unsafe_copy3d!(pointer(dst), ArrayMemory, pointer(src), ArrayMemory, size(dst)...)
164+
return dst
165+
end
130166

131167
#
132168
# Texture objects

test/base/texture.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,29 @@ end
8282
copyto!(texarr1D, a1D)
8383
tex1D = CuTexture(texarr1D)
8484
@test Array(fetch_all(tex1D)) == a1D
85+
@test sizeof(texarr1D) == sizeof(a1D)
86+
@test eltype(texarr1D) == Float32
87+
h_arr_1D = zeros(Float32, length(a1D))
88+
Base.unsafe_copyto!(pointer(h_arr_1D), pointer(texarr1D), length(h_arr_1D))
89+
@test h_arr_1D == a1D
90+
cu_arr_1D = CUDA.zeros(Float32, length(a1D))
91+
Base.unsafe_copyto!(pointer(cu_arr_1D), pointer(texarr1D), length(cu_arr_1D))
92+
@test Array(cu_arr_1D) == a1D
8593

8694
texarr2D = CuTextureArray(a2D)
8795
tex2D = CuTexture(texarr2D)
8896
@test Array(fetch_all(tex2D)) == a2D
8997
texarr2D_2 = CuTextureArray(texarr2D)
9098
tex2D_2 = CuTexture(texarr2D_2)
9199
@test Array(fetch_all(tex2D_2)) == a2D
92-
100+
texarr2D_3 = CuTextureArray{Float32, 2}(texarr2D)
101+
tex2D_2 = CuTexture(texarr2D_3)
102+
@test Array(fetch_all(tex2D_2)) == a2D
103+
copyto!(texarr2D_3, texarr2D_2)
104+
tex2D_2 = CuTexture(texarr2D_3)
105+
@test Array(fetch_all(tex2D_2)) == a2D
106+
@test sizeof(texarr2D) == sizeof(a2D)
107+
@test eltype(texarr2D) == Float32
93108

94109
tex2D_dir = CuTexture(CuTextureArray(a2D))
95110
@test Array(fetch_all(tex2D_dir)) == a2D
@@ -100,6 +115,12 @@ end
100115
texarr3D_2 = CuTextureArray(texarr3D)
101116
tex3D_2 = CuTexture(texarr3D_2)
102117
@test Array(fetch_all(tex3D_2)) == a3D
118+
texarr3D_3 = CuTextureArray{Float32, 3}(texarr3D)
119+
copyto!(texarr2D_3, texarr2D_2)
120+
tex3D_3 = CuTexture(texarr3D_3)
121+
@test Array(fetch_all(tex3D_3)) == a3D
122+
@test sizeof(texarr3D) == sizeof(a3D)
123+
@test eltype(texarr3D) == Float32
103124
end
104125

105126
@testset "CuTexture(::CuArray)" begin

0 commit comments

Comments
 (0)