Skip to content

Commit

Permalink
added binary_cross_entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
minghuaw committed Oct 5, 2024
1 parent d2951f3 commit 74a7a06
Showing 1 changed file with 110 additions and 3 deletions.
113 changes: 110 additions & 3 deletions mlx-nn/src/losses.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use mlx_rs::{
array,
error::Exception,
ops::{indexing::take_along_axis, log_sum_exp, multiply, sum},
ops::{clip, indexing::take_along_axis, log, log_add_exp, log_sum_exp, multiply, sum},
Array,
};

Expand Down Expand Up @@ -114,20 +114,83 @@ pub fn cross_entropy(
};

if let Some(weights) = weight {
assert_eq!(weights.shape(), loss.shape());
// assert_eq!(weights.shape(), loss.shape()); // TODO: is this necessary?
loss = multiply(loss, weights)?;
}

reduction.reduce(loss)
}

/// Optional parameters for the `binary_cross_entropy` function.
#[derive(Debug, Clone, Default)]
pub struct BinaryCrossEntropyOptions<'a> {
/// Optional weights for each target
pub weights: Option<&'a Array>,

/// Whether the inputs are logits
pub with_logits: Option<bool>,

/// Reduction type. Default to [`BinaryCrossEntropyOptions::DEFAULT_REDUCTION`] if `None`
pub reduction: Option<LossReduction>,
}

impl<'a> BinaryCrossEntropyOptions<'a> {
/// Default value for the with_logits parameter.
pub const DEFAULT_WITH_LOGITS: bool = true;

/// Default value for the reduction parameter.
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
}

/// Computes the binary cross entropy loss.
///
/// # Params
///
/// - `logits`: unnormalized predicted logits
/// - `targets`: binary target values in {0, 1}
/// - `options`: optional parameters. See [`BinaryCrossEntropyOptions`] for more details
pub fn binary_cross_entropy(
logits: impl AsRef<Array>,
targets: impl AsRef<Array>,
options: BinaryCrossEntropyOptions<'_>,
) -> Result<Array, Exception> {
let logits = logits.as_ref();
let targets = targets.as_ref();
let weights = options.weights;
let with_logits = options
.with_logits
.unwrap_or(BinaryCrossEntropyOptions::DEFAULT_WITH_LOGITS);
let reduction = options
.reduction
.unwrap_or(BinaryCrossEntropyOptions::DEFAULT_REDUCTION);

let mut loss = if with_logits {
log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
} else {
let log_inputs_clip = clip(&log(logits), (-100.0, ()))?;
let log_inputs_inverse_clip = clip(&log(&array!(1.0).subtract(logits)?), (-100.0, ()))?;
-(targets.multiply(log_inputs_clip)?.add(
array!(1.0)
.subtract(targets)?
.multiply(log_inputs_inverse_clip)?,
)?)
};

if let Some(weights) = weights {
// assert_eq!(weights.shape(), loss.shape()); // TODO: is this necessary?
loss = multiply(loss, weights)?;
}

reduction.reduce(loss)
}

// The following unit tests are adapted from the python API at: mlx/python/tests/test_losses.py
#[cfg(test)]
mod tests {
use mlx_rs::{array, assert_array_eq, ops::is_nan};

use super::*;

// The following unit test is adapted from the python API at: mlx/python/tests/test_losses.py
#[test]
fn test_cross_entropy() {
// No weights, no label smoothing
Expand Down Expand Up @@ -213,4 +276,48 @@ mod tests {
let loss = cross_entropy(logits, probs, options).unwrap();
assert_array_eq!(loss, expected);
}

#[test]
fn test_binary_cross_entropy_with_logits_as_inputs() {
let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
let targets = array!([0.0, 0.0, 1.0, 1.0]);

// Test with reduction 'none'
let options = BinaryCrossEntropyOptions {
reduction: Some(LossReduction::None),
..Default::default()
};
let loss_none = binary_cross_entropy(&logits, &targets, options).unwrap();
let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
assert_array_eq!(loss_none, expected_none);

// Test with reduction 'mean'
let options = BinaryCrossEntropyOptions {
reduction: Some(LossReduction::Mean),
..Default::default()
};
let loss_mean = binary_cross_entropy(&logits, &targets, options).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);

// Test with reduction 'sum'
let options = BinaryCrossEntropyOptions {
reduction: Some(LossReduction::Sum),
..Default::default()
};
let loss = binary_cross_entropy(&logits, &targets, options).unwrap();
let expected = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss, expected);

// With weights, no label smoothing
let weights = array!([1.0, 2.0, 1.0, 2.0]);
let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
let options = BinaryCrossEntropyOptions {
weights: Some(&weights),
reduction: Some(LossReduction::None),
..Default::default()
};
let loss = binary_cross_entropy(&logits, &targets, options).unwrap();
assert_array_eq!(loss, expected);
}
}

0 comments on commit 74a7a06

Please sign in to comment.