Skip to content

Commit

Permalink
rename Module::train to Module::training_mode to avoid confusion
Browse files Browse the repository at this point in the history
  • Loading branch information
minghuaw committed Sep 18, 2024
1 parent 9e25f62 commit 6c9cc26
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion mlx-nn-module/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub trait Module: ModuleParameters {
/// Training mode only applies to certain layers. For example, dropout layers applies a random
/// mask in training mode, but is the identity in evaluation mode. Implementations of nested
/// modules should propagate the training mode to all child modules.
fn train(&mut self, mode: bool);
fn training_mode(&mut self, mode: bool);
}

/// Trait for accessing and updating module parameters.
Expand Down
38 changes: 19 additions & 19 deletions mlx-nn/src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ impl Module for Glu {
glu(x, self.axis)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the element-wise logistic sigmoid.
Expand Down Expand Up @@ -330,7 +330,7 @@ impl Module for Sigmoid {
Ok(sigmoid(x))
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Mish function, element-wise.
Expand Down Expand Up @@ -367,7 +367,7 @@ impl Module for Mish {
mish(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Rectified Linear Unit.
Expand Down Expand Up @@ -400,7 +400,7 @@ impl Module for Relu {
relu(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Leaky Rectified Linear Unit.
Expand Down Expand Up @@ -440,7 +440,7 @@ impl Module for LeakyReLU {
leaky_relu(x, self.neg_slope)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Rectified Linear Unit 6.
Expand Down Expand Up @@ -473,7 +473,7 @@ impl Module for Relu6 {
relu6(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Softmax function.
Expand Down Expand Up @@ -516,7 +516,7 @@ impl Module for Softmax {
}
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Softplus function.
Expand Down Expand Up @@ -549,7 +549,7 @@ impl Module for Softplus {
softplus(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Softsign function.
Expand Down Expand Up @@ -582,7 +582,7 @@ impl Module for Softsign {
softsign(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Continuously Differentiable Exponential Linear Unit.
Expand Down Expand Up @@ -623,7 +623,7 @@ impl Module for Celu {
celu(x, self.alpha)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Sigmoid Linear Unit. Also known as Swish.
Expand Down Expand Up @@ -656,7 +656,7 @@ impl Module for Silu {
silu(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Log Softmax function.
Expand Down Expand Up @@ -696,7 +696,7 @@ impl Module for LogSoftmax {
log_softmax(x, self.axis)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Log Sigmoid function.
Expand Down Expand Up @@ -729,7 +729,7 @@ impl Module for LogSigmoid {
log_sigmoid(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the element-wise parametric ReLU.
Expand Down Expand Up @@ -760,7 +760,7 @@ impl Module for Prelu {
prelu(x, &self.alpha)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Variants of Gaussian Error Linear Units function.
Expand Down Expand Up @@ -820,7 +820,7 @@ impl Module for Gelu {
}
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the hyperbolic tangent function
Expand All @@ -847,7 +847,7 @@ impl Module for Tanh {
Ok(mlx_rs::ops::tanh(x))
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the hardswish function, element-wise
Expand Down Expand Up @@ -878,7 +878,7 @@ impl Module for HardSwish {
hard_swish(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Step Activation Function.
Expand Down Expand Up @@ -921,7 +921,7 @@ impl Module for Step {
step(x, self.threshold)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies the Scaled Exponential Linear Unit.
Expand Down Expand Up @@ -954,7 +954,7 @@ impl Module for Selu {
selu(x)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/* -------------------------------------------------------------------------- */
Expand Down
6 changes: 3 additions & 3 deletions mlx-nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl Module for Conv1d {
Ok(y)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies a 2-dimensional convolution over the multi-channel input image.
Expand Down Expand Up @@ -218,7 +218,7 @@ impl Module for Conv2d {
Ok(y)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies a 3-dimensional convolution over the multi-channel input image.
Expand Down Expand Up @@ -332,5 +332,5 @@ impl Module for Conv3d {
Ok(y)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}
4 changes: 2 additions & 2 deletions mlx-nn/src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Module for Linear {
}
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}

/// Applies a bilinear transformation to the inputs.
Expand Down Expand Up @@ -157,5 +157,5 @@ impl Module for Bilinear {
Ok(y)
}

fn train(&mut self, _: bool) {}
fn training_mode(&mut self, _: bool) {}
}
6 changes: 4 additions & 2 deletions mlx-nn/src/sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ impl Module for Sequential {
}
}

fn train(&mut self, mode: bool) {
self.layers.iter_mut().for_each(|layer| layer.train(mode));
fn training_mode(&mut self, mode: bool) {
self.layers
.iter_mut()
.for_each(|layer| layer.training_mode(mode));
}
}

Expand Down
2 changes: 1 addition & 1 deletion mlx-nn/tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Module for M {
self.linear.forward(x)
}

fn train(&mut self, _mode: bool) {}
fn training_mode(&mut self, _mode: bool) {}
}

#[test]
Expand Down

0 comments on commit 6c9cc26

Please sign in to comment.