diff --git a/mlx-nn-module/src/module.rs b/mlx-nn-module/src/module.rs index 1a39e0a6..10329a92 100644 --- a/mlx-nn-module/src/module.rs +++ b/mlx-nn-module/src/module.rs @@ -32,7 +32,7 @@ pub trait Module: ModuleParameters { /// Training mode only applies to certain layers. For example, dropout layers applies a random /// mask in training mode, but is the identity in evaluation mode. Implementations of nested /// modules should propagate the training mode to all child modules. - fn train(&mut self, mode: bool); + fn training_mode(&mut self, mode: bool); } /// Trait for accessing and updating module parameters. diff --git a/mlx-nn/src/activation.rs b/mlx-nn/src/activation.rs index 052d1668..8e5e2faf 100644 --- a/mlx-nn/src/activation.rs +++ b/mlx-nn/src/activation.rs @@ -293,7 +293,7 @@ impl Module for Glu { glu(x, self.axis) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the element-wise logistic sigmoid. @@ -330,7 +330,7 @@ impl Module for Sigmoid { Ok(sigmoid(x)) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Mish function, element-wise. @@ -367,7 +367,7 @@ impl Module for Mish { mish(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Rectified Linear Unit. @@ -400,7 +400,7 @@ impl Module for Relu { relu(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Leaky Rectified Linear Unit. @@ -440,7 +440,7 @@ impl Module for LeakyReLU { leaky_relu(x, self.neg_slope) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Rectified Linear Unit 6. @@ -473,7 +473,7 @@ impl Module for Relu6 { relu6(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Softmax function. @@ -516,7 +516,7 @@ impl Module for Softmax { } } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Softplus function. @@ -549,7 +549,7 @@ impl Module for Softplus { softplus(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Softsign function. @@ -582,7 +582,7 @@ impl Module for Softsign { softsign(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Continuously Differentiable Exponential Linear Unit. @@ -623,7 +623,7 @@ impl Module for Celu { celu(x, self.alpha) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Sigmoid Linear Unit. Also known as Swish. @@ -656,7 +656,7 @@ impl Module for Silu { silu(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Log Softmax function. @@ -696,7 +696,7 @@ impl Module for LogSoftmax { log_softmax(x, self.axis) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Log Sigmoid function. @@ -729,7 +729,7 @@ impl Module for LogSigmoid { log_sigmoid(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the element-wise parametric ReLU. @@ -760,7 +760,7 @@ impl Module for Prelu { prelu(x, &self.alpha) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Variants of Gaussian Error Linear Units function. @@ -820,7 +820,7 @@ impl Module for Gelu { } } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the hyperbolic tangent function @@ -847,7 +847,7 @@ impl Module for Tanh { Ok(mlx_rs::ops::tanh(x)) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the hardswish function, element-wise @@ -878,7 +878,7 @@ impl Module for HardSwish { hard_swish(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Step Activation Function. @@ -921,7 +921,7 @@ impl Module for Step { step(x, self.threshold) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies the Scaled Exponential Linear Unit. @@ -954,7 +954,7 @@ impl Module for Selu { selu(x) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /* -------------------------------------------------------------------------- */ diff --git a/mlx-nn/src/convolution.rs b/mlx-nn/src/convolution.rs index c6f073c1..39eca931 100644 --- a/mlx-nn/src/convolution.rs +++ b/mlx-nn/src/convolution.rs @@ -107,7 +107,7 @@ impl Module for Conv1d { Ok(y) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies a 2-dimensional convolution over the multi-channel input image. @@ -218,7 +218,7 @@ impl Module for Conv2d { Ok(y) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies a 3-dimensional convolution over the multi-channel input image. @@ -332,5 +332,5 @@ impl Module for Conv3d { Ok(y) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } diff --git a/mlx-nn/src/linear.rs b/mlx-nn/src/linear.rs index c648a00f..e7901359 100644 --- a/mlx-nn/src/linear.rs +++ b/mlx-nn/src/linear.rs @@ -77,7 +77,7 @@ impl Module for Linear { } } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } /// Applies a bilinear transformation to the inputs. @@ -157,5 +157,5 @@ impl Module for Bilinear { Ok(y) } - fn train(&mut self, _: bool) {} + fn training_mode(&mut self, _: bool) {} } diff --git a/mlx-nn/src/sequential.rs b/mlx-nn/src/sequential.rs index 45631f16..a700f01f 100644 --- a/mlx-nn/src/sequential.rs +++ b/mlx-nn/src/sequential.rs @@ -35,8 +35,10 @@ impl Module for Sequential { } } - fn train(&mut self, mode: bool) { - self.layers.iter_mut().for_each(|layer| layer.train(mode)); + fn training_mode(&mut self, mode: bool) { + self.layers + .iter_mut() + .for_each(|layer| layer.training_mode(mode)); } } diff --git a/mlx-nn/tests/test_module.rs b/mlx-nn/tests/test_module.rs index 3fd20396..07a218b2 100644 --- a/mlx-nn/tests/test_module.rs +++ b/mlx-nn/tests/test_module.rs @@ -22,7 +22,7 @@ impl Module for M { self.linear.forward(x) } - fn train(&mut self, _mode: bool) {} + fn training_mode(&mut self, _mode: bool) {} } #[test]