diff --git a/mlx-rs/src/transforms/mod.rs b/mlx-rs/src/transforms/mod.rs index 3f4cbe5a..73eb4b0a 100644 --- a/mlx-rs/src/transforms/mod.rs +++ b/mlx-rs/src/transforms/mod.rs @@ -215,69 +215,12 @@ pub fn value_and_grad<'a, F>( ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a where F: FnMut(&[Array]) -> Vec + 'a, - ArgNums: IntoOption<&'a [i32]>, { - // refining_impl_trait is fine here because we have restricted the Args and Output types - // in the generics. - #[allow(refining_impl_trait)] - fn value_and_grad( - self, - argument_numbers: ArgNums, - ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a { - let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); - build_value_and_gradient(self, argument_numbers) - } -} - -impl<'a, F, T> ValueAndGrad<'a, (&[Array], T), (), (Vec, Vec)> for F -where - F: FnMut((&[Array], T)) -> Vec + 'a, - T: Copy, -{ - /// The `argument_numbers` parameter is not used in this implementation. - #[allow(refining_impl_trait)] - fn value_and_grad( - self, - _argument_numbers: (), - ) -> impl FnMut((&[Array], T)) -> Result<(Vec, Vec), Exception> + 'a { - let mut f = self; - move |(parameters, arrays): (&[Array], T)| -> Result<(Vec, Vec), Exception> { - let inner = |params: &[Array]| -> Vec { (f)((params, arrays)) }; - let argument_numbers = (0..parameters.len() as i32).collect::>(); - - let closure = Closure::new(inner); - let c_value_and_grad = unsafe { - try_catch_c_ptr_expr! { - mlx_sys::mlx_value_and_grad( - closure.as_ptr(), - argument_numbers.as_ptr(), - argument_numbers.len(), - ) - } - }; - - let result = value_and_gradient(c_value_and_grad, parameters.iter())?; - Ok(result) - } - } -} - -/// Returns a function which computes the value and gradient of `f`. -pub fn value_and_grad<'a, F, Args, ArgNums, Output>( - f: F, - argument_numbers: ArgNums, -) -> impl FnMut(Args) -> Result + 'a -where - F: ValueAndGrad<'a, Args, ArgNums, Output>, -{ - f.value_and_grad(argument_numbers) + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + build_value_and_gradient(f, argument_numbers) } -/// Returns a function that computes the gradient and result of `f`, computing the gradient with -/// respect to the first argument. -/// -/// Note that this allows any parameters `` s they will not be part of the gradient. -pub fn value_and_grad_with_payload<'a, F, Arr, T>( +pub fn value_and_grad_with_payload<'a, F, T>( mut f: F, ) -> impl FnMut((&[Array], T)) -> Result<(Vec, Vec), Exception> + 'a where @@ -304,8 +247,6 @@ where } } - - pub trait Grad<'a, Args, Output> { fn grad( self,