Skip to content

Commit

Permalink
[DeviceAPI] Support "GetCurrentStream" (#16689)
Browse files Browse the repository at this point in the history
This PR introduces a new function `GetCurrentStream`to device API,
which returns the current stream of the given device.

Meanwhile, this PR updates the "CreateStream" of CUDA to creating
a non-blocking stream, so that the execution on this stream can
overlap with the execution of other streams.

This PR also changes the `GPUCopy` of CUDA device API to always
using `cudaMemcpyAsync`.
  • Loading branch information
MasterJH5574 authored Mar 9, 2024
1 parent ab56026 commit 48992a4
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 7 deletions.
6 changes: 6 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ class TVM_DLL DeviceAPI {
* \param stream The stream to be set.
*/
virtual void SetStream(Device dev, TVMStreamHandle stream) {}
/*!
* \brief Get the current stream
* \param dev The device to get stream.
* \return The current stream of the device.
*/
virtual TVMStreamHandle GetCurrentStream(Device dev);
/*!
* \brief Synchronize 2 streams of execution.
*
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }

void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}

TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; }

void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
}

Expand Down
12 changes: 6 additions & 6 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class CUDADeviceAPI final : public DeviceAPI {
TVMStreamHandle CreateStream(Device dev) {
CUDA_CALL(cudaSetDevice(dev.device_id));
cudaStream_t retval;
CUDA_CALL(cudaStreamCreate(&retval));
CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
return static_cast<TVMStreamHandle>(retval);
}

Expand Down Expand Up @@ -225,6 +225,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return static_cast<TVMStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
}

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand All @@ -243,11 +247,7 @@ class CUDADeviceAPI final : public DeviceAPI {
private:
static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != nullptr) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
}
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
}
};

Expand Down
1 change: 1 addition & 0 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class MetalWorkspace final : public DeviceAPI {
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
TVMStreamHandle GetCurrentStream(Device dev) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
void ReinitializeDefaultStreams();
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ int GetWarpSize(id<MTLDevice> dev) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
}

TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
return MetalThreadEntry::ThreadLocal()->stream[dev.device_id];
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ enum class RPCCode : int {
kDevCreateStream,
kDevFreeStream,
kDevSetStream,
kDevGetCurrentStream,
};

/*!
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return static_cast<TVMStreamHandle>(ROCMThreadEntry::ThreadLocal()->stream);
}

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
7 changes: 6 additions & 1 deletion src/runtime/rpc/rpc_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ class RPCDeviceAPI final : public DeviceAPI {
GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream);
}

void SetStream(Device dev, TVMStreamHandle stream) {
void SetStream(Device dev, TVMStreamHandle stream) final {
auto remote_dev = RemoveRPCSessionMask(dev);
GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
auto remote_dev = RemoveRPCSessionMask(dev);
return GetSess(dev)->GetDeviceAPI(remote_dev)->GetCurrentStream(remote_dev);
}

protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint,
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
handler->GetDeviceAPI(dev)->SetStream(dev, stream);
}

void RPCDevGetCurrentStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
Device dev = args[0];
*rv = handler->GetDeviceAPI(dev)->GetCurrentStream(dev);
}

void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
// Event handler sit at clean state at this point.
switch (code) {
Expand Down Expand Up @@ -1043,6 +1048,9 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
case RPCCode::kDevSetStream:
SysCallHandler(RPCDevSetStream);
break;
case RPCCode::kDevGetCurrentStream:
SysCallHandler(RPCDevGetCurrentStream);
break;
case RPCCode::kCopyAmongRemote:
SysCallHandler(RPCCopyAmongRemote);
break;
Expand Down Expand Up @@ -1188,6 +1196,10 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return endpoint_->SysCallRemote(RPCCode::kDevGetCurrentStream, dev);
}

DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; }

bool IsLocalSession() const final { return false; }
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
ICHECK_EQ(stream, static_cast<void*>(nullptr));
}

TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return nullptr; }

void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
TVMStreamHandle GetCurrentStream(Device dev) final;

protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
Expand Down
2 changes: 2 additions & 0 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class WebGPUDeviceAPI : public DeviceAPI {

void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; }

TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not implemented"; }

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down

0 comments on commit 48992a4

Please sign in to comment.