Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe methods set_pointer_mode and get_pointer_mode #291

Merged
merged 2 commits into from
Sep 6, 2024

Conversation

MathisWellmann
Copy link
Contributor

@MathisWellmann MathisWellmann commented Sep 5, 2024

Adds two safe methods to CudaBlas:

There is also a test to ensure it works as expected.

This is important to have, as some cuBLAS functions require the CUBLAS_POINTER_MODE_DEVICE
to be set when attempting to pass in device memory as a result buffer.
If the cublasPointerMode_t is not changed from the default CUBLAS_POINTER_MODE_HOST in that case,
then the function panics with SIGSEGV: invalid memory reference.
I discovered that mechanism while trying to use the cublas<t>asum() function (https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-asum)
Here is an example that illustrates the importance of setting the value properly:

use std::sync::Arc;

use cudarc::{
    cublas::{sys::cublasPointerMode_t, CudaBlas},
    driver::{CudaViewMut, DevicePtr, DevicePtrMut, DeviceSlice},
};

use crate::{cuda::cublas_get_pointer_mode, Result};

/// Compute the absolute sum of a slice using CUBLAS.
/// ref: <https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-asum>
pub fn cublas_asum<'a, X>(
    cublas: Arc<CudaBlas>,
    vals: &X,
    d_dest: &mut CudaViewMut<'a, f32>,
) -> Result<()>
where
    X: DevicePtr<f32>,
{
    let n = vals.len();
    assert!(n > 0);
    assert_eq!(
        d_dest.len(),
        1,
        "There must be exactly one slot in the slice for the output"
    );
    debug_assert_eq!(
        cublas.get_pointer_mode()?,
        cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
        "The cublas pointer mode must be device as we intend to place the result into GPU memory"
    );
    unsafe {
        ngrc_cuda::cudarc::cublas::sys::lib()
            .cublasSasum_v2(
                *cublas.handle(),
                n as _,
                *vals.device_ptr() as *const f32,
                1,
                *d_dest.device_ptr_mut() as *mut f32,
            )
            .result()?
    };
    Ok(())
}

#[cfg(test)]
mod tests {
    use cudarc::driver::CudaDevice;
    use rand::{thread_rng, Rng};

    use super::*;
    use crate::cuda::cublas_set_pointer_mode_device;

    #[test]
    fn compare_cublas_asum() {
        let device = CudaDevice::new_with_stream(0).unwrap();
        let cublas = Arc::new(CudaBlas::new(device.clone()).unwrap());
        // IMPORTANT LINE: panics otherwise as we pass device memory as the `asum()` result. 
        cublas.set_pointer_mode(cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE).unwrap();

        let mut rng = thread_rng();
        let vals = Vec::from_iter((0..1_000).map(|_| rng.gen::<f32>()));

        let d_vals = device.htod_copy(vals).unwrap();
        let mut d_dest = device.alloc_zeros::<f32>(1).unwrap();
        cublas_asum(cublas, &d_vals, &mut d_dest.slice_mut(..)).unwrap();
        let h_sum = device.sync_reclaim(d_dest).unwrap()[0];
    }
}

Happy to include the test example in the PR as well if desired.

PS.: Setting CUBLAS_POINTER_MODE_DEVICE also increases performance by 50% as I can show with a benchmark. I'm not sure why there is such a big gain but it happens for the dot() function as well even though the inputs and outputs are using device memory both times. The dot() function shows a decrease in execution time from 9 micros to 4 micros for a whole range of slice lengths including (but not limited to) 8192 elements. I assume it has to do with async mem copies.

@MathisWellmann MathisWellmann marked this pull request as ready for review September 5, 2024 18:33
@coreylowman coreylowman self-requested a review September 6, 2024 18:09
@coreylowman coreylowman self-assigned this Sep 6, 2024
src/cublas/safe.rs Outdated Show resolved Hide resolved
@coreylowman coreylowman merged commit 64c8c62 into coreylowman:main Sep 6, 2024
14 checks passed
@coreylowman
Copy link
Owner

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants