Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions platforms/cuda/hal/include/xsched/cuda/hal/common/cuda_command.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ class CudaKernelLaunchExCommand : public CudaKernelCommand
virtual CUresult Launch(CUstream stream) override;
};

class CudaRuntimeLaunchCommand : public CudaCommand
{
public:
CudaRuntimeLaunchCommand() : CudaCommand(preempt::kCommandPropertyDeactivatable) {}
virtual ~CudaRuntimeLaunchCommand() = default;

private:
virtual CUresult Launch(CUstream stream) override { (void)stream; return CUDA_SUCCESS; }
};

// host function
class CudaHostFuncCommand : public CudaCommand
{
Expand Down
12 changes: 10 additions & 2 deletions platforms/cuda/hal/src/arch/arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@ std::shared_ptr<HwQueue> xsched::cuda::CudaQueueCreate(CUstream stream)
CUdevice dev;
CUcontext stream_ctx;
CUcontext current_ctx;
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
CUDA_ASSERT(Driver::CtxGetCurrent(&current_ctx));
if (stream == nullptr) {
stream_ctx = current_ctx;
} else {
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
}
XASSERT(current_ctx == stream_ctx,
"create CudaQueue failed: current context (%p) does not match stream context (%p)",
current_ctx, stream_ctx);
Expand Down Expand Up @@ -95,8 +99,12 @@ CUresult xsched::cuda::DirectLaunch(std::shared_ptr<CudaKernelCommand> kernel, C
CUdevice dev;
CUcontext stream_ctx;
CUcontext current_ctx;
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
CUDA_ASSERT(Driver::CtxGetCurrent(&current_ctx));
if (stream == nullptr) {
stream_ctx = current_ctx;
} else {
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
}
XASSERT(current_ctx == stream_ctx,
"direct launch kernel failed: current context (%p) does not match stream context (%p)",
current_ctx, stream_ctx);
Expand Down
4 changes: 4 additions & 0 deletions platforms/cuda/hal/src/common/cuda_command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "xsched/cuda/hal/common/cuda_assert.h"
#include "xsched/cuda/hal/common/cuda_command.h"

#include <dlfcn.h>

using namespace xsched::cuda;

CudaCommand::CudaCommand(preempt::XCommandProperties props): HwCommand(props)
Expand Down Expand Up @@ -133,3 +135,5 @@ CUresult CudaEventWaitCommand::Launch(CUstream stream)
if (!event_) return CUDA_SUCCESS; // already waited in BeforeLaunch()
return Driver::StreamWaitEvent(stream, event_, flags_);
}

struct dim3 { unsigned int x, y, z; };
16 changes: 10 additions & 6 deletions platforms/cuda/hal/src/level1/cuda_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ CudaQueueLv1::CudaQueueLv1(CUstream stream): kStream(stream)
CUcontext stream_context = nullptr;
CUcontext current_context = nullptr;
CUDA_ASSERT(Driver::CtxGetCurrent(&current_context));
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_context));
if (stream == nullptr) {
stream_context = current_context;
} else {
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_context));
}
XASSERT(current_context == stream_context,
"create CudaQueueLv1 failed: current context (%p) does not match stream context (%p)",
current_context, stream_context);
Expand All @@ -31,7 +35,11 @@ CudaQueueLv1::CudaQueueLv1(CUstream stream): kStream(stream)
xdevice_ = MakeDevice(kDeviceTypeGPU, XDeviceId(MakePciId(dom, bus, dev, 0)));

// get stream flags
CUDA_ASSERT(Driver::StreamGetFlags(stream, &stream_flags_));
if (stream == nullptr) {
stream_flags_ = 0; // Default stream
} else {
CUDA_ASSERT(Driver::StreamGetFlags(stream, &stream_flags_));
}

// make sure no commands are running on stream_
CUDA_ASSERT(Driver::StreamSynchronize(kStream));
Expand Down Expand Up @@ -65,10 +73,6 @@ EXPORT_C_FUNC XResult CudaQueueCreate(HwQueueHandle *hwq, CUstream stream)
XWARN("CudaQueueCreate failed: hwq is nullptr");
return kXSchedErrorInvalidValue;
}
if (stream == nullptr) {
XWARN("CudaQueueCreate failed: does not support default stream");
return kXSchedErrorNotSupported;
}

HwQueueHandle hwq_h = GetHwQueueHandle(stream);
auto res = HwQueueManager::Add(hwq_h, [&]() { return xsched::cuda::CudaQueueCreate(stream); });
Expand Down
31 changes: 21 additions & 10 deletions platforms/cuda/shim/include/xsched/cuda/shim/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "xsched/preempt/hal/hw_queue.h"
#include "xsched/preempt/xqueue/xqueue.h"
#include "xsched/cuda/hal.h"
#include "xsched/cuda/hal/common/cuda.h"
#include "xsched/cuda/hal/common/driver.h"
#include "xsched/cuda/hal/common/handle.h"
Expand All @@ -13,27 +14,37 @@ namespace xsched::cuda
#define CUDA_SHIM_FUNC(name, cmd, ...) \
inline CUresult X##name(FOR_EACH_PAIR_COMMA(DECLARE_PARAM, __VA_ARGS__), CUstream stream) \
{ \
if (stream == 0) { \
WaitBlockingXQueues(); \
return Driver::name(FOR_EACH_PAIR_COMMA(DECLARE_ARG, __VA_ARGS__), stream); \
} \
auto xq = xsched::preempt::HwQueueManager::GetXQueue(GetHwQueueHandle(stream)); \
if (xq == nullptr) return Driver::name(FOR_EACH_PAIR_COMMA(DECLARE_ARG, __VA_ARGS__), stream); \
auto hw_cmd = std::make_shared<cmd>(FOR_EACH_PAIR_COMMA(DECLARE_ARG, __VA_ARGS__)); \
xq->Submit(hw_cmd); \
return CUDA_SUCCESS; \
if (xq == nullptr) { \
xsched::preempt::XQueueManager::AutoCreate([&](HwQueueHandle *hwq) -> XResult { \
return CudaQueueCreate(hwq, stream); \
}); \
xq = xsched::preempt::HwQueueManager::GetXQueue(GetHwQueueHandle(stream)); \
} \
if (xq != nullptr) { \
/* Use a dummy token for accounting, avoid double-execution of the real memory command */ \
auto token = std::make_shared<CudaRuntimeLaunchCommand>(); \
xq->Submit(token); \
} \
return Driver::name(FOR_EACH_PAIR_COMMA(DECLARE_ARG, __VA_ARGS__), stream); \
}

void WaitBlockingXQueues();

////////////////////////////// kernel related //////////////////////////////
CUresult XLaunchKernel(CUfunction f, unsigned int gdx, unsigned int gdy, unsigned int gdz, unsigned int bdx, unsigned int bdy, unsigned int bdz, unsigned int shmem, CUstream stream, void **params, void **extra);
CUresult XLaunchKernelRuntime(const void *func, unsigned int gdx, unsigned int gdy, unsigned int gdz, unsigned int bdx, unsigned int bdy, unsigned int bdz, size_t shmem, void **args, CUstream stream);
CUresult XLaunchKernelEx(const CUlaunchConfig *config, CUfunction f, void **params, void **extra);
CUresult XLaunchHostFunc(CUstream stream, CUhostFn fn, void *data);

////////////////////////////// memory related //////////////////////////////
CUDA_SHIM_FUNC(MemcpyHtoDAsync_v2, CudaMemcpyHtoDV2Command, CUdeviceptr, dstDevice, const void *, srcHost, size_t, ByteCount);
CUDA_SHIM_FUNC(MemcpyDtoHAsync_v2, CudaMemcpyDtoHV2Command, void *, dstHost, CUdeviceptr, srcDevice, size_t, ByteCount);
CUresult XMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
CUresult XMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
CUresult XMemcpyHtoD(CUdeviceptr_v1 dstDevice, const void *srcHost, unsigned int ByteCount);
CUresult XMemcpyDtoH(void *dstHost, CUdeviceptr_v1 srcDevice, unsigned int ByteCount);
CUresult XMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
CUresult XMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);

CUDA_SHIM_FUNC(MemcpyDtoDAsync_v2, CudaMemcpyDtoDV2Command, CUdeviceptr, dstDevice, CUdeviceptr, srcDevice, size_t, ByteCount);
CUDA_SHIM_FUNC(Memcpy2DAsync_v2, CudaMemcpy2DV2Command, const CUDA_MEMCPY2D *, pCopy);
CUDA_SHIM_FUNC(Memcpy3DAsync_v2, CudaMemcpy3DV2Command, const CUDA_MEMCPY3D *, pCopy);
Expand Down
Loading