diff --git a/mlx-rs/examples/tutorial.rs b/mlx-rs/examples/tutorial.rs index c98814e4..fa263678 100644 --- a/mlx-rs/examples/tutorial.rs +++ b/mlx-rs/examples/tutorial.rs @@ -83,10 +83,10 @@ fn automatic_differentiation() { let x = Array::from(1.5); - let mut dfdx = calculate_grad(f, &x); + let dfdx = calculate_grad(f, &x); assert_eq!(dfdx.item::(), 2.0 * 1.5); - let mut dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x); + let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x); assert_eq!(dfdx2.item::(), 2.0); } diff --git a/mlx-rs/src/array/operators.rs b/mlx-rs/src/array/operators.rs index 1b92a1a6..7f399a17 100644 --- a/mlx-rs/src/array/operators.rs +++ b/mlx-rs/src/array/operators.rs @@ -1,4 +1,4 @@ -use crate::{prelude::ScalarOrArray, Array, StreamOrDevice}; +use crate::{utils::ScalarOrArray, Array, StreamOrDevice}; use num_traits::Pow; use std::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Sub, SubAssign, @@ -14,7 +14,7 @@ macro_rules! impl_binary_op { fn $method(self, rhs: T) -> Self::Output { paste::paste! { - self.[<$c_method _device>](rhs, StreamOrDevice::default()).unwrap() + self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap() } } } @@ -27,7 +27,7 @@ macro_rules! impl_binary_op { fn $method(self, rhs: T) -> Self::Output { paste::paste! { - self.[<$c_method _device>](rhs, StreamOrDevice::default()).unwrap() + self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap() } } } diff --git a/mlx-rs/src/error.rs b/mlx-rs/src/error.rs index ca9df378..f5dc9beb 100644 --- a/mlx-rs/src/error.rs +++ b/mlx-rs/src/error.rs @@ -43,8 +43,8 @@ impl Exception { } thread_local! { - pub static LAST_MLX_ERROR: Cell<*const c_char> = const { Cell::new(std::ptr::null()) }; - pub static IS_HANDLER_SET: Cell = const { Cell::new(false) }; + static LAST_MLX_ERROR: Cell<*const c_char> = const { Cell::new(std::ptr::null()) }; + static IS_HANDLER_SET: Cell = const { Cell::new(false) }; } #[no_mangle] diff --git a/mlx-rs/src/lib.rs b/mlx-rs/src/lib.rs index 2eff80af..a1fd7a5c 100644 --- a/mlx-rs/src/lib.rs +++ b/mlx-rs/src/lib.rs @@ -28,7 +28,6 @@ pub mod prelude { dtype::Dtype, ops::indexing::{Ellipsis, IndexMutOp, IndexOp, IntoStrideBy, NewAxis}, stream::StreamOrDevice, - utils::ScalarOrArray, }; pub use num_traits::Pow; diff --git a/mlx-rs/src/macros/array.rs b/mlx-rs/src/macros/array.rs index edbb8309..a0383ca8 100644 --- a/mlx-rs/src/macros/array.rs +++ b/mlx-rs/src/macros/array.rs @@ -41,9 +41,19 @@ /// [10, 11, 12] /// ] /// ]); +/// +/// // Create a 2x2 array by specifying the shape +/// let a = array!([1, 2, 3, 4], shape=[2, 2]); /// ``` #[macro_export] macro_rules! array { + ([$($x:expr),*], shape=[$($s:expr),*]) => { + { + let data = [$($x,)*]; + let shape = [$($s,)*]; + $crate::Array::from_slice(&data, &shape) + } + }; ([$([$([$($x:expr),*]),*]),*]) => { { let arr = [$([$([$($x,)*],)*],)*]; @@ -133,4 +143,16 @@ mod tests { assert_eq!(a.index((1, 1, 1)).item::(), 11); assert_eq!(a.index((1, 1, 2)).item::(), 12); } + + #[test] + fn test_array_with_shape() { + let a = array!([1, 2, 3, 4], shape = [2, 2]); + + assert!(a.ndim() == 2); + assert_eq!(a.shape(), &[2, 2]); + assert_eq!(a.index((0, 0)).item::(), 1); + assert_eq!(a.index((0, 1)).item::(), 2); + assert_eq!(a.index((1, 0)).item::(), 3); + assert_eq!(a.index((1, 1)).item::(), 4); + } } diff --git a/mlx-rs/src/macros/internal.rs b/mlx-rs/src/macros/internal.rs index b26c3c04..afc70b10 100644 --- a/mlx-rs/src/macros/internal.rs +++ b/mlx-rs/src/macros/internal.rs @@ -8,12 +8,33 @@ macro_rules! try_catch_c_ptr_expr { if c_ptr.is_null() { // SAFETY: there must be an error if the pointer is null return Err($crate::error::get_and_clear_last_mlx_error() + // .or($crate::error::take_last_mlx_closure_error()) .expect("A null pointer was returned, but no error was set.")); } c_ptr }}; } +macro_rules! try_catch_mlx_closure_error { + ($expr:expr) => {{ + if !$crate::error::is_mlx_error_handler_set() { + $crate::error::setup_mlx_error_handler(); + } + + let c_ptr = $expr; + // Always check for closure errors + if let Some(error) = $crate::error::get_and_clear_last_mlx_error() + // .or($crate::error::take_last_mlx_closure_error()) + { + return Err(error); + } + if c_ptr.is_null() { + panic!("A null pointer was returned, but no error was set."); + } + c_ptr + }}; +} + /// See `assertEqual` in the swift binding tests #[allow(unused_macros)] macro_rules! assert_array_all_close { diff --git a/mlx-rs/src/ops/arithmetic.rs b/mlx-rs/src/ops/arithmetic.rs index 3d90ef34..109c7fe3 100644 --- a/mlx-rs/src/ops/arithmetic.rs +++ b/mlx-rs/src/ops/arithmetic.rs @@ -1,10 +1,9 @@ use crate::array::Array; use crate::error::Exception; -use crate::prelude::ScalarOrArray; use crate::sealed::Sealed; use crate::stream::StreamOrDevice; -use crate::utils::{IntoOption, VectorArray}; +use crate::utils::{IntoOption, ScalarOrArray, VectorArray}; use crate::Stream; use mlx_macros::default_device; use smallvec::SmallVec; @@ -47,14 +46,14 @@ impl Array { /// // c_data == [5.0, 7.0, 9.0] /// ``` #[default_device] - pub fn add_device<'a>( + pub fn add_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_add(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_add(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -80,14 +79,14 @@ impl Array { /// // c_data == [-3.0, -3.0, -3.0] /// ``` #[default_device] - pub fn subtract_device<'a>( + pub fn subtract_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_subtract(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_subtract(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -133,14 +132,14 @@ impl Array { /// // c_data == [4.0, 10.0, 18.0] /// ``` #[default_device] - pub fn multiply_device<'a>( + pub fn multiply_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_multiply(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_multiply(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -166,14 +165,14 @@ impl Array { /// // c_data == [0.25, 0.4, 0.5] /// ``` #[default_device] - pub fn divide_device<'a>( + pub fn divide_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_divide(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_divide(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -199,14 +198,14 @@ impl Array { /// // c_data == [1.0, 8.0, 81.0] /// ``` #[default_device] - pub fn power_device<'a>( + pub fn power_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_power(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_power(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -232,14 +231,14 @@ impl Array { /// // c_data == [1.0, 3.0, 2.0] /// ``` #[default_device] - pub fn remainder_device<'a>( + pub fn remainder_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_remainder(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_remainder(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -344,14 +343,14 @@ impl Array { /// // c_data == [0.25, 0.4, 0.5] /// ``` #[default_device] - pub fn floor_divide_device<'a>( + pub fn floor_divide_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_floor_divide(self.c_array, other.into_owned_or_ref_array().as_ref().as_ptr(), stream.as_ref().as_ptr()) + mlx_sys::mlx_floor_divide(self.c_array, other.as_ref().as_ptr(), stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } @@ -558,14 +557,12 @@ pub fn acosh_device(a: &Array, stream: impl AsRef) -> Array { /// See [`Array::add`]. #[default_device] -pub fn add_device<'a, 'b>( - lhs: impl ScalarOrArray<'a>, - rhs: impl ScalarOrArray<'b>, +pub fn add_device( + lhs: impl AsRef, + rhs: impl AsRef, stream: impl AsRef, ) -> Result { - lhs.into_owned_or_ref_array() - .as_ref() - .add_device(rhs, stream) + lhs.as_ref().add_device(rhs, stream) } /// Element-wise inverse sine. @@ -731,14 +728,12 @@ pub fn degrees_device(a: &Array, stream: impl AsRef) -> Array { /// See [`Array::divide`]. #[default_device] -pub fn divide_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn divide_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .divide_device(b, stream) + a.as_ref().divide_device(b, stream) } /// Element-wise quotient and remainder. @@ -748,13 +743,13 @@ pub fn divide_device<'a, 'b>( /// /// Returns Ok((quotient, remainder)) if the operation was successful. #[default_device] -pub fn divmod_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn divmod_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result<(Array, Array), Exception> { - let a_ptr = a.into_owned_or_ref_array().as_ref().as_ptr(); - let b_ptr = b.into_owned_or_ref_array().as_ref().as_ptr(); + let a_ptr = a.as_ref().as_ptr(); + let b_ptr = b.as_ref().as_ptr(); unsafe { let c_vec = try_catch_c_ptr_expr! { @@ -802,9 +797,9 @@ pub fn floor_device(a: &Array, stream: impl AsRef) -> Result( +pub fn floor_divide_device( a: &Array, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { a.floor_divide_device(other, stream) @@ -841,13 +836,13 @@ pub fn log2_device(a: &Array, stream: impl AsRef) -> Array { /// /// The computation is is a numerically stable version of `log(exp(a) + exp(b))`. #[default_device] -pub fn log_add_exp_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn log_add_exp_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - let a_ptr = a.into_owned_or_ref_array().as_ref().as_ptr(); - let b_ptr = b.into_owned_or_ref_array().as_ref().as_ptr(); + let a_ptr = a.as_ref().as_ptr(); + let b_ptr = b.as_ref().as_ptr(); unsafe { let c_array = try_catch_c_ptr_expr! { @@ -868,13 +863,13 @@ pub fn matmul_device(a: &Array, b: &Array, stream: impl AsRef) -> Result /// Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both /// input arrays can also be scalars. #[default_device] -pub fn maximum_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn maximum_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - let a_ptr = a.into_owned_or_ref_array().as_ref().as_ptr(); - let b_ptr = b.into_owned_or_ref_array().as_ref().as_ptr(); + let a_ptr = a.as_ref().as_ptr(); + let b_ptr = b.as_ref().as_ptr(); unsafe { let c_array = try_catch_c_ptr_expr! { @@ -889,13 +884,13 @@ pub fn maximum_device<'a, 'b>( /// Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both /// input arrays can also be scalars. #[default_device] -pub fn minimum_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn minimum_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - let a_ptr = a.into_owned_or_ref_array().as_ref().as_ptr(); - let b_ptr = b.into_owned_or_ref_array().as_ref().as_ptr(); + let a_ptr = a.as_ref().as_ptr(); + let b_ptr = b.as_ref().as_ptr(); unsafe { let c_array = try_catch_c_ptr_expr! { @@ -907,14 +902,12 @@ pub fn minimum_device<'a, 'b>( /// See [`Array::multiply`]. #[default_device] -pub fn multiply_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn multiply_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .multiply_device(b, stream) + a.as_ref().multiply_device(b, stream) } /// See [`Array::negative`]. @@ -925,9 +918,9 @@ pub fn negative_device(a: &Array, stream: impl AsRef) -> Result( +pub fn power_device( a: &Array, - b: impl ScalarOrArray<'a>, + b: impl AsRef, stream: impl AsRef, ) -> Result { a.power_device(b, stream) @@ -947,14 +940,12 @@ pub fn reciprocal_device(a: &Array, stream: impl AsRef) -> Array { /// See [`Array::remainder`]. #[default_device] -pub fn remainder_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn remainder_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .remainder_device(b, stream) + a.as_ref().remainder_device(b, stream) } /// See [`Array::round`]. @@ -1040,14 +1031,12 @@ pub fn square_device(a: &Array, stream: impl AsRef) -> Array { /// See [`Array::subtract`]. #[default_device] -pub fn subtract_device<'a, 'b>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn subtract_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .subtract_device(b, stream) + a.as_ref().subtract_device(b, stream) } /// See [`Array::tan`]. @@ -1121,17 +1110,17 @@ pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>( /// - `alpha`: Scaling factor for the matrix product of `a` and `b` (default: `1`) /// - `beta`: Scaling factor for `c` (default: `1`) #[default_device] -pub fn addmm_device<'c, 'a, 'b>( - c: impl ScalarOrArray<'c>, - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, +pub fn addmm_device( + c: impl AsRef, + a: impl AsRef, + b: impl AsRef, alpha: impl Into>, beta: impl Into>, stream: impl AsRef, ) -> Result { - let c_ptr = c.into_owned_or_ref_array().as_ref().as_ptr(); - let a_ptr = a.into_owned_or_ref_array().as_ref().as_ptr(); - let b_ptr = b.into_owned_or_ref_array().as_ref().as_ptr(); + let c_ptr = c.as_ref().as_ptr(); + let a_ptr = a.as_ref().as_ptr(); + let b_ptr = b.as_ref().as_ptr(); let alpha = alpha.into().unwrap_or(1.0); let beta = beta.into().unwrap_or(1.0); @@ -1833,7 +1822,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(1.0), &[2, 2, 2]).unwrap(); let res = exp(&x); - let expected = Array::full::(&[2, 2, 2], 1.0f32.exp()).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(1.0f32.exp())).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -1898,7 +1887,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(1.0), &[2, 2, 2]).unwrap(); let res = sin(&x); - let expected = Array::full::(&[2, 2, 2], 1.0f32.sin()).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(1.0f32.sin())).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -1941,7 +1930,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(1.0), &[2, 2, 2]).unwrap(); let res = cos(&x); - let expected = Array::full::(&[2, 2, 2], 1.0f32.cos()).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(1.0f32.cos())).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -1972,7 +1961,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap(); let res = degrees(&x); - let expected = Array::full::(&[2, 2, 2], 90.0).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(90.0)).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -2003,7 +1992,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(90.0), &[2, 2, 2]).unwrap(); let res = radians(&x); - let expected = Array::full::(&[2, 2, 2], std::f32::consts::PI / 2.0).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -2032,7 +2021,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(1.0), &[2, 2, 2]).unwrap(); let res = log(&x); - let expected = Array::full::(&[2, 2, 2], 0.0).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(0.0)).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -2097,7 +2086,7 @@ mod tests { // Input is irregularly strided let x = broadcast_to(&Array::from_float(1.0), &[2, 2, 2]).unwrap(); let res = log1p(&x); - let expected = Array::full::(&[2, 2, 2], 1.0f32.ln_1p()).unwrap(); + let expected = Array::full::(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap(); assert!(all_close(&res, &expected, None, None, None) .unwrap() .item::()); @@ -2144,10 +2133,10 @@ mod tests { let x = array![2]; assert_eq!(square(&x).item::(), 4); - let x = Array::full::(&[3, 3], 2.0).unwrap(); + let x = Array::full::(&[3, 3], array!(2.0)).unwrap(); assert!(all_close( square(&x), - Array::full::(&[3, 3], 4.0).unwrap(), + Array::full::(&[3, 3], array!(4.0)).unwrap(), None, None, None @@ -2162,10 +2151,10 @@ mod tests { assert_eq!(sqrt(&x).item::(), 2.0); assert_eq!(rsqrt(&x).item::(), 0.5); - let x = Array::full::(&[3, 3], 9.0).unwrap(); + let x = Array::full::(&[3, 3], array!(9.0)).unwrap(); assert!(all_close( sqrt(&x), - Array::full::(&[3, 3], 3.0).unwrap(), + Array::full::(&[3, 3], array!(3.0)).unwrap(), None, None, None @@ -2188,10 +2177,10 @@ mod tests { assert_eq!(out.dtype(), Dtype::Float32); assert_eq!(out.item::(), 0.5); - let x = Array::full::(&[3, 3], 2.0).unwrap(); + let x = Array::full::(&[3, 3], array!(2.0)).unwrap(); assert!(all_close( reciprocal(&x), - Array::full::(&[3, 3], 0.5).unwrap(), + Array::full::(&[3, 3], array!(0.5)).unwrap(), None, None, None @@ -2249,7 +2238,7 @@ mod tests { let x = broadcast_to(&array!(1.0), &[10]).unwrap(); let y = broadcast_to(&array!(2.0), &[10]).unwrap(); let z = add(&x, &y).unwrap(); - assert_eq!(z, full::(&[10], 3.0).unwrap()); + assert_eq!(z, full::(&[10], array!(3.0)).unwrap()); let x = Array::from_slice(&[1.0, 2.0], &[1, 2]); let y = Array::from_slice(&[1.0, 2.0], &[2, 1]); @@ -2371,6 +2360,10 @@ mod tests { fn test_basic_clip() { let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]); let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]); + let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap(); + assert_eq!(clipped, expected); + + // Test with scalar let clipped = clip(&a, (2.0, 6.0)).unwrap(); assert_eq!(clipped, expected); } @@ -2379,6 +2372,10 @@ mod tests { fn test_clip_with_only_min() { let a = array!([-1.0, 1.0, 0.0, 5.0]); let expected = array!([0.0, 1.0, 0.0, 5.0]); + let clipped = clip(&a, (array!(0.0), ())).unwrap(); + assert_eq!(clipped, expected); + + // Test with scalar let clipped = clip(&a, (0.0, ())).unwrap(); assert_eq!(clipped, expected); } @@ -2387,6 +2384,10 @@ mod tests { fn test_clip_with_only_max() { let a = array!([2.0, 3.0, 4.0, 5.0]); let expected = array!([2.0, 3.0, 4.0, 4.0]); + let clipped = clip(&a, ((), array!(4.0))).unwrap(); + assert_eq!(clipped, expected); + + // Test with scalar let clipped = clip(&a, ((), 4.0)).unwrap(); assert_eq!(clipped, expected); } diff --git a/mlx-rs/src/ops/conversion.rs b/mlx-rs/src/ops/conversion.rs index 327dd050..814b0d4d 100644 --- a/mlx-rs/src/ops/conversion.rs +++ b/mlx-rs/src/ops/conversion.rs @@ -268,7 +268,7 @@ mod tests { #[test] fn test_view() { let array = Array::from_slice(&[1i16, 2, 3], &[3]); - let mut new_array = array.view::(); + let new_array = array.view::(); assert_eq!(new_array.dtype(), Dtype::Int8); assert_eq!(new_array.shape(), &[6]); diff --git a/mlx-rs/src/ops/factory.rs b/mlx-rs/src/ops/factory.rs index c96e0e19..9b7aaea5 100644 --- a/mlx-rs/src/ops/factory.rs +++ b/mlx-rs/src/ops/factory.rs @@ -1,6 +1,5 @@ use crate::array::ArrayElement; use crate::error::Exception; -use crate::prelude::ScalarOrArray; use crate::Stream; use crate::{array::Array, stream::StreamOrDevice}; use mlx_macros::default_device; @@ -117,14 +116,14 @@ impl Array { /// # Example /// /// ```rust - /// use mlx_rs::{Array, StreamOrDevice}; + /// use mlx_rs::{Array, StreamOrDevice, array}; /// // create [5, 4] array filled with 7 - /// let r = Array::full_device::(&[5, 4], 7.0f32, StreamOrDevice::default()).unwrap(); + /// let r = Array::full_device::(&[5, 4], array!(7.0f32), StreamOrDevice::default()).unwrap(); /// ``` #[default_device] - pub fn full_device<'a, T: ArrayElement>( + pub fn full_device( shape: &[i32], - values: impl ScalarOrArray<'a>, + values: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { @@ -132,7 +131,7 @@ impl Array { mlx_sys::mlx_full( shape.as_ptr(), shape.len(), - values.into_owned_or_ref_array().as_ref().as_ptr(), + values.as_ref().as_ptr(), T::DTYPE.into(), stream.as_ref().as_ptr(), ) @@ -392,9 +391,9 @@ pub fn eye_device( /// See [`Array::full`] #[default_device] -pub fn full_device<'a, T: ArrayElement>( +pub fn full_device( shape: &[i32], - values: impl ScalarOrArray<'a>, + values: impl AsRef, stream: impl AsRef, ) -> Result { Array::full_device::(shape, values, stream) @@ -474,7 +473,7 @@ pub fn tri_device( #[cfg(test)] mod tests { use super::*; - use crate::dtype::Dtype; + use crate::{array, dtype::Dtype}; use half::f16; #[test] @@ -518,7 +517,7 @@ mod tests { #[test] fn test_full_scalar() { - let array = Array::full::(&[2, 3], 7f32).unwrap(); + let array = Array::full::(&[2, 3], array!(7f32)).unwrap(); assert_eq!(array.shape(), &[2, 3]); assert_eq!(array.dtype(), Dtype::Float32); diff --git a/mlx-rs/src/ops/logical.rs b/mlx-rs/src/ops/logical.rs index f35a7ee3..fb920400 100644 --- a/mlx-rs/src/ops/logical.rs +++ b/mlx-rs/src/ops/logical.rs @@ -1,6 +1,5 @@ use crate::array::Array; use crate::error::Exception; -use crate::prelude::ScalarOrArray; use crate::stream::StreamOrDevice; use crate::utils::{axes_or_default_to_all, IntoOption}; use crate::Stream; @@ -28,16 +27,16 @@ impl Array { /// // c_data == [true, true, true] /// ``` #[default_device] - pub fn eq_device<'a>( + pub fn eq_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_equal( self.as_ptr(), - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -66,16 +65,16 @@ impl Array { /// // c_data == [true, true, true] /// ``` #[default_device] - pub fn le_device<'a>( + pub fn le_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_less_equal( self.as_ptr(), - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -104,16 +103,16 @@ impl Array { /// // c_data == [true, true, true] /// ``` #[default_device] - pub fn ge_device<'a>( + pub fn ge_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_greater_equal( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -142,16 +141,16 @@ impl Array { /// // c_data == [false, false, false] /// ``` #[default_device] - pub fn ne_device<'a>( + pub fn ne_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_not_equal( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -179,16 +178,16 @@ impl Array { /// // c_data == [false, false, false] /// ``` #[default_device] - pub fn lt_device<'a>( + pub fn lt_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_less( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -216,16 +215,16 @@ impl Array { /// // c_data == [false, false, false] /// ``` #[default_device] - pub fn gt_device<'a>( + pub fn gt_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_greater( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -253,16 +252,16 @@ impl Array { /// // c_data == [true, false, false] /// ``` #[default_device] - pub fn logical_and_device<'a>( + pub fn logical_and_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_logical_and( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -290,16 +289,16 @@ impl Array { /// // c_data == [true, true, true] /// ``` #[default_device] - pub fn logical_or_device<'a>( + pub fn logical_or_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_logical_or( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -348,18 +347,18 @@ impl Array { /// /// ```rust /// use num_traits::Pow; - /// use mlx_rs::Array; - /// let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); - /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]).power(0.5).unwrap(); + /// use mlx_rs::array; + /// let a = array!([0., 1., 2., 3.]).sqrt(); + /// let b = array!([0., 1., 2., 3.]).power(array!(0.5)).unwrap(); /// let mut c = a.all_close(&b, None, None, None).unwrap(); /// /// let c_data: &[bool] = c.as_slice(); /// // c_data == [true] /// ``` #[default_device] - pub fn all_close_device<'a>( + pub fn all_close_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, rtol: impl Into>, atol: impl Into>, equal_nan: impl Into>, @@ -369,7 +368,7 @@ impl Array { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_allclose( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), rtol.into().unwrap_or(1e-5), atol.into().unwrap_or(1e-8), equal_nan.into().unwrap_or(false), @@ -438,16 +437,16 @@ impl Array { /// // c == [true] /// ``` #[default_device] - pub fn array_eq_device<'a>( + pub fn array_eq_device( &self, - other: impl ScalarOrArray<'a>, + other: impl AsRef, equal_nan: impl Into>, stream: impl AsRef, ) -> Array { unsafe { Array::from_ptr(mlx_sys::mlx_array_equal( self.c_array, - other.into_owned_or_ref_array().as_ref().as_ptr(), + other.as_ref().as_ptr(), equal_nan.into().unwrap_or(false), stream.as_ref().as_ptr(), )) @@ -511,24 +510,14 @@ pub fn any_device<'a>( /// See [`Array::logical_and`] #[default_device] -pub fn logical_and_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, -) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .logical_and_device(b, StreamOrDevice::default()) +pub fn logical_and_device(a: impl AsRef, b: impl AsRef) -> Result { + a.as_ref().logical_and_device(b, StreamOrDevice::default()) } /// See [`Array::logical_or`] #[default_device] -pub fn logical_or_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, -) -> Result { - a.into_owned_or_ref_array() - .as_ref() - .logical_or_device(b, StreamOrDevice::default()) +pub fn logical_or_device(a: impl AsRef, b: impl AsRef) -> Result { + a.as_ref().logical_or_device(b, StreamOrDevice::default()) } /// See [`Array::logical_not`] @@ -539,16 +528,15 @@ pub fn logical_not_device(a: &Array, stream: impl AsRef) -> Array { /// See [`Array::all_close`] #[default_device] -pub fn all_close_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn all_close_device( + a: impl AsRef, + b: impl AsRef, rtol: impl Into>, atol: impl Into>, equal_nan: impl Into>, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array() - .as_ref() + a.as_ref() .all_close_device(b, rtol, atol, equal_nan, stream) } @@ -567,75 +555,73 @@ pub fn is_close_device( /// See [`Array::array_eq`] #[default_device] -pub fn array_eq_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn array_eq_device( + a: impl AsRef, + b: impl AsRef, equal_nan: impl Into>, stream: impl AsRef, ) -> Array { - a.into_owned_or_ref_array() - .as_ref() - .array_eq_device(b, equal_nan, stream) + a.as_ref().array_eq_device(b, equal_nan, stream) } /// See [`Array::eq`] #[default_device] -pub fn eq_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn eq_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().eq_device(b, stream) + a.as_ref().eq_device(b, stream) } /// See [`Array::le`] #[default_device] -pub fn le_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn le_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().le_device(b, stream) + a.as_ref().le_device(b, stream) } /// See [`Array::ge`] #[default_device] -pub fn ge_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn ge_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().ge_device(b, stream) + a.as_ref().ge_device(b, stream) } /// See [`Array::ne`] #[default_device] -pub fn ne_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn ne_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().ne_device(b, stream) + a.as_ref().ne_device(b, stream) } /// See [`Array::lt`] #[default_device] -pub fn lt_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn lt_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().lt_device(b, stream) + a.as_ref().lt_device(b, stream) } /// See [`Array::gt`] #[default_device] -pub fn gt_device<'a>( - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'a>, +pub fn gt_device( + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { - a.into_owned_or_ref_array().as_ref().gt_device(b, stream) + a.as_ref().gt_device(b, stream) } // TODO: check if the functions below could throw an exception. @@ -687,18 +673,18 @@ pub fn is_neg_inf_device(array: &Array, stream: impl AsRef) -> Array { /// - a: input selected from where condition is non-zero or `true` /// - b: input selected from where condition is zero or `false` #[default_device] -pub fn r#where_device<'a, 'b>( +pub fn r#where_device( condition: &Array, - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_where( condition.c_array, - a.into_owned_or_ref_array().as_ref().as_ptr(), - b.into_owned_or_ref_array().as_ref().as_ptr(), + a.as_ref().as_ptr(), + b.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -708,18 +694,18 @@ pub fn r#where_device<'a, 'b>( /// Alias for [`r#where`] #[default_device] -pub fn which_device<'a, 'b>( +pub fn which_device( condition: &Array, - a: impl ScalarOrArray<'a>, - b: impl ScalarOrArray<'b>, + a: impl AsRef, + b: impl AsRef, stream: impl AsRef, ) -> Result { unsafe { let c_array = try_catch_c_ptr_expr! { mlx_sys::mlx_where( condition.c_array, - a.into_owned_or_ref_array().as_ref().as_ptr(), - b.into_owned_or_ref_array().as_ref().as_ptr(), + a.as_ref().as_ptr(), + b.as_ref().as_ptr(), stream.as_ref().as_ptr(), ) }; @@ -937,7 +923,7 @@ mod tests { fn test_all_close() { let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); let b = Array::from_slice(&[0., 1., 2., 3.], &[4]) - .power(0.5) + .power(array!(0.5)) .unwrap(); let c = a.all_close(&b, 1e-5, None, None).unwrap(); diff --git a/mlx-rs/src/transforms/compile.rs b/mlx-rs/src/transforms/compile.rs index 33b4cdef..b74ba935 100644 --- a/mlx-rs/src/transforms/compile.rs +++ b/mlx-rs/src/transforms/compile.rs @@ -35,16 +35,16 @@ pub fn disable_compile() { } } -pub trait Compile<'a, Args, Output>: Sized { +pub trait Compile<'a, Args, Output, Err>: Sized { fn compile( self, inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, shapeless: bool, - ) -> impl CallMut<'a, Args, Output>; + ) -> impl CallMut<'a, Args, Output, Err>; } -impl<'a, F> Compile<'a, &'a [Array], Vec> for F +impl<'a, F> Compile<'a, &'a [Array], Vec, ()> for F where F: FnMut(&[Array]) -> Vec + 'static, { @@ -53,7 +53,7 @@ where inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, shapeless: bool, - ) -> impl CallMut<'a, &'a [Array], Vec> { + ) -> impl CallMut<'a, &'a [Array], Vec, ()> { let id = type_id_to_usize(&self); let state = CompiledState { f: self, @@ -69,7 +69,32 @@ where } } -impl<'a, F> Compile<'a, &'a Array, Array> for F +impl<'a, F> Compile<'a, &'a [Array], Vec, Exception> for F +where + F: FnMut(&[Array]) -> Result, Exception> + 'static, +{ + fn compile( + self, + inputs: Option<&'a mut [Array]>, + outputs: Option<&'a mut [Array]>, + shapeless: bool, + ) -> impl CallMut<'a, &'a [Array], Vec, Exception> { + let id = type_id_to_usize(&self); + let state = CompiledState { + f: self, + inputs, + outputs, + shapeless, + id, + }; + Compiled { + f_marker: PhantomData::, + state, + } + } +} + +impl<'a, F> Compile<'a, &'a Array, Array, ()> for F where F: FnMut(&Array) -> Array + 'static, { @@ -78,7 +103,7 @@ where inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, shapeless: bool, - ) -> impl CallMut<'a, &'a Array, Array> { + ) -> impl CallMut<'a, &'a Array, Array, ()> { let f = move |args: &[Array]| -> Vec { let result = (self)(&args[0]); vec![result] @@ -98,7 +123,36 @@ where } } -impl<'a, F> Compile<'a, (&'a Array, &'a Array), Array> for F +impl<'a, F> Compile<'a, &'a Array, Array, Exception> for F +where + F: FnMut(&Array) -> Result + 'static, +{ + fn compile( + mut self, + inputs: Option<&'a mut [Array]>, + outputs: Option<&'a mut [Array]>, + shapeless: bool, + ) -> impl CallMut<'a, &'a Array, Array, Exception> { + let f = move |args: &[Array]| -> Result, Exception> { + let result = (self)(&args[0])?; + Ok(vec![result]) + }; + let id = type_id_to_usize(&f); + let state = CompiledState { + f, + inputs, + outputs, + shapeless, + id, + }; + Compiled { + f_marker: PhantomData::, + state, + } + } +} + +impl<'a, F> Compile<'a, (&'a Array, &'a Array), Array, ()> for F where F: FnMut((&Array, &Array)) -> Array + 'static, { @@ -107,7 +161,7 @@ where inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, shapeless: bool, - ) -> impl CallMut<'a, (&'a Array, &'a Array), Array> { + ) -> impl CallMut<'a, (&'a Array, &'a Array), Array, ()> { let f = move |args: &[Array]| -> Vec { let result = (self)((&args[0], &args[1])); vec![result] @@ -127,7 +181,36 @@ where } } -impl<'a, F> Compile<'a, (&'a Array, &'a Array, &'a Array), Array> for F +impl<'a, F> Compile<'a, (&'a Array, &'a Array), Array, Exception> for F +where + F: FnMut((&Array, &Array)) -> Result + 'static, +{ + fn compile( + mut self, + inputs: Option<&'a mut [Array]>, + outputs: Option<&'a mut [Array]>, + shapeless: bool, + ) -> impl CallMut<'a, (&'a Array, &'a Array), Array, Exception> { + let f = move |args: &[Array]| -> Result, Exception> { + let result = (self)((&args[0], &args[1]))?; + Ok(vec![result]) + }; + let id = type_id_to_usize(&f); + let state = CompiledState { + f, + inputs, + outputs, + shapeless, + id, + }; + Compiled { + f_marker: PhantomData::, + state, + } + } +} + +impl<'a, F> Compile<'a, (&'a Array, &'a Array, &'a Array), Array, ()> for F where F: FnMut((&Array, &Array, &Array)) -> Array + 'static, { @@ -136,7 +219,7 @@ where inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, shapeless: bool, - ) -> impl CallMut<'a, (&'a Array, &'a Array, &'a Array), Array> { + ) -> impl CallMut<'a, (&'a Array, &'a Array, &'a Array), Array, ()> { let f = move |args: &[Array]| -> Vec { let result = (self)((&args[0], &args[1], &args[2])); vec![result] @@ -156,7 +239,36 @@ where } } -pub trait CallMut<'a, Args, Output> { +impl<'a, F> Compile<'a, (&'a Array, &'a Array, &'a Array), Array, Exception> for F +where + F: FnMut((&Array, &Array, &Array)) -> Result + 'static, +{ + fn compile( + mut self, + inputs: Option<&'a mut [Array]>, + outputs: Option<&'a mut [Array]>, + shapeless: bool, + ) -> impl CallMut<'a, (&'a Array, &'a Array, &'a Array), Array, Exception> { + let f = move |args: &[Array]| -> Result, Exception> { + let result = (self)((&args[0], &args[1], &args[2]))?; + Ok(vec![result]) + }; + let id = type_id_to_usize(&f); + let state = CompiledState { + f, + inputs, + outputs, + shapeless, + id, + }; + Compiled { + f_marker: PhantomData::, + state, + } + } +} + +pub trait CallMut<'a, Args, Output, Err> { fn call_mut(&mut self, args: Args) -> Result; } @@ -166,7 +278,7 @@ pub struct Compiled<'a, F, G> { state: CompiledState<'a, G>, } -impl<'a, F, G> CallMut<'a, &'a [Array], Vec> for Compiled<'a, F, G> +impl<'a, F, G> CallMut<'a, &'a [Array], Vec, ()> for Compiled<'a, F, G> where F: FnMut(&[Array]) -> Vec + 'a, G: FnMut(&[Array]) -> Vec + 'a, @@ -176,7 +288,17 @@ where } } -impl<'a, F, G> CallMut<'a, &'a Array, Array> for Compiled<'a, F, G> +impl<'a, F, G> CallMut<'a, &'a [Array], Vec, Exception> for Compiled<'a, F, G> +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, + G: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + fn call_mut(&mut self, args: &[Array]) -> Result, Exception> { + self.state.call_mut_fallible(args) + } +} + +impl<'a, F, G> CallMut<'a, &'a Array, Array, ()> for Compiled<'a, F, G> where F: FnMut(&Array) -> Array + 'a, G: FnMut(&[Array]) -> Vec + 'a, @@ -189,7 +311,20 @@ where } } -impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array), Array> for Compiled<'a, F, G> +impl<'a, F, G> CallMut<'a, &'a Array, Array, Exception> for Compiled<'a, F, G> +where + F: FnMut(&Array) -> Result + 'a, + G: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + fn call_mut(&mut self, args: &Array) -> Result { + // Is there any way to avoid this shallow clone? + let args = &[args.clone()]; + let result = self.state.call_mut_fallible(args)?; + Ok(result.into_iter().next().unwrap()) + } +} + +impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array), Array, ()> for Compiled<'a, F, G> where F: FnMut((&Array, &Array)) -> Array + 'a, G: FnMut(&[Array]) -> Vec + 'a, @@ -202,7 +337,20 @@ where } } -impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array, &'a Array), Array> for Compiled<'a, F, G> +impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array), Array, Exception> for Compiled<'a, F, G> +where + F: FnMut((&Array, &Array)) -> Result + 'a, + G: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + fn call_mut(&mut self, args: (&Array, &Array)) -> Result { + // Is there any way to avoid this shallow clone? + let args = &[args.0.clone(), args.1.clone()]; + let result = self.state.call_mut_fallible(args)?; + Ok(result.into_iter().next().unwrap()) + } +} + +impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array, &'a Array), Array, ()> for Compiled<'a, F, G> where F: FnMut((&Array, &Array, &Array)) -> Array + 'a, G: FnMut(&[Array]) -> Vec + 'a, @@ -215,6 +363,20 @@ where } } +impl<'a, F, G> CallMut<'a, (&'a Array, &'a Array, &'a Array), Array, Exception> + for Compiled<'a, F, G> +where + F: FnMut((&Array, &Array, &Array)) -> Result + 'a, + G: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + fn call_mut(&mut self, args: (&Array, &Array, &Array)) -> Result { + // Is there any way to avoid this shallow clone? + let args = &[args.0.clone(), args.1.clone(), args.2.clone()]; + let result = self.state.call_mut_fallible(args)?; + Ok(result.into_iter().next().unwrap()) + } +} + #[derive(Debug)] struct CompiledState<'a, F> where @@ -227,6 +389,71 @@ where id: usize, } +#[inline] +fn call_mut_inner( + inner_closure: Closure, + fun_id: usize, + shapeless: bool, + state_inputs: Rc>>, + state_outputs: Rc>>, + args: &[Array], +) -> Result, Exception> { + // note: this will use the cached compile (via the id) + // but will be able to re-evaluate with fresh state if needed + let compiled = unsafe { + let constants = &[]; + let c_closure = try_catch_mlx_closure_error! { + mlx_detail_compile( + inner_closure.as_ptr(), + fun_id, + shapeless, + constants.as_ptr(), + 0, + ) + }; + Closure::from_ptr(c_closure) + }; + + let inner_inputs_vector = match state_inputs.borrow().as_ref() { + Some(s) => VectorArray::from_iter(args.iter().chain(s.iter())), + None => VectorArray::from_iter(args.iter()), + }; + + // will compile the function (if needed) and evaluate the + // compiled graph + let result_vector = unsafe { + let c_vector = try_catch_mlx_closure_error! { + mlx_closure_apply(compiled.as_ptr(), inner_inputs_vector.as_ptr()) + }; + VectorArray::from_ptr(c_vector) + }; + let result_plus_state_output: Vec = result_vector.into_values(); + + // push the stateOutput into the state + if let Some(outputs) = state_outputs.borrow_mut().as_mut() { + let result_plus_state_output_len = result_plus_state_output.len(); + let state_output_len = outputs.len(); + let suffix_len = result_plus_state_output_len - state_output_len; + for (s, new_values) in outputs + .iter_mut() + .zip(result_plus_state_output[suffix_len..].iter()) + { + update_by_replace_with_ref_to_new_array(s, new_values); + } + } + + let result_len = result_plus_state_output.len() + - state_outputs + .borrow() + .as_ref() + .map(|x| x.len()) + .unwrap_or(0); + Ok(result_plus_state_output + .into_iter() + .take(result_len) + .collect()) +} + impl<'a, F> CompiledState<'a, F> { fn call_mut(&mut self, args: &[Array]) -> Result, Exception> where @@ -285,60 +512,84 @@ impl<'a, F> CompiledState<'a, F> { let inner_closure = Closure::new(inner); - // note: this will use the cached compile (via the id) - // but will be able to re-evaluate with fresh state if needed - let compiled = unsafe { - let constants = &[]; - let c_closure = try_catch_c_ptr_expr! { - mlx_detail_compile( - inner_closure.as_ptr(), - self.id, - self.shapeless, - constants.as_ptr(), - 0, - ) - }; - Closure::from_ptr(c_closure) - }; + call_mut_inner( + inner_closure, + self.id, + self.shapeless, + state_inputs, + state_outputs, + args, + ) + } - let inner_inputs_vector = match state_inputs.borrow().as_ref() { - Some(s) => VectorArray::from_iter(args.iter().chain(s.iter())), - None => VectorArray::from_iter(args.iter()), - }; + fn call_mut_fallible(&mut self, args: &[Array]) -> Result, Exception> + where + F: FnMut(&[Array]) -> Result, Exception> + 'a, + { + let args_len = args.len(); + let state_inputs = Rc::new(RefCell::new(&mut self.inputs)); + let state_outputs = Rc::new(RefCell::new(&mut self.outputs)); + let f = &mut self.f; - // will compile the function (if needed) and evaluate the - // compiled graph - let result_vector = unsafe { - let c_vector = try_catch_c_ptr_expr! { - mlx_closure_apply(compiled.as_ptr(), inner_inputs_vector.as_ptr()) - }; - VectorArray::from_ptr(c_vector) - }; - let result_plus_state_output: Vec = result_vector.into_values(); - - // push the stateOutput into the state - if let Some(outputs) = state_outputs.borrow_mut().as_mut() { - let result_plus_state_output_len = result_plus_state_output.len(); - let state_output_len = outputs.len(); - let suffix_len = result_plus_state_output_len - state_output_len; - for (s, new_values) in outputs - .iter_mut() - .zip(result_plus_state_output[suffix_len..].iter()) - { - update_by_replace_with_ref_to_new_array(s, new_values); + let state_inputs_clone = Rc::clone(&state_inputs); + let state_outputs_clone = Rc::clone(&state_outputs); + let inner = move |tracers: &[Array]| -> Result, Exception> { + // put the tracers in their appropriate places: + // - arguments to the function + // - inner state + + let tracer_args = &tracers[..args_len]; + + // save a snapshot of the inner state + let saved_state_inputs: Option> = state_inputs_clone + .borrow() + .as_ref() + .map(|inputs| inputs.iter().map(Clone::clone).collect()); + + // replace the inner state with the tracers + if let Some(inputs) = state_inputs_clone.borrow_mut().as_mut() { + for (s, tracer) in inputs.iter_mut().zip(tracers.iter().skip(args_len)) { + update_by_replace_with_ref_to_new_array(s, tracer); + } } - } - let result_len = result_plus_state_output.len() - - state_outputs + // call the function with the tracer arguments and the state holding tracers + let mut result = (f)(tracer_args); + + // recapture the state as it may have changed + let state_output_tracers: Option> = state_outputs_clone .borrow() .as_ref() - .map(|x| x.len()) - .unwrap_or(0); - Ok(result_plus_state_output - .into_iter() - .take(result_len) - .collect()) + .map(|outputs| outputs.iter().map(Clone::clone).collect()); + + // put the original values back in the state + if let Some(inputs) = state_inputs_clone.borrow_mut().as_mut() { + for (s, saved) in inputs.iter_mut().zip(saved_state_inputs.unwrap()) { + update_by_replace_with_ref_to_new_array(s, &saved); + } + } + + // return the result of the function and the state + if let Some(mut state_output_tracers) = state_output_tracers { + result = result.map(|mut r| { + r.append(&mut state_output_tracers); + r + }); + } + + result + }; + + let inner_closure = Closure::new_fallible(inner); + + call_mut_inner( + inner_closure, + self.id, + self.shapeless, + state_inputs, + state_outputs, + args, + ) } } @@ -377,16 +628,17 @@ fn update_by_replace_with_ref_to_new_array(src: &mut Array, new_array: &Array) { /// Please refer to the [swift binding /// documentation](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/compilation) /// for more information. -pub fn compile<'a, F, Args, Output>( +pub fn compile<'a, F, Args, Output, Err>( f: F, shapeless: Option, inputs: Option<&'a mut [Array]>, outputs: Option<&'a mut [Array]>, ) -> impl FnMut(Args) -> Result + 'a where - F: Compile<'a, Args, Output> + 'static, + F: Compile<'a, Args, Output, Err> + 'static, Args: 'a, Output: 'a, + Err: 'a, { let shapeless = shapeless.unwrap_or(false); let mut compiled = f.compile(inputs, outputs, shapeless); @@ -403,7 +655,12 @@ pub fn clear_cache() { mod tests { use core::panic; - use crate::{ops::ones, Array}; + use crate::{ + array, + error::Exception, + ops::{multiply, ones}, + Array, + }; use super::compile; @@ -471,6 +728,55 @@ mod tests { assert_eq!(&r1, &r3); } + #[test] + fn test_compile_with_error() { + let f = |inputs: &[Array]| -> Result, Exception> { + multiply(&inputs[0], &inputs[1]).map(|x| vec![x]) + }; + + // Success case + let i1 = ones::(&[20, 20]).unwrap(); + let i2 = ones::(&[20, 20]).unwrap(); + let args = [i1, i2]; + + // evaluate directly + let r1 = f(&args).unwrap().drain(0..1).next().unwrap(); + + // evaluate compiled + let mut compiled = compile(f, None, None, None); + let r2 = compiled(&args).unwrap().drain(0..1).next().unwrap(); + + assert_eq!(&r1, &r2); + + let r3 = compiled(&args).unwrap().drain(0..1).next().unwrap(); + assert_eq!(&r1, &r3); + + // Error case + let a = array!([1.0, 2.0, 3.0]); + let b = array!([4.0, 5.0]); + let args = [a, b]; + + // The cache is keyed by function pointer and argument shapes + let c = array!([4.0, 5.0, 6.0]); + let d = array!([7.0, 8.0]); + let another_args = [c, d]; + + // evaluate directly + let result = f(&args); + assert!(result.is_err()); + + // evaluate compiled + let mut compiled = compile(f, None, None, None); + let result = compiled(&args); + assert!(result.is_err()); + + let result = compiled(&args); + assert!(result.is_err()); + + let result = compiled(&another_args); + assert!(result.is_err()); + } + #[test] fn test_compile_with_one_arg() { let f = |x: &Array| x * x; diff --git a/mlx-rs/src/transforms/mod.rs b/mlx-rs/src/transforms/mod.rs index 065e2db4..a2ed62bd 100644 --- a/mlx-rs/src/transforms/mod.rs +++ b/mlx-rs/src/transforms/mod.rs @@ -43,6 +43,34 @@ pub fn async_eval<'a>(outputs: impl IntoIterator) -> Resul get_and_clear_last_mlx_error().map_or(Ok(()), Err) } +#[inline] +fn jvp_inner( + closure: Closure<'_>, + primals: &[Array], + tangents: &[Array], +) -> Result<(Vec, Vec), Exception> { + let c_primals = VectorArray::from_iter(primals.iter()); + let c_tangents = VectorArray::from_iter(tangents.iter()); + + let vector_pair = unsafe { + let c_vector_pair = try_catch_mlx_closure_error! { + mlx_sys::mlx_jvp( + closure.as_ptr(), + c_primals.as_ptr(), + c_tangents.as_ptr(), + ) + }; + VectorVectorArray::from_ptr(c_vector_pair) + }; + + let vector_pair_values: SmallVec<[VectorArray; 2]> = vector_pair.into_values(); + let mut iter = vector_pair_values.into_iter(); + let v1 = iter.next().unwrap().into_values(); + let v2 = iter.next().unwrap().into_values(); + + Ok((v1, v2)) +} + /// Compute the Jacobian-vector product. /// /// This computes the product of the Jacobian of a function `f` evaluated at `primals` with the @@ -69,16 +97,37 @@ where F: FnMut(&[Array]) -> Vec + 'a, { let closure = Closure::new(f); + jvp_inner(closure, primals, tangents) +} +/// Similar to [`jvp`] but handles closures that can return an error. +pub fn jvp_fallible<'a, F>( + f: F, + primals: &[Array], + tangents: &[Array], +) -> Result<(Vec, Vec), Exception> +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + let closure = Closure::new_fallible(f); + jvp_inner(closure, primals, tangents) +} + +#[inline] +fn vjp_inner( + closure: Closure<'_>, + primals: &[Array], + cotangents: &[Array], +) -> Result<(Vec, Vec), Exception> { let c_primals = VectorArray::from_iter(primals.iter()); - let c_tangents = VectorArray::from_iter(tangents.iter()); + let c_cotangents = VectorArray::from_iter(cotangents.iter()); let vector_pair = unsafe { - let c_vector_pair = try_catch_c_ptr_expr! { - mlx_sys::mlx_jvp( + let c_vector_pair = try_catch_mlx_closure_error! { + mlx_sys::mlx_vjp( closure.as_ptr(), c_primals.as_ptr(), - c_tangents.as_ptr(), + c_cotangents.as_ptr(), ) }; VectorVectorArray::from_ptr(c_vector_pair) @@ -117,27 +166,20 @@ where F: FnMut(&[Array]) -> Vec + 'a, { let closure = Closure::new(f); + vjp_inner(closure, primals, cotangents) +} - let c_primals = VectorArray::from_iter(primals.iter()); - let c_cotangents = VectorArray::from_iter(cotangents.iter()); - - let vector_pair = unsafe { - let c_vector_pair = try_catch_c_ptr_expr! { - mlx_sys::mlx_vjp( - closure.as_ptr(), - c_primals.as_ptr(), - c_cotangents.as_ptr(), - ) - }; - VectorVectorArray::from_ptr(c_vector_pair) - }; - - let vector_pair_values: SmallVec<[VectorArray; 2]> = vector_pair.into_values(); - let mut iter = vector_pair_values.into_iter(); - let v1 = iter.next().unwrap().into_values(); - let v2 = iter.next().unwrap().into_values(); - - Ok((v1, v2)) +/// Similar to [`vjp`] but handles closures that can return an error. +pub fn vjp_fallible<'a, F>( + f: F, + primals: &[Array], + cotangents: &[Array], +) -> Result<(Vec, Vec), Exception> +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + let closure = Closure::new_fallible(f); + vjp_inner(closure, primals, cotangents) } fn value_and_gradient( @@ -147,7 +189,7 @@ fn value_and_gradient( let input_vector = VectorArray::from_iter(arrays); let vector_pair = unsafe { - let c_vector_pair = try_catch_c_ptr_expr! { + let c_vector_pair = try_catch_mlx_closure_error! { mlx_closure_value_and_grad_apply(value_and_grad, input_vector.as_ptr()) }; VectorVectorArray::from_ptr(c_vector_pair) @@ -161,17 +203,14 @@ fn value_and_gradient( Ok((values_vec, gradients_vec)) } -fn build_gradient<'a, F>( - f: F, +#[inline] +fn build_gradient_inner<'a>( + closure: Closure<'a>, argument_numbers: &'a [i32], -) -> impl FnMut(&[Array]) -> Result, Exception> + 'a -where - F: FnMut(&[Array]) -> Vec + 'a, -{ - let closure = Closure::new(f); +) -> impl FnMut(&[Array]) -> Result, Exception> + 'a { move |arrays: &[Array]| -> Result, Exception> { unsafe { - let c_value_and_grad = try_catch_c_ptr_expr! { + let c_value_and_grad = try_catch_mlx_closure_error! { mlx_sys::mlx_value_and_grad( closure.as_ptr(), argument_numbers.as_ptr(), @@ -185,17 +224,35 @@ where } } -fn build_value_and_gradient<'a, F>( +fn build_gradient<'a, F>( f: F, argument_numbers: &'a [i32], -) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a +) -> impl FnMut(&[Array]) -> Result, Exception> + 'a where F: FnMut(&[Array]) -> Vec + 'a, { let closure = Closure::new(f); + build_gradient_inner(closure, argument_numbers) +} + +fn build_gradient_fallible<'a, F>( + f: F, + argument_numbers: &'a [i32], +) -> impl FnMut(&[Array]) -> Result, Exception> + 'a +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + let closure = Closure::new_fallible(f); + build_gradient_inner(closure, argument_numbers) +} + +fn build_value_and_gradient_inner<'a>( + closure: Closure<'a>, + argument_numbers: &'a [i32], +) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a { move |arrays: &[Array]| { let c_value_and_grad = unsafe { - try_catch_c_ptr_expr! { + try_catch_mlx_closure_error! { mlx_sys::mlx_value_and_grad( closure.as_ptr(), argument_numbers.as_ptr(), @@ -208,6 +265,28 @@ where } } +fn build_value_and_gradient<'a, F>( + f: F, + argument_numbers: &'a [i32], +) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a +where + F: FnMut(&[Array]) -> Vec + 'a, +{ + let closure = Closure::new(f); + build_value_and_gradient_inner(closure, argument_numbers) +} + +fn build_value_and_gradient_fallible<'a, F>( + f: F, + argument_numbers: &'a [i32], +) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + let closure = Closure::new_fallible(f); + build_value_and_gradient_inner(closure, argument_numbers) +} + /// Returns a function which computes the value and gradient of `f`. pub fn value_and_grad<'a, F>( f: F, @@ -219,14 +298,24 @@ where build_value_and_gradient(f, argument_numbers) } -pub trait Grad<'a, Args, Output> { +pub fn value_and_grad_fallible<'a, F>( + f: F, + argument_numbers: &'a [i32], +) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + build_value_and_gradient_fallible(f, argument_numbers) +} + +pub trait Grad<'a, Args, Output, Err> { fn grad( self, argument_numbers: &'a [i32], ) -> impl FnMut(Args) -> Result + 'a; } -impl<'a, F> Grad<'a, &[Array], Vec> for F +impl<'a, F> Grad<'a, &[Array], Vec, ()> for F where F: FnMut(&[Array]) -> Vec + 'a, { @@ -241,7 +330,20 @@ where } } -impl<'a, F> Grad<'a, &Array, Array> for F +impl<'a, F> Grad<'a, &[Array], Vec, Exception> for F +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + #[allow(refining_impl_trait)] + fn grad( + self, + argument_numbers: &'a [i32], + ) -> impl FnMut(&[Array]) -> Result, Exception> + 'a { + build_gradient_fallible(self, argument_numbers) + } +} + +impl<'a, F> Grad<'a, &Array, Array, ()> for F where F: FnMut(&Array) -> Array + 'a, { @@ -260,7 +362,28 @@ where } } -impl<'a, F> Grad<'a, &[Array], Array> for F +impl<'a, F> Grad<'a, &Array, Array, Exception> for F +where + F: FnMut(&Array) -> Result + 'a, +{ + #[allow(refining_impl_trait)] + fn grad( + mut self, + argument_numbers: &'a [i32], + ) -> impl FnMut(&Array) -> Result + 'a { + let f = move |args: &[Array]| -> Result, Exception> { + self(&args[0]).map(|res| vec![res]) + }; + let mut g = build_gradient_fallible(f, argument_numbers); + move |args: &Array| -> Result { + let args_clone = &[args.clone()]; + let result = g(args_clone)?; + Ok(result.into_iter().next().unwrap()) + } + } +} + +impl<'a, F> Grad<'a, &[Array], Array, ()> for F where F: FnMut(&[Array]) -> Array + 'a, { @@ -278,7 +401,27 @@ where } } -impl<'a, F> Grad<'a, &Array, Vec> for F +impl<'a, F> Grad<'a, &[Array], Array, Exception> for F +where + F: FnMut(&[Array]) -> Result + 'a, +{ + #[allow(refining_impl_trait)] + fn grad( + mut self, + argument_numbers: &'a [i32], + ) -> impl FnMut(&[Array]) -> Result + 'a { + let f = move |args: &[Array]| -> Result, Exception> { + self(args).map(|res| vec![res]) + }; + let mut g = build_gradient_fallible(f, argument_numbers); + move |args: &[Array]| -> Result { + let result = g(args)?; + Ok(result.into_iter().next().unwrap()) + } + } +} + +impl<'a, F> Grad<'a, &Array, Vec, ()> for F where F: FnMut(&Array) -> Vec + 'a, { @@ -297,24 +440,41 @@ where } } +impl<'a, F> Grad<'a, &Array, Vec, Exception> for F +where + F: FnMut(&Array) -> Result, Exception> + 'a, +{ + #[allow(refining_impl_trait)] + fn grad( + mut self, + argument_numbers: &'a [i32], + ) -> impl FnMut(&Array) -> Result, Exception> + 'a { + let f = move |args: &[Array]| -> Result, Exception> { self(&args[0]) }; + let mut g = build_gradient_fallible(f, argument_numbers); + move |args: &Array| -> Result, Exception> { + let args_clone = &[args.clone()]; + let result = g(args_clone)?; + Ok(result) + } + } +} + /// Returns a function which computes the gradient of `f`. -pub fn grad<'a, F, Args, Output>( +pub fn grad<'a, F, Args, Output, Err>( f: F, argument_numbers: &'a [i32], ) -> impl FnMut(Args) -> Result + 'a where - F: Grad<'a, Args, Output>, + F: Grad<'a, Args, Output, Err>, { f.grad(argument_numbers) } #[cfg(test)] mod tests { - use crate::{ - array, - transforms::{grad, jvp, value_and_grad, vjp}, - Array, - }; + use crate::{array, error::Exception, Array}; + + use super::*; // The unit tests below are adapted from the mlx c++ codebase @@ -328,6 +488,27 @@ mod tests { assert_eq!(dout[0].item::(), 4.0f32); } + #[test] + fn test_jvp_with_error() { + let f = |inputs: &[Array]| -> Result, Exception> { + inputs[0].add(&inputs[1]).map(|res| vec![res]) + }; + + // Success case + let x = array!(1.0f32); + let y = array!(1.0f32); + let (out, dout) = jvp_fallible(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap(); + assert_eq!(out[0].item::(), 2.0f32); + assert_eq!(dout[0].item::(), 4.0f32); + + // Error case + // Use non-broadcastable shapes + let a = array!([1.0, 2.0, 3.0]); + let b = array!([4.0, 5.0]); + let result = jvp_fallible(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]); + assert!(result.is_err()); + } + #[test] fn test_vjp() { let f = |inputs: &[Array]| -> Vec { vec![&inputs[0] + &inputs[1]] }; @@ -341,7 +522,30 @@ mod tests { } #[test] - fn test_grad() { + fn test_vjp_with_error() { + let f = |inputs: &[Array]| -> Result, Exception> { + inputs[0].add(&inputs[1]).map(|res| vec![res]) + }; + + // Success case + let x = array!(1.0f32); + let y = array!(1.0f32); + let primals = vec![x, y]; + let cotangents = vec![array!(1.0f32)]; + let (out, dout) = vjp_fallible(f, &primals, &cotangents).unwrap(); + assert_eq!(out[0].item::(), 2.0f32); + assert_eq!(dout[0].item::(), 1.0f32); + + // Error case + // Use non-broadcastable shapes + let a = array!([1.0, 2.0, 3.0]); + let b = array!([4.0, 5.0]); + let result = vjp_fallible(f, &[a, b], &[array!(1.0f32)]); + assert!(result.is_err()); + } + + #[test] + fn test_value_and_grad() { let x = &[Array::from_float(1.0)]; let fun = |argin: &[Array]| -> Vec { vec![&argin[0] + 1.0] }; let argnums = &[0]; @@ -357,4 +561,25 @@ mod tests { assert_eq!(z[0].item::(), 1.0); assert_eq!(d2fdx2[0].item::(), 0.0); } + + #[test] + fn test_value_and_grad_with_error() { + let fun = |argin: &[Array]| -> Result, Exception> { + argin[0].add(array!(1.0)).map(|res| vec![res]) + }; + + // Success case + let argnums = &[0]; + let x = array!(1.0f32); + let y = array!(1.0f32); + let result = value_and_grad_fallible(fun, argnums)(&[x, y]); + assert!(result.is_ok()); + + // Error case + // Use non-broadcastable shapes + let a = array!([1.0, 2.0, 3.0]); + let b = array!([4.0, 5.0]); + let result = value_and_grad_fallible(fun, argnums)(&[a, b]); + assert!(result.is_err()); + } } diff --git a/mlx-rs/src/utils.rs b/mlx-rs/src/utils.rs index 04ca88cb..27279b98 100644 --- a/mlx-rs/src/utils.rs +++ b/mlx-rs/src/utils.rs @@ -2,7 +2,7 @@ use std::{ffi::NulError, marker::PhantomData, os::raw::c_void}; use mlx_sys::{mlx_closure, mlx_vector_array}; -use crate::{complex64, Array, FromNested}; +use crate::{complex64, error::Exception, Array, FromNested}; /// Helper method to get a string representation of an mlx object. pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option { @@ -167,10 +167,6 @@ impl<'a> ScalarOrArray<'a> for &'a Array { } } -// We can't replace the following four impls with `FromScalar` trait bound -// because compiler would complain about conflicting implementations with -// `FromNested`. - impl ScalarOrArray<'static> for bool { type Array = Array; @@ -228,6 +224,10 @@ impl<'a> Closure<'a> { } } + pub(crate) fn as_ptr(&self) -> mlx_closure { + self.c_closure + } + pub(crate) fn new(closure: F) -> Self where F: FnMut(&[Array]) -> Vec + 'a, @@ -239,8 +239,15 @@ impl<'a> Closure<'a> { } } - pub(crate) fn as_ptr(&self) -> mlx_closure { - self.c_closure + pub(crate) fn new_fallible(closure: F) -> Self + where + F: FnMut(&[Array]) -> Result, Exception> + 'a, + { + let c_closure = new_mlx_fallible_closure(closure); + Self { + c_closure, + lt_marker: PhantomData, + } } } @@ -267,6 +274,23 @@ where } } +fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + let boxed = Box::new(closure); + let raw = Box::into_raw(boxed); + let payload = raw as *mut std::ffi::c_void; + + unsafe { + mlx_sys::mlx_fallible_closure_new_with_payload( + Some(trampoline_fallible::), + payload, + Some(noop_dtor), + ) + } +} + /// Function to create a new (+1 reference) mlx_vector_array from a vector of Array fn new_mlx_vector_array(arrays: Vec) -> mlx_sys::mlx_vector_array { unsafe { @@ -309,6 +333,34 @@ where } } +extern "C" fn trampoline_fallible<'a, F>( + vector_array: mlx_sys::mlx_vector_array, + payload: *mut std::ffi::c_void, +) -> mlx_sys::mlx_vector_array_result +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + use std::ffi::CString; + + unsafe { + let raw_closure: *mut F = payload as *mut _; + let mut closure = Box::from_raw(raw_closure); + let arrays = mlx_vector_array_values(vector_array); + let result = closure(&arrays); + match result { + Ok(result) => { + let c_result = new_mlx_vector_array(result); + mlx_sys::mlx_vector_array_result_new_ok(c_result) + } + Err(exception) => { + let what = CString::new(exception.what).unwrap(); + let mlx_string = mlx_sys::mlx_string_new(what.as_ptr()); + mlx_sys::mlx_vector_array_result_new_err(mlx_string) + } + } + } +} + extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {} pub(crate) struct VectorVectorArray { diff --git a/mlx-sys/Cargo.toml b/mlx-sys/Cargo.toml index 38ce7a7c..69163b14 100644 --- a/mlx-sys/Cargo.toml +++ b/mlx-sys/Cargo.toml @@ -29,3 +29,4 @@ metal = [] [build-dependencies] bindgen = "0.69.4" cmake = "0.1.31" +cc = "1" \ No newline at end of file diff --git a/mlx-sys/build.rs b/mlx-sys/build.rs index d6ac6ae7..32c6c81c 100644 --- a/mlx-sys/build.rs +++ b/mlx-sys/build.rs @@ -2,9 +2,9 @@ extern crate cmake; use cmake::Config; use std::env; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; -fn main() { +fn build_and_link_mlx_c() -> PathBuf { let mut config = Config::new("src/mlx-c"); config.very_verbose(true); config.define("CMAKE_INSTALL_PREFIX", "."); @@ -53,6 +53,32 @@ fn main() { println!("cargo:rustc-link-lib=framework=Accelerate"); } + dst +} + +fn build_shim(mlx_c_dst: impl AsRef) { + cc::Build::new() + .cpp(true) + .flag("-std=c++17") + .flag("-Wno-deprecated-copy") + .flag("-Wno-unused-parameter") + .include("src/mlx-c") + .include("src/shim") + .include(mlx_c_dst.as_ref().join("build/include")) + .file("src/shim/result.cpp") + .file("src/shim/closure.cpp") + .compile("libmlxc_shim.a"); + + // Rebuild if the shim changes + println!("cargo:rerun-if-changed=src/shim/shim.cpp"); + + println!("cargo:rustc-link-lib=static=mlxc_shim"); +} + +fn main() { + let mlx_c_dst = build_and_link_mlx_c(); + build_shim(&mlx_c_dst); + // generate bindings let bindings = bindgen::Builder::default() .header("src/mlx-c/mlx/c/mlx.h") @@ -60,6 +86,8 @@ fn main() { .header("src/mlx-c/mlx/c/error.h") .header("src/mlx-c/mlx/c/transforms_impl.h") .clang_arg("-Isrc/mlx-c") + .header("src/shim/shim.h") + .clang_arg("-Isrc/shim") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate() .expect("Unable to generate bindings"); diff --git a/mlx-sys/src/shim/closure.cpp b/mlx-sys/src/shim/closure.cpp new file mode 100644 index 00000000..775e5a0c --- /dev/null +++ b/mlx-sys/src/shim/closure.cpp @@ -0,0 +1,36 @@ +#include + +#include "mlx/c/object.h" +#include "mlx/c/string.h" +#include "mlx/c/error.h" +#include "mlx/c/private/array.h" +#include "mlx/c/private/closure.h" +#include "mlx/c/private/string.h" +#include "mlx/c/private/utils.h" + +#include "closure.h" +#include "result.h" + +extern "C" mlx_closure mlx_fallible_closure_new_with_payload( + mlx_vector_array_result (*fun)(const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*) +) { + auto cpp_payload = std::shared_ptr(payload, dtor); + auto cpp_closure = [fun, cpp_payload](const std::vector& input) { + auto c_input = new mlx_vector_array_(input); + auto c_res = fun(c_input, cpp_payload.get()); + mlx_free(c_input); + if (mlx_vector_array_result_is_err(&c_res)) { + auto err = mlx_vector_array_result_into_err(c_res); + std::string msg = std::move(err->ctx); + mlx_free(err); + throw std::runtime_error(msg); + } + auto c_ok = mlx_vector_array_result_into_ok(c_res); + auto res = c_ok->ctx; + mlx_free(c_ok); + return res; + }; + MLX_TRY_CATCH(return new mlx_closure_(cpp_closure), return nullptr); +} \ No newline at end of file diff --git a/mlx-sys/src/shim/closure.h b/mlx-sys/src/shim/closure.h new file mode 100644 index 00000000..ea2b87ce --- /dev/null +++ b/mlx-sys/src/shim/closure.h @@ -0,0 +1,22 @@ +#ifndef MLX_C_SHIM_CLOSURE_H +#define MLX_C_SHIM_CLOSURE_H + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" + +#include "result.h" + +#ifdef __cplusplus +extern "C" { +#endif + +mlx_closure mlx_fallible_closure_new_with_payload( + mlx_vector_array_result (*fun)(const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)); + +#ifdef __cplusplus +} +#endif + +#endif // MLX_C_SHIM_CLOSURE_H \ No newline at end of file diff --git a/mlx-sys/src/shim/result.cpp b/mlx-sys/src/shim/result.cpp new file mode 100644 index 00000000..219e03eb --- /dev/null +++ b/mlx-sys/src/shim/result.cpp @@ -0,0 +1,41 @@ +#include "result.h" + +extern "C" mlx_vector_array_result mlx_vector_array_result_new_ok( + mlx_vector_array ok) { + mlx_vector_array_result result; + result.tag = mlx_result_tag_ok; + result.ok = ok; + return result; +} + +extern "C" mlx_vector_array_result mlx_vector_array_result_new_err( + mlx_string err) { + mlx_vector_array_result result; + result.tag = mlx_result_tag_err; + result.err = err; + return result; +} + +extern "C" bool mlx_vector_array_result_is_ok(mlx_vector_array_result* result) { + return result->tag == mlx_result_tag_ok; +} + +extern "C" bool mlx_vector_array_result_is_err( + mlx_vector_array_result* result) { + return result->tag == mlx_result_tag_err; +} + +extern "C" mlx_result_tag mlx_vector_array_result_get_tag( + mlx_vector_array_result* result) { + return result->tag; +} + +mlx_vector_array mlx_vector_array_result_into_ok( + mlx_vector_array_result result) { + return result.ok; +} + +extern "C" mlx_string mlx_vector_array_result_into_err( + mlx_vector_array_result result) { + return result.err; +} \ No newline at end of file diff --git a/mlx-sys/src/shim/result.h b/mlx-sys/src/shim/result.h new file mode 100644 index 00000000..4001f8b7 --- /dev/null +++ b/mlx-sys/src/shim/result.h @@ -0,0 +1,39 @@ +#ifndef MLC_X_SHIM_RESULT_H +#define MLC_X_SHIM_RESULT_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +#include + +extern "C" { +#endif + +typedef enum { + mlx_result_tag_ok, + mlx_result_tag_err, +} mlx_result_tag; + +typedef struct { + mlx_result_tag tag; + union { + mlx_vector_array ok; + mlx_string err; + }; +} mlx_vector_array_result; + +mlx_vector_array_result mlx_vector_array_result_new_ok(mlx_vector_array ok); +mlx_vector_array_result mlx_vector_array_result_new_err(mlx_string err); +mlx_result_tag mlx_vector_array_result_get_tag(mlx_vector_array_result* result); +bool mlx_vector_array_result_is_ok(mlx_vector_array_result* result); +bool mlx_vector_array_result_is_err(mlx_vector_array_result* result); +mlx_vector_array mlx_vector_array_result_into_ok( + mlx_vector_array_result result); +mlx_string mlx_vector_array_result_into_err(mlx_vector_array_result result); + +#ifdef __cplusplus +} +#endif + +#endif // MLC_X_SHIM_RESULT_H \ No newline at end of file diff --git a/mlx-sys/src/shim/shim.h b/mlx-sys/src/shim/shim.h new file mode 100644 index 00000000..9f93745e --- /dev/null +++ b/mlx-sys/src/shim/shim.h @@ -0,0 +1,7 @@ +#ifndef MLX_C_SHIM_H +#define MLX_C_SHIM_H + +#include "result.h" +#include "closure.h" + +#endif // MLX_C_SHIM_H \ No newline at end of file