Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking][Waiting to merge] &Option<T> -> Option<&T> in nccl broadcast. Add Option argument to nccl reduce. #292

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,23 @@ impl Comm {
}
}

/// Broadcasts a value from `root` rank to every other ranks `recvbuff`.
/// sendbuff is ignored on ranks other than `root`, so you can pass `None`
/// on non-root ranks.
///
/// sendbuff must be Some on root rank!
pub fn broadcast<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &Option<S>,
sendbuff: Option<&S>,
recvbuff: &mut R,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(sendbuff.is_some() || self.rank != root as usize);
let send_ptr = match sendbuff {
Some(buffer) => *buffer.device_ptr() as *mut _,
None => ptr::null(),
};
unsafe {
let send_ptr = match sendbuff {
Some(buffer) => *buffer.device_ptr() as *mut _,
None => ptr::null(),
};
result::broadcast(
send_ptr,
*recvbuff.device_ptr_mut() as *mut _,
Expand Down Expand Up @@ -298,17 +304,26 @@ impl Comm {
}
}

/// Reduces the sendbuff from all ranks into the recvbuff on the
/// `root` rank.
///
/// recvbuff must be Some on the root rank!
pub fn reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
&self,
sendbuff: &S,
recvbuff: &mut R,
recvbuff: Option<&mut R>,
reduce_op: &ReduceOp,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(recvbuff.is_some() || self.rank != root as usize);
let recv_ptr = match recvbuff {
Some(buffer) => *buffer.device_ptr_mut() as *mut _,
None => ptr::null_mut(),
};
unsafe {
result::reduce(
*sendbuff.device_ptr() as *mut _,
*recvbuff.device_ptr_mut() as *mut _,
recv_ptr,
sendbuff.len(),
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
Expand Down
Loading