Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial attempt at implementing
mlx-nn
NN modules (#100)
* add Module and Sequential * added some activation fn * added compiled activation functions * removed swift references * added more activation functions * added docs * added remaining activation functions * renamed clone to deep_clone to avoid confusion and changed to public * changed Module trait to take ref * cargo fmt * make argument_numbers optional * improve doc * use ValueAndGrad trait instead of just value_and_grad fn * remove ValueAndGrad, use value_and_grad and value_and_grad_with_payload instead * add docs * Module::update is complicated by ownership * Trust synchronization on the cpp side (#102) * remove mut ref and smart ptr on rust side * use clone instead of mlx_retain * removed OwnedOrRef * changed Module trait to take ref * cargo fmt * use ValueAndGrad trait instead of just value_and_grad fn * remove ValueAndGrad, use value_and_grad and value_and_grad_with_payload instead * change OwnedOrRef to Cow * Change to ValueAndGrad trait to allow owned closure * added nested * change generic T to V * moved nested from mlx-rs to mlx-nn * initial impl of value_and_grad with HashMap * use generic instead of hard coded HashMap * use Rc<str> to avoid expensive clone * impl value_and_grad_with_hashmap * initial impl of model_value_and_grad * added mlx-nn-module crate and ModuleParameters derive macro * cargo clippy, fix, & fmt * move closure inner into the returned closure * added as_trainable_nested_value * initial attempt at impl Linear and Bilinear * tested value_and_grad * change setter fn for optional fields to have "with_" prefix * derive ModuleParameters for Sequential * cargo clippy, fix, fmt * Trust synchronization on the cpp side (#102) * remove mut ref and smart ptr on rust side * use clone instead of mlx_retain * removed OwnedOrRef * chore(example): add automatic_differentiation example (#103) * chore(mlx-c): update to 0.0.9 (#105) * cargo clippy, fix, fmt * add example in ModuleParameters doc * fixed error in doc example * added Conv2d and Conv3d * cargo fmt * added train() to Module trait & added docs to mlx-nn-module * wrap number literals in `array!()` to avoid SIGSEGV * added remaining docs * cargo fmt * impl SGD * added unit test for Sgd * change setters for optional configs * rename Sgd to SGD * Revert "rename Sgd to SGD" This reverts commit dbf6052. * added RmsProp optimizer * rename Module::train to Module::training_mode to avoid confusion * rename mod sequential to mod container * added panics section in doc * add mlx_nn::error::Error * add dropout mod * cargo clippy, fix, fmt * add unit tests for dropout * ported convolution unit tests from the swift binding * ported unit tests for linear * fixed bug and ported unit tests for activation modules * fix sigsegv caused by mixing borrow and owned value * cargo fmt * fix problem that was caused by clippy --fix * remove lint attr because the cause is already resolved * Use custom c shim for fallible closure (#116) * add support for closure with error in grads * compiled fn that should fail somehow doesn't fail after first failure * Revert "compiled fn that should fail somehow doesn't fail after first failure" This reverts commit 623ced1. * cargo clippy, fix, fmt * returns an empty mlx_vector_array to avoid invalid memory ref * propagate err with c binding fallible closure * compile fallible fn * rename *_with_error to *_fallible * cargo fmt * Restrict use of ScalarOrArray trait to operator impl and ClipBound (#108) * restrict use of ScalarOrArray trait to operator impl and ClipBound * fix doctest * Allow using named `shape` parameter in `array!()` macro (#109) * allow using array!() with shape * cargo fmt * fix error in unit test * simple c shim fails * value_and_grad is getting memory error even for success case * use custom shim c code for fallible closure * change fallible suffix to prefix and added fallible_value_and_grad_with_hashmap * revert back to use Exception for mlx-nn errors * use fallible version of compile * add cross_entropy * added binary_cross_entropy * added remaining losses * added losses unit tests * generate builder impl with option_builder proc macro * unified regular and fallible value_and_grad * moved mlx-nn-module into mlx-rs * initialize mnist example crate * change Exception::from to Excepion::custom * renamed mod optimizer to mod optimizers and do not re-export losses * change eval to take ref instead of mut ref * fix lifetime issue * initial nn example * suppress clippy::module_inception * cargo fmt * added eval_params and async_eval_params * allow taking ref of options * added smoke tests * remove commented code * added missing `#[option_builder]` * new attempt at builder pattern on Module * attempt new optional args handling on cross entropy * attemp builder pattern on RmsProp * impl basic generate_builder * moved to builder pattern for activation, losses and optimizer * removed WithBias and added builder pattern for conv and linear * cargo clippy and fmt * added builder for dropout modules * cargo clippy and fmt * removed unused generic builder * fixed wrong default due to derive * fix error caused by mixing operator and arith functions * generate Default impl if no mandatory field & hide internal macro from doc * added doc for GenerateBuilder macro * moved internal macros into separate crate --------- Co-authored-by: David Chavez <david@dcvz.io>
- Loading branch information