diff --git a/mlx-nn/src/losses.rs b/mlx-nn/src/losses.rs index 8961231b..483ec928 100644 --- a/mlx-nn/src/losses.rs +++ b/mlx-nn/src/losses.rs @@ -1,7 +1,7 @@ use mlx_rs::{ array, error::Exception, - ops::{indexing::take_along_axis, log_sum_exp, multiply, sum}, + ops::{clip, indexing::take_along_axis, log, log_add_exp, log_sum_exp, multiply, sum}, Array, }; @@ -114,20 +114,83 @@ pub fn cross_entropy( }; if let Some(weights) = weight { - assert_eq!(weights.shape(), loss.shape()); + // assert_eq!(weights.shape(), loss.shape()); // TODO: is this necessary? loss = multiply(loss, weights)?; } reduction.reduce(loss) } +/// Optional parameters for the `binary_cross_entropy` function. +#[derive(Debug, Clone, Default)] +pub struct BinaryCrossEntropyOptions<'a> { + /// Optional weights for each target + pub weights: Option<&'a Array>, + + /// Whether the inputs are logits + pub with_logits: Option, + + /// Reduction type. Default to [`BinaryCrossEntropyOptions::DEFAULT_REDUCTION`] if `None` + pub reduction: Option, +} + +impl<'a> BinaryCrossEntropyOptions<'a> { + /// Default value for the with_logits parameter. + pub const DEFAULT_WITH_LOGITS: bool = true; + + /// Default value for the reduction parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; +} + +/// Computes the binary cross entropy loss. +/// +/// # Params +/// +/// - `logits`: unnormalized predicted logits +/// - `targets`: binary target values in {0, 1} +/// - `options`: optional parameters. See [`BinaryCrossEntropyOptions`] for more details +pub fn binary_cross_entropy( + logits: impl AsRef, + targets: impl AsRef, + options: BinaryCrossEntropyOptions<'_>, +) -> Result { + let logits = logits.as_ref(); + let targets = targets.as_ref(); + let weights = options.weights; + let with_logits = options + .with_logits + .unwrap_or(BinaryCrossEntropyOptions::DEFAULT_WITH_LOGITS); + let reduction = options + .reduction + .unwrap_or(BinaryCrossEntropyOptions::DEFAULT_REDUCTION); + + let mut loss = if with_logits { + log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)? + } else { + let log_inputs_clip = clip(&log(logits), (-100.0, ()))?; + let log_inputs_inverse_clip = clip(&log(&array!(1.0).subtract(logits)?), (-100.0, ()))?; + -(targets.multiply(log_inputs_clip)?.add( + array!(1.0) + .subtract(targets)? + .multiply(log_inputs_inverse_clip)?, + )?) + }; + + if let Some(weights) = weights { + // assert_eq!(weights.shape(), loss.shape()); // TODO: is this necessary? + loss = multiply(loss, weights)?; + } + + reduction.reduce(loss) +} + +// The following unit tests are adapted from the python API at: mlx/python/tests/test_losses.py #[cfg(test)] mod tests { use mlx_rs::{array, assert_array_eq, ops::is_nan}; use super::*; - // The following unit test is adapted from the python API at: mlx/python/tests/test_losses.py #[test] fn test_cross_entropy() { // No weights, no label smoothing @@ -213,4 +276,48 @@ mod tests { let loss = cross_entropy(logits, probs, options).unwrap(); assert_array_eq!(loss, expected); } + + #[test] + fn test_binary_cross_entropy_with_logits_as_inputs() { + let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]); + let targets = array!([0.0, 0.0, 1.0, 1.0]); + + // Test with reduction 'none' + let options = BinaryCrossEntropyOptions { + reduction: Some(LossReduction::None), + ..Default::default() + }; + let loss_none = binary_cross_entropy(&logits, &targets, options).unwrap(); + let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let options = BinaryCrossEntropyOptions { + reduction: Some(LossReduction::Mean), + ..Default::default() + }; + let loss_mean = binary_cross_entropy(&logits, &targets, options).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let options = BinaryCrossEntropyOptions { + reduction: Some(LossReduction::Sum), + ..Default::default() + }; + let loss = binary_cross_entropy(&logits, &targets, options).unwrap(); + let expected = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss, expected); + + // With weights, no label smoothing + let weights = array!([1.0, 2.0, 1.0, 2.0]); + let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]); + let options = BinaryCrossEntropyOptions { + weights: Some(&weights), + reduction: Some(LossReduction::None), + ..Default::default() + }; + let loss = binary_cross_entropy(&logits, &targets, options).unwrap(); + assert_array_eq!(loss, expected); + } }