-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial attempt at implementing mlx-nn
NN modules
#100
Conversation
mlx-nn
activation functions
This is just an initial attempt at implementing some of the neural net components. I think some feedback on the overall API design and ergonomics would be nice @dcvz |
Ended up changing |
* remove mut ref and smart ptr on rust side * use clone instead of mlx_retain * removed OwnedOrRef
@@ -142,3 +145,63 @@ pub fn generate_test_cases(input: TokenStream) -> TokenStream { | |||
|
|||
TokenStream::from(tests) | |||
} | |||
|
|||
/// Derive the `ModuleParameters` trait for a struct. Mark a field with `#[param]` attribute to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we'll eventually want to have a derive(Module)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so because the Module
implementation would include the forward
method. The associated error type was an attempt to allow user to have more customized errors but is currently limited because the c ffi can only take an mlx Exception
. One thing I think that we could change to the Module
trait is to provide an default noop training_mode
implementation, but I didn't do this because I felt like user might forget if this is not mandatory.
mlx-nn/src/optimizers/rmsprop.rs
Outdated
/// # Panics | ||
/// | ||
/// Panics if `alpha` is negative. | ||
pub fn with_alpha(mut self, alpha: impl Into<Option<f32>>) -> Self { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we type this to not be negative? I'm always a bit torn on that kind of typing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an existing crate that does this for f32
or should we add a new type?
mlx-nn/src/optimizers/rmsprop.rs
Outdated
// let v = alpha * state + (1 - alpha) * square(gradient) | ||
// return (parameter - learningRate * gradient / (sqrt(v) + eps), v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get rid of this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was me trying to make sure the math is correct, but yeah we could get rid of this
mlx-nn/src/utils.rs
Outdated
/// A custom type to indicate whether a `Module` should include a bias or not. | ||
/// Default to `Yes`. | ||
#[derive(Debug, Clone, Copy, Default)] | ||
pub enum WithBias { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be nice to eventually be able to define bias as Tensor<Shape>
.. I wonder if there's some way we might already be able to do something similar with Array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a part that I'm not too sure about. This, on the on hand, is not consistent with how we handle optional arg in other Module
s or functions, but on the other hand the bias would either be zeros
or random::uniform
. But we could probably get rid of this and just define an associated pub const
for the conv layers and linear layers (these are the only Module
s that use this type right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! Let's align on the optionals so we can have one way of doing things but I like the direction for them too.
I have the gut feeling we may have a better way to do some of the utils for types we're using, but i'm happy to try and improve that once we start writing more examples and see where it could be more ergonomic.
|
||
/// Optional parameters for the `Conv1d` module. | ||
#[derive(Debug, Clone, Default)] | ||
pub struct Conv1dBuilder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcvz What about this kind of builder pattern? We could get rid of the WithBias
type this way, and we could probably apply this on all Module
's that take optional args? For cross_entropy
, we could make it a struct and then apply this approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we could have two different way of handling optional args? Use this builder pattern on all Added a builder pattern impl of cross entropy, see that part for more detailsModule
structs, and keep functions like cross_entropy
as it is?
struct TestStruct { | ||
#[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_1)] | ||
opt_field_1: i32, | ||
#[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's still useful to have consts for defaults here? The fact that we can annotate means we can quickly see what those defaults are going to be without the indirection. Or are you worried about perf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also because this is a macro, IDEs won't always be able to CMD+click into the definition.
This is my initial attempt at implementing the neural net modules and optimizers. The following are included in this PR
nn/activation
nn/convolution
nn/dropout
nn/linear
nn/sequential
nn/losses/
optimizer/sgd
optimizer/rmsprop