Skip to content

Commit

Permalink
Initial attempt at implementing mlx-nn NN modules (#100)
Browse files Browse the repository at this point in the history
* add Module and Sequential

* added some activation fn

* added compiled activation functions

* removed swift references

* added more activation functions

* added docs

* added remaining activation functions

* renamed clone to deep_clone to avoid confusion and changed to public

* changed Module trait to take ref

* cargo fmt

* make argument_numbers optional

* improve doc

* use ValueAndGrad trait instead of just value_and_grad fn

* remove ValueAndGrad, use value_and_grad and value_and_grad_with_payload instead

* add docs

* Module::update is complicated by ownership

* Trust synchronization on the cpp side (#102)

* remove mut ref and smart ptr on rust side

* use clone instead of mlx_retain

* removed OwnedOrRef

* changed Module trait to take ref

* cargo fmt

* use ValueAndGrad trait instead of just value_and_grad fn

* remove ValueAndGrad, use value_and_grad and value_and_grad_with_payload instead

* change OwnedOrRef to Cow

* Change to ValueAndGrad trait to allow owned closure

* added nested

* change generic T to V

* moved nested from mlx-rs to mlx-nn

* initial impl of value_and_grad with HashMap

* use generic instead of hard coded HashMap

* use Rc<str> to avoid expensive clone

* impl value_and_grad_with_hashmap

* initial impl of model_value_and_grad

* added mlx-nn-module crate and ModuleParameters derive macro

* cargo clippy, fix, & fmt

* move closure inner into the returned closure

* added as_trainable_nested_value

* initial attempt at impl Linear and Bilinear

* tested value_and_grad

* change setter fn for optional fields to have "with_" prefix

* derive ModuleParameters for Sequential

* cargo clippy, fix, fmt

* Trust synchronization on the cpp side (#102)

* remove mut ref and smart ptr on rust side

* use clone instead of mlx_retain

* removed OwnedOrRef

* chore(example): add automatic_differentiation example (#103)

* chore(mlx-c): update to 0.0.9 (#105)

* cargo clippy, fix, fmt

* add example in ModuleParameters doc

* fixed error in doc example

* added Conv2d and Conv3d

* cargo fmt

* added train() to Module trait & added docs to mlx-nn-module

* wrap number literals in `array!()` to avoid SIGSEGV

* added remaining docs

* cargo fmt

* impl SGD

* added unit test for Sgd

* change setters for optional configs

* rename Sgd to SGD

* Revert "rename Sgd to SGD"

This reverts commit dbf6052.

* added RmsProp optimizer

* rename Module::train to Module::training_mode to avoid confusion

* rename mod sequential to mod container

* added panics section in doc

* add mlx_nn::error::Error

* add dropout mod

* cargo clippy, fix, fmt

* add unit tests for dropout

* ported convolution unit tests from the swift binding

* ported unit tests for linear

* fixed bug and ported unit tests for activation modules

* fix sigsegv caused by mixing borrow and owned value

* cargo fmt

* fix problem that was caused by clippy --fix

* remove lint attr because the cause is already resolved

* Use custom c shim for fallible closure (#116)

* add support for closure with error in grads

* compiled fn that should fail somehow doesn't fail after first failure

* Revert "compiled fn that should fail somehow doesn't fail after first failure"

This reverts commit 623ced1.

* cargo clippy, fix, fmt

* returns an empty mlx_vector_array to avoid invalid memory ref

* propagate err with c binding fallible closure

* compile fallible fn

* rename *_with_error to *_fallible

* cargo fmt

* Restrict use of ScalarOrArray trait to operator impl and ClipBound (#108)

* restrict use of ScalarOrArray trait to operator impl and ClipBound

* fix doctest

* Allow using named `shape` parameter in `array!()` macro (#109)

* allow using array!() with shape

* cargo fmt

* fix error in unit test

* simple c shim fails

* value_and_grad is getting memory error even for success case

* use custom shim c code for fallible closure

* change fallible suffix to prefix and added fallible_value_and_grad_with_hashmap

* revert back to use Exception for mlx-nn errors

* use fallible version of compile

* add cross_entropy

* added binary_cross_entropy

* added remaining losses

* added losses unit tests

* generate builder impl with option_builder proc macro

* unified regular and fallible value_and_grad

* moved mlx-nn-module into mlx-rs

* initialize mnist example crate

* change Exception::from to Excepion::custom

* renamed mod optimizer to mod optimizers and do not re-export losses

* change eval to take ref instead of mut ref

* fix lifetime issue

* initial nn example

* suppress clippy::module_inception

* cargo fmt

* added eval_params and async_eval_params

* allow taking ref of options

* added smoke tests

* remove commented code

* added missing `#[option_builder]`

* new attempt at builder pattern on Module

* attempt new optional args handling on cross entropy

* attemp builder pattern on RmsProp

* impl basic generate_builder

* moved to builder pattern for activation, losses and optimizer

* removed WithBias and added builder pattern for conv and linear

* cargo clippy and fmt

* added builder for dropout modules

* cargo clippy and fmt

* removed unused generic builder

* fixed wrong default due to derive

* fix error caused by mixing operator and arith functions

* generate Default impl if no mandatory field & hide internal macro from doc

* added doc for GenerateBuilder macro

* moved internal macros into separate crate

---------

Co-authored-by: David Chavez <david@dcvz.io>
  • Loading branch information
minghuaw and dcvz authored Oct 20, 2024
1 parent f828af3 commit 9279d3d
Show file tree
Hide file tree
Showing 58 changed files with 6,812 additions and 239 deletions.
17 changes: 15 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@

[workspace]
members = ["mlx-macros", "mlx-sys", "mlx-rs", "mlx-nn"]
members = [
"mlx-macros",
"mlx-sys",
"mlx-rs",
"mlx-nn",
"examples/*", "mlx-internal-macros",
]

resolver = "2"

[workspace.dependencies]
# workspace local dependencies
mlx-sys = { version = "0.0.9", path = "mlx-sys" }
mlx-macros = { version = "0.1.0", path = "mlx-macros" }
mlx-rs = { version = "0.14.0", path = "mlx-rs" }
mlx-internal-macros = { version = "0.1.0", path = "mlx-internal-macros" }
mlx-rs = { version = "0.14.0", path = "mlx-rs" }
mlx-nn = { version = "0.1.0", path = "mlx-nn" }

# external dependencies
thiserror = "1"
float_eq = "1"
8 changes: 8 additions & 0 deletions examples/mnist/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "mnist"
version = "0.1.0"
edition = "2021"

[dependencies]
mlx-rs.workspace = true
mlx-nn.workspace = true
71 changes: 71 additions & 0 deletions examples/mnist/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use mlx_nn::{
losses::{CrossEntropy, LossReduction},
module_value_and_grad,
optimizers::Optimizer,
};
use mlx_rs::{
array,
error::Exception,
module::{Module, ModuleParameters},
transforms::eval_params,
Array,
};

/// MLP model
mod mlp;

/// Retrieves MNIST dataset
mod mnist;

#[derive(Clone)]
struct Loader {}

impl Iterator for Loader {
type Item = (Array, Array);

fn next(&mut self) -> Option<Self::Item> {
todo!()
}
}

fn load_training_data() -> Result<Loader, Box<dyn std::error::Error>> {
todo!()
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let num_layers = 3;
let input_dim = 784;
let hidden_dim = 256;
let output_dim = 10;
let lr = 1e-2;
let num_epochs = 10;

let loader = load_training_data()?;
let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, output_dim)?;

let cross_entropy = CrossEntropy::builder()
.reduction(LossReduction::Mean)
.build()?;
let loss_fn = |model: &mlp::Mlp, (x, y): (&Array, &Array)| -> Result<Array, Exception> {
let y_pred = model.forward(x)?;
cross_entropy.apply(y_pred, y)
};
let mut loss_and_grad_fn = module_value_and_grad(loss_fn);

let mut optimizer = mlx_nn::optimizers::Sgd::new(lr);

for _ in 0..num_epochs {
let mut loss_sum = array!(0.0);
for (x, y) in loader.clone() {
let (loss, grad) = loss_and_grad_fn(&mut model, (&x, &y))?;
optimizer.update(&mut model, grad);
eval_params(model.parameters())?;

loss_sum += loss;
}

println!("Epoch: {}, Loss sum: {}", num_epochs, loss_sum);
}

Ok(())
}
50 changes: 50 additions & 0 deletions examples/mnist/src/mlp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use mlx_nn::{macros::ModuleParameters, module::Param, Linear, Relu, Sequential};
use mlx_rs::{error::Exception, module::Module, Array};

#[derive(Debug, ModuleParameters)]
pub struct Mlp {
#[param]
pub layers: Param<Sequential>,
}

impl Module for Mlp {
type Error = Exception;

fn forward(&self, x: &Array) -> Result<Array, Self::Error> {
self.layers.forward(x)
}

fn training_mode(&mut self, mode: bool) {
self.layers.training_mode(mode);
}
}

impl Mlp {
pub fn new(
num_layers: usize,
input_dim: i32,
hidden_dim: i32,
output_dim: i32,
) -> Result<Self, Exception> {
let mut layers = Sequential::new();

// Add the input layer
layers = layers
.append(Linear::new(input_dim, hidden_dim)?)
.append(Relu);

// Add the hidden layers
for _ in 1..num_layers {
layers = layers
.append(Linear::new(hidden_dim, hidden_dim)?)
.append(Relu);
}

// Add the output layer
layers = layers.append(Linear::new(hidden_dim, output_dim)?);

Ok(Self {
layers: Param::new(layers),
})
}
}
1 change: 1 addition & 0 deletions examples/mnist/src/mnist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// TODO
13 changes: 13 additions & 0 deletions mlx-internal-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "mlx-internal-macros"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = { version = "2.0.60", features = ["full"] }
quote = "1.0"
darling = "0.20"
proc-macro2 = "1.0"
199 changes: 199 additions & 0 deletions mlx-internal-macros/src/generate_builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
use darling::FromAttributes;
use quote::quote;
use syn::{Attribute, Ident, ItemStruct, Path};

#[derive(Debug, FromAttributes)]
#[darling(attributes(generate_builder))]
struct StructAttr {
generate_build_fn: Option<bool>,
}

#[derive(Debug, FromAttributes)]
#[darling(attributes(optional))]
struct FieldAttr {
default_value: Option<Path>,
skip: Option<bool>,
}

fn attrs_contains_optional(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path().is_ident("optional"))
}

pub(crate) fn expand_generate_builder(
input: &ItemStruct,
) -> Result<proc_macro2::TokenStream, Box<dyn std::error::Error>> {
let generate_build_fn = StructAttr::from_attributes(&input.attrs)?
.generate_build_fn
.unwrap_or(true);
let struct_ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let mut optional_field_idents = Vec::new();
let mut optional_field_types = Vec::new();
let mut optional_field_defaults = Vec::new();
let mut optional_field_skip = Vec::new();
let mut mandatory_field_idents = Vec::new();
let mut mandatory_field_types = Vec::new();

for field in input.fields.iter() {
if attrs_contains_optional(&field.attrs) {
let field_attr = FieldAttr::from_attributes(&field.attrs)?;
let skip = field_attr.skip.unwrap_or(false);
if skip && generate_build_fn {
return Err("Skip is not allowed when build function is generated".into());
}
optional_field_skip.push(skip);

optional_field_idents.push(
field
.ident
.as_ref()
.ok_or("Only named fields are supported")?,
);
optional_field_types.push(&field.ty);
if generate_build_fn {
optional_field_defaults.push(field_attr.default_value);
}
} else {
mandatory_field_idents.push(&field.ident);
mandatory_field_types.push(&field.ty);
}
}

let builder_ident = Ident::new(&format!("{}Builder", struct_ident), struct_ident.span());
let modified_optional_field_types = optional_field_types
.iter()
.zip(optional_field_skip.iter())
.map(|(field_type, skip)| {
if !skip {
quote! { Option<#field_type> }
} else {
quote! { #field_type }
}
});

let builder_struct_doc = format!("Builder for [`{}`]", struct_ident);
let field_doc = format!("See [`{}`] for more details", struct_ident);
let builder_struct = quote! {
#[doc = #builder_struct_doc]
#[derive(Debug, Clone, Default)]
pub struct #builder_ident #ty_generics #where_clause {
#(
#[doc = #field_doc]
pub #optional_field_idents: #modified_optional_field_types,
)*
}
};

let builder_new_doc = format!("Create a new [`{}`]", builder_ident);
let struct_builder_doc = format!(
"Create a new [`{}`] builder with the default values",
struct_ident
);

let builder_init = quote! {
impl #impl_generics #builder_ident #ty_generics #where_clause {
#[doc = #builder_new_doc]
pub fn new() -> Self {
Self::default()
}
}

impl #impl_generics #struct_ident #ty_generics #where_clause {
#[doc = #struct_builder_doc]
pub fn builder() -> #builder_ident #ty_generics {
#builder_ident::new()
}
}
};

let builder_setter_docs = optional_field_idents
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_ident, skip)| {
if !skip {
Some(format!("Set the value of `{:?}`", field_ident))
} else {
None
}
});
let filtered_optional_field_idents = optional_field_idents
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_ident, skip)| if !skip { Some(field_ident) } else { None });
let filtered_optional_field_types = optional_field_types
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_type, skip)| if !skip { Some(field_type) } else { None });

