From c5de356ada7dd3ec9e70dc22d247f8b347714085 Mon Sep 17 00:00:00 2001 From: Daniel Kales <11509575+dkales@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:58:26 +0000 Subject: [PATCH] [breaking] feat: allow NCCL comm type to operate on views as well (#281) * feat: allow NCCL comm type to operate on views as well * doc: address clippy warnings for missing indents --- src/cublaslt/safe.rs | 2 +- src/driver/mod.rs | 4 +- src/driver/result.rs | 4 +- src/driver/safe/core.rs | 4 +- src/driver/safe/launch.rs | 6 +-- src/nccl/safe.rs | 109 ++++++++++++++++++++++++-------------- 6 files changed, 80 insertions(+), 49 deletions(-) diff --git a/src/cublaslt/safe.rs b/src/cublaslt/safe.rs index d1de3ca6..df0b184f 100644 --- a/src/cublaslt/safe.rs +++ b/src/cublaslt/safe.rs @@ -12,7 +12,7 @@ use std::sync::Arc; /// /// 1. Create with [CudaBlasLT::new()] /// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported -/// if feature `half` is activated +/// if feature `half` is activated /// /// Note: This maintains a instance of [`Arc`], so will prevent the device /// from being dropped. Kernels will be launched on the device device default stream. diff --git a/src/driver/mod.rs b/src/driver/mod.rs index a02ba4c9..8bf6847a 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -11,7 +11,7 @@ //! ``` //! //! 2. Allocate device memory with host data with [CudaDevice::htod_copy()], [CudaDevice::alloc_zeros()], -//! or [CudaDevice::htod_sync_copy()]. +//! or [CudaDevice::htod_sync_copy()]. //! //! You can also copy data to CudaSlice using [CudaDevice::htod_sync_copy_into()] //! @@ -24,7 +24,7 @@ //! ``` //! //! 3. Transfer to host memory with [CudaDevice::sync_reclaim()], [CudaDevice::dtoh_sync_copy()], -//! or [CudaDevice::dtoh_sync_copy_into()] +//! or [CudaDevice::dtoh_sync_copy_into()] //! //! ```rust //! # use cudarc::driver::*; diff --git a/src/driver/result.rs b/src/driver/result.rs index 3d8b5c51..fbec114f 100644 --- a/src/driver/result.rs +++ b/src/driver/result.rs @@ -981,10 +981,10 @@ pub mod event { /// /// 1. The cuda function must be a valid handle returned from a non-unloaded module. /// 2. This is asynchronous, so the results of calling this function happen -/// at a later point after this function returns. +/// at a later point after this function returns. /// 3. All parameters used for this kernel should have been allocated by stream (I think?) /// 4. The cuda kernel has mutable access to every parameter, that means every parameter -/// can change at a later point after callign this function. *Even non-mutable references*. +/// can change at a later point after callign this function. *Even non-mutable references*. #[inline] pub unsafe fn launch_kernel( f: sys::CUfunction, diff --git a/src/driver/safe/core.rs b/src/driver/safe/core.rs index 95b8311e..23d5920c 100644 --- a/src/driver/safe/core.rs +++ b/src/driver/safe/core.rs @@ -29,9 +29,9 @@ use std::{collections::BTreeMap, marker::Unpin, pin::Pin, sync::Arc, vec::Vec}; /// # Safety /// 1. impl [Drop] to call all the corresponding resource cleanup methods /// 2. Doesn't impl clone, so you can't have multiple device pointers -/// hanging around. +/// hanging around. /// 3. Any allocations enforce that self is an [Arc], meaning no allocation -/// can outlive the [CudaDevice] +/// can outlive the [CudaDevice] #[derive(Debug)] pub struct CudaDevice { pub(crate) cu_device: sys::CUdevice, diff --git a/src/driver/safe/launch.rs b/src/driver/safe/launch.rs index afa6d997..329417fd 100644 --- a/src/driver/safe/launch.rs +++ b/src/driver/safe/launch.rs @@ -174,11 +174,11 @@ pub unsafe trait LaunchAsync { /// /// 1. `params` can be changed regardless of `&` or `&mut` usage. /// 2. `params` will be changed at some later point after the - /// function returns because the kernel is executed async. + /// function returns because the kernel is executed async. /// 3. There are no guaruntees that the `params` - /// are the correct number/types/order for `func`. + /// are the correct number/types/order for `func`. /// 4. Specifying the wrong values for [LaunchConfig] can result - /// in accessing/modifying values past memory limits. + /// in accessing/modifying values past memory limits. /// /// ## Asynchronous mutation /// diff --git a/src/nccl/safe.rs b/src/nccl/safe.rs index 863180f1..596c0755 100644 --- a/src/nccl/safe.rs +++ b/src/nccl/safe.rs @@ -1,5 +1,5 @@ use super::{result, sys}; -use crate::driver::{CudaDevice, CudaSlice}; +use crate::driver::{CudaDevice, DevicePtr, DevicePtrMut}; use std::mem::MaybeUninit; use std::ptr; use std::{sync::Arc, vec, vec::Vec}; @@ -204,15 +204,15 @@ impl Comm { } impl Comm { - pub fn send( + pub fn send, T: NcclType>( &self, - data: &CudaSlice, + data: &S, peer: i32, ) -> Result<(), result::NcclError> { unsafe { result::send( - data.cu_device_ptr as *mut _, - data.len, + *data.device_ptr() as *mut _, + data.len(), T::as_nccl_type(), peer, self.comm, @@ -222,15 +222,15 @@ impl Comm { Ok(()) } - pub fn recv( + pub fn recv, T: NcclType>( &self, - buff: &mut CudaSlice, + buff: &mut R, peer: i32, ) -> Result { unsafe { result::recv( - buff.cu_device_ptr as *mut _, - buff.len, + *buff.device_ptr_mut() as *mut _, + buff.len(), T::as_nccl_type(), peer, self.comm, @@ -239,21 +239,21 @@ impl Comm { } } - pub fn broadcast( + pub fn broadcast, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &Option>, - recvbuff: &mut CudaSlice, + sendbuff: &Option, + recvbuff: &mut R, root: i32, ) -> Result { unsafe { let send_ptr = match sendbuff { - Some(buffer) => buffer.cu_device_ptr as *mut _, + Some(buffer) => *buffer.device_ptr() as *mut _, None => ptr::null(), }; result::broadcast( send_ptr, - recvbuff.cu_device_ptr as *mut _, - recvbuff.len, + *recvbuff.device_ptr_mut() as *mut _, + recvbuff.len(), T::as_nccl_type(), root, self.comm, @@ -262,16 +262,16 @@ impl Comm { } } - pub fn all_gather( + pub fn all_gather, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &CudaSlice, - recvbuff: &mut CudaSlice, + sendbuff: &S, + recvbuff: &mut R, ) -> Result { unsafe { result::all_gather( - sendbuff.cu_device_ptr as *mut _, - recvbuff.cu_device_ptr as *mut _, - sendbuff.len, + *sendbuff.device_ptr() as *mut _, + *recvbuff.device_ptr_mut() as *mut _, + sendbuff.len(), T::as_nccl_type(), self.comm, self.device.stream as *mut _, @@ -279,17 +279,17 @@ impl Comm { } } - pub fn all_reduce( + pub fn all_reduce, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &CudaSlice, - recvbuff: &mut CudaSlice, + sendbuff: &S, + recvbuff: &mut R, reduce_op: &ReduceOp, ) -> Result { unsafe { result::all_reduce( - sendbuff.cu_device_ptr as *mut _, - recvbuff.cu_device_ptr as *mut _, - sendbuff.len, + *sendbuff.device_ptr() as *mut _, + *recvbuff.device_ptr_mut() as *mut _, + sendbuff.len(), T::as_nccl_type(), convert_to_nccl_reduce_op(reduce_op), self.comm, @@ -298,18 +298,18 @@ impl Comm { } } - pub fn reduce( + pub fn reduce, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &CudaSlice, - recvbuff: &mut CudaSlice, + sendbuff: &S, + recvbuff: &mut R, reduce_op: &ReduceOp, root: i32, ) -> Result { unsafe { result::reduce( - sendbuff.cu_device_ptr as *mut _, - recvbuff.cu_device_ptr as *mut _, - sendbuff.len, + *sendbuff.device_ptr() as *mut _, + *recvbuff.device_ptr_mut() as *mut _, + sendbuff.len(), T::as_nccl_type(), convert_to_nccl_reduce_op(reduce_op), root, @@ -319,17 +319,17 @@ impl Comm { } } - pub fn reduce_scatter( + pub fn reduce_scatter, R: DevicePtrMut, T: NcclType>( &self, - sendbuff: &CudaSlice, - recvbuff: &mut CudaSlice, + sendbuff: &S, + recvbuff: &mut R, reduce_op: &ReduceOp, ) -> Result { unsafe { result::reduce_scatter( - sendbuff.cu_device_ptr as *mut _, - recvbuff.cu_device_ptr as *mut _, - recvbuff.len, + *sendbuff.device_ptr() as *mut _, + *recvbuff.device_ptr_mut() as *mut _, + recvbuff.len(), T::as_nccl_type(), convert_to_nccl_reduce_op(reduce_op), self.comm, @@ -385,4 +385,35 @@ mod tests { t.join().unwrap() } } + + #[test] + fn test_all_reduce_views() { + let n = 2; + let n_devices = CudaDevice::count().unwrap() as usize; + let id = Id::new().unwrap(); + let threads: Vec<_> = (0..n_devices) + .map(|i| { + println!("III {i}"); + std::thread::spawn(move || { + println!("Within thread {i}"); + let dev = CudaDevice::new(i).unwrap(); + let comm = Comm::from_rank(dev.clone(), i, n_devices, id).unwrap(); + let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap(); + let mut slice_receive = dev.alloc_zeros::(n).unwrap(); + let slice_view = slice.slice(..); + let mut slice_receive_view = slice_receive.slice_mut(..); + + comm.all_reduce(&slice_view, &mut slice_receive_view, &ReduceOp::Sum) + .unwrap(); + + let out = dev.dtoh_sync_copy(&slice_receive).unwrap(); + + assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]); + }) + }) + .collect(); + for t in threads { + t.join().unwrap() + } + } }