Skip to content

Commit

Permalink
[breaking] feat: allow NCCL comm type to operate on views as well (#281)
Browse files Browse the repository at this point in the history
* feat: allow NCCL comm type to operate on views as well

* doc: address clippy warnings for missing indents
  • Loading branch information
dkales authored Aug 20, 2024
1 parent 6f99335 commit c5de356
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/cublaslt/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::sync::Arc;
///
/// 1. Create with [CudaBlasLT::new()]
/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported
/// if feature `half` is activated
/// if feature `half` is activated
///
/// Note: This maintains a instance of [`Arc<CudaDevice>`], so will prevent the device
/// from being dropped. Kernels will be launched on the device device default stream.
Expand Down
4 changes: 2 additions & 2 deletions src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//! ```
//!
//! 2. Allocate device memory with host data with [CudaDevice::htod_copy()], [CudaDevice::alloc_zeros()],
//! or [CudaDevice::htod_sync_copy()].
//! or [CudaDevice::htod_sync_copy()].
//!
//! You can also copy data to CudaSlice using [CudaDevice::htod_sync_copy_into()]
//!
Expand All @@ -24,7 +24,7 @@
//! ```
//!
//! 3. Transfer to host memory with [CudaDevice::sync_reclaim()], [CudaDevice::dtoh_sync_copy()],
//! or [CudaDevice::dtoh_sync_copy_into()]
//! or [CudaDevice::dtoh_sync_copy_into()]
//!
//! ```rust
//! # use cudarc::driver::*;
Expand Down
4 changes: 2 additions & 2 deletions src/driver/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -981,10 +981,10 @@ pub mod event {
///
/// 1. The cuda function must be a valid handle returned from a non-unloaded module.
/// 2. This is asynchronous, so the results of calling this function happen
/// at a later point after this function returns.
/// at a later point after this function returns.
/// 3. All parameters used for this kernel should have been allocated by stream (I think?)
/// 4. The cuda kernel has mutable access to every parameter, that means every parameter
/// can change at a later point after callign this function. *Even non-mutable references*.
/// can change at a later point after callign this function. *Even non-mutable references*.
#[inline]
pub unsafe fn launch_kernel(
f: sys::CUfunction,
Expand Down
4 changes: 2 additions & 2 deletions src/driver/safe/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ use std::{collections::BTreeMap, marker::Unpin, pin::Pin, sync::Arc, vec::Vec};
/// # Safety
/// 1. impl [Drop] to call all the corresponding resource cleanup methods
/// 2. Doesn't impl clone, so you can't have multiple device pointers
/// hanging around.
/// hanging around.
/// 3. Any allocations enforce that self is an [Arc], meaning no allocation
/// can outlive the [CudaDevice]
/// can outlive the [CudaDevice]
#[derive(Debug)]
pub struct CudaDevice {
pub(crate) cu_device: sys::CUdevice,
Expand Down
6 changes: 3 additions & 3 deletions src/driver/safe/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ pub unsafe trait LaunchAsync<Params> {
///
/// 1. `params` can be changed regardless of `&` or `&mut` usage.
/// 2. `params` will be changed at some later point after the
/// function returns because the kernel is executed async.
/// function returns because the kernel is executed async.
/// 3. There are no guaruntees that the `params`
/// are the correct number/types/order for `func`.
/// are the correct number/types/order for `func`.
/// 4. Specifying the wrong values for [LaunchConfig] can result
/// in accessing/modifying values past memory limits.
/// in accessing/modifying values past memory limits.
///
/// ## Asynchronous mutation
///
Expand Down
109 changes: 70 additions & 39 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{result, sys};
use crate::driver::{CudaDevice, CudaSlice};
use crate::driver::{CudaDevice, DevicePtr, DevicePtrMut};
use std::mem::MaybeUninit;
use std::ptr;
use std::{sync::Arc, vec, vec::Vec};
Expand Down Expand Up @@ -204,15 +204,15 @@ impl Comm {
}

impl Comm {
pub fn send<T: NcclType>(
pub fn send<S: DevicePtr<T>, T: NcclType>(
&self,
data: &CudaSlice<T>,
data: &S,
peer: i32,
) -> Result<(), result::NcclError> {
unsafe {
result::send(
data.cu_device_ptr as *mut _,
data.len,
*data.device_ptr() as *mut _,
data.len(),
T::as_nccl_type(),
peer,
self.comm,
Expand All @@ -222,15 +222,15 @@ impl Comm {
Ok(())
}

pub fn recv<T: NcclType>(
pub fn recv<R: DevicePtrMut<T>, T: NcclType>(
&self,
buff: &mut CudaSlice<T>,
buff: &mut R,
peer: i32,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::recv(
buff.cu_device_ptr as *mut _,
buff.len,
*buff.device_ptr_mut() as *mut _,
buff.len(),
T::as_nccl_type(),
peer,
self.comm,
Expand All @@ -239,21 +239,21 @@ impl Comm {
}
}

pub fn broadcast<T: NcclType>(
pub fn broadcast<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &Option<CudaSlice<T>>,
recvbuff: &mut CudaSlice<T>,
sendbuff: &Option<S>,
recvbuff: &mut R,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
let send_ptr = match sendbuff {
Some(buffer) => buffer.cu_device_ptr as *mut _,
Some(buffer) => *buffer.device_ptr() as *mut _,
None => ptr::null(),
};
result::broadcast(
send_ptr,
recvbuff.cu_device_ptr as *mut _,
recvbuff.len,
*recvbuff.device_ptr_mut() as *mut _,
recvbuff.len(),
T::as_nccl_type(),
root,
self.comm,
Expand All @@ -262,34 +262,34 @@ impl Comm {
}
}

pub fn all_gather<T: NcclType>(
pub fn all_gather<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &CudaSlice<T>,
recvbuff: &mut CudaSlice<T>,
sendbuff: &S,
recvbuff: &mut R,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::all_gather(
sendbuff.cu_device_ptr as *mut _,
recvbuff.cu_device_ptr as *mut _,
sendbuff.len,
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
sendbuff.len(),
T::as_nccl_type(),
self.comm,
self.device.stream as *mut _,
)
}
}

pub fn all_reduce<T: NcclType>(
pub fn all_reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &CudaSlice<T>,
recvbuff: &mut CudaSlice<T>,
sendbuff: &S,
recvbuff: &mut R,
reduce_op: &ReduceOp,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::all_reduce(
sendbuff.cu_device_ptr as *mut _,
recvbuff.cu_device_ptr as *mut _,
sendbuff.len,
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
sendbuff.len(),
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
self.comm,
Expand All @@ -298,18 +298,18 @@ impl Comm {
}
}

pub fn reduce<T: NcclType>(
pub fn reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &CudaSlice<T>,
recvbuff: &mut CudaSlice<T>,
sendbuff: &S,
recvbuff: &mut R,
reduce_op: &ReduceOp,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::reduce(
sendbuff.cu_device_ptr as *mut _,
recvbuff.cu_device_ptr as *mut _,
sendbuff.len,
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
sendbuff.len(),
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
root,
Expand All @@ -319,17 +319,17 @@ impl Comm {
}
}

pub fn reduce_scatter<T: NcclType>(
pub fn reduce_scatter<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &CudaSlice<T>,
recvbuff: &mut CudaSlice<T>,
sendbuff: &S,
recvbuff: &mut R,
reduce_op: &ReduceOp,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::reduce_scatter(
sendbuff.cu_device_ptr as *mut _,
recvbuff.cu_device_ptr as *mut _,
recvbuff.len,
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
recvbuff.len(),
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
self.comm,
Expand Down Expand Up @@ -385,4 +385,35 @@ mod tests {
t.join().unwrap()
}
}

#[test]
fn test_all_reduce_views() {
let n = 2;
let n_devices = CudaDevice::count().unwrap() as usize;
let id = Id::new().unwrap();
let threads: Vec<_> = (0..n_devices)
.map(|i| {
println!("III {i}");
std::thread::spawn(move || {
println!("Within thread {i}");
let dev = CudaDevice::new(i).unwrap();
let comm = Comm::from_rank(dev.clone(), i, n_devices, id).unwrap();
let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap();
let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
let slice_view = slice.slice(..);
let mut slice_receive_view = slice_receive.slice_mut(..);

comm.all_reduce(&slice_view, &mut slice_receive_view, &ReduceOp::Sum)
.unwrap();

let out = dev.dtoh_sync_copy(&slice_receive).unwrap();

assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
})
})
.collect();
for t in threads {
t.join().unwrap()
}
}
}

0 comments on commit c5de356

Please sign in to comment.