let builder_setters = quote! {
impl #impl_generics #builder_ident #ty_generics #where_clause {
#(
#[doc = #builder_setter_docs]
pub fn #filtered_optional_field_idents(mut self, value: impl Into<#filtered_optional_field_types>) -> Self {
self.#filtered_optional_field_idents = Some(value.into());
self
}
)*
}
};

let builder_build = if generate_build_fn {
let builder_build_doc = format!("Build a new [`{}`]", struct_ident);
let struct_new_doc = format!("Create a new [`{}`] with default values", struct_ident);
let optional_field_defaults: Vec<_> = optional_field_defaults
.iter()
.map(|default| {
default
.clone()
.ok_or("Default value must be supplied to generate build function")
})
.collect::<Result<Vec<_>, _>>()?;

quote! {
impl #impl_generics #builder_ident #ty_generics #where_clause {
#[doc = #builder_build_doc]
pub fn build(self, #(#mandatory_field_idents: #mandatory_field_types),*) -> #struct_ident #ty_generics {
#struct_ident {
#(
#mandatory_field_idents,
)*
#(
#optional_field_idents: self.#optional_field_idents.unwrap_or_else(|| #optional_field_defaults),
)*
}
}
}

impl #impl_generics #struct_ident #ty_generics #where_clause {
#[doc = #struct_new_doc]
pub fn new(#(#mandatory_field_idents: #mandatory_field_types),*) -> Self {
Self::builder().build(#(#mandatory_field_idents),*)
}
}
}
} else {
quote! {}
};

// Only implement Default trait if no mandatory fields are present
let default_impl = if mandatory_field_idents.is_empty() && generate_build_fn {
quote! {
impl #impl_generics Default for #struct_ident #ty_generics #where_clause {
fn default() -> Self {
Self::new()
}
}
}
} else {
quote! {}
};

Ok(quote! {
#builder_struct
#builder_init
#builder_setters
#builder_build
#default_impl
})
}
Loading

0 comments on commit 9279d3d

Please sign in to comment.