From c662dd8f298dfeef8bcb33dfbda7fe369774d516 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Mon, 2 Sep 2024 20:40:14 -0700 Subject: [PATCH] use ValueAndGrad trait instead of just value_and_grad fn --- mlx-rs/src/transforms/mod.rs | 57 ++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/mlx-rs/src/transforms/mod.rs b/mlx-rs/src/transforms/mod.rs index 972da553..3f4cbe5a 100644 --- a/mlx-rs/src/transforms/mod.rs +++ b/mlx-rs/src/transforms/mod.rs @@ -215,9 +215,62 @@ pub fn value_and_grad<'a, F>( ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a where F: FnMut(&[Array]) -> Vec + 'a, + ArgNums: IntoOption<&'a [i32]>, { - let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); - build_value_and_gradient(f, argument_numbers) + // refining_impl_trait is fine here because we have restricted the Args and Output types + // in the generics. + #[allow(refining_impl_trait)] + fn value_and_grad( + self, + argument_numbers: ArgNums, + ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a { + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + build_value_and_gradient(self, argument_numbers) + } +} + +impl<'a, F, T> ValueAndGrad<'a, (&[Array], T), (), (Vec, Vec)> for F +where + F: FnMut((&[Array], T)) -> Vec + 'a, + T: Copy, +{ + /// The `argument_numbers` parameter is not used in this implementation. + #[allow(refining_impl_trait)] + fn value_and_grad( + self, + _argument_numbers: (), + ) -> impl FnMut((&[Array], T)) -> Result<(Vec, Vec), Exception> + 'a { + let mut f = self; + move |(parameters, arrays): (&[Array], T)| -> Result<(Vec, Vec), Exception> { + let inner = |params: &[Array]| -> Vec { (f)((params, arrays)) }; + let argument_numbers = (0..parameters.len() as i32).collect::>(); + + let closure = Closure::new(inner); + let c_value_and_grad = unsafe { + try_catch_c_ptr_expr! { + mlx_sys::mlx_value_and_grad( + closure.as_ptr(), + argument_numbers.as_ptr(), + argument_numbers.len(), + ) + } + }; + + let result = value_and_gradient(c_value_and_grad, parameters.iter())?; + Ok(result) + } + } +} + +/// Returns a function which computes the value and gradient of `f`. +pub fn value_and_grad<'a, F, Args, ArgNums, Output>( + f: F, + argument_numbers: ArgNums, +) -> impl FnMut(Args) -> Result + 'a +where + F: ValueAndGrad<'a, Args, ArgNums, Output>, +{ + f.value_and_grad(argument_numbers) } /// Returns a function that computes the gradient and result of `f`, computing the gradient with