Skip to content

Commit

Permalink
chore(example): add automatic_differentiation example (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Sep 4, 2024
1 parent 3f62189 commit 4b01ce6
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mlx-rs/examples/tutorial.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use mlx_rs::transforms::grad;
use mlx_rs::{Array, Dtype};

fn scalar_basics() {
Expand Down Expand Up @@ -71,7 +72,26 @@ fn array_basics() {
println!("{}", z); // implicit evaluation
}

fn automatic_differentiation() {
fn f(x: &Array) -> Array {
x.square()
}

fn calculate_grad(func: impl Fn(&Array) -> Array, arg: &Array) -> Array {
grad(&func, &[0])(arg).unwrap()
}

let x = Array::from(1.5);

let mut dfdx = calculate_grad(f, &x);

Check warning on line 86 in mlx-rs/examples/tutorial.rs

View workflow job for this annotation

GitHub Actions / tests (stable)

variable does not need to be mutable

Check warning on line 86 in mlx-rs/examples/tutorial.rs

View workflow job for this annotation

GitHub Actions / tests (1.75.0)

variable does not need to be mutable
assert_eq!(dfdx.item::<f32>(), 2.0 * 1.5);

let mut dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x);

Check warning on line 89 in mlx-rs/examples/tutorial.rs

View workflow job for this annotation

GitHub Actions / tests (stable)

variable does not need to be mutable

Check warning on line 89 in mlx-rs/examples/tutorial.rs

View workflow job for this annotation

GitHub Actions / tests (1.75.0)

variable does not need to be mutable
assert_eq!(dfdx2.item::<f32>(), 2.0);
}

fn main() {
scalar_basics();
array_basics();
automatic_differentiation();
}

0 comments on commit 4b01ce6

Please sign in to comment.