Skip to content

Commit

Permalink
remove ValueAndGrad, use value_and_grad and value_and_grad_with_paylo…
Browse files Browse the repository at this point in the history
…ad instead
  • Loading branch information
minghuaw committed Sep 4, 2024
1 parent c662dd8 commit b31fb94
Showing 1 changed file with 3 additions and 62 deletions.
65 changes: 3 additions & 62 deletions mlx-rs/src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,69 +215,12 @@ pub fn value_and_grad<'a, F>(
) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a
where
F: FnMut(&[Array]) -> Vec<Array> + 'a,
ArgNums: IntoOption<&'a [i32]>,
{
// refining_impl_trait is fine here because we have restricted the Args and Output types
// in the generics.
#[allow(refining_impl_trait)]
fn value_and_grad(
self,
argument_numbers: ArgNums,
) -> impl FnMut(&[Array]) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a {
let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]);
build_value_and_gradient(self, argument_numbers)
}
}

impl<'a, F, T> ValueAndGrad<'a, (&[Array], T), (), (Vec<Array>, Vec<Array>)> for F
where
F: FnMut((&[Array], T)) -> Vec<Array> + 'a,
T: Copy,
{
/// The `argument_numbers` parameter is not used in this implementation.
#[allow(refining_impl_trait)]
fn value_and_grad(
self,
_argument_numbers: (),
) -> impl FnMut((&[Array], T)) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a {
let mut f = self;
move |(parameters, arrays): (&[Array], T)| -> Result<(Vec<Array>, Vec<Array>), Exception> {
let inner = |params: &[Array]| -> Vec<Array> { (f)((params, arrays)) };
let argument_numbers = (0..parameters.len() as i32).collect::<Vec<_>>();

let closure = Closure::new(inner);
let c_value_and_grad = unsafe {
try_catch_c_ptr_expr! {
mlx_sys::mlx_value_and_grad(
closure.as_ptr(),
argument_numbers.as_ptr(),
argument_numbers.len(),
)
}
};

let result = value_and_gradient(c_value_and_grad, parameters.iter())?;
Ok(result)
}
}
}

/// Returns a function which computes the value and gradient of `f`.
pub fn value_and_grad<'a, F, Args, ArgNums, Output>(
f: F,
argument_numbers: ArgNums,
) -> impl FnMut(Args) -> Result<Output, Exception> + 'a
where
F: ValueAndGrad<'a, Args, ArgNums, Output>,
{
f.value_and_grad(argument_numbers)
let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]);
build_value_and_gradient(f, argument_numbers)
}

/// Returns a function that computes the gradient and result of `f`, computing the gradient with
/// respect to the first argument.
///
/// Note that this allows any parameters `<T>` s they will not be part of the gradient.
pub fn value_and_grad_with_payload<'a, F, Arr, T>(
pub fn value_and_grad_with_payload<'a, F, T>(
mut f: F,
) -> impl FnMut((&[Array], T)) -> Result<(Vec<Array>, Vec<Array>), Exception> + 'a
where
Expand All @@ -304,8 +247,6 @@ where
}
}



pub trait Grad<'a, Args, Output> {
fn grad(
self,
Expand Down

0 comments on commit b31fb94

Please sign in to comment.