diff --git a/examples/mnist/src/main.rs b/examples/mnist/src/main.rs index 0d673649..eca357a2 100644 --- a/examples/mnist/src/main.rs +++ b/examples/mnist/src/main.rs @@ -43,7 +43,9 @@ fn main() -> Result<(), Box> { let loader = load_training_data()?; let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, output_dim)?; - let cross_entropy = CrossEntropy::builder().reduction(LossReduction::Mean).build()?; + let cross_entropy = CrossEntropy::builder() + .reduction(LossReduction::Mean) + .build()?; let loss_fn = |model: &mlp::Mlp, (x, y): (&Array, &Array)| -> Result { let y_pred = model.forward(x)?; cross_entropy.apply(y_pred, y) diff --git a/mlx-macros/src/generate_builder.rs b/mlx-macros/src/generate_builder.rs index 75c17b25..47c6c942 100644 --- a/mlx-macros/src/generate_builder.rs +++ b/mlx-macros/src/generate_builder.rs @@ -52,10 +52,7 @@ pub(crate) fn expand_generate_builder( ); optional_field_types.push(&field.ty); if generate_build_fn { - optional_field_defaults.push( - field_attr - .default_value - ); + optional_field_defaults.push(field_attr.default_value); } } else { mandatory_field_idents.push(&field.ident); @@ -67,12 +64,14 @@ pub(crate) fn expand_generate_builder( let modified_optional_field_types = optional_field_types .iter() .zip(optional_field_skip.iter()) - .map(|(field_type, skip)| if !skip { - quote! { Option<#field_type> } - } else { - quote! { #field_type } + .map(|(field_type, skip)| { + if !skip { + quote! { Option<#field_type> } + } else { + quote! { #field_type } + } }); - + let builder_struct_doc = format!("Builder for [`{}`]", struct_ident); let field_doc = format!("See [`{}`] for more details", struct_ident); let builder_struct = quote! { @@ -111,27 +110,21 @@ pub(crate) fn expand_generate_builder( let builder_setter_docs = optional_field_idents .iter() .zip(optional_field_skip.iter()) - .filter_map(|(field_ident, skip)| if !skip { - Some(format!("Set the value of `{:?}`", field_ident)) - } else { - None + .filter_map(|(field_ident, skip)| { + if !skip { + Some(format!("Set the value of `{:?}`", field_ident)) + } else { + None + } }); let filtered_optional_field_idents = optional_field_idents .iter() .zip(optional_field_skip.iter()) - .filter_map(|(field_ident, skip)| if !skip { - Some(field_ident) - } else { - None - }); + .filter_map(|(field_ident, skip)| if !skip { Some(field_ident) } else { None }); let filtered_optional_field_types = optional_field_types .iter() .zip(optional_field_skip.iter()) - .filter_map(|(field_type, skip)| if !skip { - Some(field_type) - } else { - None - }); + .filter_map(|(field_type, skip)| if !skip { Some(field_type) } else { None }); let builder_setters = quote! { impl #impl_generics #builder_ident #ty_generics #where_clause { @@ -148,9 +141,14 @@ pub(crate) fn expand_generate_builder( let builder_build = if generate_build_fn { let builder_build_doc = format!("Build a new [`{}`]", struct_ident); let struct_new_doc = format!("Create a new [`{}`] with default values", struct_ident); - let optional_field_defaults: Vec<_> = optional_field_defaults.iter().map(|default| { - default.clone().ok_or_else(|| "Default value must be supplied to generate build function") - }).collect::, _>>()?; + let optional_field_defaults: Vec<_> = optional_field_defaults + .iter() + .map(|default| { + default + .clone() + .ok_or("Default value must be supplied to generate build function") + }) + .collect::, _>>()?; quote! { impl #impl_generics #builder_ident #ty_generics #where_clause { diff --git a/mlx-nn/src/activation.rs b/mlx-nn/src/activation.rs index 56bc4b11..6f87c8d2 100644 --- a/mlx-nn/src/activation.rs +++ b/mlx-nn/src/activation.rs @@ -613,9 +613,7 @@ impl PreluBuilder { let count = self.count.unwrap_or(Prelu::DEFAULT_COUNT); let value = self.value.unwrap_or(Prelu::DEFAULT_VALUE); let weight = Param::new(mlx_rs::ops::full::(&[count], &array!(value))?); - Ok(Prelu { - weight, - }) + Ok(Prelu { weight }) } } @@ -639,7 +637,9 @@ impl Prelu { /// Creates a new Prelu module with the default values. pub fn new() -> Prelu { - PreluBuilder::default().build().expect("Default value should be valid") + PreluBuilder::default() + .build() + .expect("Default value should be valid") } } diff --git a/mlx-nn/src/convolution.rs b/mlx-nn/src/convolution.rs index 59234d8f..6fb83a37 100644 --- a/mlx-nn/src/convolution.rs +++ b/mlx-nn/src/convolution.rs @@ -203,7 +203,12 @@ impl Conv2dBuilder { let weight = uniform::<_, f32>( -scale, scale, - &[output_channels, kernel_size.0, kernel_size.1, input_channels], + &[ + output_channels, + kernel_size.0, + kernel_size.1, + input_channels, + ], None, )?; let bias = if with_bias { @@ -221,7 +226,6 @@ impl Conv2dBuilder { } } - /// Applies a 2-dimensional convolution over the multi-channel input image. /// /// The channels are expected to be last i.e. the input shape should be `NHWC` where: @@ -355,7 +359,13 @@ impl Conv3dBuilder { let weight = uniform::<_, f32>( -scale, scale, - &[output_channels, kernel_size.0, kernel_size.1, kernel_size.2, input_channels], + &[ + output_channels, + kernel_size.0, + kernel_size.1, + kernel_size.2, + input_channels, + ], None, )?; let bias = if with_bias { diff --git a/mlx-nn/src/linear.rs b/mlx-nn/src/linear.rs index 3e6e3c2a..4cc1ca1d 100644 --- a/mlx-nn/src/linear.rs +++ b/mlx-nn/src/linear.rs @@ -35,9 +35,12 @@ impl LinearBuilder { mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?; let bias = if with_bias { - Some( - mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims], None)? - ) + Some(mlx_rs::random::uniform::<_, f32>( + -scale, + scale, + &[output_dims], + None, + )?) } else { None }; @@ -115,7 +118,12 @@ impl BilinearBuilder { } /// Builds a new [`Bilinear`] layer. - pub fn build(self, input_dims_1: i32, input_dims_2: i32, output_dims: i32) -> Result { + pub fn build( + self, + input_dims_1: i32, + input_dims_2: i32, + output_dims: i32, + ) -> Result { let with_bias = self.with_bias.unwrap_or(Bilinear::DEFAULT_WITH_BIAS); let scale = f32::sqrt(1.0 / (input_dims_1 as f32)); @@ -127,9 +135,12 @@ impl BilinearBuilder { )?; let bias = if with_bias { - Some( - mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims], None)? - ) + Some(mlx_rs::random::uniform::<_, f32>( + -scale, + scale, + &[output_dims], + None, + )?) } else { None }; diff --git a/mlx-nn/src/losses.rs b/mlx-nn/src/losses.rs index 42dc30bf..06040363 100644 --- a/mlx-nn/src/losses.rs +++ b/mlx-nn/src/losses.rs @@ -174,7 +174,7 @@ impl<'a> CrossEntropy<'a> { } /// Binary cross entropy loss. -/// +/// /// By default, this function takes the pre-sigmoid logits, which results in a faster /// and more precise loss. For improved numerical stability when `inputs_are_logits` is true, /// the loss calculation clips the input probabilities (in log-space) to a minimum value @@ -207,8 +207,12 @@ impl<'a> BinaryCrossEntropyBuilder<'a> { pub fn build(self) -> BinaryCrossEntropy<'a> { BinaryCrossEntropy { weights: self.weights, - inputs_are_logits: self.inputs_are_logits.unwrap_or(BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS), - reduction: self.reduction.unwrap_or(BinaryCrossEntropy::DEFAULT_REDUCTION), + inputs_are_logits: self + .inputs_are_logits + .unwrap_or(BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS), + reduction: self + .reduction + .unwrap_or(BinaryCrossEntropy::DEFAULT_REDUCTION), } } } @@ -232,12 +236,16 @@ impl<'a> BinaryCrossEntropy<'a> { } /// Apply the binary cross entropy loss function on the given logits and targets. - /// + /// /// # Params - /// + /// /// - `logits`: unnormalized predicted logits /// - `targets`: binary target values in {0, 1} - pub fn apply(&self, logits: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + logits: impl AsRef, + targets: impl AsRef, + ) -> Result { let logits = logits.as_ref(); let targets = targets.as_ref(); let weights = self.weights; @@ -278,12 +286,16 @@ impl L1Loss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; /// Compute the L1 loss. - /// + /// /// # Params - /// + /// /// - `predictions`: predicted values /// - `targets`: target values - pub fn apply(&self, predictions: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { let predictions = predictions.as_ref(); let targets = targets.as_ref(); let reduction = self.reduction; @@ -313,12 +325,16 @@ impl MseLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; /// Compute the mean squared error loss. - /// + /// /// # Params - /// + /// /// - `predictions`: predicted values /// - `targets`: target values - pub fn apply(&self, predictions: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { let predictions = predictions.as_ref(); let targets = targets.as_ref(); let reduction = self.reduction; @@ -355,12 +371,16 @@ impl NllLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Compute the negative log likelihood loss. - /// + /// /// # Params - /// + /// /// - `inputs`: predicted distribution in log space /// - `targets`: target values - pub fn apply(&self, inputs: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { let inputs = inputs.as_ref(); let targets = targets.as_ref(); let axis = self.axis; @@ -406,9 +426,9 @@ impl GaussianNllLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Compute the negative log likelihood loss for a Gaussian distribution. - /// + /// /// # Params - /// + /// /// - `inputs`: The predicted expectation of the Gaussian distribution. /// - `targets`: The target values (samples from the Gaussian distribution). /// - `vars`: The predicted variance of the Gaussian distribution. @@ -429,7 +449,8 @@ impl GaussianNllLoss { check_shape(inputs, vars, "inputs", "vars")?; let vars = maximum(vars, array!(eps))?; - let mut loss = array!(0.5) * (log(&vars).add(square(&targets.subtract(inputs)?).divide(&vars)?)?); + let mut loss = + array!(0.5) * (log(&vars).add(square(&targets.subtract(inputs)?).divide(&vars)?)?); if full { let pi = array!(std::f32::consts::PI); @@ -441,7 +462,7 @@ impl GaussianNllLoss { } /// Compute the Kullback-Leibler divergence loss. -/// +/// /// Computes the following when the `reduction` is `LossReduction::None`: /// /// ```rust, ignore @@ -472,18 +493,26 @@ impl KlDivLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Compute the Kullback-Leibler divergence loss. - /// + /// /// # Params - /// + /// /// - `inputs`: Log probabilities for the predicted distribution. /// - `targets`: Log probabilities for the target distribution. - pub fn apply(&self, inputs: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { let inputs = inputs.as_ref(); let targets = targets.as_ref(); let axis = self.axis; let reduction = self.reduction; - let loss = sum(&exp(targets).multiply(targets.subtract(inputs)?)?, &[axis], None)?; + let loss = sum( + &exp(targets).multiply(targets.subtract(inputs)?)?, + &[axis], + None, + )?; reduction.reduce(loss) } } @@ -519,12 +548,16 @@ impl SmoothL1Loss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; /// Compute the smooth L1 loss. - /// + /// /// # Params - /// + /// /// - `predictions`: predicted values /// - `targets`: target values - pub fn apply(&self, predictions: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { let predictions = predictions.as_ref(); let targets = targets.as_ref(); let beta = self.beta; @@ -614,7 +647,7 @@ impl TripletLoss { let eps = array!(eps); let p = array!(p); let margin = array!(margin); - + let pos = sqrt( &power(&anchors.subtract(positives)?, &p)? .sum(&[axis], None)? @@ -649,12 +682,16 @@ impl HingeLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Compute the hinge loss. - /// + /// /// # Params - /// + /// /// - `inputs`: predicted values /// - `targets`: target values, -1 or 1 - pub fn apply(&self, inputs: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { let inputs = inputs.as_ref(); let targets = targets.as_ref(); let reduction = self.reduction; @@ -693,12 +730,16 @@ impl HuberLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Compute the Huber loss. - /// + /// /// # Params - /// + /// /// - `inputs`: predicted values /// - `targets`: target values - pub fn apply(&self, inputs: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { let inputs = inputs.as_ref(); let targets = targets.as_ref(); let delta = self.delta; @@ -740,10 +781,14 @@ impl LogCoshLoss { /// Computes the log cosh loss between inputs and targets. /// /// # Params - /// + /// /// - `inputs`: predicted values /// - `targets`: target values - pub fn apply(&self, inputs: impl AsRef, targets: impl AsRef) -> Result { + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { let inputs = inputs.as_ref(); let targets = targets.as_ref(); let reduction = self.reduction; @@ -753,7 +798,6 @@ impl LogCoshLoss { let loss = log_add_exp(errors, neg_errors)?.subtract(log(&array!(2.0)))?; reduction.reduce(loss) } - } /// Computes the cosine similarity loss. @@ -790,9 +834,9 @@ impl CosineSimilarityLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Computes the cosine similarity loss. - /// + /// /// # Params - /// + /// /// - `x1`: first array /// - `x2`: second array pub fn apply(&self, x1: impl AsRef, x2: impl AsRef) -> Result { @@ -801,7 +845,7 @@ impl CosineSimilarityLoss { let axis = self.axis; let eps = self.eps; let reduction = self.reduction; - + fn l2_loss(a: &Array, axis: i32) -> Result { if a.dtype().is_complex() { Ok(sqrt(&sum(&abs(a).square(), &[axis], None)?)) @@ -809,19 +853,18 @@ impl CosineSimilarityLoss { Ok(sqrt(&sum(&a.square(), &[axis], None)?)) } } - + let x1_norm = l2_loss(x1, axis)?; let x2_norm = l2_loss(x2, axis)?; - + let num = sum(&x1.multiply(x2)?, &[axis], None)?; let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?; let loss = num.divide(&den)?; - + reduction.reduce(loss) } } - /// Computes the margin ranking loss. #[derive(Debug, Clone, GenerateBuilder)] pub struct MarginRankingLoss { @@ -849,9 +892,9 @@ impl MarginRankingLoss { pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; /// Computes the margin ranking loss. - /// + /// /// # Params - /// + /// /// - `inputs1`: Scores for the first input. /// - `inputs2`: Scores for the second input. /// - `targets`: Labels indicating whether samples in `inputs1` should be ranked higher than samples @@ -922,7 +965,8 @@ mod tests { let cross_entropy = CrossEntropy::builder() .weights(&weights) .reduction(LossReduction::None) - .build().unwrap(); + .build() + .unwrap(); let loss = cross_entropy.apply(logits, probs).unwrap(); assert_array_eq!(loss, expected); @@ -933,7 +977,8 @@ mod tests { let cross_entropy = CrossEntropy::builder() .label_smoothing(0.3) .reduction(LossReduction::None) - .build().unwrap(); + .build() + .unwrap(); let loss = cross_entropy.apply(&logits, indices).unwrap(); assert_array_eq!(loss, expected); @@ -941,7 +986,8 @@ mod tests { let cross_entropy = CrossEntropy::builder() .label_smoothing(0.3) .reduction(LossReduction::None) - .build().unwrap(); + .build() + .unwrap(); let loss = cross_entropy.apply(logits, probs).unwrap(); assert_array_eq!(loss, expected); @@ -954,7 +1000,8 @@ mod tests { .weights(&weights) .label_smoothing(0.3) .reduction(LossReduction::None) - .build().unwrap(); + .build() + .unwrap(); let loss = cross_entropy.apply(&logits, indices).unwrap(); assert_array_eq!(loss, expected); @@ -963,7 +1010,8 @@ mod tests { .weights(&weights) .label_smoothing(0.3) .reduction(LossReduction::None) - .build().unwrap(); + .build() + .unwrap(); let loss = cross_entropy.apply(logits, probs).unwrap(); assert_array_eq!(loss, expected); } @@ -1084,21 +1132,15 @@ mod tests { let expected_sum = expected_none.sum(None, None).unwrap(); let expected_mean = expected_none.mean(None, None).unwrap(); - let l1_loss = L1Loss::builder() - .reduction(LossReduction::None) - .build(); + let l1_loss = L1Loss::builder().reduction(LossReduction::None).build(); let loss_none = l1_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_none, expected_none); - let l1_loss = L1Loss::builder() - .reduction(LossReduction::Sum) - .build(); + let l1_loss = L1Loss::builder().reduction(LossReduction::Sum).build(); let loss_sum = l1_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_sum, expected_sum); - let l1_loss = L1Loss::builder() - .reduction(LossReduction::Mean) - .build(); + let l1_loss = L1Loss::builder().reduction(LossReduction::Mean).build(); let loss_mean = l1_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_mean, expected_mean); } @@ -1112,21 +1154,15 @@ mod tests { let expected_mean = expected_none.mean(None, None).unwrap(); let expected_sum = expected_none.sum(None, None).unwrap(); - let mse_loss = MseLoss::builder() - .reduction(LossReduction::None) - .build(); + let mse_loss = MseLoss::builder().reduction(LossReduction::None).build(); let loss_none = mse_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_none, expected_none); - let mse_loss = MseLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let mse_loss = MseLoss::builder().reduction(LossReduction::Mean).build(); let loss_mean = mse_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_mean, expected_mean); - let mse_loss = MseLoss::builder() - .reduction(LossReduction::Sum) - .build(); + let mse_loss = MseLoss::builder().reduction(LossReduction::Sum).build(); let loss_sum = mse_loss.apply(&predictions, &targets).unwrap(); assert_array_eq!(loss_sum, expected_sum); } @@ -1172,21 +1208,15 @@ mod tests { let expected_sum = expected_none.sum(None, None).unwrap(); let expected_mean = expected_none.mean(None, None).unwrap(); - let nll_loss = NllLoss::builder() - .reduction(LossReduction::None) - .build(); + let nll_loss = NllLoss::builder().reduction(LossReduction::None).build(); let loss_none = nll_loss.apply(&logits, &targets).unwrap(); assert_array_eq!(loss_none, expected_none); - let nll_loss = NllLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let nll_loss = NllLoss::builder().reduction(LossReduction::Mean).build(); let loss_mean = nll_loss.apply(&logits, &targets).unwrap(); assert_array_eq!(loss_mean, expected_mean); - let nll_loss = NllLoss::builder() - .reduction(LossReduction::Sum) - .build(); + let nll_loss = NllLoss::builder().reduction(LossReduction::Sum).build(); let loss_sum = nll_loss.apply(&logits, &targets).unwrap(); assert_array_eq!(loss_sum, expected_sum); } @@ -1258,25 +1288,19 @@ mod tests { let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log(); // Test with reduction 'none' - let kl_div_loss = KlDivLoss::builder() - .reduction(LossReduction::None) - .build(); + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::None).build(); let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); let expected_none = array!([0.0, 0.831777]); assert_array_eq!(loss_none, expected_none); // Test with reduction 'mean' - let kl_div_loss = KlDivLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::Mean).build(); let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); let expected_mean = expected_none.mean(None, None).unwrap(); assert_array_eq!(loss_mean, expected_mean); // Test with reduction 'sum' - let kl_div_loss = KlDivLoss::builder() - .reduction(LossReduction::Sum) - .build(); + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::Sum).build(); let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); let expected_sum = expected_none.sum(None, None).unwrap(); assert_array_eq!(loss_sum, expected_sum); @@ -1292,7 +1316,9 @@ mod tests { let triplet_loss = TripletLoss::builder() .reduction(LossReduction::None) .build(); - let loss_none = triplet_loss.apply(&anchors, &positives, &negatives).unwrap(); + let loss_none = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); let expected_none = array!([0.0, 2.31662]); assert_array_eq!(loss_none, expected_none); @@ -1300,15 +1326,17 @@ mod tests { let triplet_loss = TripletLoss::builder() .reduction(LossReduction::Mean) .build(); - let loss_mean = triplet_loss.apply(&anchors, &positives, &negatives).unwrap(); + let loss_mean = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); let expected_mean = expected_none.mean(None, None).unwrap(); assert_array_eq!(loss_mean, expected_mean); // Test with reduction 'sum' - let triplet_loss = TripletLoss::builder() - .reduction(LossReduction::Sum) - .build(); - let loss_sum = triplet_loss.apply(&anchors, &positives, &negatives).unwrap(); + let triplet_loss = TripletLoss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); let expected_sum = expected_none.sum(None, None).unwrap(); assert_array_eq!(loss_sum, expected_sum); } @@ -1317,9 +1345,7 @@ mod tests { fn test_hinge_loss() { let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]); let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); - let hinge_loss = HingeLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let hinge_loss = HingeLoss::builder().reduction(LossReduction::Mean).build(); let loss = hinge_loss.apply(&inputs, &targets).unwrap(); assert_eq!(loss.item::(), 1.0); } @@ -1328,9 +1354,7 @@ mod tests { fn test_huber_loss() { let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]); let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); - let huber_loss = HuberLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let huber_loss = HuberLoss::builder().reduction(LossReduction::Mean).build(); let loss = huber_loss.apply(&inputs, &targets).unwrap(); assert_eq!(loss.item::(), 0.5); } @@ -1355,7 +1379,9 @@ mod tests { let cosine_similarity_loss = CosineSimilarityLoss::builder() .reduction(LossReduction::None) .build(); - let loss_none = cosine_similarity_loss.apply(&embeddings1, &embeddings2).unwrap(); + let loss_none = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); let expected_none = array!([0.985344, 0.961074]); assert_array_eq!(loss_none, expected_none); @@ -1363,7 +1389,9 @@ mod tests { let cosine_similarity_loss = CosineSimilarityLoss::builder() .reduction(LossReduction::Mean) .build(); - let loss_mean = cosine_similarity_loss.apply(&embeddings1, &embeddings2).unwrap(); + let loss_mean = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); let expected_mean = expected_none.mean(None, None).unwrap(); assert_array_eq!(loss_mean, expected_mean); @@ -1371,7 +1399,9 @@ mod tests { let cosine_similarity_loss = CosineSimilarityLoss::builder() .reduction(LossReduction::Sum) .build(); - let loss_sum = cosine_similarity_loss.apply(&embeddings1, &embeddings2).unwrap(); + let loss_sum = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); let expected_sum = expected_none.sum(None, None).unwrap(); assert_array_eq!(loss_sum, expected_sum); } @@ -1386,7 +1416,9 @@ mod tests { let margin_ranking_loss = MarginRankingLoss::builder() .reduction(LossReduction::None) .build(); - let loss = margin_ranking_loss.apply(&inputs1, &inputs2, &targets).unwrap(); + let loss = margin_ranking_loss + .apply(&inputs1, &inputs2, &targets) + .unwrap(); let expected = array!([1.329369, 0.990929, 0.0]); assert_array_eq!(loss, expected); @@ -1395,7 +1427,9 @@ mod tests { .margin(0.5) .reduction(LossReduction::None) .build(); - let loss = margin_ranking_loss.apply(&inputs1, &inputs2, &targets).unwrap(); + let loss = margin_ranking_loss + .apply(&inputs1, &inputs2, &targets) + .unwrap(); let expected = array!([1.829369, 1.490929, 0.179205]); assert_array_eq!(loss, expected); } diff --git a/mlx-nn/src/optimizers/mod.rs b/mlx-nn/src/optimizers/mod.rs index 34626ddd..adce1d90 100644 --- a/mlx-nn/src/optimizers/mod.rs +++ b/mlx-nn/src/optimizers/mod.rs @@ -148,9 +148,7 @@ mod tests { { let mut optimizer = f(); - let mse_loss = MseLoss::builder() - .reduction(LossReduction::Mean) - .build(); + let mse_loss = MseLoss::builder().reduction(LossReduction::Mean).build(); let loss = |model: &LinearFunctionModel, (x, y): (&Array, &Array)| { mse_loss.apply(model.forward(x)?, y) };