Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [dev] - XXX. XX, XXXX

### Added
* `dpctl.SyclQueue.copy` and `dpctl.SyclQueue.copy_async` methods [gh-2273](https://github.com/IntelPython/dpctl/pull/2273)

### Change

Expand Down
12 changes: 12 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,18 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
size_t Count,
const DPCTLSyclEventRef *depEvents,
size_t depEventsCount)
cdef DPCTLSyclEventRef DPCTLQueue_CopyData(
const DPCTLSyclQueueRef Q,
void *Dest,
const void *Src,
size_t Count)
cdef DPCTLSyclEventRef DPCTLQueue_CopyDataWithEvents(
const DPCTLSyclQueueRef Q,
void *Dest,
const void *Src,
size_t Count,
const DPCTLSyclEventRef *depEvents,
size_t depEventsCount)
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
const DPCTLSyclQueueRef Q,
void *Dest,
Expand Down
4 changes: 4 additions & 0 deletions dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ cdef public api class SyclQueue (_SyclQueue) [
cdef DPCTLSyclQueueRef get_queue_ref(self)
cpdef memcpy(self, dest, src, size_t count)
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
cpdef copy(self, dest, src, size_t count, str dtype=*)
cpdef SyclEvent copy_async(
self, dest, src, size_t count, list dEvents=*, str dtype=*
)
cpdef prefetch(self, ptr, size_t count=*)
cpdef mem_advise(self, ptr, size_t count, int mem)
cpdef SyclEvent submit_barrier(self, dependent_events=*)
Expand Down
177 changes: 173 additions & 4 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ from ._backend cimport ( # noqa: E211
DPCTLFilterSelector_Create,
DPCTLQueue_AreEq,
DPCTLQueue_Copy,
DPCTLQueue_CopyData,
DPCTLQueue_CopyDataWithEvents,
DPCTLQueue_Create,
DPCTLQueue_Delete,
DPCTLQueue_GetBackend,
Expand Down Expand Up @@ -459,13 +461,46 @@ cdef bint _is_buffer(object o):
return PyObject_CheckBuffer(o)


cdef DPCTLSyclEventRef _memcpy_impl(
# Function pointer typedefs for the C API queue copy functions
ctypedef DPCTLSyclEventRef (*queue_copy_fn)(
const DPCTLSyclQueueRef, void*, const void*, size_t
)

ctypedef DPCTLSyclEventRef (*queue_copy_with_events_fn)(
const DPCTLSyclQueueRef, void*, const void*, size_t,
const DPCTLSyclEventRef*, size_t
)


cdef size_t _get_dtype_size(str dtype) except *:
"""
Parse numpy-style dtype string and return element size in bytes.
Supports: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8
"""
if dtype == "i1" or dtype == "u1":
return 1
elif dtype == "i2" or dtype == "u2":
return 2
elif dtype == "i4" or dtype == "u4" or dtype == "f4":
return 4
elif dtype == "i8" or dtype == "u8" or dtype == "f8":
return 8
else:
raise ValueError(
f"Unrecognized dtype '{dtype}'. "
"Expected one of: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8"
)


cdef DPCTLSyclEventRef _copy_memcpy_impl(
SyclQueue q,
object dst,
object src,
size_t byte_count,
DPCTLSyclEventRef *dep_events,
size_t dep_events_count
size_t dep_events_count,
queue_copy_fn copy_fn,
queue_copy_with_events_fn copy_with_events_fn
) except *:
cdef void *c_dst_ptr = NULL
cdef void *c_src_ptr = NULL
Expand Down Expand Up @@ -512,9 +547,9 @@ cdef DPCTLSyclEventRef _memcpy_impl(
)

if dep_events_count == 0 or dep_events is NULL:
ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
ERef = copy_fn(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
else:
ERef = DPCTLQueue_MemcpyWithEvents(
ERef = copy_with_events_fn(
q._queue_ref,
c_dst_ptr,
c_src_ptr,
Expand All @@ -531,6 +566,45 @@ cdef DPCTLSyclEventRef _memcpy_impl(
return ERef


cdef DPCTLSyclEventRef _memcpy_impl(
SyclQueue q,
object dst,
object src,
size_t byte_count,
DPCTLSyclEventRef *dep_events,
size_t dep_events_count
) except *:
return _copy_memcpy_impl(
q, dst, src, byte_count, dep_events, dep_events_count,
DPCTLQueue_Memcpy, DPCTLQueue_MemcpyWithEvents
)


cdef DPCTLSyclEventRef _copy_impl(
SyclQueue q,
object dst,
object src,
size_t byte_count,
DPCTLSyclEventRef *dep_events,
size_t dep_events_count,
str dtype=None
) except *:
cdef size_t element_size = 0

if dtype is not None:
element_size = _get_dtype_size(dtype)
if byte_count % element_size != 0:
raise ValueError(
f"byte_count ({byte_count}) must be a multiple of dtype "
f"element size ({element_size} bytes for '{dtype}')"
)

return _copy_memcpy_impl(
q, dst, src, byte_count, dep_events, dep_events_count,
DPCTLQueue_CopyData, DPCTLQueue_CopyDataWithEvents
)


cdef class _SyclQueue:
""" Barebone data owner class used by SyclQueue.
"""
Expand Down Expand Up @@ -1421,6 +1495,101 @@ cdef class SyclQueue(_SyclQueue):

return SyclEvent._create(ERef)

cpdef copy(self, dest, src, size_t count, str dtype=None):
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.

Internally, this dispatches ``sycl::queue::copy`` instantiated for
byte-sized elements (or typed elements if dtype is specified).

This is a synchronizing variant corresponding to
:meth:`dpctl.SyclQueue.copy_async`.

Args:
dest:
Destination USM object or Python object supporting
writable buffer protocol.
src:
Source USM object or Python object supporting buffer
protocol.
count (int):
Number of bytes to copy.
dtype (str, optional):
Data type string (e.g., 'i4', 'f8') for typed copy
validation. If provided, validates that count is a
multiple of the element size.
"""
cdef DPCTLSyclEventRef ERef = NULL

ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0, dtype)
if (ERef is NULL):
raise RuntimeError(
"SyclQueue.copy operation encountered an error"
)
with nogil:
DPCTLEvent_Wait(ERef)
DPCTLEvent_Delete(ERef)

cpdef SyclEvent copy_async(
self, dest, src, size_t count, list dEvents=None, str dtype=None
):
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.

Internally, this dispatches ``sycl::queue::copy`` instantiated for
byte-sized elements (or typed elements if dtype is specified).

Args:
dest:
Destination USM object or Python object supporting
writable buffer protocol.
src:
Source USM object or Python object supporting buffer
protocol.
count (int):
Number of bytes to copy.
dEvents (List[dpctl.SyclEvent], optional):
Events that this copy depends on.
dtype (str, optional):
Data type string (e.g., 'i4', 'f8') for typed copy
validation. If provided, validates that count is a
multiple of the element size.
Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.

Returns:
dpctl.SyclEvent:
Event associated with the copy operation.
"""
cdef DPCTLSyclEventRef ERef = NULL
cdef DPCTLSyclEventRef *depEvents = NULL
cdef size_t nDE = 0

if dEvents is None:
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0, dtype)
else:
nDE = len(dEvents)
depEvents = (
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
)
if depEvents is NULL:
raise MemoryError()
else:
for idx, de in enumerate(dEvents):
if isinstance(de, SyclEvent):
depEvents[idx] = (<SyclEvent>de).get_event_ref()
else:
free(depEvents)
raise TypeError(
"A sequence of dpctl.SyclEvent is expected"
)
ERef = _copy_impl(self, dest, src, count, depEvents, nDE, dtype)
free(depEvents)
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated

if (ERef is NULL):
raise RuntimeError(
"SyclQueue.copy operation encountered an error"
)

return SyclEvent._create(ERef)

cpdef prefetch(self, mem, size_t count=0):
cdef void *ptr
cdef DPCTLSyclEventRef ERef = NULL
Expand Down
Loading
Loading