Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial attempt at implementing mlx-nn NN modules #100

Merged
merged 114 commits into from
Oct 20, 2024
Merged

Conversation

minghuaw
Copy link
Collaborator

@minghuaw minghuaw commented Aug 2, 2024

This is my initial attempt at implementing the neural net modules and optimizers. The following are included in this PR

  • nn/activation
  • nn/convolution
  • nn/dropout
  • nn/linear
  • nn/sequential
  • nn/losses/
  • optimizer/sgd
  • optimizer/rmsprop

@minghuaw minghuaw marked this pull request as draft August 2, 2024 19:34
@minghuaw minghuaw changed the title Implement neural net layers Implement activation functions Aug 11, 2024
@minghuaw minghuaw changed the title Implement activation functions Implement mlx-nn activation functions Aug 11, 2024
@minghuaw minghuaw marked this pull request as ready for review August 11, 2024 01:36
@minghuaw minghuaw requested a review from dcvz August 11, 2024 01:36
@minghuaw
Copy link
Collaborator Author

minghuaw commented Aug 11, 2024

This is just an initial attempt at implementing some of the neural net components. I think some feedback on the overall API design and ergonomics would be nice @dcvz

@minghuaw
Copy link
Collaborator Author

Ended up changing Module trait to take &Array. The only cost (so far) is that an emptySequential would end up deep_clone the input array, but this should make overall usage more flexible as most (if not all) ops only require a ref input

@minghuaw minghuaw marked this pull request as draft September 4, 2024 18:12
@@ -142,3 +145,63 @@ pub fn generate_test_cases(input: TokenStream) -> TokenStream {

TokenStream::from(tests)
}

/// Derive the `ModuleParameters` trait for a struct. Mark a field with `#[param]` attribute to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we'll eventually want to have a derive(Module)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because the Module implementation would include the forward method. The associated error type was an attempt to allow user to have more customized errors but is currently limited because the c ffi can only take an mlx Exception. One thing I think that we could change to the Module trait is to provide an default noop training_mode implementation, but I didn't do this because I felt like user might forget if this is not mandatory.

/// # Panics
///
/// Panics if `alpha` is negative.
pub fn with_alpha(mut self, alpha: impl Into<Option<f32>>) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we type this to not be negative? I'm always a bit torn on that kind of typing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an existing crate that does this for f32 or should we add a new type?

Comment on lines 78 to 79
// let v = alpha * state + (1 - alpha) * square(gradient)
// return (parameter - learningRate * gradient / (sqrt(v) + eps), v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get rid of this right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was me trying to make sure the math is correct, but yeah we could get rid of this

/// A custom type to indicate whether a `Module` should include a bias or not.
/// Default to `Yes`.
#[derive(Debug, Clone, Copy, Default)]
pub enum WithBias {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be nice to eventually be able to define bias as Tensor<Shape>.. I wonder if there's some way we might already be able to do something similar with Array

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a part that I'm not too sure about. This, on the on hand, is not consistent with how we handle optional arg in other Modules or functions, but on the other hand the bias would either be zeros or random::uniform. But we could probably get rid of this and just define an associated pub const for the conv layers and linear layers (these are the only Modules that use this type right now

Copy link
Contributor

@dcvz dcvz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! Let's align on the optionals so we can have one way of doing things but I like the direction for them too.

I have the gut feeling we may have a better way to do some of the utils for types we're using, but i'm happy to try and improve that once we start writing more examples and see where it could be more ergonomic.


/// Optional parameters for the `Conv1d` module.
#[derive(Debug, Clone, Default)]
pub struct Conv1dBuilder {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcvz What about this kind of builder pattern? We could get rid of the WithBias type this way, and we could probably apply this on all Module's that take optional args? For cross_entropy, we could make it a struct and then apply this approach?

Copy link
Collaborator Author

@minghuaw minghuaw Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could have two different way of handling optional args? Use this builder pattern on all Module structs, and keep functions like cross_entropy as it is? Added a builder pattern impl of cross entropy, see that part for more details

struct TestStruct {
#[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_1)]
opt_field_1: i32,
#[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's still useful to have consts for defaults here? The fact that we can annotate means we can quickly see what those defaults are going to be without the indirection. Or are you worried about perf?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also because this is a macro, IDEs won't always be able to CMD+click into the definition.

mlx-nn/src/activation.rs Outdated Show resolved Hide resolved
mlx-macros/src/lib.rs Outdated Show resolved Hide resolved
@minghuaw minghuaw merged commit 9279d3d into main Oct 20, 2024
3 checks passed
@minghuaw minghuaw deleted the api/mlx-nn-impl branch October 20, 2024 21:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants