Skip to content

Commit

Permalink
use ValueAndGrad trait instead of just value_and_grad fn
Browse files Browse the repository at this point in the history
  • Loading branch information
minghuaw committed Sep 4, 2024
1 parent 53198a3 commit c662dd8
Showing 1 changed file with 55 additions and 2 deletions.
57 changes: 55 additions & 2 deletions mlx-rs/src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,62 @@ pub fn value_and_grad<'a, F>(
) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
ArgNums: IntoOption<&'a [i32]>,
{
let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]);
build_value_and_gradient(f, argument_numbers)
// refining_impl_trait is fine here because we have restricted the Args and Output types
// in the generics.
#[allow(refining_impl_trait)]
fn value_and_grad(
self,
argument_numbers: ArgNums,
) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a {
let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]);
build_value_and_gradient(self, argument_numbers)
}
}

impl<'a, F, T> ValueAndGrad<'a, (&[Array], T), (), (Vec<Array>, Vec<Array>)> for F
where
F: FnMut((&[Array], T)) -> Vec<Array> + 'a,
T: Copy,
{
/// The `argument_numbers` parameter is not used in this implementation.
#[allow(refining_impl_trait)]
fn value_and_grad(
self,
_argument_numbers: (),
) -> impl FnMut((&[Array], T)) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a {
let mut f = self;
move |(parameters, arrays): (&[Array], T)| -> Result<(Vec<Array>, Vec<Array>), Exception> {
let inner = |params: &[Array]| -> Vec<Array> { (f)((params, arrays)) };
let argument_numbers = (0..parameters.len() as i32).collect::<Vec<_>>();

let closure = Closure::new(inner);
let c_value_and_grad = unsafe {
try_catch_c_ptr_expr! {
mlx_sys::mlx_value_and_grad(
closure.as_ptr(),
argument_numbers.as_ptr(),
argument_numbers.len(),
)
}
};

let result = value_and_gradient(c_value_and_grad, parameters.iter())?;
Ok(result)
}
}
}

/// Returns a function which computes the value and gradient of `f`.
pub fn value_and_grad<'a, F, Args, ArgNums, Output>(
f: F,
argument_numbers: ArgNums,
) -> impl FnMut(Args) -> Result<Output, Exception> + 'a
where
F: ValueAndGrad<'a, Args, ArgNums, Output>,
{
f.value_and_grad(argument_numbers)
}

/// Returns a function that computes the gradient and result of `f`, computing the gradient with
Expand Down

0 comments on commit c662dd8

Please sign in to comment.