diff --git a/Cargo.toml b/Cargo.toml index a27ce2e9..8b9f28e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,23 @@ [workspace] -members = ["mlx-macros", "mlx-sys", "mlx-rs", "mlx-nn"] +members = [ + "mlx-macros", + "mlx-sys", + "mlx-rs", + "mlx-nn", + "examples/*", "mlx-internal-macros", +] resolver = "2" [workspace.dependencies] +# workspace local dependencies mlx-sys = { version = "0.0.9", path = "mlx-sys" } mlx-macros = { version = "0.1.0", path = "mlx-macros" } -mlx-rs = { version = "0.14.0", path = "mlx-rs" } \ No newline at end of file +mlx-internal-macros = { version = "0.1.0", path = "mlx-internal-macros" } +mlx-rs = { version = "0.14.0", path = "mlx-rs" } +mlx-nn = { version = "0.1.0", path = "mlx-nn" } + +# external dependencies +thiserror = "1" +float_eq = "1" diff --git a/examples/mnist/Cargo.toml b/examples/mnist/Cargo.toml new file mode 100644 index 00000000..13991682 --- /dev/null +++ b/examples/mnist/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "mnist" +version = "0.1.0" +edition = "2021" + +[dependencies] +mlx-rs.workspace = true +mlx-nn.workspace = true \ No newline at end of file diff --git a/examples/mnist/src/main.rs b/examples/mnist/src/main.rs new file mode 100644 index 00000000..eca357a2 --- /dev/null +++ b/examples/mnist/src/main.rs @@ -0,0 +1,71 @@ +use mlx_nn::{ + losses::{CrossEntropy, LossReduction}, + module_value_and_grad, + optimizers::Optimizer, +}; +use mlx_rs::{ + array, + error::Exception, + module::{Module, ModuleParameters}, + transforms::eval_params, + Array, +}; + +/// MLP model +mod mlp; + +/// Retrieves MNIST dataset +mod mnist; + +#[derive(Clone)] +struct Loader {} + +impl Iterator for Loader { + type Item = (Array, Array); + + fn next(&mut self) -> Option { + todo!() + } +} + +fn load_training_data() -> Result> { + todo!() +} + +fn main() -> Result<(), Box> { + let num_layers = 3; + let input_dim = 784; + let hidden_dim = 256; + let output_dim = 10; + let lr = 1e-2; + let num_epochs = 10; + + let loader = load_training_data()?; + let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, output_dim)?; + + let cross_entropy = CrossEntropy::builder() + .reduction(LossReduction::Mean) + .build()?; + let loss_fn = |model: &mlp::Mlp, (x, y): (&Array, &Array)| -> Result { + let y_pred = model.forward(x)?; + cross_entropy.apply(y_pred, y) + }; + let mut loss_and_grad_fn = module_value_and_grad(loss_fn); + + let mut optimizer = mlx_nn::optimizers::Sgd::new(lr); + + for _ in 0..num_epochs { + let mut loss_sum = array!(0.0); + for (x, y) in loader.clone() { + let (loss, grad) = loss_and_grad_fn(&mut model, (&x, &y))?; + optimizer.update(&mut model, grad); + eval_params(model.parameters())?; + + loss_sum += loss; + } + + println!("Epoch: {}, Loss sum: {}", num_epochs, loss_sum); + } + + Ok(()) +} diff --git a/examples/mnist/src/mlp.rs b/examples/mnist/src/mlp.rs new file mode 100644 index 00000000..bdc87a7d --- /dev/null +++ b/examples/mnist/src/mlp.rs @@ -0,0 +1,50 @@ +use mlx_nn::{macros::ModuleParameters, module::Param, Linear, Relu, Sequential}; +use mlx_rs::{error::Exception, module::Module, Array}; + +#[derive(Debug, ModuleParameters)] +pub struct Mlp { + #[param] + pub layers: Param, +} + +impl Module for Mlp { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + self.layers.forward(x) + } + + fn training_mode(&mut self, mode: bool) { + self.layers.training_mode(mode); + } +} + +impl Mlp { + pub fn new( + num_layers: usize, + input_dim: i32, + hidden_dim: i32, + output_dim: i32, + ) -> Result { + let mut layers = Sequential::new(); + + // Add the input layer + layers = layers + .append(Linear::new(input_dim, hidden_dim)?) + .append(Relu); + + // Add the hidden layers + for _ in 1..num_layers { + layers = layers + .append(Linear::new(hidden_dim, hidden_dim)?) + .append(Relu); + } + + // Add the output layer + layers = layers.append(Linear::new(hidden_dim, output_dim)?); + + Ok(Self { + layers: Param::new(layers), + }) + } +} diff --git a/examples/mnist/src/mnist.rs b/examples/mnist/src/mnist.rs new file mode 100644 index 00000000..70b786d1 --- /dev/null +++ b/examples/mnist/src/mnist.rs @@ -0,0 +1 @@ +// TODO diff --git a/mlx-internal-macros/Cargo.toml b/mlx-internal-macros/Cargo.toml new file mode 100644 index 00000000..ccf21cec --- /dev/null +++ b/mlx-internal-macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mlx-internal-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0.60", features = ["full"] } +quote = "1.0" +darling = "0.20" +proc-macro2 = "1.0" \ No newline at end of file diff --git a/mlx-internal-macros/src/generate_builder.rs b/mlx-internal-macros/src/generate_builder.rs new file mode 100644 index 00000000..d0ef27a9 --- /dev/null +++ b/mlx-internal-macros/src/generate_builder.rs @@ -0,0 +1,199 @@ +use darling::FromAttributes; +use quote::quote; +use syn::{Attribute, Ident, ItemStruct, Path}; + +#[derive(Debug, FromAttributes)] +#[darling(attributes(generate_builder))] +struct StructAttr { + generate_build_fn: Option, +} + +#[derive(Debug, FromAttributes)] +#[darling(attributes(optional))] +struct FieldAttr { + default_value: Option, + skip: Option, +} + +fn attrs_contains_optional(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| attr.path().is_ident("optional")) +} + +pub(crate) fn expand_generate_builder( + input: &ItemStruct, +) -> Result> { + let generate_build_fn = StructAttr::from_attributes(&input.attrs)? + .generate_build_fn + .unwrap_or(true); + let struct_ident = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let mut optional_field_idents = Vec::new(); + let mut optional_field_types = Vec::new(); + let mut optional_field_defaults = Vec::new(); + let mut optional_field_skip = Vec::new(); + let mut mandatory_field_idents = Vec::new(); + let mut mandatory_field_types = Vec::new(); + + for field in input.fields.iter() { + if attrs_contains_optional(&field.attrs) { + let field_attr = FieldAttr::from_attributes(&field.attrs)?; + let skip = field_attr.skip.unwrap_or(false); + if skip && generate_build_fn { + return Err("Skip is not allowed when build function is generated".into()); + } + optional_field_skip.push(skip); + + optional_field_idents.push( + field + .ident + .as_ref() + .ok_or("Only named fields are supported")?, + ); + optional_field_types.push(&field.ty); + if generate_build_fn { + optional_field_defaults.push(field_attr.default_value); + } + } else { + mandatory_field_idents.push(&field.ident); + mandatory_field_types.push(&field.ty); + } + } + + let builder_ident = Ident::new(&format!("{}Builder", struct_ident), struct_ident.span()); + let modified_optional_field_types = optional_field_types + .iter() + .zip(optional_field_skip.iter()) + .map(|(field_type, skip)| { + if !skip { + quote! { Option<#field_type> } + } else { + quote! { #field_type } + } + }); + + let builder_struct_doc = format!("Builder for [`{}`]", struct_ident); + let field_doc = format!("See [`{}`] for more details", struct_ident); + let builder_struct = quote! { + #[doc = #builder_struct_doc] + #[derive(Debug, Clone, Default)] + pub struct #builder_ident #ty_generics #where_clause { + #( + #[doc = #field_doc] + pub #optional_field_idents: #modified_optional_field_types, + )* + } + }; + + let builder_new_doc = format!("Create a new [`{}`]", builder_ident); + let struct_builder_doc = format!( + "Create a new [`{}`] builder with the default values", + struct_ident + ); + + let builder_init = quote! { + impl #impl_generics #builder_ident #ty_generics #where_clause { + #[doc = #builder_new_doc] + pub fn new() -> Self { + Self::default() + } + } + + impl #impl_generics #struct_ident #ty_generics #where_clause { + #[doc = #struct_builder_doc] + pub fn builder() -> #builder_ident #ty_generics { + #builder_ident::new() + } + } + }; + + let builder_setter_docs = optional_field_idents + .iter() + .zip(optional_field_skip.iter()) + .filter_map(|(field_ident, skip)| { + if !skip { + Some(format!("Set the value of `{:?}`", field_ident)) + } else { + None + } + }); + let filtered_optional_field_idents = optional_field_idents + .iter() + .zip(optional_field_skip.iter()) + .filter_map(|(field_ident, skip)| if !skip { Some(field_ident) } else { None }); + let filtered_optional_field_types = optional_field_types + .iter() + .zip(optional_field_skip.iter()) + .filter_map(|(field_type, skip)| if !skip { Some(field_type) } else { None }); + + let builder_setters = quote! { + impl #impl_generics #builder_ident #ty_generics #where_clause { + #( + #[doc = #builder_setter_docs] + pub fn #filtered_optional_field_idents(mut self, value: impl Into<#filtered_optional_field_types>) -> Self { + self.#filtered_optional_field_idents = Some(value.into()); + self + } + )* + } + }; + + let builder_build = if generate_build_fn { + let builder_build_doc = format!("Build a new [`{}`]", struct_ident); + let struct_new_doc = format!("Create a new [`{}`] with default values", struct_ident); + let optional_field_defaults: Vec<_> = optional_field_defaults + .iter() + .map(|default| { + default + .clone() + .ok_or("Default value must be supplied to generate build function") + }) + .collect::, _>>()?; + + quote! { + impl #impl_generics #builder_ident #ty_generics #where_clause { + #[doc = #builder_build_doc] + pub fn build(self, #(#mandatory_field_idents: #mandatory_field_types),*) -> #struct_ident #ty_generics { + #struct_ident { + #( + #mandatory_field_idents, + )* + #( + #optional_field_idents: self.#optional_field_idents.unwrap_or_else(|| #optional_field_defaults), + )* + } + } + } + + impl #impl_generics #struct_ident #ty_generics #where_clause { + #[doc = #struct_new_doc] + pub fn new(#(#mandatory_field_idents: #mandatory_field_types),*) -> Self { + Self::builder().build(#(#mandatory_field_idents),*) + } + } + } + } else { + quote! {} + }; + + // Only implement Default trait if no mandatory fields are present + let default_impl = if mandatory_field_idents.is_empty() && generate_build_fn { + quote! { + impl #impl_generics Default for #struct_ident #ty_generics #where_clause { + fn default() -> Self { + Self::new() + } + } + } + } else { + quote! {} + }; + + Ok(quote! { + #builder_struct + #builder_init + #builder_setters + #builder_build + #default_impl + }) +} diff --git a/mlx-internal-macros/src/lib.rs b/mlx-internal-macros/src/lib.rs new file mode 100644 index 00000000..b08cc826 --- /dev/null +++ b/mlx-internal-macros/src/lib.rs @@ -0,0 +1,193 @@ +extern crate proc_macro; +use darling::FromMeta; +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::punctuated::Punctuated; +use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, ItemStruct, Pat}; + +mod generate_builder; + +#[derive(Debug, FromMeta)] +enum DeviceType { + Cpu, + Gpu, +} + +#[derive(Debug)] +struct DefaultDeviceInput { + device: DeviceType, +} + +impl FromMeta for DefaultDeviceInput { + fn from_meta(meta: &syn::Meta) -> darling::Result { + let syn::Meta::NameValue(meta_name_value) = meta else { + return Err(darling::Error::unsupported_format( + "expected a name-value attribute", + )); + }; + + let ident = meta_name_value.path.get_ident().unwrap(); + assert_eq!(ident, "device", "expected `device`"); + + let device = DeviceType::from_expr(&meta_name_value.value)?; + + Ok(DefaultDeviceInput { device }) + } +} + +#[doc(hidden)] +#[proc_macro_attribute] +pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream { + let input = if !attr.is_empty() { + let meta = syn::parse_macro_input!(attr as syn::Meta); + Some(DefaultDeviceInput::from_meta(&meta).unwrap()) + } else { + None + }; + + let mut input_fn = parse_macro_input!(item as ItemFn); + let original_fn = input_fn.clone(); + + // Ensure function name convention + if !input_fn.sig.ident.to_string().contains("_device") { + panic!("Function name must end with '_device'"); + } + let new_fn_name = format_ident!("{}", &input_fn.sig.ident.to_string().replace("_device", "")); + input_fn.sig.ident = new_fn_name; + + // Filter out the `stream` parameter and reconstruct the Punctuated collection + let filtered_inputs = input_fn + .sig + .inputs + .iter() + .filter(|arg| match arg { + FnArg::Typed(pat_typed) => { + if let Pat::Ident(pat_ident) = &*pat_typed.pat { + pat_ident.ident != "stream" + } else { + true + } + } + _ => true, + }) + .cloned() + .collect::>(); + + input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs); + + // Prepend default stream initialization + let default_stream_stmt = match input.map(|input| input.device) { + Some(DeviceType::Cpu) => parse_quote! { + let stream = StreamOrDevice::cpu(); + }, + Some(DeviceType::Gpu) => parse_quote! { + let stream = StreamOrDevice::gpu(); + }, + None => parse_quote! { + let stream = StreamOrDevice::default(); + }, + }; + input_fn.block.stmts.insert(0, default_stream_stmt); + + // Combine the original and modified functions into the output + let expanded = quote! { + #original_fn + + #input_fn + }; + + TokenStream::from(expanded) +} + +#[doc(hidden)] +#[proc_macro_derive(GenerateDtypeTestCases)] +pub fn generate_test_cases(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + + let tests = quote! { + /// MLX's rules for promoting two dtypes. + #[rustfmt::skip] + const TYPE_RULES: [[Dtype; 13]; 13] = [ + // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64 + [Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bool + [Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint8 + [Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint16 + [Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint32 + [Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint64 + [Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int8 + [Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int16 + [Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int32 + [Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int64 + [Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float16 + [Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float32 + [Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bfloat16 + [Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64 + ]; + + #[cfg(test)] + mod generated_tests { + use super::*; + use strum::IntoEnumIterator; + use pretty_assertions::assert_eq; + + #[test] + fn test_all_combinations() { + for a in #name::iter() { + for b in #name::iter() { + let result = a.promote_with(b); + let expected = TYPE_RULES[a as usize][b as usize]; + assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b)); + } + } + } + } + }; + + TokenStream::from(tests) +} + +/// This is for internal use only +/// +/// The struct that this macro is applied to should NOT derive the `Default` trait. +/// +/// This macro takes the following attributes: +/// +/// - `generate_builder`: This attribute should be applied on the struct. +/// +/// ## Arguments +/// +/// - `generate_build_fn = `: It defaults to `true` if not specified. If `true`, it will: +/// 1. generate a `Builder::build` function that takes the mandatory fields as arguments +/// and returns the struct. +/// 2. genertae a `::new` function that takes the mandatory fields as arguments and +/// returns the struct. This is a convenience function that simply calls +/// `Builder::new().build(...)`. Additionally, if there is NO mandatory field, it +/// will implement the `Default` trait for the struct. +/// +/// - `optional`: This attribute should be applied on the field. It indicates that the field is +/// optional. Behaviour of the generated builder struct depends on the argument of this attribute. +/// +/// ## Arguments +/// +/// - `skip = `: Default `false`. If `true`, the macro will NOT generate a setter for this +/// field in the builder struct. It will also NOT wrap the field in an `Option` in the struct, +/// and this field will remain as its original type in the builder struct. It is the user's +/// responsibility to implement the setter for this field in the builder struct. +/// +/// The `build` function cannot be generated if any field is marked as `skip = true`, and an +/// error will be shown in that case. +/// +/// - `default_value = `: This argument is required if no field is marked as `skip = true`. +/// It specifies the default value for the field. The value should be a `Path` (something that +/// is interpreted as a `syn::Path`) to a constant or an enum variant. +/// +/// If either `generate_build_fn` is `false` or any field is marked as `skip = true`, this +/// argument is not required. +#[doc(hidden)] +#[proc_macro_derive(GenerateBuilder, attributes(generate_builder, optional))] +pub fn derive_generate_builder(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + let builder = generate_builder::expand_generate_builder(&input).unwrap(); + TokenStream::from(builder) +} diff --git a/mlx-internal-macros/tests/test_generate_builder.rs b/mlx-internal-macros/tests/test_generate_builder.rs new file mode 100644 index 00000000..d10840fb --- /dev/null +++ b/mlx-internal-macros/tests/test_generate_builder.rs @@ -0,0 +1,29 @@ +use mlx_internal_macros::GenerateBuilder; + +#[derive(Debug, Default, GenerateBuilder)] +struct TestStruct { + #[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_1)] + opt_field_1: i32, + #[optional(default_value = TestStruct::DEFAULT_OPT_FIELD_2)] + opt_field_2: i32, + mandatory_field_1: i32, + mandatory_field_2: i32, +} + +impl TestStruct { + pub const DEFAULT_OPT_FIELD_1: i32 = 1; + pub const DEFAULT_OPT_FIELD_2: i32 = 2; +} + +#[test] +fn build_test_struct() { + let test_struct = TestStruct::builder() + .opt_field_1(2) + .opt_field_2(3) + .build(4, 5); + + assert_eq!(test_struct.opt_field_1, 2); + assert_eq!(test_struct.opt_field_2, 3); + assert_eq!(test_struct.mandatory_field_1, 4); + assert_eq!(test_struct.mandatory_field_2, 5); +} diff --git a/mlx-macros/Cargo.toml b/mlx-macros/Cargo.toml index 37093c89..d9d5876d 100644 --- a/mlx-macros/Cargo.toml +++ b/mlx-macros/Cargo.toml @@ -18,3 +18,4 @@ proc-macro = true syn = { version = "2.0.60", features = ["full"] } quote = "1.0" darling = "0.20" +proc-macro2 = "1.0" diff --git a/mlx-macros/src/lib.rs b/mlx-macros/src/lib.rs index 0a854459..e5a501d6 100644 --- a/mlx-macros/src/lib.rs +++ b/mlx-macros/src/lib.rs @@ -1,144 +1,59 @@ extern crate proc_macro; -use darling::FromMeta; use proc_macro::TokenStream; -use quote::{format_ident, quote}; -use syn::punctuated::Punctuated; -use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat}; - -#[derive(Debug, FromMeta)] -enum DeviceType { - Cpu, - Gpu, -} - -#[derive(Debug)] -struct DefaultDeviceInput { - device: DeviceType, -} - -impl FromMeta for DefaultDeviceInput { - fn from_meta(meta: &syn::Meta) -> darling::Result { - let syn::Meta::NameValue(meta_name_value) = meta else { - return Err(darling::Error::unsupported_format( - "expected a name-value attribute", - )); - }; - - let ident = meta_name_value.path.get_ident().unwrap(); - assert_eq!(ident, "device", "expected `device`"); - - let device = DeviceType::from_expr(&meta_name_value.value)?; - - Ok(DefaultDeviceInput { device }) - } -} - -#[proc_macro_attribute] -pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream { - let input = if !attr.is_empty() { - let meta = syn::parse_macro_input!(attr as syn::Meta); - Some(DefaultDeviceInput::from_meta(&meta).unwrap()) - } else { - None - }; - - let mut input_fn = parse_macro_input!(item as ItemFn); - let original_fn = input_fn.clone(); - - // Ensure function name convention - if !input_fn.sig.ident.to_string().contains("_device") { - panic!("Function name must end with '_device'"); - } - let new_fn_name = format_ident!("{}", &input_fn.sig.ident.to_string().replace("_device", "")); - input_fn.sig.ident = new_fn_name; - - // Filter out the `stream` parameter and reconstruct the Punctuated collection - let filtered_inputs = input_fn - .sig - .inputs - .iter() - .filter(|arg| match arg { - FnArg::Typed(pat_typed) => { - if let Pat::Ident(pat_ident) = &*pat_typed.pat { - pat_ident.ident != "stream" - } else { - true - } - } - _ => true, - }) - .cloned() - .collect::>(); - - input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs); - - // Prepend default stream initialization - let default_stream_stmt = match input.map(|input| input.device) { - Some(DeviceType::Cpu) => parse_quote! { - let stream = StreamOrDevice::cpu(); - }, - Some(DeviceType::Gpu) => parse_quote! { - let stream = StreamOrDevice::gpu(); - }, - None => parse_quote! { - let stream = StreamOrDevice::default(); - }, - }; - input_fn.block.stmts.insert(0, default_stream_stmt); - - // Combine the original and modified functions into the output - let expanded = quote! { - #original_fn - - #input_fn - }; - - TokenStream::from(expanded) -} - -#[proc_macro_derive(GenerateDtypeTestCases)] -pub fn generate_test_cases(input: TokenStream) -> TokenStream { +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +mod module_parameters; + +/// Derive the `ModuleParameters` trait for a struct. Mark a field with `#[param]` attribute to +/// include it in the parameters. The field type must implement the `Parameter` trait defined in +/// `mlx-nn-module` crate. +/// +/// Make sure to include `mlx-nn-module` as a dependency in your `Cargo.toml`. +/// +/// # Example +/// +/// ```rust, ignore +/// use mlx_macros::ModuleParameters; +/// use mlx_rs::module::{ModuleParameters, Param}; +/// +/// #[derive(ModuleParameters)] +/// struct Example { +/// #[param] +/// regular: Param, +/// +/// #[param] +/// optional: Param>, +/// +/// #[param] +/// nested: Param, +/// +/// #[param] +/// vec_nested: Param>, +/// +/// #[param] +/// trait_object: Param>, +/// +/// #[param] +/// trait_object_vec: Param>>, +/// } +/// +/// #[derive(ModuleParameters)] +/// struct Inner { +/// #[param] +/// a: Param, +/// } +/// ``` +#[proc_macro_derive(ModuleParameters, attributes(param))] +pub fn derive_module_parameters(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; - - let tests = quote! { - /// MLX's rules for promoting two dtypes. - #[rustfmt::skip] - const TYPE_RULES: [[Dtype; 13]; 13] = [ - // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64 - [Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bool - [Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint8 - [Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint16 - [Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint32 - [Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint64 - [Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int8 - [Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int16 - [Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int32 - [Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int64 - [Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float16 - [Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float32 - [Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bfloat16 - [Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64 - ]; + let module_param_impl = module_parameters::expand_module_parameters(&input).unwrap(); - #[cfg(test)] - mod generated_tests { - use super::*; - use strum::IntoEnumIterator; - use pretty_assertions::assert_eq; - - #[test] - fn test_all_combinations() { - for a in #name::iter() { - for b in #name::iter() { - let result = a.promote_with(b); - let expected = TYPE_RULES[a as usize][b as usize]; - assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b)); - } - } - } - } + let output = quote! { + const _: () = { + extern crate mlx_rs as _mlx_rs; + #module_param_impl + }; }; - - TokenStream::from(tests) + TokenStream::from(output) } diff --git a/mlx-macros/src/module_parameters.rs b/mlx-macros/src/module_parameters.rs new file mode 100644 index 00000000..c9c732b3 --- /dev/null +++ b/mlx-macros/src/module_parameters.rs @@ -0,0 +1,77 @@ +use syn::{DataStruct, DeriveInput, Generics, Ident}; + +pub(crate) fn expand_module_parameters( + input: &DeriveInput, +) -> Result { + let struct_ident = &input.ident; + let generics = &input.generics; + match &input.data { + syn::Data::Struct(data) => { + expand_module_parameters_for_struct(struct_ident, generics, data) + } + _ => Err(syn::Error::new_spanned( + input, + "ModuleParameters can only be derived for structs", + )), + } +} + +fn expand_module_parameters_for_struct( + ident: &Ident, + generics: &Generics, + data: &DataStruct, +) -> Result { + let fields = match &data.fields { + syn::Fields::Named(fields) => { + // filter out fields with #[param] + fields + .named + .iter() + .filter(|field| field.attrs.iter().any(|attr| attr.path().is_ident("param"))) + .collect() + } + syn::Fields::Unit => vec![], + syn::Fields::Unnamed(_) => { + return Err(syn::Error::new_spanned( + ident, + "ModuleParameters cannot be derived for structs with unnamed fields", + )) + } + }; + + Ok(impl_module_parameters_for_struct(ident, generics, fields)) +} + +fn impl_module_parameters_for_struct( + ident: &Ident, + generics: &Generics, + fields: Vec<&syn::Field>, +) -> proc_macro2::TokenStream { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect(); + quote::quote! { + impl #impl_generics _mlx_rs::module::ModuleParameters for #ident #ty_generics #where_clause { + fn parameters(&self) -> _mlx_rs::module::ModuleParamRef<'_> { + let mut parameters = _mlx_rs::nested::NestedHashMap::new(); + #(parameters.insert(stringify!(#field_names), _mlx_rs::module::Parameter::as_nested_value(&self.#field_names));)* + parameters + } + + fn parameters_mut(&mut self) -> _mlx_rs::module::ModuleParamMut<'_> { + let mut parameters = _mlx_rs::nested::NestedHashMap::new(); + #(parameters.insert(stringify!(#field_names), _mlx_rs::module::Parameter::as_nested_value_mut(&mut self.#field_names));)* + parameters + } + + fn trainable_parameters(&self) -> _mlx_rs::module::ModuleParamRef<'_> { + let mut parameters = _mlx_rs::nested::NestedHashMap::new(); + #( + if let Some(field) = _mlx_rs::module::Parameter::as_trainable_nested_value(&self.#field_names) { + parameters.insert(stringify!(#field_names), field); + } + )* + parameters + } + } + } +} diff --git a/mlx-nn/Cargo.toml b/mlx-nn/Cargo.toml index 72b9a564..101e5769 100644 --- a/mlx-nn/Cargo.toml +++ b/mlx-nn/Cargo.toml @@ -4,4 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] -mlx-rs = { workspace = true } \ No newline at end of file +mlx-rs.workspace = true +mlx-macros.workspace = true +mlx-internal-macros.workspace = true +thiserror.workspace = true + +[dev-dependencies] +float_eq.workspace = true \ No newline at end of file diff --git a/mlx-nn/src/activation.rs b/mlx-nn/src/activation.rs new file mode 100644 index 00000000..896d071b --- /dev/null +++ b/mlx-nn/src/activation.rs @@ -0,0 +1,1512 @@ +use std::f32::consts::PI; + +use mlx_internal_macros::GenerateBuilder; +use mlx_macros::ModuleParameters; +use mlx_rs::module::{Module, Param}; +use mlx_rs::{ + array, + error::Exception, + ops::{abs, exp, log_sum_exp, maximum, minimum, multiply, which}, + transforms::compile::compile, + Array, +}; + +/// Applies the element-wise sigmoid logistic sigmoid. +/// +/// For details, please see +/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html) +/// +/// This is: +/// +/// ```rust, ignore +/// sigmoid(x) +/// ``` +pub fn sigmoid(x: impl AsRef) -> Array { + mlx_rs::ops::sigmoid(x.as_ref()) +} + +/// Applies the Rectified Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(x, 0) +/// ``` +pub fn relu(x: impl AsRef) -> Result { + mlx_rs::ops::maximum(x.as_ref(), &array!(0)) +} + +/// Applies the Leaky Rectified Linear Unit. +/// +/// `neg_slope` is default to 0.01 if not provided. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(neg_slope * x, x) +/// ``` +pub fn leaky_relu( + x: impl AsRef, + neg_slope: impl Into>, +) -> Result { + let neg_slope = array!(neg_slope.into().unwrap_or(0.01)); + // We have to use this indirection, otherwise the compiler cannot + // infer the lifetime of the value returned by the closure properly + compiled_leaky_relu(x.as_ref(), &neg_slope) +} + +/// Applies the Log Softmax function. +/// +/// This is: +/// +/// ```rust, ignore +/// x - log_sum_exp(x, axis, true) +/// ``` +pub fn log_softmax(x: impl AsRef, axis: impl Into>) -> Result { + let x = x.as_ref(); + let axis = axis.into().unwrap_or(-1); + x.subtract(log_sum_exp(x, &[axis], true)?) +} + +/// Applies the Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// which(x.gt(0), x, alpha * (exp(x) - 1)) +/// ``` +/// +/// # Params +/// +/// - `x`: The input array +/// - `alpha`: Default to 1.0 if not provided +pub fn elu(x: impl AsRef, alpha: impl Into>) -> Result { + let alpha = array!(alpha.into().unwrap_or(1.0)); + // We have to use this indirection, otherwise the compiler cannot + // infer the lifetime of the value returned by the closure properly + compiled_elu(x.as_ref(), &alpha) +} + +/// Applies the Rectified Linear Unit 6. +/// +/// This is: +/// +/// ```rust, ignore +/// minimum(maximum(x, 0), 6) +/// ``` +pub fn relu6(x: impl AsRef) -> Result { + compiled_relu6(x.as_ref()) +} + +/// Applies the Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// log_add_exp(x, 0) +/// ``` +pub fn softplus(x: impl AsRef) -> Result { + mlx_rs::ops::log_add_exp(x.as_ref(), &array!(0)) +} + +/// Applies the Softsign function. +/// +/// This is: +/// +/// ```rust, ignore +/// x / (1 + abs(x)) +/// ``` +pub fn softsign(x: impl AsRef) -> Result { + compiled_softsign(x.as_ref()) +} + +/// Applies the Continuously Differentiable Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(x, 0) + alpha * (exp(minimum(x, 0) / alpha) - 1) +/// ``` +pub fn celu(x: impl AsRef, alpha: impl Into>) -> Result { + let alpha = array!(alpha.into().unwrap_or(1.0)); + // We have to use this indirection, otherwise the compiler cannot + // infer the lifetime of the value returned by the closure properly + compiled_celu(x.as_ref(), &alpha) +} + +/// Applies the Sigmoid Linear Unit. Also known as Swish. +/// +/// This is: +/// +/// ```rust, ignore +/// x * sigmoid(x) +/// ``` +pub fn silu(x: impl AsRef) -> Result { + compiled_silu(x.as_ref()) +} + +/// Applies the Log Sigmoid function. +/// +/// This is: +/// +/// ```rust, ignore +/// -softplus(-x) +/// ``` +pub fn log_sigmoid(x: impl AsRef) -> Result { + compiled_log_sigmoid(x.as_ref()) +} + +/// Applies the Gaussian Error Linear Units function. +/// +/// This is: +/// +/// ```rust, ignore +/// x * (1 + erf(x / 2.sqrt())) / 2 +/// ``` +pub fn gelu(x: impl AsRef) -> Result { + compiled_gelu(x.as_ref()) +} + +/// An approximation to Gaussian Error Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// 0.5 * x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x ** 3))) +/// ``` +pub fn gelu_approximate(x: impl AsRef) -> Result { + compiled_gelu_approximate(x.as_ref()) +} + +/// A fast approximation to Gaussian Error Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// x * sigmoid(1.773 * x) +/// ``` +pub fn gelu_fast_approximate(x: impl AsRef) -> Result { + compiled_gelu_fast_approximate(x.as_ref()) +} + +/// Applies the gated linear unit function. +/// +/// This function splits the `axis` dimension of the input into two halves +/// (`a` and `b`) and applies `a * sigmoid(b)`. +pub fn glu(x: impl AsRef, axis: impl Into>) -> Result { + let split = x.as_ref().split_equal(2, axis)?; + let (a, b) = (&split[0], &split[1]); + Ok(a * sigmoid(b)) +} + +/// Applies the Step Activation Function. +/// +/// This function implements a binary step activation, where the output is set +/// to 1 if the input is greater than a specified threshold, and 0 otherwise. +/// +/// This is: +/// +/// ```rust, ignore +/// r#where(x.gt(threshold), 1, 0) +/// ``` +pub fn step(x: impl AsRef, threshold: impl Into>) -> Result { + let threshold = array!(threshold.into().unwrap_or(0.0)); + mlx_rs::ops::r#where(&x.as_ref().gt(threshold)?, &array!(1), &array!(0)) +} + +/// Applies the Scaled Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// elu(x, 1.67326) * 1.0507 +/// ``` +pub fn selu(x: impl AsRef) -> Result { + compiled_selu(x.as_ref()) +} + +/// Applies the element-wise parametric ReLU. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(0, x) + alpha * minimum(0, x) +/// ``` +pub fn prelu(x: impl AsRef, alpha: impl AsRef) -> Result { + compiled_prelu(x.as_ref(), alpha.as_ref()) +} + +/// Applies the Mish function, element-wise. +/// +/// Mish: A Self Regularized Non-Monotonic Neural Activation Function. +/// +/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681) +/// +/// This is: +/// +/// ```rust, ignore +/// x * tanh(softplus(x)) +/// ``` +pub fn mish(x: impl AsRef) -> Result { + compiled_mish(x.as_ref()) +} + +/// Applies the hardswish function, element-wise. +/// +/// This is: +/// +/// ```rust, ignore +/// x * minimum(maximum(x + 3, 0), 6) / 6 +/// ``` +pub fn hard_swish(x: impl AsRef) -> Result { + compiled_hard_swish(x.as_ref()) +} + +/// Applies the gated linear unit function. +/// +/// This splits the `axis` dimension of the input into two halves +/// (`a` and `b`) and applies `a * sigmoid(b)`. +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct Glu { + /// The axis to split the input tensor. Default to [`Glu::DEFAULT_AXIS`] if not provided. + #[optional(default_value = Glu::DEFAULT_AXIS)] + pub axis: i32, +} + +impl Glu { + /// The default axis value. + pub const DEFAULT_AXIS: i32 = -1; +} + +impl Module for Glu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + glu(x, self.axis).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the element-wise logistic sigmoid. +/// +/// For details, please see +/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html) +/// +/// This is: +/// +/// ```rust, ignore +/// sigmoid(x) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Sigmoid; + +impl Module for Sigmoid { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + Ok(sigmoid(x)) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Mish function, element-wise. +/// +/// Mish: A Self Regularized Non-Monotonic Neural Activation Function. +/// +/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681) +/// +/// This is: +/// +/// ```rust, ignore +/// x * tanh(softplus(x)) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Mish; + +impl Module for Mish { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + mish(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Rectified Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(x, 0) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Relu; + +impl Module for Relu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + relu(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Leaky Rectified Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(neg_slope * x, x) +/// ``` +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct LeakyRelu { + /// The negative slope. Default to [`LeakyReLU::`] if not provided. + #[optional(default_value = LeakyRelu::DEFAULT_NEG_SLOPE)] + pub neg_slope: f32, +} + +impl LeakyRelu { + /// The default negative slope value. + pub const DEFAULT_NEG_SLOPE: f32 = 0.01; +} + +impl Module for LeakyRelu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + leaky_relu(x, self.neg_slope).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Rectified Linear Unit 6. +/// +/// This is: +/// +/// ```rust, ignore +/// minimum(&maximum(x, 0).unwrap(), 6).unwrap() +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Relu6; + +impl Module for Relu6 { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + relu6(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Softmax function. +/// +/// This is: +/// +/// ```rust, ignore +/// softmax(&x, None, None) +/// ``` +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct Softmax { + /// The axis to apply the softmax. + #[optional(default_value = Softmax::DEFAULT_AXIS)] + pub axis: i32, +} + +impl Softmax { + /// The default axis value. + pub const DEFAULT_AXIS: i32 = -1; +} + +impl Module for Softmax { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + Ok(mlx_rs::ops::softmax(x, &[self.axis], None)) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Softplus function. +/// +/// This is: +/// +/// ```rust, ignore +/// log_add_exp(x, 0) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Softplus; + +impl Module for Softplus { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + softplus(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Softsign function. +/// +/// This is: +/// +/// ```rust, ignore +/// x / (array!(1) + abs(x) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Softsign; + +impl Module for Softsign { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + softsign(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Continuously Differentiable Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(x, 0.0).unwrap() +/// + alpha * (exp(&(minimum(x, 0.0).unwrap() / alpha)) - 1) +/// ``` +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct Celu { + /// The alpha value. Default to [`Celu::DEFAULT_ALPHA`] if not provided. + #[optional(default_value = Celu::DEFAULT_ALPHA)] + pub alpha: f32, +} + +impl Celu { + /// The default alpha value. + pub const DEFAULT_ALPHA: f32 = 1.0; +} + +impl Module for Celu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + celu(x, self.alpha).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Sigmoid Linear Unit. Also known as Swish. +/// +/// This is: +/// +/// ```rust, ignore +/// x * sigmoid(x) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Silu; + +impl Module for Silu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + silu(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Log Softmax function. +/// +/// This is: +/// +/// ```rust, ignore +/// x - log_sum_exp(x, axis, true) +/// ``` +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct LogSoftmax { + /// The axis value. Default to [`LogSoftmax::DEFAULT_AXIS`] if not provided. + #[optional(default_value = LogSoftmax::DEFAULT_AXIS)] + pub axis: i32, +} + +impl LogSoftmax { + /// The default axis value. + pub const DEFAULT_AXIS: i32 = -1; +} + +impl Module for LogSoftmax { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + log_softmax(x, self.axis).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Log Sigmoid function. +/// +/// This is: +/// +/// ```rust, ignore +/// -softplus(-x) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct LogSigmoid; + +impl Module for LogSigmoid { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + log_sigmoid(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the element-wise parametric ReLU. +/// +/// This is: +/// +/// ```rust, ignore +/// maximum(0, x) + alpha * minimum(0, x) +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Prelu { + /// The alpha value. See [`prelu`] for more details. + #[param] + pub weight: Param, // TODO: double check if this is trainable +} + +/// The builder for the Prelu module. +#[derive(Debug, Clone, Default)] +pub struct PreluBuilder { + /// The count. Default to [`Prelu::DEFAULT_COUNT`] if not provided. + pub count: Option, + + /// The value. Default to [`Prelu::DEFAULT_VALUE`] if not provided. + pub value: Option, +} + +impl PreluBuilder { + /// Sets the count value. + pub fn count(mut self, count: i32) -> Self { + self.count = Some(count); + self + } + + /// Sets the value. + pub fn value(mut self, value: f32) -> Self { + self.value = Some(value); + self + } + + /// Builds the Prelu module. + pub fn build(self) -> Result { + let count = self.count.unwrap_or(Prelu::DEFAULT_COUNT); + let value = self.value.unwrap_or(Prelu::DEFAULT_VALUE); + let weight = Param::new(mlx_rs::ops::full::(&[count], &array!(value))?); + Ok(Prelu { weight }) + } +} + +impl Default for Prelu { + fn default() -> Self { + Prelu::new() + } +} + +impl Prelu { + /// The default count value. + pub const DEFAULT_COUNT: i32 = 1; + + /// The default value. + pub const DEFAULT_VALUE: f32 = 0.25; + + /// Creates a new PreluBuilder. + pub fn builder() -> PreluBuilder { + PreluBuilder::default() + } + + /// Creates a new Prelu module with the default values. + pub fn new() -> Prelu { + PreluBuilder::default() + .build() + .expect("Default value should be valid") + } +} + +impl Module for Prelu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + prelu(x, &self.weight).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Variants of Gaussian Error Linear Units function. +#[derive(Debug, Clone, Copy, Default)] +pub enum GeluApprox { + /// Uses [`gelu`] + #[default] + None, + + /// Uses [`gelu_approximate`] + Precise, + + /// Uses [`gelu_fast_approximate`] + Fast, +} + +/// Applies the Gaussian Error Linear Units function. +/// +/// There are three variants: +/// +/// - `GeluApprox::None`: Uses [`gelu`]. This is the default. +/// - `GeluApprox::Precise`: Uses [`gelu_approximate`] +/// - `GeluApprox::Fast`: Uses [`gelu_fast_approximate`] +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct Gelu { + /// The approximation to use. Default to `GeluApprox::None` if not provided. + #[optional(default_value = GeluApprox::None)] + pub approximate: GeluApprox, +} + +impl Module for Gelu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + match self.approximate { + GeluApprox::None => gelu(x).map_err(Into::into), + GeluApprox::Precise => gelu_approximate(x).map_err(Into::into), + GeluApprox::Fast => gelu_fast_approximate(x).map_err(Into::into), + } + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the hyperbolic tangent function +#[derive(Debug, Clone, ModuleParameters)] +pub struct Tanh; + +impl Module for Tanh { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + Ok(mlx_rs::ops::tanh(x)) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the hardswish function, element-wise +/// +/// This is: +/// +/// ```rust, ignore +/// x * minimum(maximum(x + 3, 0), 6) / 6 +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct HardSwish; + +impl Module for HardSwish { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + hard_swish(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Step Activation Function. +/// +/// This function implements a binary step activation, where the output is set +/// to 1 if the input is greater than a specified threshold, and 0 otherwise. +/// +/// This is: +/// +/// ```rust, ignore +/// r#where(x.gt(threshold), 1, 0) +/// ``` +#[derive(Debug, Clone, ModuleParameters, GenerateBuilder)] +pub struct Step { + /// The threshold value. Default to [`Step::DEFAULT_THRESHOLD`] if not provided. + #[optional(default_value = Step::DEFAULT_THRESHOLD)] + pub threshold: f32, +} + +impl Step { + /// The default threshold value. + pub const DEFAULT_THRESHOLD: f32 = 0.0; +} + +impl Module for Step { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + step(x, self.threshold).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Applies the Scaled Exponential Linear Unit. +/// +/// This is: +/// +/// ```rust, ignore +/// elu(x, 1.67326) * 1.0507 +/// ``` +#[derive(Debug, Clone, ModuleParameters)] +pub struct Selu; + +impl Module for Selu { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + selu(x).map_err(Into::into) + } + + fn training_mode(&mut self, _: bool) {} +} + +/* -------------------------------------------------------------------------- */ +/* Compiled activation functions */ +/* -------------------------------------------------------------------------- */ + +#[inline] +fn compiled_leaky_relu(x: &Array, neg_slope: &Array) -> Result { + let f = |(x_, neg_slope_): (&Array, &Array)| { + // This will not panic because a scalar can always be broadcasted to any shape + let a = multiply(neg_slope_, x_)?; + maximum(&a, x_) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled((x, neg_slope)) +} + +#[inline] +fn compiled_elu(x: &Array, alpha: &Array) -> Result { + let f = |(x_, alpha_): (&Array, &Array)| { + which(&x_.gt(&array!(0.0))?, x_, alpha_ * (exp(x_) - array!(1.0))) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled((x, alpha)) +} + +#[inline] +fn compiled_relu6(x: &Array) -> Result { + let f = |x_: &Array| minimum(maximum(x_, &array!(0.0))?, &array!(6.0)); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_softsign(x: &Array) -> Result { + let f = |x_: &Array| x_ / (array!(1.0) + abs(x_)); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_celu(x: &Array, alpha: &Array) -> Result { + let f = |(x_, alpha_): (&Array, &Array)| { + maximum(x_, &array!(0.0))? + .add(alpha_.multiply(exp(&(minimum(x_, &array!(0.0))? / alpha_)) - array!(1.0))?) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled((x, alpha)) +} + +#[inline] +fn compiled_silu(x: &Array) -> Result { + let f = |x_: &Array| x_ * sigmoid(x_); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_log_sigmoid(x: &Array) -> Result { + let f = |x_: &Array| Ok(-softplus(&(-x_))?); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_gelu(x: &Array) -> Result { + use mlx_rs::ops::erf; + let f = |x_: &Array| x_ * (array!(1) + erf(&(x_ / array!(2f32.sqrt())))) / array!(2.0); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_gelu_approximate(x: &Array) -> Result { + use mlx_rs::ops::{sqrt, tanh}; + + let f = move |x_: &Array| { + // 0.5 * x * (1 + tanh(sqrt(2 / Float.pi) * (x + 0.044715 * x ** 3))) + array!(0.5).multiply(x_)?.multiply( + array!(1.0).add(tanh( + &(sqrt(&array!(2.0 / PI)) + .multiply(x_ + array!(0.044715).multiply(x_.power(&array!(3))?)?)?), + ))?, + ) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_gelu_fast_approximate(x: &Array) -> Result { + let f = |x_: &Array| x_ * sigmoid(&(array!(1.773) * x_)); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_selu(x: &Array) -> Result { + let f = |x_: &Array| elu(x_, 1.67326)?.multiply(array!(1.0507)); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_prelu(x: &Array, alpha: &Array) -> Result { + let f = |(x_, alpha_): (&Array, &Array)| { + maximum(&array!(0.0), x_)?.add(alpha_ * minimum(&array!(0.0), x_)?) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled((x, alpha)) +} + +#[inline] +fn compiled_mish(x: &Array) -> Result { + use mlx_rs::ops::tanh; + + let f = |x_: &Array| x_.multiply(tanh(&softplus(x_)?)); + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +#[inline] +fn compiled_hard_swish(x: &Array) -> Result { + let f = |x_: &Array| { + let max_x_plus_3 = maximum(&(x_ + array!(3.0)), &array!(0.0))?; + x_.multiply(minimum(&max_x_plus_3, &array!(6.0))?)? + .divide(&array!(6.0)) + }; + let mut compiled = compile(f, Some(true), None, None); + compiled(x) +} + +// The following tests are ported from the swift binding: +// mlx-swift/Tests/MLXTests/IntegrationTests.swift +#[cfg(test)] +mod tests { + use float_eq::assert_float_eq; + use mlx_rs::{random::uniform, Dtype}; + + use super::*; + + #[test] + fn test_glu() { + mlx_rs::random::seed(850); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.547_252_66, + abs <= 0.010_945_053 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 140.096_68, + abs <= 2.801_933_5 + ); + let result = Glu::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 8]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.333_276_75, + abs <= 0.006_665_535 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 42.659_424, + abs <= 0.853_188_46 + ); + } + + #[test] + fn test_sigmoid() { + mlx_rs::random::seed(589); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.529_697_9, + abs <= 0.010_593_958 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 135.602_66, + abs <= 2.712_053_3 + ); + let result = Sigmoid.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.627_014, + abs <= 0.012_540_28 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 160.515_58, + abs <= 3.210_311_7 + ); + } + + #[test] + fn test_mish() { + mlx_rs::random::seed(122); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.501_719_8, + abs <= 0.010_034_395 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 128.440_26, + abs <= 2.568_805_2 + ); + let result = Mish.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.395_375_73, + abs <= 0.007_907_514 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 101.216_19, + abs <= 2.024_323_7 + ); + } + + #[test] + fn test_relu() { + mlx_rs::random::seed(400); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.478_322_74, + abs <= 0.009_566_455 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 122.450_62, + abs <= 2.449_012_5 + ); + let result = Relu.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.478_322_74, + abs <= 0.009_566_455 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 122.450_62, + abs <= 2.449_012_5 + ); + } + + #[test] + fn test_leaky_relu() { + mlx_rs::random::seed(93); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.499_930_68, + abs <= 0.009_998_614 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 127.982_254, + abs <= 2.559_645_2 + ); + let result = LeakyRelu::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.499_930_68, + abs <= 0.009_998_614 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 127.982_254, + abs <= 2.559_645_2 + ); + } + + #[test] + fn test_relu6() { + mlx_rs::random::seed(379); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.493_258_66, + abs <= 0.009_865_173 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 126.274_216, + abs <= 2.525_484_3 + ); + let result = Relu6.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.493_258_66, + abs <= 0.009_865_173 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 126.274_216, + abs <= 2.525_484_3 + ); + } + + #[test] + fn test_softmax() { + mlx_rs::random::seed(853); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.514_396_3, + abs <= 0.010_287_926_5 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 131.685_46, + abs <= 2.633_709_2 + ); + let result = Softmax::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.062_499_996, + abs <= 0.001_25 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 15.999_999, + abs <= 0.32 + ); + } + + #[test] + fn test_softplus() { + mlx_rs::random::seed(118); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.498_981_42, + abs <= 0.009_979_628 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 127.739_24, + abs <= 2.554_784_8 + ); + let result = Softplus.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.982_857_76, + abs <= 0.019_657_155 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 251.611_59, + abs <= 5.032_232 + ); + } + + #[test] + fn test_softsign() { + mlx_rs::random::seed(37); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.506_551_27, + abs <= 0.010_131_026 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 129.677_12, + abs <= 2.593_542_6 + ); + let result = Softsign.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.314_089_83, + abs <= 0.006_281_797 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 80.407, + abs <= 1.608_14 + ); + } + + #[test] + fn test_celu() { + mlx_rs::random::seed(620); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.466_748_18, + abs <= 0.009_334_964 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 119.487_53, + abs <= 2.389_750_7 + ); + let result = Celu::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.466_748_18, + abs <= 0.009_334_964 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 119.487_53, + abs <= 2.389_750_7 + ); + } + + #[test] + fn test_silu() { + mlx_rs::random::seed(22); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.502_970_6, + abs <= 0.010_059_412 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 128.760_47, + abs <= 2.575_209_4 + ); + let result = Silu.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.331_970_93, + abs <= 0.006_639_418_7 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 84.984_56, + abs <= 1.699_691_2 + ); + } + + #[test] + fn test_log_softmax() { + mlx_rs::random::seed(199); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.527_843_7, + abs <= 0.010_556_874 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 135.127_99, + abs <= 2.702_559_7 + ); + let result = LogSoftmax::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + -2.810_954_6, + abs <= 0.056_219_09 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + -719.604_4, + abs <= 14.392_087 + ); + } + + #[test] + fn test_log_sigmoid() { + mlx_rs::random::seed(984); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.510_977_7, + abs <= 0.010_219_553_5 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 130.810_29, + abs <= 2.616_205_7 + ); + let result = LogSigmoid.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + -0.479_598_55, + abs <= 0.009_591_971 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + -122.777_23, + abs <= 2.455_544_5 + ); + } + + #[test] + fn test_prelu() { + mlx_rs::random::seed(993); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.496_651_44, + abs <= 0.009_933_028 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 127.142_77, + abs <= 2.542_855_3 + ); + let result = Prelu::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.496_651_44, + abs <= 0.009_933_028 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 127.142_77, + abs <= 2.542_855_3 + ); + } + + #[test] + fn test_gelu() { + mlx_rs::random::seed(189); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.492_950_32, + abs <= 0.009_859_007 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 126.195_28, + abs <= 2.523_905_8 + ); + let result = Gelu::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.365_638_38, + abs <= 0.007_312_767_7 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 93.603_424, + abs <= 1.872_068_5 + ); + } + + #[test] + fn test_tanh() { + mlx_rs::random::seed(735); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.474_122_7, + abs <= 0.009_482_454_5 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 121.375_41, + abs <= 2.427_508_4 + ); + let result = Tanh.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.413_079_68, + abs <= 0.008_261_594 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 105.748_4, + abs <= 2.114_968 + ); + } + + #[test] + fn test_hardswish() { + mlx_rs::random::seed(126); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.491_892_46, + abs <= 0.009_837_849 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 125.924_47, + abs <= 2.518_489_4 + ); + let result = HardSwish.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.299_602_24, + abs <= 0.005_992_044_7 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 76.698_17, + abs <= 1.533_963_4 + ); + } + + #[test] + fn test_step() { + mlx_rs::random::seed(490); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.479_360_64, + abs <= 0.009_587_212_5 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 122.716_324, + abs <= 2.454_326_4 + ); + let result = Step::default().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Int32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 1.0, + abs <= 0.02 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 256.0, + abs <= 5.12 + ); + } + + #[test] + fn test_selu() { + mlx_rs::random::seed(215); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.493_026_8, + abs <= 0.009_860_536 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 126.214_86, + abs <= 2.524_297_2 + ); + let result = Selu.forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.518_023_2, + abs <= 0.010_360_463_5 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 132.613_94, + abs <= 2.652_278_7 + ); + } +} diff --git a/mlx-nn/src/container.rs b/mlx-nn/src/container.rs new file mode 100644 index 00000000..2e336a19 --- /dev/null +++ b/mlx-nn/src/container.rs @@ -0,0 +1,75 @@ +use std::borrow::Cow; + +use mlx_macros::ModuleParameters; +use mlx_rs::module::{Module, Param}; +use mlx_rs::{error::Exception, Array}; + +/// Marker trait for items that can be used in a `Sequential` module. +/// +/// It is implemented for all types that implement [`Module`] and [`std::fmt::Debug`]. +pub trait SequentialModuleItem: Module + std::fmt::Debug {} + +impl SequentialModuleItem for T +where + T: Module + std::fmt::Debug, + Err: std::error::Error + 'static, +{ +} + +/// A sequential layer. +/// +/// It calls each layer in sequence. +#[derive(Debug, ModuleParameters)] +pub struct Sequential { + /// The layers to be called in sequence. + #[param] + pub layers: Param>>>, +} + +impl Module for Sequential { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + let mut x = Cow::Borrowed(x); + + for layer in &self.layers.value { + x = Cow::Owned(layer.forward(x.as_ref())?); + } + + match x { + Cow::Owned(array) => Ok(array), + Cow::Borrowed(array) => Ok(array.deep_clone()), + } + } + + fn training_mode(&mut self, mode: bool) { + self.layers + .iter_mut() + .for_each(|layer| layer.training_mode(mode)); + } +} + +impl Default for Sequential { + fn default() -> Self { + Self::new() + } +} + +impl Sequential { + /// Creates a new [`Sequential`] module. + pub fn new() -> Self { + Self { + layers: Param::new(Vec::new()), + } + } + + /// Appends a layer to the sequential module. + pub fn append(mut self, layer: M) -> Self + where + M: Module + std::fmt::Debug + 'static, + Err: std::error::Error + 'static, + { + self.layers.push(Box::new(layer)); + self + } +} diff --git a/mlx-nn/src/convolution.rs b/mlx-nn/src/convolution.rs new file mode 100644 index 00000000..6fb83a37 --- /dev/null +++ b/mlx-nn/src/convolution.rs @@ -0,0 +1,538 @@ +use mlx_macros::ModuleParameters; +use mlx_rs::module::{Module, Param}; +use mlx_rs::{ + error::Exception, + ops::{conv1d, conv2d, zeros}, + random::uniform, + Array, +}; + +use crate::utils::{IntOrPair, IntOrTriple}; + +/// Builder for the `Conv1d` module. +#[derive(Debug, Clone, Default)] +pub struct Conv1dBuilder { + /// If `true`, add a learnable bias to the output. Default to [`Conv1d::DEFAULT_WITH_BIAS`] if not + /// specified. + pub with_bias: Option, + + /// Padding. Default to [`Conv1d::DEFAULT_PADDING`] if not specified. + pub padding: Option, + + /// Stride. Default to [`Conv1d::DEFAULT_STRIDE`] if not specified. + pub stride: Option, +} + +impl Conv1dBuilder { + /// Creates a new `Conv1dBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Sets the `with_bias` parameter. + pub fn with_bias(mut self, with_bias: impl Into>) -> Self { + self.with_bias = with_bias.into(); + self + } + + /// Sets the `padding` parameter. + pub fn padding(mut self, padding: impl Into>) -> Self { + self.padding = padding.into(); + self + } + + /// Sets the `stride` parameter. + pub fn stride(mut self, stride: impl Into>) -> Self { + self.stride = stride.into(); + self + } + + /// Builds a new `Conv1d` module. + pub fn build( + self, + input_channels: i32, + output_channels: i32, + kernel_size: i32, + ) -> Result { + let with_bias = self.with_bias.unwrap_or(Conv1d::DEFAULT_WITH_BIAS); + let padding = self.padding.unwrap_or(Conv1d::DEFAULT_PADDING); + let stride = self.stride.unwrap_or(Conv1d::DEFAULT_STRIDE); + + let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32); + let weight = uniform::<_, f32>( + -scale, + scale, + &[output_channels, kernel_size, input_channels], + None, + )?; + let bias = if with_bias { + Some(zeros::(&[output_channels])?) + } else { + None + }; + + Ok(Conv1d { + weight: Param::new(weight), + bias: Param::new(bias), + padding, + stride, + }) + } +} + +/// Applies a 1-dimensional convolution over the multi-channel input sequence. +/// +/// The channels are expected to be last i.e. the input shape should be `NLC` where: +/// +/// - `N` is the batch dimension +/// - `L` is the sequence length +/// - `C` is the number of input channels +#[derive(Debug, Clone, ModuleParameters)] +pub struct Conv1d { + /// The weight of the convolution layer. + #[param] + pub weight: Param, + + /// The bias of the convolution layer. + #[param] + pub bias: Param>, + + /// Padding. Default to 0 if not specified. + pub padding: i32, + + /// Stride. Default to 1 if not specified. + pub stride: i32, +} + +impl Conv1d { + /// Default value for `with_bias` if not specified. + pub const DEFAULT_WITH_BIAS: bool = true; + + /// Default value for `padding` if not specified. + pub const DEFAULT_PADDING: i32 = 0; + + /// Default value for `stride` if not specified. + pub const DEFAULT_STRIDE: i32 = 1; + + /// Creates a new `Conv1dBuilder`. + pub fn builder() -> Conv1dBuilder { + Conv1dBuilder::new() + } + + /// Creates a new Conv1d module with all optional parameters set to their default values. + pub fn new( + input_channels: i32, + output_channels: i32, + kernel_size: i32, + ) -> Result { + Conv1dBuilder::new().build(input_channels, output_channels, kernel_size) + } +} + +impl Module for Conv1d { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + let mut y = conv1d( + x, + self.weight.as_ref(), + self.stride, + self.padding, + None, + None, + )?; + if let Some(bias) = &self.bias.value { + y += bias; + } + Ok(y) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Builder for the `Conv2d` module. +#[derive(Debug, Clone, Default)] +pub struct Conv2dBuilder { + /// If `true`, add a learnable bias to the output. Default to [`Conv2d::DEFAULT_WITH_BIAS`] if not + /// specified. + with_bias: Option, + + /// Padding. Default to [`Conv2d::DEFAULT_PADDING`] if not specified. + padding: Option<(i32, i32)>, + + /// Stride. Default to [`Conv2d::DEFAULT_STRIDE`] if not specified. + stride: Option<(i32, i32)>, +} + +impl Conv2dBuilder { + /// Creates a new `Conv2dBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Sets the `with_bias` parameter. + pub fn with_bias(mut self, with_bias: impl Into>) -> Self { + self.with_bias = with_bias.into(); + self + } + + /// Sets the `padding` parameter. + pub fn padding(mut self, padding: impl Into>) -> Self { + self.padding = padding.into(); + self + } + + /// Sets the `stride` parameter. + pub fn stride(mut self, stride: impl Into>) -> Self { + self.stride = stride.into(); + self + } + + /// Builds a new `Conv2d` module. + pub fn build( + self, + input_channels: i32, + output_channels: i32, + kernel_size: (i32, i32), + ) -> Result { + let with_bias = self.with_bias.unwrap_or(Conv2d::DEFAULT_WITH_BIAS); + let padding = self.padding.unwrap_or(Conv2d::DEFAULT_PADDING); + let stride = self.stride.unwrap_or(Conv2d::DEFAULT_STRIDE); + + let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size.0 * kernel_size.1) as f32); + let weight = uniform::<_, f32>( + -scale, + scale, + &[ + output_channels, + kernel_size.0, + kernel_size.1, + input_channels, + ], + None, + )?; + let bias = if with_bias { + Some(zeros::(&[output_channels])?) + } else { + None + }; + + Ok(Conv2d { + weight: Param::new(weight), + bias: Param::new(bias), + padding, + stride, + }) + } +} + +/// Applies a 2-dimensional convolution over the multi-channel input image. +/// +/// The channels are expected to be last i.e. the input shape should be `NHWC` where: +/// +/// - `N` is the batch dimension +/// - `H` is the input image height +/// - `W` is the input image width +/// - `C` is the number of input channels +#[derive(Debug, Clone, ModuleParameters)] +pub struct Conv2d { + /// The weight of the convolution layer. + #[param] + pub weight: Param, + + /// The bias of the convolution layer. + #[param] + pub bias: Param>, + + /// Padding. Default to `(0, 0)` if not specified. + pub padding: (i32, i32), + + /// Stride. Default to `(1, 1)` if not specified. + pub stride: (i32, i32), +} + +impl Conv2d { + /// Default value for `with_bias` if not specified. + pub const DEFAULT_WITH_BIAS: bool = true; + + /// Default value for `padding` if not specified. + pub const DEFAULT_PADDING: (i32, i32) = (0, 0); + + /// Default value for `stride` if not specified. + pub const DEFAULT_STRIDE: (i32, i32) = (1, 1); + + /// Creates a new `Conv2dBuilder`. + pub fn builder() -> Conv2dBuilder { + Conv2dBuilder::new() + } + + /// Creates a new 2-dimensional convolution layer. + /// + /// # Params + /// + /// - `input_channels`: number of input channels + /// - `output_channels`: number of output channels + /// - `kernel_size`: size of the convolution filters + pub fn new( + input_channels: i32, + output_channels: i32, + kernel_size: impl IntOrPair, + ) -> Result { + let kernel_size = kernel_size.into_pair(); + Conv2dBuilder::new().build(input_channels, output_channels, kernel_size) + } +} + +impl Module for Conv2d { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + let mut y = conv2d( + x, + self.weight.as_ref(), + self.stride, + self.padding, + None, + None, + )?; + if let Some(bias) = &self.bias.value { + y += bias; + } + Ok(y) + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Builder for the `Conv3d` module. +#[derive(Debug, Clone, Default)] +pub struct Conv3dBuilder { + /// If `true`, add a learnable bias to the output. Default to [`Conv3d::DEFAULT_WITH_BIAS`] if not + /// specified. + with_bias: Option, + + /// Padding. Default to [`Conv3d::DEFAULT_PADDING`] if not specified. + padding: Option<(i32, i32, i32)>, + + /// Stride. Default to [`Conv3d::DEFAULT_STRIDE`] if not specified. + stride: Option<(i32, i32, i32)>, +} + +impl Conv3dBuilder { + /// Creates a new `Conv3dBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Sets the `with_bias` parameter. + pub fn with_bias(mut self, with_bias: impl Into>) -> Self { + self.with_bias = with_bias.into(); + self + } + + /// Sets the `padding` parameter. + pub fn padding(mut self, padding: impl Into>) -> Self { + self.padding = padding.into(); + self + } + + /// Sets the `stride` parameter. + pub fn stride(mut self, stride: impl Into>) -> Self { + self.stride = stride.into(); + self + } + + /// Builds a new `Conv3d` module. + pub fn build( + self, + input_channels: i32, + output_channels: i32, + kernel_size: (i32, i32, i32), + ) -> Result { + let with_bias = self.with_bias.unwrap_or(Conv3d::DEFAULT_WITH_BIAS); + let padding = self.padding.unwrap_or(Conv3d::DEFAULT_PADDING); + let stride = self.stride.unwrap_or(Conv3d::DEFAULT_STRIDE); + + let scale = f32::sqrt( + 1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32, + ); + let weight = uniform::<_, f32>( + -scale, + scale, + &[ + output_channels, + kernel_size.0, + kernel_size.1, + kernel_size.2, + input_channels, + ], + None, + )?; + let bias = if with_bias { + Some(zeros::(&[output_channels])?) + } else { + None + }; + + Ok(Conv3d { + weight: Param::new(weight), + bias: Param::new(bias), + padding, + stride, + }) + } +} + +/// Applies a 3-dimensional convolution over the multi-channel input image. +/// +/// The channels are expected to be last i.e. the input shape should be `NHWC` where: +/// +/// - `N` is the batch dimension +/// - `H` is the input image height +/// - `W` is the input image width +/// - `C` is the number of input channels +#[derive(Debug, Clone, ModuleParameters)] +pub struct Conv3d { + /// The weight of the convolution layer. + #[param] + pub weight: Param, + + /// The bias of the convolution layer. + #[param] + pub bias: Param>, + + /// Padding. Default to `(0, 0, 0)` if not specified. + pub padding: (i32, i32, i32), + + /// Stride. Default to `(1, 1, 1)` if not specified. + pub stride: (i32, i32, i32), +} + +impl Conv3d { + /// Default value for `with_bias` if not specified. + pub const DEFAULT_WITH_BIAS: bool = true; + + /// Default value for `padding` if not specified. + pub const DEFAULT_PADDING: (i32, i32, i32) = (0, 0, 0); + + /// Default value for `stride` if not specified. + pub const DEFAULT_STRIDE: (i32, i32, i32) = (1, 1, 1); + + /// Creates a new `Conv3dBuilder`. + pub fn builder() -> Conv3dBuilder { + Conv3dBuilder::new() + } + + /// Creates a new 3-dimensional convolution layer. + /// + /// # Params + /// + /// - `input_channels`: number of input channels + /// - `output_channels`: number of output channels + /// - `kernel_size`: size of the convolution filters + pub fn new( + input_channels: i32, + output_channels: i32, + kernel_size: impl IntOrTriple, + ) -> Result { + let kernel_size = kernel_size.into_triple(); + Conv3dBuilder::new().build(input_channels, output_channels, kernel_size) + } +} + +impl Module for Conv3d { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + let mut y = mlx_rs::ops::conv3d( + x, + self.weight.as_ref(), + self.stride, + self.padding, + None, + None, + )?; + if let Some(bias) = &self.bias.value { + y += bias; + } + Ok(y) + } + + fn training_mode(&mut self, _: bool) {} +} + +// The following tests are ported from the swift bindings: +// mlx-swift/Tests/MLXTests/IntegrationTests.swift +#[cfg(test)] +mod tests { + use float_eq::assert_float_eq; + use mlx_rs::module::Module; + use mlx_rs::{random::uniform, Dtype}; + + use crate::Conv1d; + + #[test] + fn test_conv1d() { + mlx_rs::random::seed(819); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.512_987_5, + abs <= 0.010_259_75 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 131.324_8, + abs <= 2.626_496 + ); + let result = Conv1d::new(16, 2, 8).unwrap().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 1, 2]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.264_865_2, + abs <= 0.005_297_303_7 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 1.059_460_8, + abs <= 0.021_189_215 + ); + } + + #[test] + fn test_conv2d() { + mlx_rs::random::seed(62); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 8, 4]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.522_504_27, + abs <= 0.010_450_086 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 267.522_2, + abs <= 5.350_444 + ); + let result = crate::Conv2d::new(4, 2, (8, 8)) + .unwrap() + .forward(&a) + .unwrap(); + assert_eq!(result.shape(), &[2, 1, 1, 2]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + -0.279_321_5, + abs <= 0.005_586_43 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + -1.117_286, + abs <= 0.022_345_72 + ); + } +} diff --git a/mlx-nn/src/dropout.rs b/mlx-nn/src/dropout.rs new file mode 100644 index 00000000..4cc9a720 --- /dev/null +++ b/mlx-nn/src/dropout.rs @@ -0,0 +1,366 @@ +use mlx_macros::ModuleParameters; +use mlx_rs::module::Module; +use mlx_rs::{array, error::Exception, ops::multiply, random::bernoulli}; + +use crate::error::DropoutBuildError; + +macro_rules! impl_dropout_builder { + ($builder_name:ident, $target_name:ident, $default_p:expr, $default_training:expr) => { + /// Builder for [`$target_name`]. + #[derive(Debug, Clone, Default)] + pub struct $builder_name { + /// Probability of zeroing an element. + p: Option, + } + + impl $builder_name { + /// Creates a new dropout builder. + pub fn new() -> Self { + Self::default() + } + + /// Sets the probability of zeroing an element. + pub fn p(mut self, p: impl Into) -> Self { + self.p = Some(p.into()); + self + } + + /// Builds a dropout layer. + pub fn build(self) -> Result<$target_name, DropoutBuildError> { + let p = self.p.unwrap_or($default_p); + + if !(0.0..1.0).contains(&p) { + return Err(DropoutBuildError::InvalidProbability); + } + + Ok($target_name { + one_minus_p: 1.0 - p, + training: $default_training, + }) + } + } + + impl $target_name { + /// Creates a builder for the dropout layer. + pub fn builder() -> $builder_name { + $builder_name::new() + } + + /// Creates a new dropout layer with the default parameters. + pub fn new() -> Self { + $builder_name::new() + .build() + .expect("Default values are valid") + } + } + + impl Default for $target_name { + fn default() -> Self { + Self::new() + } + } + }; +} + +impl_dropout_builder!( + DropoutBuilder, + Dropout, + Dropout::DEFAULT_P, + Dropout::DEFAULT_TRAINING +); + +/// Randomly zero a portion of the elements during training. +/// +/// The remaining elements are multiplied with `1 / (1-p)` where +/// `p` is the probability of zeroing an element. This is done so the +/// expected value of a given element will remain the same. +#[derive(Debug, Clone, ModuleParameters)] +pub struct Dropout { + /// `1-p`, where `p` is the probability of zeroing an element. `p` is default to + /// [`Dropout::DEFAULT_P`] if not specified. + pub one_minus_p: f32, + + /// Whether the layer is in training mode. Default to [`Dropout::DEFAULT_TRAINING`] if not + /// specified. + pub training: bool, +} + +impl Dropout { + /// Default value for the probability of zeroing an element. + pub const DEFAULT_P: f32 = 0.5; + + /// Default value for the training mode. + pub const DEFAULT_TRAINING: bool = true; +} + +impl Module for Dropout { + type Error = Exception; + + fn forward(&self, x: &mlx_rs::Array) -> Result { + if self.one_minus_p == 1.0 || !self.training { + return Ok(x.clone()); + } + + let p1 = array!(self.one_minus_p); + let mask = bernoulli(&p1, x.shape(), None)?; + multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into) + } + + fn training_mode(&mut self, mode: bool) { + self.training = mode; + } +} + +impl_dropout_builder!( + Dropout2dBuilder, + Dropout2d, + Dropout2d::DEFAULT_P, + Dropout2d::DEFAULT_TRAINING +); + +/// Apply 2D channel-wise dropout during training. +/// +/// Randomly zero out entire channels independently with probability `p`. +/// This layer expects the channels to be last, i.e. the input shape should be +/// `NWHC` or `WHC` where:`N` is the batch dimension,`H` is the input +/// image height,`W` is the input image width, and`C` is the number of +/// input channels +/// +/// The remaining channels are scaled by `1 / (1-p)` to +/// maintain the expected value of each element. Unlike traditional dropout, +/// which zeros individual entries, this layer zeros entire channels. This is +/// beneficial for early convolution layers where adjacent pixels are +/// correlated. In such case, traditional dropout may not effectively +/// regularize activations. For more details, see [1]. +/// +/// [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. +/// Efficient Object Localization Using Convolutional Networks. CVPR 2015. +#[derive(Debug, Clone, ModuleParameters)] +pub struct Dropout2d { + /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to + /// [`Dropout2d::DEFAULT_P`] if not specified. + pub one_minus_p: f32, + + /// Whether the layer is in training mode. Default to [`Dropout2d::DEFAULT_TRAINING`] if not + /// specified. Default to [`Dropout2d::DEFAULT_TRAINING`] if not specified. + pub training: bool, +} + +impl Dropout2d { + /// Default value for the probability of zeroing a channel. + pub const DEFAULT_P: f32 = 0.5; + + /// Default value for the training mode. + pub const DEFAULT_TRAINING: bool = true; +} + +impl Module for Dropout2d { + type Error = Exception; + + fn forward(&self, x: &mlx_rs::Array) -> Result { + let ndim = x.ndim(); + + if ndim != 3 && ndim != 4 { + return Err(Exception::custom("Expecting 3D or 4D input")); + } + + if self.one_minus_p == 1.0 || !self.training { + return Ok(x.clone()); + } + + // Dropout is applied on the whole channel + // 3D input: (1, 1, C) + // 4D input: (B, 1, 1, C) + + let mut mask_shape = x.shape().to_vec(); + let len = mask_shape.len(); + mask_shape[len - 2] = 1; + mask_shape[len - 3] = 1; + + let p1 = array!(self.one_minus_p); + let mask = bernoulli(&p1, &mask_shape, None)?; + + multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into) + } + + fn training_mode(&mut self, mode: bool) { + self.training = mode; + } +} + +impl_dropout_builder!( + Dropout3dBuilder, + Dropout3d, + Dropout3d::DEFAULT_P, + Dropout3d::DEFAULT_TRAINING +); + +/// Apply 3D channel-wise dropout during training. +/// +/// Randomly zero out entire channels independently with probability `p`. +/// This layer expects the channels to be last, i.e., the input shape should be +/// `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth, +/// `H` is the input image height, `W` is the input image width, and `C` is +/// the number of input channels. +/// +/// The remaining channels are scaled by `1 / (1-p)` to +/// maintain the expected value of each element. Unlike traditional dropout, +/// which zeros individual entries, this layer zeros entire channels. This is +/// often beneficial for convolutional layers processing 3D data, like in +/// medical imaging or video processing. +#[derive(Debug, Clone, ModuleParameters)] +pub struct Dropout3d { + /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to + /// [`Dropout3d::DEFAULT_P`] if not specified. + pub one_minus_p: f32, + + /// Whether the layer is in training mode. Default to [`Dropout3d::DEFAULT_TRAINING`] if not + /// specified. + pub training: bool, +} + +impl Dropout3d { + /// Default value for the probability of zeroing a channel. + pub const DEFAULT_P: f32 = 0.5; + + /// Default value for the training mode. + pub const DEFAULT_TRAINING: bool = true; +} + +impl Module for Dropout3d { + type Error = Exception; + + fn forward(&self, x: &mlx_rs::Array) -> Result { + let ndim = x.ndim(); + + if ndim != 4 && ndim != 5 { + return Err(Exception::custom("Expecting 4D or 5D input")); + } + + if self.one_minus_p == 1.0 || !self.training { + return Ok(x.clone()); + } + + // Dropout is applied on the whole channel + // 4D input: (1, 1, 1, C) + // 5D input: (B, 1, 1, 1, C) + + let mut mask_shape = x.shape().to_vec(); + let len = mask_shape.len(); + mask_shape[len - 2] = 1; + mask_shape[len - 3] = 1; + mask_shape[len - 4] = 1; + + let p1 = array!(self.one_minus_p); + let mask = bernoulli(&p1, &mask_shape, None)?; + + multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into) + } + + fn training_mode(&mut self, mode: bool) { + self.training = mode; + } +} + +// The following tests were ported from the swift binding: +// mlx-swift/Tests/MLXTests/IntegrationTests.swift +#[cfg(test)] +mod tests { + use float_eq::assert_float_eq; + use mlx_rs::random::uniform; + + use super::*; + + #[test] + fn test_dropout() { + mlx_rs::random::seed(959); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.511_429_2, + abs <= 0.010_228_584 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 130.925_87, + abs <= 2.618_517_4 + ); + let result = Dropout::new().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.477_913_62, + abs <= 0.009_558_273 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 122.345_89, + abs <= 2.446_917_8 + ); + } + + #[test] + fn test_dropout2d() { + mlx_rs::random::seed(695); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.457_839_9, + abs <= 0.009_156_798 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 117.207_016, + abs <= 2.344_140_3 + ); + let result = Dropout2d::new().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 16]); + assert_eq!(result.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.368_284_34, + abs <= 0.007_365_687 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 94.280_79, + abs <= 1.885_615_8 + ); + } + + #[test] + fn test_dropout3d() { + mlx_rs::random::seed(23); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 8, 4]); + assert_eq!(a.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.500_606_2, + abs <= 0.010_012_124 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 256.310_36, + abs <= 5.126_207_4 + ); + let result = Dropout3d::new().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 8, 4]); + assert_eq!(result.dtype(), mlx_rs::Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.237_284_15, + abs <= 0.004_745_683 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 121.489_49, + abs <= 2.429_789_8 + ); + } +} diff --git a/mlx-nn/src/error.rs b/mlx-nn/src/error.rs new file mode 100644 index 00000000..6914f65b --- /dev/null +++ b/mlx-nn/src/error.rs @@ -0,0 +1,38 @@ +//! Custom error types for mlx-nn + +use mlx_rs::error::Exception; +use thiserror::Error; + +/// Error with building a cross-entropy loss function +#[derive(Debug, Clone, PartialEq, Error)] +pub enum CrossEntropyBuildError { + /// Label smoothing factor must be in the range [0, 1) + #[error("Label smoothing factor must be in the range [0, 1)")] + InvalidLabelSmoothingFactor, +} + +impl From for Exception { + fn from(value: CrossEntropyBuildError) -> Self { + Exception::custom(format!("{}", value)) + } +} + +/// Error with building a RmsProp optimizer +#[derive(Debug, Clone, PartialEq, Error)] +pub enum RmsPropBuildError { + /// Alpha must be non-negative + #[error("alpha must be non-negative")] + NegativeAlpha, + + /// Epsilon must be non-negative + #[error("epsilon must be non-negative")] + NegativeEpsilon, +} + +/// Error with building a dropout layer +#[derive(Debug, Clone, PartialEq, Error)] +pub enum DropoutBuildError { + /// Dropout probability must be in the range [0, 1) + #[error("Dropout probability must be in the range [0, 1)")] + InvalidProbability, +} diff --git a/mlx-nn/src/lib.rs b/mlx-nn/src/lib.rs index 8b137891..1abc7dc5 100644 --- a/mlx-nn/src/lib.rs +++ b/mlx-nn/src/lib.rs @@ -1 +1,31 @@ +#![deny(missing_docs, missing_debug_implementations)] +//! Neural network support for MLX +//! +//! All modules provide a `new()` function that take mandatory parameters and other methods +//! to set optional parameters. + +pub mod error; +pub mod losses; +pub mod macros; +pub mod optimizers; +pub mod utils; + +mod activation; +mod container; +mod convolution; +mod dropout; +mod linear; +mod value_and_grad; + +pub use activation::*; +pub use container::*; +pub use convolution::*; +pub use dropout::*; +pub use linear::*; +pub use value_and_grad::*; + +/// Re-export of the `mlx-nn-module` crate. +pub mod module { + pub use mlx_rs::module::*; +} diff --git a/mlx-nn/src/linear.rs b/mlx-nn/src/linear.rs new file mode 100644 index 00000000..4cc1ca1d --- /dev/null +++ b/mlx-nn/src/linear.rs @@ -0,0 +1,252 @@ +use std::iter::once; + +use mlx_rs::{error::Exception, Array}; + +use crate::{ + macros::ModuleParameters, + module::{Module, Param}, +}; + +/// Builder for [`Linear`] module +#[derive(Debug, Clone, Default)] +pub struct LinearBuilder { + /// Whether to include bias in the linear layer. Default to [`Linear::DEFAULT_WITH_BIAS`]. + pub with_bias: Option, +} + +impl LinearBuilder { + /// Creates a new [`LinearBuilder`]. + pub fn new() -> Self { + Self::default() + } + + /// Sets the `with_bias` field. + pub fn with_bias(mut self, with_bias: impl Into>) -> Self { + self.with_bias = with_bias.into(); + self + } + + /// Builds a new [`Linear`] layer. + pub fn build(self, input_dims: i32, output_dims: i32) -> Result { + let with_bias = self.with_bias.unwrap_or(Linear::DEFAULT_WITH_BIAS); + + let scale = f32::sqrt(1.0 / (input_dims as f32)); + let weight = + mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?; + + let bias = if with_bias { + Some(mlx_rs::random::uniform::<_, f32>( + -scale, + scale, + &[output_dims], + None, + )?) + } else { + None + }; + + Ok(Linear { + weight: Param::new(weight), + bias: Param::new(bias), + }) + } +} + +/// Applies an affine transformation to the input. +#[derive(Debug, Clone, ModuleParameters)] +pub struct Linear { + /// The weight of the linear layer. + #[param] + pub weight: Param, + + /// The bias of the linear layer. + #[param] + pub bias: Param>, +} + +impl Linear { + /// Default value for `with_bias` + pub const DEFAULT_WITH_BIAS: bool = true; + + /// Creates a new [`LinearBuilder`]. + pub fn builder() -> LinearBuilder { + LinearBuilder::new() + } + + /// Creates a new [`Linear`] layer. + pub fn new(input_dims: i32, output_dims: i32) -> Result { + LinearBuilder::new().build(input_dims, output_dims) + } + + /// Returns the shape of the linear layer. + pub fn shape(&self) -> (i32, i32) { + let weight_shape = self.weight.as_ref().shape(); + (weight_shape[0], weight_shape[1]) + } +} + +impl Module for Linear { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + match &self.bias.value { + Some(bias) => mlx_rs::ops::addmm(bias, x, self.weight.value.t(), None, None), + None => mlx_rs::ops::matmul(x, &self.weight.value.t()), + } + } + + fn training_mode(&mut self, _: bool) {} +} + +/// Builder for [`Bilinear`] module +#[derive(Debug, Clone, Default)] +pub struct BilinearBuilder { + /// Whether to include bias in the bilinear layer. Default to [Bilinear::DEFAULT_WITH_BIAS]. + with_bias: Option, +} + +impl BilinearBuilder { + /// Creates a new [`BilinearBuilder`]. + pub fn new() -> Self { + Self { with_bias: None } + } + + /// Sets the `with_bias` field. + pub fn with_bias(mut self, with_bias: impl Into>) -> Self { + self.with_bias = with_bias.into(); + self + } + + /// Builds a new [`Bilinear`] layer. + pub fn build( + self, + input_dims_1: i32, + input_dims_2: i32, + output_dims: i32, + ) -> Result { + let with_bias = self.with_bias.unwrap_or(Bilinear::DEFAULT_WITH_BIAS); + + let scale = f32::sqrt(1.0 / (input_dims_1 as f32)); + let weights = mlx_rs::random::uniform::<_, f32>( + -scale, + scale, + &[output_dims, input_dims_2, input_dims_1], + None, + )?; + + let bias = if with_bias { + Some(mlx_rs::random::uniform::<_, f32>( + -scale, + scale, + &[output_dims], + None, + )?) + } else { + None + }; + + Ok(Bilinear { + weights: Param::new(weights), + bias: Param::new(bias), + }) + } +} + +/// Applies a bilinear transformation to the inputs. +#[derive(Debug, Clone, ModuleParameters)] +pub struct Bilinear { + /// The weight of the bilinear layer. + #[param] + pub weights: Param, + + /// The bias of the bilinear layer. + #[param] + pub bias: Param>, +} + +impl Bilinear { + /// Default value for `with_bias` + pub const DEFAULT_WITH_BIAS: bool = true; + + /// Creates a new [`BilinearBuilder`]. + pub fn builder() -> BilinearBuilder { + BilinearBuilder::new() + } + + /// Creates a new [`Bilinear`] layer. + pub fn new(input_dims_1: i32, input_dims_2: i32, output_dims: i32) -> Result { + BilinearBuilder::new().build(input_dims_1, input_dims_2, output_dims) + } +} + +impl Module for Bilinear { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + let shape = self.weights.shape(); + let (out, in2, in1) = (shape[0], shape[1], shape[2]); + let x_shape = &x.shape()[..x.shape().len() - 1]; + let x1 = x.reshape(&[-1, in1])?; + let x2 = x.reshape(&[-1, 1, in2])?; + + // perform the bilinear transform + let w = self.weights.reshape(&[out * in2, in1])?; + let mut y = mlx_rs::ops::matmul(&x1, &w.t())?; + y = y.reshape(&[-1, out, in2])?.swap_axes(-2, -1)?; + y = mlx_rs::ops::matmul(&x2, &y)?; + y = y.squeeze(&[1])?; + + // reset the shape + let new_shape = x_shape.iter().cloned().chain(once(out)).collect::>(); + y = y.reshape(&new_shape)?; + + if let Some(bias) = &self.bias.value { + y = mlx_rs::ops::add(&y, bias)?; + } + + Ok(y) + } + + fn training_mode(&mut self, _: bool) {} +} + +// The following tests are ported from the swift binding: +// mlx-swift/Tests/MLXTests/IntegrationTests.swift +#[cfg(test)] +mod tests { + use float_eq::assert_float_eq; + use mlx_rs::{random::uniform, Dtype}; + + use super::*; + + #[test] + fn test_linear() { + mlx_rs::random::seed(744); + let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap(); + assert_eq!(a.shape(), &[2, 8, 16]); + assert_eq!(a.dtype(), Dtype::Float32); + assert_float_eq!( + a.mean(None, None).unwrap().item::(), + 0.508_688_57, + abs <= 0.010_173_771_5 + ); + assert_float_eq!( + a.sum(None, None).unwrap().item::(), + 130.224_27, + abs <= 2.604_485_5 + ); + let result = Linear::new(16, 5).unwrap().forward(&a).unwrap(); + assert_eq!(result.shape(), &[2, 8, 5]); + assert_eq!(result.dtype(), Dtype::Float32); + assert_float_eq!( + result.mean(None, None).unwrap().item::(), + 0.104_193_09, + abs <= 0.002_083_861_7 + ); + assert_float_eq!( + result.sum(None, None).unwrap().item::(), + 8.335_447, + abs <= 0.166_708_95 + ); + } +} diff --git a/mlx-nn/src/losses.rs b/mlx-nn/src/losses.rs new file mode 100644 index 00000000..32b2577e --- /dev/null +++ b/mlx-nn/src/losses.rs @@ -0,0 +1,1373 @@ +//! Loss functions + +use mlx_internal_macros::GenerateBuilder; +use mlx_rs::{ + array, + error::Exception, + ops::{ + abs, clip, exp, indexing::take_along_axis, log, log_add_exp, log_sum_exp, maximum, minimum, + multiply, power, r#where, sqrt, square, sum, + }, + Array, +}; + +use crate::error::CrossEntropyBuildError; + +#[inline] +fn check_shape( + left: &Array, + right: &Array, + left_ident: &str, + right_ident: &str, +) -> Result<(), Exception> { + if left.shape() != right.shape() { + return Err(Exception::custom(format!( + "The shape of the {} ({:?}) does not match the shape of the {} ({:?})", + left_ident, + left.shape(), + right_ident, + right.shape() + ))); + } + Ok(()) +} + +/// Different types of loss reductions +#[derive(Debug, Clone, Copy)] +pub enum LossReduction { + /// No reduction is applied. + None, + /// The sum of the output will be computed. + Sum, + /// The mean of the output will be computed. + Mean, +} + +impl LossReduction { + /// Reduces the loss according to the reduction type. + pub fn reduce(&self, loss: Array) -> Result { + match self { + LossReduction::None => Ok(loss), + LossReduction::Sum => Ok(loss.sum(None, None)?), + LossReduction::Mean => Ok(loss.mean(None, None)?), + } + } +} + +/// Cross entropy loss function. +#[derive(Debug, Clone, GenerateBuilder)] +#[generate_builder(generate_build_fn = false)] +pub struct CrossEntropy<'a> { + /// Weights for each target + #[optional(skip = true)] + pub weights: Option<&'a Array>, + + /// The axis over which to compute softmax. Default to [`CrossEntropy::DEFAULT_AXIS`] + #[optional(default_value = CrossEntropy::DEFAULT_AXIS)] + pub axis: i32, + + /// The label smoothing factor, range [0, 1). Default to + /// [`CrossEntropy::DEFAULT_LABEL_SMOOTHING`] + #[optional(default_value = CrossEntropy::DEFAULT_LABEL_SMOOTHING)] + pub label_smoothing: f32, + + /// Reduction type. Default to [`CrossEntropy::DEFAULT_REDUCTION`] + #[optional(default_value = CrossEntropy::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl<'a> CrossEntropyBuilder<'a> { + /// Sets the `weight` field. + pub fn weights(mut self, weights: impl Into>) -> Self { + self.weights = weights.into(); + self + } + + /// Build a [`CrossEntropy`] loss function. + pub fn build(self) -> Result, CrossEntropyBuildError> { + let axis = self.axis.unwrap_or(CrossEntropy::DEFAULT_AXIS); + let label_smoothing = self + .label_smoothing + .unwrap_or(CrossEntropy::DEFAULT_LABEL_SMOOTHING); + let reduction = self.reduction.unwrap_or(CrossEntropy::DEFAULT_REDUCTION); + + if !(0.0..1.0).contains(&label_smoothing) { + return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor); + } + + Ok(CrossEntropy { + weights: self.weights, + axis, + label_smoothing, + reduction, + }) + } +} + +impl<'a> Default for CrossEntropy<'a> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> CrossEntropy<'a> { + /// Default value for the `axis` parameter. + pub const DEFAULT_AXIS: i32 = -1; + + /// Default value for the `label_smoothing` parameter. + pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Create a new [`CrossEntropy`] with all optional parameters set to their default values. + pub fn new() -> Self { + Self::builder().build().expect("Default values are valid") + } + + /// Apply the cross entropy loss function on the given logits and targets. + /// + /// # Params + /// + /// - `logits`: unnormalized predicted logits + /// - `targets`: target values, as class indices + pub fn apply( + &self, + logits: impl AsRef, + targets: impl AsRef, + ) -> Result { + let logits = logits.as_ref(); + let targets = targets.as_ref(); + + let target_as_probs = targets.ndim() == logits.ndim(); + + let score = if target_as_probs { + sum(&logits.multiply(targets)?, &[self.axis], None)? + } else { + take_along_axis(logits, &targets.expand_dims(&[-1])?, self.axis)?.squeeze(&[-1])? + }; + let log_sum_exp_logits = log_sum_exp(logits, &[self.axis], None)?; + + let mut loss = if self.label_smoothing > 0.0 { + // adjust the true class score with label smoothing + let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?; + + // calculate the mean logit across the classes for smoothed loss + let mean_logits = logits.mean(&[self.axis], None)?; + let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?; + + // combine the adjusted score and smoothed loss with the logsumexp logits + log_sum_exp_logits + .subtract(adjusted_score)? + .add(smoothed_loss)? + } else { + log_sum_exp_logits.subtract(score)? + }; + + if let Some(weights) = self.weights { + check_shape(weights, &loss, "weights", "loss")?; + loss = multiply(loss, weights)?; + } + + self.reduction.reduce(loss) + } +} + +/// Binary cross entropy loss. +/// +/// By default, this function takes the pre-sigmoid logits, which results in a faster +/// and more precise loss. For improved numerical stability when `inputs_are_logits` is true, +/// the loss calculation clips the input probabilities (in log-space) to a minimum value +/// of `-100`. +#[derive(Debug, Clone, GenerateBuilder)] +#[generate_builder(generate_build_fn = false)] +pub struct BinaryCrossEntropy<'a> { + /// Optional weights for each target + #[optional(skip = true)] + pub weights: Option<&'a Array>, + + /// Whether the inputs are logits. Default to + /// [`BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS`] + #[optional] + pub inputs_are_logits: bool, + + /// Reduction type. Default to [`BinaryCrossEntropy::DEFAULT_REDUCTION`] + #[optional] + pub reduction: LossReduction, +} + +impl<'a> BinaryCrossEntropyBuilder<'a> { + /// Sets the `weights` field. + pub fn weights(mut self, weights: impl Into>) -> Self { + self.weights = weights.into(); + self + } + + /// Build a [`BinaryCrossEntropy`] loss function. + pub fn build(self) -> BinaryCrossEntropy<'a> { + BinaryCrossEntropy { + weights: self.weights, + inputs_are_logits: self + .inputs_are_logits + .unwrap_or(BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS), + reduction: self + .reduction + .unwrap_or(BinaryCrossEntropy::DEFAULT_REDUCTION), + } + } +} + +impl<'a> Default for BinaryCrossEntropy<'a> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> BinaryCrossEntropy<'a> { + /// Default value for the `with_logits` parameter. + pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Create a new [`BinaryCrossEntropy`] with all optional parameters set to their default values. + pub fn new() -> Self { + Self::builder().build() + } + + /// Apply the binary cross entropy loss function on the given logits and targets. + /// + /// # Params + /// + /// - `logits`: unnormalized predicted logits + /// - `targets`: binary target values in {0, 1} + pub fn apply( + &self, + logits: impl AsRef, + targets: impl AsRef, + ) -> Result { + let logits = logits.as_ref(); + let targets = targets.as_ref(); + let weights = self.weights; + let inputs_are_logits = self.inputs_are_logits; + let reduction = self.reduction; + + let mut loss = if inputs_are_logits { + log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)? + } else { + let log_inputs_clip = clip(&log(logits), (-100.0, ()))?; + let log_inputs_inverse_clip = clip(&log(&array!(1.0).subtract(logits)?), (-100.0, ()))?; + -(targets.multiply(log_inputs_clip)?.add( + array!(1.0) + .subtract(targets)? + .multiply(log_inputs_inverse_clip)?, + )?) + }; + + if let Some(weights) = weights { + check_shape(weights, &loss, "weights", "loss")?; + loss = multiply(loss, weights)?; + } + + reduction.reduce(loss) + } +} + +/// Computes the L1 loss +#[derive(Debug, Clone, GenerateBuilder)] +pub struct L1Loss { + /// Reduction type. Default to [`L1loss::DEFAULT_REDUCTION`] + #[optional(default_value = L1Loss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl L1Loss { + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; + + /// Compute the L1 loss. + /// + /// # Params + /// + /// - `predictions`: predicted values + /// - `targets`: target values + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { + let predictions = predictions.as_ref(); + let targets = targets.as_ref(); + let reduction = self.reduction; + + check_shape(predictions, targets, "predictions", "targets")?; + let loss = predictions.subtract(targets)?.abs(); + reduction.reduce(loss) + } +} + +/// Computes the mean squared error loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct MseLoss { + /// Reduction type. Default to [`MseLoss::DEFAULT_REDUCTION`] + #[optional(default_value = MseLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl MseLoss { + /// Default value for the reduction parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; + + /// Compute the mean squared error loss. + /// + /// # Params + /// + /// - `predictions`: predicted values + /// - `targets`: target values + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { + let predictions = predictions.as_ref(); + let targets = targets.as_ref(); + let reduction = self.reduction; + + check_shape(predictions, targets, "predictions", "targets")?; + let loss = predictions.subtract(targets)?.square(); + reduction.reduce(loss) + } +} + +/// Computes the negative log likelihood loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct NllLoss { + /// distribution axis. Default to [`NllLoss::DEFAULT_AXIS`] + #[optional(default_value = NllLoss::DEFAULT_AXIS)] + pub axis: i32, + + /// Reduction type. Default to [`NllLoss::DEFAULT_REDUCTION`] + #[optional(default_value = NllLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl NllLoss { + /// Default value for the `axis` parameter. + pub const DEFAULT_AXIS: i32 = -1; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Compute the negative log likelihood loss. + /// + /// # Params + /// + /// - `inputs`: predicted distribution in log space + /// - `targets`: target values + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let axis = self.axis; + let reduction = self.reduction; + + let loss = -take_along_axis(inputs, &targets.expand_dims(&[-1])?, axis)?.squeeze(&[-1])?; + reduction.reduce(loss) + } +} + +/// Compute the negative log likelihood loss for a Gaussian distribution. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct GaussianNllLoss { + /// Whether to include the constant term in the loss calculation. Default to + /// [`GaussianNllLoss::DEFAULT_FULL`] + #[optional(default_value = GaussianNllLoss::DEFAULT_FULL)] + pub full: bool, + + /// Small positive constant for numerical stability. Default to + /// [`GaussianNllLoss::DEFAULT_EPS`] + #[optional(default_value = GaussianNllLoss::DEFAULT_EPS)] + pub eps: f32, + + /// Reduction type. Default to [`GaussianNllLoss::DEFAULT_REDUCTION`] + #[optional(default_value = GaussianNllLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl GaussianNllLoss { + /// Default value for the `full` parameter. + pub const DEFAULT_FULL: bool = false; + + /// Default value for the `eps` parameter. + pub const DEFAULT_EPS: f32 = 1e-6; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Compute the negative log likelihood loss for a Gaussian distribution. + /// + /// # Params + /// + /// - `inputs`: The predicted expectation of the Gaussian distribution. + /// - `targets`: The target values (samples from the Gaussian distribution). + /// - `vars`: The predicted variance of the Gaussian distribution. + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + vars: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let vars = vars.as_ref(); + let full = self.full; + let eps = self.eps; + let reduction = self.reduction; + + check_shape(inputs, targets, "inputs", "targets")?; + check_shape(inputs, vars, "inputs", "vars")?; + + let vars = maximum(vars, array!(eps))?; + let mut loss = + array!(0.5) * (log(&vars).add(square(&targets.subtract(inputs)?).divide(&vars)?)?); + + if full { + let pi = array!(std::f32::consts::PI); + loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?))?)?; + } + + reduction.reduce(loss) + } +} + +/// Compute the Kullback-Leibler divergence loss. +/// +/// Computes the following when the `reduction` is `LossReduction::None`: +/// +/// ```rust, ignore +/// sum(exp(targets) * (targets - inputs), axis, None) +/// ``` +#[derive(Debug, Clone, GenerateBuilder)] +pub struct KlDivLoss { + /// The distribution axis. Default to [`KlDivLoss::DEFAULT_AXIS`] + #[optional(default_value = KlDivLoss::DEFAULT_AXIS)] + pub axis: i32, + + /// Reduction type. Default to [`KlDivLoss::DEFAULT_REDUCTION`] + #[optional(default_value = KlDivLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl KlDivLoss { + /// Default value for the `axis` parameter. + pub const DEFAULT_AXIS: i32 = -1; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Compute the Kullback-Leibler divergence loss. + /// + /// # Params + /// + /// - `inputs`: Log probabilities for the predicted distribution. + /// - `targets`: Log probabilities for the target distribution. + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let axis = self.axis; + let reduction = self.reduction; + + let loss = sum( + &exp(targets).multiply(targets.subtract(inputs)?)?, + &[axis], + None, + )?; + reduction.reduce(loss) + } +} + +/// Computes the smooth L1 loss. +/// +/// The smooth L1 loss is a variant of the L1 loss which replaces the absolute +/// difference with a squared difference when the absolute difference is less +/// than `beta`. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct SmoothL1Loss { + /// The threshold after which the loss changes from the squared to the absolute difference. + /// Default to [`SmoothL1Loss::DEFAULT_BETA`] + #[optional(default_value = SmoothL1Loss::DEFAULT_BETA)] + pub beta: f32, + + /// Reduction type. Default to [`SmoothL1Loss::DEFAULT_REDUCTION`] + #[optional(default_value = SmoothL1Loss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl SmoothL1Loss { + /// Default value for the `beta` parameter. + pub const DEFAULT_BETA: f32 = 1.0; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean; + + /// Compute the smooth L1 loss. + /// + /// # Params + /// + /// - `predictions`: predicted values + /// - `targets`: target values + pub fn apply( + &self, + predictions: impl AsRef, + targets: impl AsRef, + ) -> Result { + let predictions = predictions.as_ref(); + let targets = targets.as_ref(); + let beta = self.beta; + let reduction = self.reduction; + + check_shape(predictions, targets, "predictions", "targets")?; + let diff = predictions.subtract(targets)?; + let loss = r#where( + &diff.lt(array!(beta))?, + array!(0.5).multiply(square(&diff))?.divide(&array!(beta))?, + abs(&diff).subtract(array!(0.5).multiply(array!(beta))?)?, + )?; + reduction.reduce(loss) + } +} + +/// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is +/// represented with alpha in the math section. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct TripletLoss { + /// Distribution axis. Default to [`TripletLoss::DEFAULT_AXIS`] + #[optional(default_value = TripletLoss::DEFAULT_AXIS)] + pub axis: i32, + + /// The norm degree for pairwise distance. Default to [`TripletLoss::DEFAULT_P`] + #[optional(default_value = TripletLoss::DEFAULT_P)] + pub p: f32, + + /// Margin for the triplet loss. Default to [`TripletLoss::DEFAULT_MARGIN`] + #[optional(default_value = TripletLoss::DEFAULT_MARGIN)] + pub margin: f32, + + /// Small positive constant for numerical stability. Default to [`TripletLoss::DEFAULT_EPS`] + #[optional(default_value = TripletLoss::DEFAULT_EPS)] + pub eps: f32, + + /// Reduction type. Default to [`TripletLoss::DEFAULT_REDUCTION`] + #[optional(default_value = TripletLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl TripletLoss { + /// Default value for the `axis` parameter. + pub const DEFAULT_AXIS: i32 = -1; + + /// Default value for the `p` parameter. + pub const DEFAULT_P: f32 = 2.0; + + /// Default value for the `margin` parameter. + pub const DEFAULT_MARGIN: f32 = 1.0; + + /// Default value for the `eps` parameter. + pub const DEFAULT_EPS: f32 = 1e-6; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is + /// represented with alpha in the math section. + /// + /// # Params + /// + /// - `anchors`: The anchor samples + /// - `positives`: The positive samples + /// - `neonatives`: The negative samples + pub fn apply( + &self, + anchors: impl AsRef, + positives: impl AsRef, + negatives: impl AsRef, + ) -> Result { + let anchors = anchors.as_ref(); + let positives = positives.as_ref(); + let negatives = negatives.as_ref(); + let axis = self.axis; + let p = self.p; + let margin = self.margin; + let eps = self.eps; + let reduction = self.reduction; + + let eps = array!(eps); + let p = array!(p); + let margin = array!(margin); + + let pos = sqrt( + &power(&anchors.subtract(positives)?, &p)? + .sum(&[axis], None)? + .add(&eps)?, + ); + let neg = sqrt( + &power(&anchors.subtract(negatives)?, &p)? + .sum(&[axis], None)? + .add(&eps)?, + ); + let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?; + reduction.reduce(loss) + } +} + +/// Compute the hinge loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct HingeLoss { + /// Reduction type. Default to [`HingeLoss::DEFAULT_REDUCTION`] + #[optional(default_value = HingeLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl HingeLoss { + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Compute the hinge loss. + /// + /// # Params + /// + /// - `inputs`: predicted values + /// - `targets`: target values, -1 or 1 + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let reduction = self.reduction; + + let a = array!(1.0).subtract(inputs.multiply(targets)?)?; + let b = array!(0.0); + let loss = maximum(a, b)?; + reduction.reduce(loss) + } +} + +/// Compute the Huber loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct HuberLoss { + /// The threshold at which to change between L1 and L2 loss. Default to + /// [`HuberLoss::DEFAULT_DELTA`] + #[optional(default_value = HuberLoss::DEFAULT_DELTA)] + pub delta: f32, + + /// Reduction type. Default to [`HuberLoss::DEFAULT_REDUCTION`] + #[optional(default_value = HuberLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl HuberLoss { + /// Default value for the `delta` parameter. + pub const DEFAULT_DELTA: f32 = 1.0; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Compute the Huber loss. + /// + /// # Params + /// + /// - `inputs`: predicted values + /// - `targets`: target values + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let delta = self.delta; + let reduction = self.reduction; + + let errors = inputs.subtract(targets)?; + let abs_errors = errors.abs(); + let quadratic = minimum(&abs_errors, array!(delta))?; + let linear = abs_errors.subtract(&quadratic)?; + let loss = array!(0.5) + .multiply(square(&quadratic))? + .add(array!(delta).multiply(linear)?)?; + reduction.reduce(loss) + } +} + +/// Computes the log cosh loss between inputs and targets. +/// +/// Logcosh acts like L2 loss for small errors, ensuring stable gradients, +/// and like the L1 loss for large errors, reducing sensitivity to outliers. This +/// dual behavior offers a balanced, robust approach for regression tasks. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct LogCoshLoss { + /// Reduction type. Default to [`LogCoshLoss::DEFAULT_REDUCTION`] + #[optional(default_value = LogCoshLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl LogCoshLoss { + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Computes the log cosh loss between inputs and targets. + /// + /// # Params + /// + /// - `inputs`: predicted values + /// - `targets`: target values + pub fn apply( + &self, + inputs: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs = inputs.as_ref(); + let targets = targets.as_ref(); + let reduction = self.reduction; + + let errors = inputs.subtract(targets)?; + let neg_errors = errors.negative()?; + let loss = log_add_exp(errors, neg_errors)?.subtract(log(&array!(2.0)))?; + reduction.reduce(loss) + } +} + +/// Computes the cosine similarity loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct CosineSimilarityLoss { + /// Embedding axis. Default to [`CosineSimilarityLoss::DEFAULT_AXIS`] + #[optional(default_value = CosineSimilarityLoss::DEFAULT_AXIS)] + pub axis: i32, + + /// minimum value of the denominator used for numerical stability. Default to + /// [`CosineSimilarityLoss::DEFAULT_EPS`] + #[optional(default_value = CosineSimilarityLoss::DEFAULT_EPS)] + pub eps: f32, + + /// Reduction type. Default to [`CosineSimilarityLoss::DEFAULT_REDUCTION`] + #[optional(default_value = CosineSimilarityLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl CosineSimilarityLoss { + /// Default value for the `axis` parameter. + pub const DEFAULT_AXIS: i32 = -1; + + /// Default value for the `eps` parameter. + pub const DEFAULT_EPS: f32 = 1e-8; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Computes the cosine similarity loss. + /// + /// # Params + /// + /// - `x1`: first array + /// - `x2`: second array + pub fn apply(&self, x1: impl AsRef, x2: impl AsRef) -> Result { + let x1 = x1.as_ref(); + let x2 = x2.as_ref(); + let axis = self.axis; + let eps = self.eps; + let reduction = self.reduction; + + fn l2_loss(a: &Array, axis: i32) -> Result { + if a.dtype().is_complex() { + Ok(sqrt(&sum(&abs(a).square(), &[axis], None)?)) + } else { + Ok(sqrt(&sum(&a.square(), &[axis], None)?)) + } + } + + let x1_norm = l2_loss(x1, axis)?; + let x2_norm = l2_loss(x2, axis)?; + + let num = sum(&x1.multiply(x2)?, &[axis], None)?; + let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?; + let loss = num.divide(&den)?; + + reduction.reduce(loss) + } +} + +/// Computes the margin ranking loss. +#[derive(Debug, Clone, GenerateBuilder)] +pub struct MarginRankingLoss { + /// The margin by which the scores should be separated. Default to + /// [`MarginRankingLoss::DEFAULT_MARGIN`] + #[optional(default_value = MarginRankingLoss::DEFAULT_MARGIN)] + pub margin: f32, + + /// Reduction type. Default to [`MarginRankingLoss::DEFAULT_REDUCTION`] + #[optional(default_value = MarginRankingLoss::DEFAULT_REDUCTION)] + pub reduction: LossReduction, +} + +impl MarginRankingLoss { + /// Default value for the `margin` parameter. + pub const DEFAULT_MARGIN: f32 = 0.0; + + /// Default value for the `reduction` parameter. + pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None; + + /// Computes the margin ranking loss. + /// + /// # Params + /// + /// - `inputs1`: Scores for the first input. + /// - `inputs2`: Scores for the second input. + /// - `targets`: Labels indicating whether samples in `inputs1` should be ranked higher than samples + /// in `inputs2`. Values should be 1 or -1. + pub fn apply( + &self, + inputs1: impl AsRef, + inputs2: impl AsRef, + targets: impl AsRef, + ) -> Result { + let inputs1 = inputs1.as_ref(); + let inputs2 = inputs2.as_ref(); + let targets = targets.as_ref(); + let margin = self.margin; + let reduction = self.reduction; + + check_shape(inputs1, inputs2, "inputs1", "inputs2")?; + check_shape(inputs1, targets, "inputs1", "targets")?; + + let margin = array!(margin); + let diff = inputs1.subtract(inputs2)?; + let loss = maximum( + array!(0.0), + targets.multiply(diff)?.negative()?.add(margin)?, + )?; + reduction.reduce(loss) + } +} + +#[cfg(test)] +#[allow(clippy::approx_constant)] +mod tests { + use float_eq::assert_float_eq; + use mlx_rs::{array, assert_array_eq, ops::is_nan}; + + use super::*; + + // The following unit tests are adapted from the python API at: mlx/python/tests/test_losses.py + + #[test] + fn test_cross_entropy() { + // No weights, no label smoothing + let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]); + let indices = array!([0, 1]); + let expected = array!([0.0, 0.0]); + let loss = CrossEntropy::default().apply(&logits, indices).unwrap(); + assert_array_eq!(loss, expected); + + let probs = array!([[1.0, 0.0], [0.0, 1.0]]); + let cross_entropy = CrossEntropy::builder() + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(logits, probs).unwrap(); + assert!(is_nan(&loss).all(None, None).unwrap().item::()); + + // With weights, no label smoothing + let logits = array!([[2.0, -1.0], [-1.0, 2.0]]); + let indices = array!([0, 1]); + let weights = array!([1.0, 2.0]); + let expected = array!([0.04858735, 0.0971747]); + let cross_entropy = CrossEntropy::builder() + .weights(&weights) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(&logits, indices).unwrap(); + assert_array_eq!(loss, expected); + + let probs = array!([[1.0, 0.0], [0.0, 1.0]]); + let cross_entropy = CrossEntropy::builder() + .weights(&weights) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(logits, probs).unwrap(); + assert_array_eq!(loss, expected); + + // No weights, with label smoothing + let logits = array!([[2.0, -1.0], [-1.0, 2.0]]); + let indices = array!([0, 1]); + let expected = array!([0.498587, 0.498587]); + let cross_entropy = CrossEntropy::builder() + .label_smoothing(0.3) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(&logits, indices).unwrap(); + assert_array_eq!(loss, expected); + + let probs = array!([[1.0, 0.0], [0.0, 1.0]]); + let cross_entropy = CrossEntropy::builder() + .label_smoothing(0.3) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(logits, probs).unwrap(); + assert_array_eq!(loss, expected); + + // With weights and label smoothing + let logits = array!([[2.0, -1.0], [-1.0, 2.0]]); + let indices = array!([0, 1]); + let weights = array!([1.0, 2.0]); + let expected = array!([0.49858734, 0.9971747]); + let cross_entropy = CrossEntropy::builder() + .weights(&weights) + .label_smoothing(0.3) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(&logits, indices).unwrap(); + assert_array_eq!(loss, expected); + + let probs = array!([[1.0, 0.0], [0.0, 1.0]]); + let cross_entropy = CrossEntropy::builder() + .weights(&weights) + .label_smoothing(0.3) + .reduction(LossReduction::None) + .build() + .unwrap(); + let loss = cross_entropy.apply(logits, probs).unwrap(); + assert_array_eq!(loss, expected); + } + + #[test] + fn test_binary_cross_entropy_with_logits_as_inputs() { + let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]); + let targets = array!([0.0, 0.0, 1.0, 1.0]); + + // Test with reduction 'none' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .reduction(LossReduction::None) + .build(); + let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap(); + let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .reduction(LossReduction::Mean) + .build(); + let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .reduction(LossReduction::Sum) + .build(); + let loss = binary_cross_entropy.apply(&logits, &targets).unwrap(); + let expected = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss, expected); + + // With weights, no label smoothing + let weights = array!([1.0, 2.0, 1.0, 2.0]); + let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]); + let binary_cross_entropy = BinaryCrossEntropy::builder() + .weights(&weights) + .reduction(LossReduction::None) + .build(); + let loss = binary_cross_entropy.apply(&logits, &targets).unwrap(); + assert_array_eq!(loss, expected); + } + + #[test] + fn test_binary_cross_entropy_with_probs_as_inputs() { + let probs = array!([0.5, 0.6, 0.7, 0.8]); + let targets = array!([0.0, 0.0, 1.0, 1.0]); + + // Test with reduction 'none' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::None) + .build(); + let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::Mean) + .build(); + let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::Sum) + .build(); + let loss = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss, expected); + } + + #[test] + fn test_binary_cross_entropy_with_tiny_probs_as_inputs() { + let tiny_prob = 1e-59; + let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]); + let targets = array!([0.0, 0.0, 1.0, 1.0]); + + // Test with reduction 'none' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::None) + .build(); + let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::Mean) + .build(); + let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let binary_cross_entropy = BinaryCrossEntropy::builder() + .inputs_are_logits(false) + .reduction(LossReduction::Sum) + .build(); + let loss = binary_cross_entropy.apply(&probs, &targets).unwrap(); + let expected = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss, expected); + } + + #[test] + fn test_l1_loss() { + let predictions = array!([0.5, 0.2, 0.9, 0.0]); + let targets = array!([0.5, 0.2, 0.9, 0.0]); + + let expected_none = array!([0.0, 0.0, 0.0, 0.0]); + let expected_sum = expected_none.sum(None, None).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + + let l1_loss = L1Loss::builder().reduction(LossReduction::None).build(); + let loss_none = l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_none, expected_none); + + let l1_loss = L1Loss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + + let l1_loss = L1Loss::builder().reduction(LossReduction::Mean).build(); + let loss_mean = l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + } + + #[test] + fn test_mse_loss() { + let predictions = array!([0.5, 0.2, 0.9, 0.0]); + let targets = array!([0.7, 0.1, 0.8, 0.2]); + + let expected_none = array!([0.04, 0.01, 0.01, 0.04]); + let expected_mean = expected_none.mean(None, None).unwrap(); + let expected_sum = expected_none.sum(None, None).unwrap(); + + let mse_loss = MseLoss::builder().reduction(LossReduction::None).build(); + let loss_none = mse_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_none, expected_none); + + let mse_loss = MseLoss::builder().reduction(LossReduction::Mean).build(); + let loss_mean = mse_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + let mse_loss = MseLoss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = mse_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + } + + #[test] + fn test_smooth_l1_loss() { + let predictions = array!([1.5, 2.5, 0.5, 3.5]); + let targets = array!([1.0, 2.0, 0.5, 2.5]); + let beta = 1.0; + + let expected_none = array!([0.125, 0.125, 0.0, 0.5]); + let expected_sum = expected_none.sum(None, None).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + + let smooth_l1_loss = SmoothL1Loss::builder() + .beta(beta) + .reduction(LossReduction::None) + .build(); + let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_none, expected_none); + + let smooth_l1_loss = SmoothL1Loss::builder() + .beta(beta) + .reduction(LossReduction::Sum) + .build(); + let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + + let smooth_l1_loss = SmoothL1Loss::builder() + .beta(beta) + .reduction(LossReduction::Mean) + .build(); + let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + } + + #[test] + fn test_nll_loss() { + let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]); + let targets = array!([0, 1]); + + let expected_none = array!([0.0, 0.0]); + let expected_sum = expected_none.sum(None, None).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + + let nll_loss = NllLoss::builder().reduction(LossReduction::None).build(); + let loss_none = nll_loss.apply(&logits, &targets).unwrap(); + assert_array_eq!(loss_none, expected_none); + + let nll_loss = NllLoss::builder().reduction(LossReduction::Mean).build(); + let loss_mean = nll_loss.apply(&logits, &targets).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + let nll_loss = NllLoss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = nll_loss.apply(&logits, &targets).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + } + + #[test] + fn test_gaussian_nll_loss() { + let inputs = array!([[0.1, 0.2], [0.3, 0.4]]); + let targets = array!([[0.2, 0.1], [0.1, 0.2]]); + let vars = array!([[0.1, 0.2], [0.3, 0.4]]); + + // Test with reduction 'none', full=False + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(false) + .reduction(LossReduction::None) + .build(); + let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean', full=False + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(false) + .reduction(LossReduction::Mean) + .build(); + let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum', full=False + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(false) + .reduction(LossReduction::Sum) + .build(); + let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_sum = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + + // Test with reduction='none', full=True + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(true) + .reduction(LossReduction::None) + .build(); + let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]); + assert_array_eq!(loss_none_full, expected_none_full); + + // Test with reduction='mean', full=True + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(true) + .reduction(LossReduction::Mean) + .build(); + let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_mean_full = expected_none_full.mean(None, None).unwrap(); + assert_array_eq!(loss_mean_full, expected_mean_full); + + // Test with reduction='sum', full=True + let gaussian_nll_loss = GaussianNllLoss::builder() + .full(true) + .reduction(LossReduction::Sum) + .build(); + let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap(); + let expected_sum_full = expected_none_full.sum(None, None).unwrap(); + assert_array_eq!(loss_sum_full, expected_sum_full); + } + + #[test] + fn test_kl_div_loss() { + let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log(); + let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log(); + + // Test with reduction 'none' + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::None).build(); + let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); + let expected_none = array!([0.0, 0.831777]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::Mean).build(); + let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let kl_div_loss = KlDivLoss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap(); + let expected_sum = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + } + + #[test] + fn test_triplet_loss() { + let anchors = array!([[1, 2, 3], [1, 2, 3]]); + let positives = array!([[4, 5, 6], [0, -1, 2]]); + let negatives = array!([[7, 8, 9], [3, 2, 3]]); + + // Test with reduction 'none' + let triplet_loss = TripletLoss::builder() + .reduction(LossReduction::None) + .build(); + let loss_none = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); + let expected_none = array!([0.0, 2.31662]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let triplet_loss = TripletLoss::builder() + .reduction(LossReduction::Mean) + .build(); + let loss_mean = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let triplet_loss = TripletLoss::builder().reduction(LossReduction::Sum).build(); + let loss_sum = triplet_loss + .apply(&anchors, &positives, &negatives) + .unwrap(); + let expected_sum = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + } + + #[test] + fn test_hinge_loss() { + let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]); + let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); + let hinge_loss = HingeLoss::builder().reduction(LossReduction::Mean).build(); + let loss = hinge_loss.apply(&inputs, &targets).unwrap(); + assert_eq!(loss.item::(), 1.0); + } + + #[test] + fn test_huber_loss() { + let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]); + let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); + let huber_loss = HuberLoss::builder().reduction(LossReduction::Mean).build(); + let loss = huber_loss.apply(&inputs, &targets).unwrap(); + assert_eq!(loss.item::(), 0.5); + } + + #[test] + fn test_log_cosh_loss() { + let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]); + let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); + let log_cosh_loss = LogCoshLoss::builder() + .reduction(LossReduction::Mean) + .build(); + let loss = log_cosh_loss.apply(&inputs, &targets).unwrap(); + assert_float_eq!(loss.item::(), 0.433781, abs <= 1e-6); + } + + #[test] + fn test_cosine_similarity_loss() { + let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]); + let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]); + + // Test with reduction 'none' + let cosine_similarity_loss = CosineSimilarityLoss::builder() + .reduction(LossReduction::None) + .build(); + let loss_none = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); + let expected_none = array!([0.985344, 0.961074]); + assert_array_eq!(loss_none, expected_none); + + // Test with reduction 'mean' + let cosine_similarity_loss = CosineSimilarityLoss::builder() + .reduction(LossReduction::Mean) + .build(); + let loss_mean = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); + let expected_mean = expected_none.mean(None, None).unwrap(); + assert_array_eq!(loss_mean, expected_mean); + + // Test with reduction 'sum' + let cosine_similarity_loss = CosineSimilarityLoss::builder() + .reduction(LossReduction::Sum) + .build(); + let loss_sum = cosine_similarity_loss + .apply(&embeddings1, &embeddings2) + .unwrap(); + let expected_sum = expected_none.sum(None, None).unwrap(); + assert_array_eq!(loss_sum, expected_sum); + } + + #[test] + fn test_margin_ranking_loss() { + let inputs1 = array!([-0.573409, -0.765166, -0.0638]); + let inputs2 = array!([0.75596, 0.225763, 0.256995]); + let targets = array!([1, 1, -1]); + + // Test with no margin + let margin_ranking_loss = MarginRankingLoss::builder() + .reduction(LossReduction::None) + .build(); + let loss = margin_ranking_loss + .apply(&inputs1, &inputs2, &targets) + .unwrap(); + let expected = array!([1.329369, 0.990929, 0.0]); + assert_array_eq!(loss, expected); + + // Test with margin + let margin_ranking_loss = MarginRankingLoss::builder() + .margin(0.5) + .reduction(LossReduction::None) + .build(); + let loss = margin_ranking_loss + .apply(&inputs1, &inputs2, &targets) + .unwrap(); + let expected = array!([1.829369, 1.490929, 0.179205]); + assert_array_eq!(loss, expected); + } +} diff --git a/mlx-nn/src/macros.rs b/mlx-nn/src/macros.rs new file mode 100644 index 00000000..a63a8dfc --- /dev/null +++ b/mlx-nn/src/macros.rs @@ -0,0 +1,3 @@ +//! Macro re-exports. + +pub use mlx_macros::ModuleParameters; diff --git a/mlx-nn/src/optimizers/mod.rs b/mlx-nn/src/optimizers/mod.rs new file mode 100644 index 00000000..adce1d90 --- /dev/null +++ b/mlx-nn/src/optimizers/mod.rs @@ -0,0 +1,218 @@ +//! Trait and implementations for optimizers. + +use std::rc::Rc; + +use mlx_rs::module::{FlattenedModuleParam, ModuleParameters}; + +mod rmsprop; +mod sgd; + +use mlx_rs::Array; +pub use rmsprop::*; +pub use sgd::*; + +type OptimizerState = FlattenedModuleParam; + +/// Trait for optimizers. +pub trait Optimizer { + /// Update a single parameter with the given gradient. + fn update_single(&mut self, key: Rc, gradient: Array, parameter: &mut Array); + + /// Apply the gradients to the parameters of the model and update the model with the new + /// parameters. + fn update(&mut self, model: &mut M, gradients: FlattenedModuleParam) + where + M: ModuleParameters, + { + let mut parameters = model.parameters_mut().flatten(); + + for (key, gradient) in gradients { + if let Some(parameter) = parameters.get_mut(&key) { + self.update_single(key, gradient, parameter); + } + } + } +} + +#[cfg(test)] +mod optim_test_util { + use mlx_macros::ModuleParameters; + use mlx_rs::module::{FlattenedModuleParam, ModuleParameters, Param}; + use mlx_rs::{ + ops::{ones, zeros}, + Array, + }; + + #[derive(Debug, ModuleParameters)] + pub(super) struct First { + #[param] + pub a: Param, + + #[param] + pub b: Param, + } + + #[derive(Debug, ModuleParameters)] + pub(super) struct Model { + #[param] + pub first: Param, + + #[param] + pub second: Param, + } + + pub(super) type GradsMap = FlattenedModuleParam; + + pub(super) fn create_default_test_model_and_grads() -> (Model, GradsMap) { + let first = First { + a: Param::new(zeros::(&[10]).unwrap()), + b: Param::new(zeros::(&[1]).unwrap()), + }; + let model = Model { + first: Param::new(first), + second: Param::new(zeros::(&[1]).unwrap()), + }; + + let grads_map: GradsMap = model + .parameters() + .flatten() + .iter() + .map(|(k, v)| { + let g = ones::(v.shape()).unwrap(); + (k.clone(), g) + }) + .collect(); + + (model, grads_map) + } + + pub(super) const ATOL: f64 = 1e-5; +} + +#[cfg(test)] +mod tests { + use mlx_macros::ModuleParameters; + use mlx_rs::{ + array, + error::Exception, + module::{Module, ModuleParameters, Param}, + random::uniform, + transforms::{eval, eval_params}, + Array, + }; + + use crate::{ + losses::{LossReduction, MseLoss}, + module_value_and_grad, + }; + + use super::*; + + /// A helper model for testing optimizers. + /// + /// This is adapted from the swift binding tests in `mlx-swift/Tests/MLXTests/OptimizerTests.swift`. + #[derive(Debug, ModuleParameters)] + struct LinearFunctionModel { + #[param] + pub m: Param, + + #[param] + pub b: Param, + } + + impl Module for LinearFunctionModel { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + self.m.multiply(x)?.add(&self.b) + } + + fn training_mode(&mut self, _mode: bool) {} + } + + impl LinearFunctionModel { + pub fn new() -> Result { + let m = uniform::<_, f32>(-5.0, 5.0, None, None)?; + let b = uniform::<_, f32>(-5.0, 5.0, None, None)?; + Ok(Self { + m: Param::new(m), + b: Param::new(b), + }) + } + } + + pub fn train(f: F, steps: usize) -> Result> + where + F: FnOnce() -> O, + O: Optimizer, + { + let mut optimizer = f(); + + let mse_loss = MseLoss::builder().reduction(LossReduction::Mean).build(); + let loss = |model: &LinearFunctionModel, (x, y): (&Array, &Array)| { + mse_loss.apply(model.forward(x)?, y) + }; + + // TODO: check compiled model once we have it + let mut model = LinearFunctionModel::new()?; + eval_params(model.parameters())?; + + let m = array!(0.25); + let b = array!(7.0); + + let mut lg = module_value_and_grad(loss); + + let mut last_loss = None; + for _ in 0..steps { + // println!("target: b = {}, m = {}", b, m); + // println!("parameters: {:?}", model.parameters()); + + // generate random training data along with the ground truth. + // notice that the shape is [B, 1] where B is the batch + // dimension -- this allows us to train on 10 samples simultaneously + let x = uniform::<_, f32>(-5.0, 5.0, &[10, 1], None)?; + let y = &m * &x + &b; + eval([&x, &y])?; + + // compute the loss and gradients. use the optimizer + // to adjust the parameters closer to the target + let (loss, g) = lg(&mut model, (&x, &y))?; + optimizer.update(&mut model, g); + + eval_params(model.parameters())?; + + last_loss = Some(loss); + } + + Ok(last_loss.unwrap()) + } + + const NUM_TRIALS: usize = 3; + + #[test] + fn test_sgd_converges() { + let mut total_loss = 0.0; + for _ in 0..NUM_TRIALS { + let loss = train(|| Sgd::new(0.1), 30).unwrap(); + total_loss += loss.item::(); + } + // It sometimes doesn't converge that fast, so we take the average loss + // across multiple trials + let avg_loss = total_loss / NUM_TRIALS as f32; + assert!(avg_loss < 0.1, "avg loss: {}", avg_loss); + } + + #[test] + fn test_rmsprop_converges() { + let mut total_loss = 0.0; + for _ in 0..NUM_TRIALS { + // RMSProp doesn't seem to converge as fast as SGD + let loss = train(|| RmsProp::new(0.1), 100).unwrap(); + total_loss += loss.item::(); + } + // It sometimes doesn't converge that fast, so we take the average loss + // across multiple trials + let avg_loss = total_loss / NUM_TRIALS as f32; + assert!(avg_loss < 0.1, "avg loss: {}", avg_loss); + } +} diff --git a/mlx-nn/src/optimizers/rmsprop.rs b/mlx-nn/src/optimizers/rmsprop.rs new file mode 100644 index 00000000..44446f09 --- /dev/null +++ b/mlx-nn/src/optimizers/rmsprop.rs @@ -0,0 +1,146 @@ +use std::rc::Rc; + +use mlx_internal_macros::GenerateBuilder; +use mlx_rs::{ + array, + ops::{sqrt, square}, + Array, +}; + +use crate::{error::RmsPropBuildError, utils::get_mut_or_insert_with}; + +use super::*; + +/// The RMSprop optimizer [1]. +/// +/// [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for +/// machine learning +#[derive(Debug, Clone, GenerateBuilder)] +#[generate_builder(generate_build_fn = false)] +pub struct RmsProp { + /// Learning rate + pub lr: f32, + + /// The smoothing constant. Default to [`RmsProp::DEFAULT_ALPHA`] if not specified. + #[optional] + pub alpha: f32, + + /// The epsilon added to the denominator to improve numerical stability. Default to + /// [`RmsProp::DEFAULT_EPSILON`] if not specified. + #[optional] + pub epsilon: f32, + + /// Inner state + pub state: OptimizerState, +} + +impl RmsPropBuilder { + /// Builds a new [`RmsProp`]. + /// + /// # Params + /// + /// - `lr`: The learning rate. + pub fn build(self, lr: f32) -> Result { + let alpha = self.alpha.unwrap_or(RmsProp::DEFAULT_ALPHA); + let epsilon = self.epsilon.unwrap_or(RmsProp::DEFAULT_EPSILON); + + if alpha < 0.0 { + return Err(RmsPropBuildError::NegativeAlpha); + } + + if epsilon < 0.0 { + return Err(RmsPropBuildError::NegativeEpsilon); + } + + Ok(RmsProp { + lr, + alpha, + epsilon, + state: OptimizerState::new(), + }) + } +} + +impl RmsProp { + /// Default alpha if not specified. + pub const DEFAULT_ALPHA: f32 = 0.99; + + /// Default epsilon if not specified. + pub const DEFAULT_EPSILON: f32 = 1e-8; + + /// Creates a new `RmsProp` optimizer with all optional params set to their default values. + /// + /// # Params + /// + /// - `lr`: The learning rate. + pub fn new(lr: f32) -> Self { + Self::builder().build(lr).expect("Default values are valid") + } +} + +impl Optimizer for RmsProp { + fn update_single(&mut self, key: Rc, gradient: Array, parameter: &mut Array) { + let state = get_mut_or_insert_with(&mut self.state, &key, || array!(0.0)); + + let lr = array!(self.lr); + let alpha = array!(self.alpha); + let eps = array!(self.epsilon); + + let v = &alpha * &*state + (array!(1.0) - &alpha) * square(&gradient); + let new_param = &*parameter - &lr * &gradient / (sqrt(&v) + &eps); + + *parameter = new_param; + *state = v; + } +} + +#[cfg(test)] +mod tests { + + use mlx_rs::assert_array_eq; + use mlx_rs::ops::ones; + + use super::optim_test_util::*; + use super::*; + + // This unit test is adapted from the python unit test `test_rmsprop` in + // `tests/test_optimizer.py`. + #[test] + fn test_rmsprop() { + const LR: f32 = 1e-2; + const ALPHA: f32 = 0.99; + + let (mut model, gradients) = create_default_test_model_and_grads(); + + let mut optim = RmsProp::builder().alpha(ALPHA).build(LR).unwrap(); + optim.update(&mut model, gradients); + + let expected_first_a = ones::(&[10]).unwrap() * -0.1; + let expected_first_b = ones::(&[1]).unwrap() * -0.1; + let expected_second = ones::(&[1]).unwrap() * -0.1; + + assert_array_eq!(model.first.a.as_ref(), expected_first_a, ATOL); + assert_array_eq!(model.first.b.as_ref(), expected_first_b, ATOL); + assert_array_eq!(model.second.as_ref(), expected_second, ATOL); + + let expected_state_first_a = ones::(&[10]).unwrap() * 0.01; + let expected_state_first_b = ones::(&[1]).unwrap() * 0.01; + let expected_state_second = ones::(&[1]).unwrap() * 0.01; + + assert_array_eq!( + optim.state.get("first.a").unwrap(), + expected_state_first_a, + ATOL + ); + assert_array_eq!( + optim.state.get("first.b").unwrap(), + expected_state_first_b, + ATOL + ); + assert_array_eq!( + optim.state.get("second").unwrap(), + expected_state_second, + ATOL + ); + } +} diff --git a/mlx-nn/src/optimizers/sgd.rs b/mlx-nn/src/optimizers/sgd.rs new file mode 100644 index 00000000..0eb6000c --- /dev/null +++ b/mlx-nn/src/optimizers/sgd.rs @@ -0,0 +1,158 @@ +use std::rc::Rc; + +use mlx_internal_macros::GenerateBuilder; +use mlx_rs::{array, Array}; + +use crate::utils::get_mut_or_insert_with; + +use super::*; + +/// Stochastic gradient descent optimizer. +#[derive(Debug, Clone, GenerateBuilder)] +#[generate_builder(generate_build_fn = false)] +pub struct Sgd { + /// Learning rate + pub lr: f32, + + /// Momentum strength. Default to [`Sgd::DEFAULT_MOMENTUM`] if not specified. + #[optional] + pub momentum: f32, + + /// Weight decay (L2 penalty). Default to [`Sgd::DEFAULT_WEIGHT_DECAY`] if not specified. + #[optional] + pub weight_decay: f32, + + /// Dampening for momentum. Default to [`Sgd::DEFAULT_DAMPENING`] if not specified. + #[optional] + pub dampening: f32, + + /// Enables nesterov momentum. Default to [`Sgd::DEFAULT_NESTEROV`] if not specified. + #[optional] + pub nesterov: bool, + + /// Inner state + pub state: OptimizerState, +} + +impl SgdBuilder { + /// Builds a new [`Sgd`]. + pub fn build(self, lr: f32) -> Sgd { + let momentum = self.momentum.unwrap_or(Sgd::DEFAULT_MOMENTUM); + let weight_decay = self.weight_decay.unwrap_or(Sgd::DEFAULT_WEIGHT_DECAY); + let dampening = self.dampening.unwrap_or(Sgd::DEFAULT_DAMPENING); + let nesterov = self.nesterov.unwrap_or(Sgd::DEFAULT_NESTEROV); + + Sgd { + lr, + momentum, + weight_decay, + dampening, + nesterov, + state: OptimizerState::new(), + } + } +} + +impl Sgd { + /// Default momentum if not specified. + pub const DEFAULT_MOMENTUM: f32 = 0.0; + + /// Default weight decay if not specified. + pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0; + + /// Default dampening if not specified. + pub const DEFAULT_DAMPENING: f32 = 0.0; + + /// Default nesterov if not specified. + pub const DEFAULT_NESTEROV: bool = false; + + /// Creates a new `Sgd` optimizer. + pub fn new(lr: f32) -> Self { + Self::builder().build(lr) + } +} + +impl Optimizer for Sgd { + /// Apply SGD to a single parameter. Returns the updated parameter and the updated state. + #[inline] + fn update_single(&mut self, key: Rc, mut gradient: Array, parameter: &mut Array) { + let state = get_mut_or_insert_with(&mut self.state, &key, || array!(0.0)); + + // Apply weight decay + if self.weight_decay != 0.0 { + gradient = &gradient + array!(self.weight_decay) * &*parameter; + } + + let lr = array!(self.lr); + + // Apply momentum + if self.momentum <= 0.0 { + *parameter = &*parameter - &lr * &gradient; + return; + } + + let momentum = array!(self.momentum); + let mut v = &*state * &momentum; + if self.dampening > 0.0 { + v = &v + (&array!(1.0 - self.dampening) * &gradient); + } else { + v = &v + &gradient; + } + + match self.nesterov { + true => { + let update = gradient + (&momentum * &v); + *parameter = &*parameter - &lr * update; + *state = v; + } + false => { + let update = &v; + *parameter = &*parameter - &lr * update; + *state = v; + } + } + } +} + +#[cfg(test)] +mod tests { + use mlx_rs::assert_array_eq; + use mlx_rs::ops::ones; + + use super::optim_test_util::*; + use super::*; + + // This unit test is adapted from the python unit test `test_sgd` in + // `mlx/python/tests/test_optimizers.py` + #[test] + fn test_sgd() { + let (mut model, gradients) = create_default_test_model_and_grads(); + + let mut optim = Sgd::builder().momentum(0.9).build(1e-2); + optim.update(&mut model, gradients); + + let expected_first_a = ones::(&[10]).unwrap() * -0.01; + let expected_first_b = ones::(&[1]).unwrap() * -0.01; + let expected_second = ones::(&[1]).unwrap() * -0.01; + + assert_array_eq!(model.first.a.as_ref(), expected_first_a, ATOL); + assert_array_eq!(model.first.b.as_ref(), expected_first_b, ATOL); + assert_array_eq!(model.second.as_ref(), expected_second, ATOL); + + let expected_state_first_a = ones::(&[10]).unwrap(); + let expected_state_first_b = ones::(&[1]).unwrap(); + let expected_state_second = ones::(&[1]).unwrap(); + + assert_array_eq!( + optim.state["first.a"].as_ref(), + expected_state_first_a, + ATOL + ); + assert_array_eq!( + optim.state["first.b"].as_ref(), + expected_state_first_b, + ATOL + ); + assert_array_eq!(optim.state["second"].as_ref(), expected_state_second, ATOL); + } +} diff --git a/mlx-nn/src/utils.rs b/mlx-nn/src/utils.rs new file mode 100644 index 00000000..0f4cd651 --- /dev/null +++ b/mlx-nn/src/utils.rs @@ -0,0 +1,54 @@ +//! Utility types and functions. + +use std::rc::Rc; + +use mlx_rs::module::FlattenedModuleParam; +use mlx_rs::Array; + +/// A convenience trait to convert a single value or a pair of values into a pair of values. +pub trait IntOrPair { + /// Converts the value into a pair of values. + fn into_pair(self) -> (i32, i32); +} + +impl IntOrPair for i32 { + fn into_pair(self) -> (i32, i32) { + (self, self) + } +} + +impl IntOrPair for (i32, i32) { + fn into_pair(self) -> (i32, i32) { + self + } +} + +/// A convenience trait to convert a single value or a triple of values into a triple of values. +pub trait IntOrTriple { + /// Converts the value into a triple of values. + fn into_triple(self) -> (i32, i32, i32); +} + +impl IntOrTriple for i32 { + fn into_triple(self) -> (i32, i32, i32) { + (self, self, self) + } +} + +impl IntOrTriple for (i32, i32, i32) { + fn into_triple(self) -> (i32, i32, i32) { + self + } +} + +pub(crate) fn get_mut_or_insert_with<'a>( + map: &'a mut FlattenedModuleParam, + key: &Rc, + f: impl FnOnce() -> Array, +) -> &'a mut Array { + if !map.contains_key(key) { + map.insert(key.clone(), f()); + } + + map.get_mut(key).unwrap() +} diff --git a/mlx-nn/src/value_and_grad.rs b/mlx-nn/src/value_and_grad.rs new file mode 100644 index 00000000..bc3fd412 --- /dev/null +++ b/mlx-nn/src/value_and_grad.rs @@ -0,0 +1,216 @@ +use mlx_rs::module::update_flattened_parameters; +use mlx_rs::{error::Exception, Array}; + +use crate::module::{FlattenedModuleParam, FlattenedModuleParamRef, Module}; + +fn trainable_params(model: &impl Module) -> FlattenedModuleParam { + model + .trainable_parameters() + .flatten() + .into_iter() + .map(|(k, v)| (k, v.clone())) + .collect() +} + +/// Helper trait for [`value_and_grad`] +pub trait IntoModuleValueAndGrad<'a, M, Args, Val, Err> +where + M: Module + 'a, + Args: Clone, +{ + /// Computes the valud and gradient of the passed function `f(model, args)` with regard to the + /// model's trainable parameters. + fn into_module_value_and_grad( + self, + ) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a; +} + +impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec, ()> for F +where + M: Module + 'a, + F: FnMut(&M, Args) -> Vec + 'a, + Args: Clone, +{ + fn into_module_value_and_grad( + mut self, + ) -> impl FnMut(&mut M, Args) -> Result<(Vec, FlattenedModuleParam), Exception> + 'a + { + move |model, arrays| { + let trainable_parameters = trainable_params(model); + let inner = |parameters: FlattenedModuleParamRef, arrays: Args| -> Vec { + let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone())); + update_flattened_parameters(model, flattened_parameters); + + self(model, arrays) + }; + let mut vg = mlx_rs::transforms::value_and_grad_with_hashmap(inner); + + let (v, g) = vg(trainable_parameters, arrays)?; + Ok((v, g)) + } + } +} + +impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec, Exception> for F +where + M: Module + 'a, + F: FnMut(&M, Args) -> Result, Exception> + 'a, + Args: Clone, +{ + fn into_module_value_and_grad( + mut self, + ) -> impl FnMut(&mut M, Args) -> Result<(Vec, FlattenedModuleParam), Exception> + 'a + { + move |model, arrays| { + let trainable_parameters = trainable_params(model); + let inner = |parameters: FlattenedModuleParamRef, + arrays: Args| + -> Result, Exception> { + let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone())); + update_flattened_parameters(model, flattened_parameters); + + self(model, arrays) + }; + let mut vg = mlx_rs::transforms::value_and_grad_with_hashmap(inner); + + let (v, g) = vg(trainable_parameters, arrays)?; + Ok((v, g)) + } + } +} + +impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, ()> for F +where + M: Module + 'a, + F: FnMut(&M, Args) -> Array + 'a, + Args: Clone, +{ + fn into_module_value_and_grad( + mut self, + ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a { + move |model, arrays| { + let trainable_parameters = trainable_params(model); + let inner = |parameters: FlattenedModuleParamRef, arrays: Args| -> Vec { + let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone())); + update_flattened_parameters(model, flattened_parameters); + + vec![self(model, arrays)] + }; + let mut vg = mlx_rs::transforms::value_and_grad_with_hashmap(inner); + + let (v, g) = vg(trainable_parameters, arrays)?; + let v = v.into_iter().next().expect("Expected a single value"); + Ok((v, g)) + } + } +} + +impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, Exception> for F +where + M: Module + 'a, + F: FnMut(&M, Args) -> Result + 'a, + Args: Clone, +{ + fn into_module_value_and_grad( + mut self, + ) -> impl FnMut(&mut M, Args) -> Result<(Array, FlattenedModuleParam), Exception> + 'a { + move |model, arrays| { + let trainable_parameters = trainable_params(model); + let inner = |parameters: FlattenedModuleParamRef, + arrays: Args| + -> Result, Exception> { + let flattened_parameters = parameters.into_iter().map(|(k, v)| (k, v.clone())); + update_flattened_parameters(model, flattened_parameters); + + self(model, arrays).map(|v| vec![v]) + }; + let mut vg = mlx_rs::transforms::value_and_grad_with_hashmap(inner); + + let (v, g) = vg(trainable_parameters, arrays)?; + let v = v.into_iter().next().expect("Expected a single value"); + Ok((v, g)) + } + } +} + +/// Transform the passed function `f(model, args)` to a function that computes the gradients of `f` +/// with regard to the model's trainable parameters and also its value. +pub fn module_value_and_grad<'a, F, M, Args, Val, Err>( + f: F, +) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a +where + M: Module + 'a, + F: IntoModuleValueAndGrad<'a, M, Args, Val, Err>, + Args: Clone, +{ + f.into_module_value_and_grad() +} + +#[cfg(test)] +mod tests { + use mlx_rs::module::Module; + use mlx_rs::{array, error::Exception, Array}; + + use crate::Linear; + + use super::*; + + // The unit test below is adapted from `test_compiled_optimizer` in + // `mlx/python/tests/test_optimizers.py`` + #[test] + fn test_module_value_and_grad() { + let mut model = Linear::new(2, 2).unwrap(); + let x = mlx_rs::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap(); + + let loss = |model: &Linear, x: &Array| -> Vec { + vec![model.forward(x).unwrap().sum(None, None).unwrap()] + }; + + let mut vg = module_value_and_grad(loss); + let (v, g) = vg(&mut model, &x).unwrap(); + + assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0)); + } + + #[test] + fn test_fallible_module_value_and_grad() { + let mut model = Linear::new(2, 2).unwrap(); + let x = mlx_rs::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap(); + + let loss = |model: &Linear, x: &Array| -> Result, Exception> { + Ok(vec![model.forward(x)?.sum(None, None)?]) + }; + + let mut vg = module_value_and_grad(loss); + let (v, g) = vg(&mut model, &x).unwrap(); + + assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0)); + } + + #[test] + fn test_module_value_and_grad_with_two_args() { + let mut model = Linear::new(2, 2).unwrap(); + let x = mlx_rs::random::uniform::<_, f32>(1.0, 2.0, &[2, 2], None).unwrap(); + let y = mlx_rs::ops::ones::(x.shape()).unwrap(); + + let loss = |model: &Linear, (x, y): (&Array, &Array)| -> Result, Exception> { + model + .forward(x)? + .subtract(y)? + .square() + .sum(None, None) + .map(|v| vec![v]) + }; + + let mut vg = module_value_and_grad(loss); + let (v, g) = vg(&mut model, (&x, &y)).unwrap(); + + assert_ne!(v[0].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["weight"].sum(None, None).unwrap(), array!(0.0)); + assert_ne!(g["bias"].sum(None, None).unwrap(), array!(0.0)); + } +} diff --git a/mlx-nn/tests/test_module.rs b/mlx-nn/tests/test_module.rs new file mode 100644 index 00000000..37549de5 --- /dev/null +++ b/mlx-nn/tests/test_module.rs @@ -0,0 +1,36 @@ +use mlx_macros::ModuleParameters; +use mlx_nn::Linear; +use mlx_rs::module::{Module, Param}; +use mlx_rs::{error::Exception, Array}; + +#[derive(Debug, Clone, ModuleParameters)] +struct M { + #[param] + linear: Param, +} + +impl M { + pub fn new() -> Self { + Self { + linear: Param::new(Linear::new(5, 5).unwrap()), + } + } +} + +impl Module for M { + type Error = Exception; + + fn forward(&self, x: &Array) -> Result { + self.linear.forward(x) + } + + fn training_mode(&mut self, _mode: bool) {} +} + +#[test] +fn test_nested_module() { + let m = M::new(); + let x = mlx_rs::random::uniform::<_, f32>(1.0, 2.0, &[1, 5], None).unwrap(); + let y = m.forward(&x).unwrap(); + assert_ne!(y.sum(None, None).unwrap(), mlx_rs::array!(0.0)); +} diff --git a/mlx-nn/tests/test_module_parameters.rs b/mlx-nn/tests/test_module_parameters.rs new file mode 100644 index 00000000..37e74410 --- /dev/null +++ b/mlx-nn/tests/test_module_parameters.rs @@ -0,0 +1,213 @@ +use mlx_macros::ModuleParameters; +use mlx_rs::module::{ModuleParameters, Param, Parameter}; +use mlx_rs::{array, Array}; + +#[derive(ModuleParameters)] +pub struct StructModule { + #[param] + a: Param, + + #[param] + b: Param, + + #[param] + c: Param>, +} + +#[derive(ModuleParameters)] +pub struct UnitStructModule; + +#[test] +fn test_module_parameters() { + let m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(None), + }; + + let flattened = m.parameters().flatten(); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + + let m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(Some(array!(3.0))), + }; + + let flattened = m.parameters().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + assert_eq!(flattened["c"], &array!(3.0)); +} + +#[test] +fn test_module_parameters_mut() { + let mut m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(None), + }; + + let flattened = m.parameters_mut().flatten(); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened["a"], &mut array!(1.0)); + assert_eq!(flattened["b"], &mut array!(2.0)); + + let mut m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(Some(array!(3.0))), + }; + + let flattened = m.parameters_mut().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &mut array!(1.0)); + assert_eq!(flattened["b"], &mut array!(2.0)); + assert_eq!(flattened["c"], &mut array!(3.0)); +} + +#[test] +fn test_module_trainable_parameters_all_trainable() { + let m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(None), + }; + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + + let m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(Some(array!(3.0))), + }; + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + assert_eq!(flattened["c"], &array!(3.0)); +} + +#[test] +fn test_module_trainable_parameters_partial_freeze() { + let mut m = StructModule { + a: Param::new(array!(1.0)), + b: Param::new(array!(2.0)), + c: Param::new(None), + }; + + // Freeze one parameter that is not optional + m.a.freeze(); + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 1); + assert_eq!(flattened["b"], &array!(2.0)); + + // Now freeze the optional parameter + m.c.freeze(); + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 1); + assert_eq!(flattened["b"], &array!(2.0)); + + // Unfreeze the non-optional parameter + m.a.unfreeze(); + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + + // Set the optional parameter to Some but still frozen + m.c.value = Some(array!(3.0)); + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + + // Unfreeze the optional parameter + m.c.unfreeze(); + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["b"], &array!(2.0)); + assert_eq!(flattened["c"], &array!(3.0)); +} + +#[test] +fn test_unit_struct_module_parameters() { + let m = UnitStructModule; + + let flattened = m.parameters().flatten(); + assert_eq!(flattened.len(), 0); +} + +#[test] +fn test_unit_struct_module_parameters_mut() { + let mut m = UnitStructModule; + + let flattened = m.parameters_mut().flatten(); + assert_eq!(flattened.len(), 0); +} + +#[test] +fn test_unit_struct_module_trainable_parameters() { + let m = UnitStructModule; + + let flattened = m.trainable_parameters().flatten(); + assert_eq!(flattened.len(), 0); +} + +#[derive(ModuleParameters)] +struct StructModuleWithNested { + #[param] + a: Param, + + #[param] + nested: Param, +} + +#[test] +fn test_nested_module_parameters() { + let m = StructModuleWithNested { + a: Param::new(array!(1.0)), + nested: Param::new(StructModule { + a: Param::new(array!(2.0)), + b: Param::new(array!(3.0)), + c: Param::new(None), + }), + }; + + let flattened = m.parameters().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &array!(1.0)); + assert_eq!(flattened["nested.a"], &array!(2.0)); + assert_eq!(flattened["nested.b"], &array!(3.0)); +} + +#[test] +fn test_nested_module_parameters_mut() { + let mut m = StructModuleWithNested { + a: Param::new(array!(1.0)), + nested: Param::new(StructModule { + a: Param::new(array!(2.0)), + b: Param::new(array!(3.0)), + c: Param::new(None), + }), + }; + + let flattened = m.parameters_mut().flatten(); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["a"], &mut array!(1.0)); + assert_eq!(flattened["nested.a"], &mut array!(2.0)); + assert_eq!(flattened["nested.b"], &mut array!(3.0)); +} diff --git a/mlx-rs/Cargo.toml b/mlx-rs/Cargo.toml index c7bcd87c..073e894c 100644 --- a/mlx-rs/Cargo.toml +++ b/mlx-rs/Cargo.toml @@ -22,22 +22,22 @@ targets = [ ] [dependencies] -mlx-sys = { workspace = true } -mlx-macros = { workspace = true } +mlx-sys.workspace = true +mlx-internal-macros.workspace = true half = "2" -mach-sys = "0.5.4" +mach-sys = "0.5" num-complex = "0.4" -num_enum = "0.7.2" -num-traits = "0.2.18" -paste = "1.0.14" +num_enum = "0.7" +num-traits = "0.2" +paste = "1" smallvec = "1" strum = { version = "0.26", features = ["derive"] } -thiserror = "1.0.58" +thiserror.workspace = true libc = "0.2" [dev-dependencies] pretty_assertions = "1.4.0" -float_eq = "1" +float_eq.workspace = true [features] default = ["accelerate", "metal"] diff --git a/mlx-rs/src/array/mod.rs b/mlx-rs/src/array/mod.rs index b519edd9..40043433 100644 --- a/mlx-rs/src/array/mod.rs +++ b/mlx-rs/src/array/mod.rs @@ -7,7 +7,7 @@ use crate::{ sealed::Sealed, Stream, StreamOrDevice, }; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use mlx_sys::mlx_array; use num_complex::Complex; use std::ffi::c_void; diff --git a/mlx-rs/src/dtype.rs b/mlx-rs/src/dtype.rs index 142b399b..6e7acf2e 100644 --- a/mlx-rs/src/dtype.rs +++ b/mlx-rs/src/dtype.rs @@ -1,4 +1,4 @@ -use mlx_macros::GenerateDtypeTestCases; +use mlx_internal_macros::GenerateDtypeTestCases; use strum::EnumIter; /// Array element type diff --git a/mlx-rs/src/error.rs b/mlx-rs/src/error.rs index f5dc9beb..83727de2 100644 --- a/mlx-rs/src/error.rs +++ b/mlx-rs/src/error.rs @@ -40,6 +40,19 @@ impl Exception { pub fn what(&self) -> &str { &self.what } + + /// Creates a new exception with the given message. + pub fn custom(what: impl Into) -> Self { + Self { what: what.into() } + } +} + +impl From<&str> for Exception { + fn from(what: &str) -> Self { + Self { + what: what.to_string(), + } + } } thread_local! { diff --git a/mlx-rs/src/fft/fftn.rs b/mlx-rs/src/fft/fftn.rs index 72432c0c..1123ca36 100644 --- a/mlx-rs/src/fft/fftn.rs +++ b/mlx-rs/src/fft/fftn.rs @@ -1,4 +1,4 @@ -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{array::Array, error::Exception, stream::StreamOrDevice, utils::IntoOption, Stream}; diff --git a/mlx-rs/src/fft/rfftn.rs b/mlx-rs/src/fft/rfftn.rs index 3fb058a7..6215fd37 100644 --- a/mlx-rs/src/fft/rfftn.rs +++ b/mlx-rs/src/fft/rfftn.rs @@ -1,4 +1,4 @@ -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{error::Exception, utils::IntoOption, Array, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/lib.rs b/mlx-rs/src/lib.rs index a1fd7a5c..f5aed84b 100644 --- a/mlx-rs/src/lib.rs +++ b/mlx-rs/src/lib.rs @@ -10,6 +10,8 @@ mod dtype; pub mod error; pub mod fft; pub mod linalg; +pub mod module; +pub mod nested; pub mod ops; pub mod random; mod stream; diff --git a/mlx-rs/src/linalg.rs b/mlx-rs/src/linalg.rs index 8a4caaa7..21f1abf4 100644 --- a/mlx-rs/src/linalg.rs +++ b/mlx-rs/src/linalg.rs @@ -1,7 +1,7 @@ use crate::error::Exception; use crate::utils::{IntoOption, MlxString, VectorArray}; use crate::{Array, Stream, StreamOrDevice}; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use smallvec::SmallVec; use std::f64; diff --git a/mlx-rs/src/macros/assert.rs b/mlx-rs/src/macros/assert.rs index 4b0b6fac..25287939 100644 --- a/mlx-rs/src/macros/assert.rs +++ b/mlx-rs/src/macros/assert.rs @@ -1,5 +1,8 @@ #[macro_export] macro_rules! assert_array_eq { + ($value:expr, $expected:expr) => { + assert_array_eq!($value, $expected, None); + }; ($value:expr, $expected:expr, $atol:expr) => { assert_eq!($value.shape(), $expected.shape(), "Shapes are not equal"); let assert = $value.all_close(&$expected, $atol, $atol, None); diff --git a/mlx-rs/src/module/mod.rs b/mlx-rs/src/module/mod.rs new file mode 100644 index 00000000..a1bd04f3 --- /dev/null +++ b/mlx-rs/src/module/mod.rs @@ -0,0 +1,12 @@ +/// This crate defines the traits for neural network modules and parameters. +/// +/// This is to separate the trait definitions from the implementations, which are in the `mlx-nn` +/// crate. This also allows using the `mlx_macros::ModuleParameters` derive macro in crates other +/// than `mlx-nn`. + +#[allow(clippy::module_inception)] +mod module; +mod param; + +pub use module::*; +pub use param::*; diff --git a/mlx-rs/src/module/module.rs b/mlx-rs/src/module/module.rs new file mode 100644 index 00000000..c9bbe355 --- /dev/null +++ b/mlx-rs/src/module/module.rs @@ -0,0 +1,124 @@ +use std::{collections::HashMap, rc::Rc}; + +use crate::{nested::NestedHashMap, Array}; + +/// Type alias for owned module parameters. +pub type ModuleParam = NestedHashMap<&'static str, Array>; + +/// Type alias for borrowed module parameters. +pub type ModuleParamRef<'a> = NestedHashMap<&'static str, &'a Array>; + +/// Type alias for mutably borrowed module parameters. +pub type ModuleParamMut<'a> = NestedHashMap<&'static str, &'a mut Array>; + +/// Type alias for flattened module parameters. +pub type FlattenedModuleParam = HashMap, Array>; + +/// Type alias for borrowed flattened module parameters. +pub type FlattenedModuleParamRef<'a> = HashMap, &'a Array>; + +/// Type alias for mutably borrowed flattened module parameters. +pub type FlattenedModuleParamMut<'a> = HashMap, &'a mut Array>; + +/// Trait for a neural network module. +pub trait Module: ModuleParameters { + /// Error type for the module. + type Error: std::error::Error; + + /// Forward pass of the module. + fn forward(&self, x: &Array) -> Result; + + /// Set whether the module is in training mode. + /// + /// Training mode only applies to certain layers. For example, dropout layers applies a random + /// mask in training mode, but is the identity in evaluation mode. Implementations of nested + /// modules should propagate the training mode to all child modules. + fn training_mode(&mut self, mode: bool); +} + +/// Trait for accessing and updating module parameters. +pub trait ModuleParameters { + /// Get references to the module parameters. + fn parameters(&self) -> ModuleParamRef<'_>; + + /// Get mutable references to the module parameters. + fn parameters_mut(&mut self) -> ModuleParamMut<'_>; + + /// Get references to the trainable parameters. A parameter is trainable if it is NOT frozen. + fn trainable_parameters(&self) -> ModuleParamRef<'_>; + + /// Update the module parameters. + fn update(&mut self, parameters: ModuleParam) { + let flattened_parameters = parameters.flatten(); + update_flattened_parameters(self, flattened_parameters) + } + + /// Update the module parameters from a flattened representation. + fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) { + update_flattened_parameters(self, flattened_parameters) + } +} + +/// Update the module parameters from an iterator of flattened parameters. +pub fn update_flattened_parameters(module: &mut M, flattened_parameters: I) +where + M: ModuleParameters + ?Sized, + I: IntoIterator, Array)>, +{ + let mut flattened_self_parameters = module.parameters_mut().flatten(); + + for (key, value) in flattened_parameters { + if let Some(self_value) = flattened_self_parameters.get_mut(&key) { + **self_value = value.to_owned(); + } + } +} + +impl ModuleParameters for Box +where + T: ModuleParameters + ?Sized, +{ + fn parameters(&self) -> ModuleParamRef<'_> { + self.as_ref().parameters() + } + + fn parameters_mut(&mut self) -> ModuleParamMut<'_> { + self.as_mut().parameters_mut() + } + + fn trainable_parameters(&self) -> ModuleParamRef<'_> { + self.as_ref().trainable_parameters() + } +} + +impl ModuleParameters for Vec +where + T: ModuleParameters, +{ + fn parameters(&self) -> ModuleParamRef<'_> { + let mut parameters = NestedHashMap::new(); + self.iter().for_each(|module| { + let module_parameters = module.parameters(); + parameters.entries.extend(module_parameters.entries); + }); + parameters + } + + fn parameters_mut(&mut self) -> ModuleParamMut<'_> { + let mut parameters = NestedHashMap::new(); + self.iter_mut().for_each(|module| { + let module_parameters = module.parameters_mut(); + parameters.entries.extend(module_parameters.entries); + }); + parameters + } + + fn trainable_parameters(&self) -> ModuleParamRef<'_> { + let mut parameters = NestedHashMap::new(); + self.iter().for_each(|module| { + let module_parameters = module.trainable_parameters(); + parameters.entries.extend(module_parameters.entries); + }); + parameters + } +} diff --git a/mlx-rs/src/module/param.rs b/mlx-rs/src/module/param.rs new file mode 100644 index 00000000..c9c818a0 --- /dev/null +++ b/mlx-rs/src/module/param.rs @@ -0,0 +1,179 @@ +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, +}; + +use crate::{nested::NestedValue, Array}; + +use super::ModuleParameters; + +/// Trait for a module parameter. +pub trait Parameter { + /// Freeze the parameter. + fn freeze(&mut self); + + /// Unfreeze the parameter. + fn unfreeze(&mut self); + + /// Check if the parameter is frozen. + fn is_frozen(&self) -> bool; + + /// Get the parameter as a nested value. + fn as_nested_value<'a>(&self) -> NestedValue<&'a str, &Array>; + + /// Get the parameter as a mutable nested value. + fn as_nested_value_mut<'a>(&mut self) -> NestedValue<&'a str, &mut Array>; + + /// Get the parameter as a nested value if it is trainable. + fn as_trainable_nested_value<'a>(&self) -> Option>; +} + +/// A simple wrapper for a module parameter. +#[derive(Debug, Clone)] +pub struct Param { + /// The value of the parameter. + pub value: T, + + /// Whether the parameter is frozen. + pub is_frozen: bool, +} + +impl Param { + /// Create a new `Param` + pub fn new(value: T) -> Self { + Self { + value, + is_frozen: false, + } + } +} + +impl From for Param { + fn from(inner: T) -> Self { + Self::new(inner) + } +} + +impl Deref for Param { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl DerefMut for Param { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} + +impl AsRef for Param { + fn as_ref(&self) -> &T { + &self.value + } +} + +impl AsMut for Param { + fn as_mut(&mut self) -> &mut T { + &mut self.value + } +} + +impl Parameter for Param { + fn freeze(&mut self) { + self.is_frozen = true; + } + + fn unfreeze(&mut self) { + self.is_frozen = false; + } + + fn is_frozen(&self) -> bool { + self.is_frozen + } + + fn as_nested_value<'a>(&self) -> NestedValue<&'a str, &Array> { + NestedValue::Value(&self.value) + } + + fn as_nested_value_mut<'a>(&mut self) -> NestedValue<&'a str, &mut Array> { + NestedValue::Value(&mut self.value) + } + + fn as_trainable_nested_value<'a>(&self) -> Option> { + match self.is_frozen { + true => None, + false => Some(NestedValue::Value(&self.value)), + } + } +} + +impl Parameter for Param> { + fn freeze(&mut self) { + self.is_frozen = true; + } + + fn unfreeze(&mut self) { + self.is_frozen = false; + } + + fn is_frozen(&self) -> bool { + self.is_frozen + } + + fn as_nested_value<'a>(&self) -> NestedValue<&'a str, &Array> { + match &self.value { + Some(array) => NestedValue::Value(array), + // An empty map entry will be ignored during flattening + None => NestedValue::Map(HashMap::with_capacity(0)), + } + } + + fn as_nested_value_mut<'a>(&mut self) -> NestedValue<&'a str, &mut Array> { + match &mut self.value { + Some(array) => NestedValue::Value(array), + // An empty map entry will be ignored during flattening + None => NestedValue::Map(HashMap::with_capacity(0)), + } + } + + fn as_trainable_nested_value<'a>(&self) -> Option> { + match self.is_frozen { + true => None, + false => self.value.as_ref().map(NestedValue::Value), + } + } +} + +impl Parameter for Param +where + T: ModuleParameters, +{ + fn freeze(&mut self) { + self.is_frozen = true; + } + + fn unfreeze(&mut self) { + self.is_frozen = false; + } + + fn is_frozen(&self) -> bool { + self.is_frozen + } + + fn as_nested_value<'a>(&self) -> NestedValue<&'a str, &Array> { + self.parameters().into() + } + + fn as_nested_value_mut<'a>(&mut self) -> NestedValue<&'a str, &mut Array> { + self.parameters_mut().into() + } + + fn as_trainable_nested_value<'a>(&self) -> Option> { + match self.is_frozen { + true => None, + false => Some(self.trainable_parameters().into()), + } + } +} diff --git a/mlx-rs/src/nested.rs b/mlx-rs/src/nested.rs new file mode 100644 index 00000000..c47b509c --- /dev/null +++ b/mlx-rs/src/nested.rs @@ -0,0 +1,182 @@ +use std::{collections::HashMap, fmt::Display, rc::Rc}; + +const DELIMITER: char = '.'; + +#[derive(Debug, Clone)] +pub enum NestedValue { + Value(T), + Map(HashMap>), +} + +impl NestedValue { + pub fn flatten(self, prefix: &str) -> HashMap, V> + where + K: Display, + { + match self { + NestedValue::Value(array) => { + let mut map = HashMap::new(); + map.insert(prefix.into(), array); + map + } + NestedValue::Map(entries) => entries + .into_iter() + .flat_map(|(key, value)| value.flatten(&format!("{}{}{}", prefix, DELIMITER, key))) + .collect(), + } + } +} + +#[derive(Debug, Clone)] +pub struct NestedHashMap { + pub entries: HashMap>, +} + +impl From> for NestedValue { + fn from(map: NestedHashMap) -> Self { + NestedValue::Map(map.entries) + } +} + +impl Default for NestedHashMap { + fn default() -> Self { + Self::new() + } +} + +impl NestedHashMap { + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + pub fn insert(&mut self, key: K, value: NestedValue) + where + K: Eq + std::hash::Hash, + { + self.entries.insert(key, value); + } + + pub fn flatten(self) -> HashMap, V> + where + K: AsRef + Display, + { + self.entries + .into_iter() + .flat_map(|(key, value)| value.flatten(key.as_ref())) + .collect() + } +} + +#[cfg(test)] +mod tests { + use crate::array; + + use super::*; + + #[test] + fn test_flatten_nested_hash_map_of_owned_arrays() { + let first_entry = NestedValue::Value(array!([1, 2, 3])); + let second_entry = NestedValue::Map({ + let mut map = HashMap::new(); + map.insert("a", NestedValue::Value(array!([4, 5, 6]))); + map.insert("b", NestedValue::Value(array!([7, 8, 9]))); + map + }); + + let map = NestedHashMap { + entries: { + let mut map = HashMap::new(); + map.insert("first", first_entry); + map.insert("second", second_entry); + map + }, + }; + + let flattened = map.flatten(); + + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["first"], array!([1, 2, 3])); + assert_eq!(flattened["second.a"], array!([4, 5, 6])); + assert_eq!(flattened["second.b"], array!([7, 8, 9])); + } + + #[test] + fn test_flatten_nested_hash_map_of_borrowed_arrays() { + let first_entry_content = array!([1, 2, 3]); + let first_entry = NestedValue::Value(&first_entry_content); + + let second_entry_content_a = array!([4, 5, 6]); + let second_entry_content_b = array!([7, 8, 9]); + let second_entry = NestedValue::Map({ + let mut map = HashMap::new(); + map.insert("a", NestedValue::Value(&second_entry_content_a)); + map.insert("b", NestedValue::Value(&second_entry_content_b)); + map + }); + + let map = NestedHashMap { + entries: { + let mut map = HashMap::new(); + map.insert("first", first_entry); + map.insert("second", second_entry); + map + }, + }; + + let flattened = map.flatten(); + + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["first"], &first_entry_content); + assert_eq!(flattened["second.a"], &second_entry_content_a); + assert_eq!(flattened["second.b"], &second_entry_content_b); + } + + #[test] + fn test_flatten_nested_hash_map_of_mut_borrowed_arrays() { + let mut first_entry_content = array!([1, 2, 3]); + let first_entry = NestedValue::Value(&mut first_entry_content); + + let mut second_entry_content_a = array!([4, 5, 6]); + let mut second_entry_content_b = array!([7, 8, 9]); + let second_entry = NestedValue::Map({ + let mut map = HashMap::new(); + map.insert("a", NestedValue::Value(&mut second_entry_content_a)); + map.insert("b", NestedValue::Value(&mut second_entry_content_b)); + map + }); + + let map = NestedHashMap { + entries: { + let mut map = HashMap::new(); + map.insert("first", first_entry); + map.insert("second", second_entry); + map + }, + }; + + let flattened = map.flatten(); + + assert_eq!(flattened.len(), 3); + assert_eq!(flattened["first"], &mut array!([1, 2, 3])); + assert_eq!(flattened["second.a"], &mut array!([4, 5, 6])); + assert_eq!(flattened["second.b"], &mut array!([7, 8, 9])); + } + + #[test] + fn test_flatten_empty_nested_hash_map() { + let map = NestedHashMap::<&str, i32>::new(); + let flattened = map.flatten(); + + assert!(flattened.is_empty()); + + // Insert another empty map + let mut map = NestedHashMap::<&str, i32>::new(); + let empty_map = NestedValue::Map(HashMap::new()); + map.insert("empty", empty_map); + + let flattened = map.flatten(); + assert!(flattened.is_empty()); + } +} diff --git a/mlx-rs/src/ops/arithmetic.rs b/mlx-rs/src/ops/arithmetic.rs index 109c7fe3..2afa9878 100644 --- a/mlx-rs/src/ops/arithmetic.rs +++ b/mlx-rs/src/ops/arithmetic.rs @@ -5,7 +5,7 @@ use crate::stream::StreamOrDevice; use crate::utils::{IntoOption, ScalarOrArray, VectorArray}; use crate::Stream; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use smallvec::SmallVec; impl Array { @@ -2203,7 +2203,7 @@ mod tests { assert_eq!(z.item::(), 3.0); // Chain a few adds: - let mut out = x.clone(); + let mut out = x.deep_clone(); for _ in 0..10 { out = add(&out, &x).unwrap(); } @@ -2501,8 +2501,8 @@ mod tests { // Check that we can eval on both outputs let x = array![1.0]; let y = array![2.0]; - let (mut quo, mut rem) = divmod(&x, &y).unwrap(); - eval([&mut quo, &mut rem]).unwrap(); + let (quo, rem) = divmod(&x, &y).unwrap(); + eval([&quo, &rem]).unwrap(); assert_eq!(quo.item::(), 0.0); assert_eq!(rem.item::(), 1.0); @@ -2518,16 +2518,16 @@ mod tests { let (quo, _) = divmod(&x, &y).unwrap(); vec![quo] }; - eval(out_holder.iter_mut()).unwrap(); + eval(out_holder.iter()).unwrap(); assert_eq!(out_holder[0].item::(), 0.0); // Check that we can still eval when the other output goes out of scope out_holder.clear(); - let mut out_holder = { + let out_holder = { let (_, rem) = divmod(&x, &y).unwrap(); vec![rem] }; - eval(out_holder.iter_mut()).unwrap(); + eval(out_holder.iter()).unwrap(); assert_eq!(out_holder[0].item::(), 1.0); } } diff --git a/mlx-rs/src/ops/conversion.rs b/mlx-rs/src/ops/conversion.rs index 814b0d4d..1742fd18 100644 --- a/mlx-rs/src/ops/conversion.rs +++ b/mlx-rs/src/ops/conversion.rs @@ -1,4 +1,4 @@ -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{Array, ArrayElement, Dtype, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/ops/convolution.rs b/mlx-rs/src/ops/convolution.rs index fc8c84a6..1fc109e0 100644 --- a/mlx-rs/src/ops/convolution.rs +++ b/mlx-rs/src/ops/convolution.rs @@ -1,7 +1,7 @@ use crate::error::Exception; use crate::utils::IntoOption; use crate::{Array, Stream, StreamOrDevice}; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; /// General convolution over an input with several channels returning an error if the inputs are invalid. /// @@ -72,10 +72,10 @@ pub fn conv_general_device<'a>( /// /// - array: input array of shape `&[N, H, C_in]` /// - weight: weight array of shape `&[C_out, H, C_in]` -/// - stride: kernel stride -/// - padding: input padding -/// - dilation: kernel dilation -/// - groups: input feature groups +/// - stride: kernel stride. Default to 1 if not specified. +/// - padding: input padding. Default to 0 if not specified. +/// - dilation: kernel dilation. Default to 1 if not specified. +/// - groups: input feature groups. Default to 1 if not specified. #[default_device] pub fn conv1d_device( array: &Array, @@ -115,10 +115,10 @@ pub fn conv1d_device( /// /// - array: input array of shape `[N, H, W, C_in]` /// - weight: weight array of shape `[C_out, H, W, C_in]` -/// - stride: kernel stride -/// - padding: input padding -/// - dilation: kernel dilation -/// - groups: input feature groups +/// - stride: kernel stride. Default to (1, 1) if not specified. +/// - padding: input padding. Default to (0, 0) if not specified. +/// - dilation: kernel dilation. Default to (1, 1) if not specified. +/// - groups: input feature groups. Default to 1 if not specified. #[default_device] pub fn conv2d_device( array: &Array, @@ -152,6 +152,45 @@ pub fn conv2d_device( } } +/// 3D convolution over an input with several channels. +/// +/// Only the default `groups=1` is currently supported. +#[default_device] +pub fn conv3d_device( + array: &Array, + weight: &Array, + stride: impl Into>, + padding: impl Into>, + dilation: impl Into>, + groups: impl Into>, + stream: impl AsRef, +) -> Result { + let stride = stride.into().unwrap_or((1, 1, 1)); + let padding = padding.into().unwrap_or((0, 0, 0)); + let dilation = dilation.into().unwrap_or((1, 1, 1)); + + unsafe { + let c_array = try_catch_c_ptr_expr! { + mlx_sys::mlx_conv3d( + array.as_ptr(), + weight.as_ptr(), + stride.0, + stride.1, + stride.2, + padding.0, + padding.1, + padding.2, + dilation.0, + dilation.1, + dilation.2, + groups.into().unwrap_or(1), + stream.as_ref().as_ptr(), + ) + }; + Ok(Array::from_ptr(c_array)) + } +} + // TODO: Implement convolve once we have `reshape` and `slice` #[cfg(test)] @@ -212,6 +251,41 @@ mod tests { assert_eq!(result.as_slice::(), &[expected_output]); } + #[test] + fn test_conv3d() { + // Define a 2x2x2 input with one channel + let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let input_shape = [1, 2, 2, 2, 1]; // [N, D, H, W, C] + let input_array = Array::from_slice(&input_data, &input_shape); + + // Define a 2x2x2 kernel with one input channel and one output channel + let weight_data = [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; + let weight_shape = [1, 2, 2, 2, 1]; // [C_out, D_k, H_k, W_k, C_in] + let weight_array = Array::from_slice(&weight_data, &weight_shape); + + // Perform the convolution with no padding and stride of 1 + let result = conv3d( + &input_array, + &weight_array, + Some((1, 1, 1)), // stride + Some((0, 0, 0)), // padding + Some((1, 1, 1)), // dilation + Some(1), // groups + ) + .unwrap(); + + // Expected result is the convolution of a 2x2x2 filter over a 2x2x2 input with valid padding, resulting in a single output value + let expected_output = 1.0 * 1.0 + + 2.0 * 0.0 + + 3.0 * 0.0 + + 4.0 * 1.0 + + 5.0 * 0.0 + + 6.0 * 1.0 + + 7.0 * 1.0 + + 8.0 * 0.0; // = 1*1 + 4*1 + 6*1 + 7*1 = 18 + assert_eq!(result.as_slice::(), &[expected_output]); + } + #[test] fn test_conv_wrong_dimensions() { let input_data = [1.0, 2.0, 3.0, 4.0]; diff --git a/mlx-rs/src/ops/cumulative.rs b/mlx-rs/src/ops/cumulative.rs index c56effd8..d069eaea 100644 --- a/mlx-rs/src/ops/cumulative.rs +++ b/mlx-rs/src/ops/cumulative.rs @@ -1,6 +1,6 @@ use crate::error::Exception; use crate::{Array, Stream, StreamOrDevice}; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; impl Array { /// Return the cumulative maximum of the elements along the given axis returning an error if the inputs are invalid. diff --git a/mlx-rs/src/ops/factory.rs b/mlx-rs/src/ops/factory.rs index 9b7aaea5..e6c74148 100644 --- a/mlx-rs/src/ops/factory.rs +++ b/mlx-rs/src/ops/factory.rs @@ -2,7 +2,7 @@ use crate::array::ArrayElement; use crate::error::Exception; use crate::Stream; use crate::{array::Array, stream::StreamOrDevice}; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use num_traits::NumCast; impl Array { diff --git a/mlx-rs/src/ops/indexing/index_impl.rs b/mlx-rs/src/ops/indexing/index_impl.rs index 4e35b2a3..c4a7e3d2 100644 --- a/mlx-rs/src/ops/indexing/index_impl.rs +++ b/mlx-rs/src/ops/indexing/index_impl.rs @@ -764,7 +764,7 @@ fn get_item( use ArrayIndexOp::*; match index.index_op() { - Ellipsis => Ok(src.clone()), + Ellipsis => Ok(src.deep_clone()), TakeIndex { index } => get_item_index(src, index, 0, stream), TakeArray { indices } => get_item_array(src, &indices, 0, stream), Slice(range) => get_item_slice(src, range, stream), diff --git a/mlx-rs/src/ops/indexing/mod.rs b/mlx-rs/src/ops/indexing/mod.rs index d1a4b5bf..2c79334d 100644 --- a/mlx-rs/src/ops/indexing/mod.rs +++ b/mlx-rs/src/ops/indexing/mod.rs @@ -101,7 +101,7 @@ use std::{ ops::{Bound, RangeBounds}, }; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{error::Exception, Array, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/ops/logical.rs b/mlx-rs/src/ops/logical.rs index fb920400..15ce9924 100644 --- a/mlx-rs/src/ops/logical.rs +++ b/mlx-rs/src/ops/logical.rs @@ -3,7 +3,7 @@ use crate::error::Exception; use crate::stream::StreamOrDevice; use crate::utils::{axes_or_default_to_all, IntoOption}; use crate::Stream; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; impl Array { /// Element-wise equality returning an error if the arrays are not broadcastable. diff --git a/mlx-rs/src/ops/other.rs b/mlx-rs/src/ops/other.rs index 7745ac78..122ac111 100644 --- a/mlx-rs/src/ops/other.rs +++ b/mlx-rs/src/ops/other.rs @@ -1,4 +1,4 @@ -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{error::Exception, Array, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/ops/quantization.rs b/mlx-rs/src/ops/quantization.rs index d5d7a74f..36cb102c 100644 --- a/mlx-rs/src/ops/quantization.rs +++ b/mlx-rs/src/ops/quantization.rs @@ -2,7 +2,7 @@ // quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:stream:) // dequantized(_:scales:biases:groupSize:bits:stream:) -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use smallvec::SmallVec; use crate::{error::Exception, utils::VectorArray, Array, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/ops/reduction.rs b/mlx-rs/src/ops/reduction.rs index 3bd84b02..c0d90763 100644 --- a/mlx-rs/src/ops/reduction.rs +++ b/mlx-rs/src/ops/reduction.rs @@ -3,7 +3,7 @@ use crate::error::Exception; use crate::stream::StreamOrDevice; use crate::utils::{axes_or_default_to_all, IntoOption}; use crate::Stream; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; impl Array { /// An `and` reduction over the given axes returning an error if the axes are invalid. diff --git a/mlx-rs/src/ops/shapes.rs b/mlx-rs/src/ops/shapes.rs index 7f0e6681..d405af1b 100644 --- a/mlx-rs/src/ops/shapes.rs +++ b/mlx-rs/src/ops/shapes.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use smallvec::SmallVec; use crate::{ diff --git a/mlx-rs/src/ops/sort.rs b/mlx-rs/src/ops/sort.rs index d904f6d2..8b0cbb4c 100644 --- a/mlx-rs/src/ops/sort.rs +++ b/mlx-rs/src/ops/sort.rs @@ -1,6 +1,6 @@ //! Implements bindings for the sorting ops. -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use crate::{error::Exception, Array, Stream, StreamOrDevice}; diff --git a/mlx-rs/src/random.rs b/mlx-rs/src/random.rs index a85c02b1..ae91ce9e 100644 --- a/mlx-rs/src/random.rs +++ b/mlx-rs/src/random.rs @@ -2,7 +2,7 @@ use crate::prelude::IndexOp; use crate::utils::IntoOption; use crate::{error::Exception, Array, ArrayElement, Stream, StreamOrDevice}; use mach_sys::mach_time; -use mlx_macros::default_device; +use mlx_internal_macros::default_device; use std::borrow::Cow; use std::sync::{Mutex, OnceLock}; diff --git a/mlx-rs/src/transforms/mod.rs b/mlx-rs/src/transforms/mod.rs index a2ed62bd..aac1ecef 100644 --- a/mlx-rs/src/transforms/mod.rs +++ b/mlx-rs/src/transforms/mod.rs @@ -1,3 +1,5 @@ +use std::{collections::HashMap, rc::Rc}; + use mlx_sys::{mlx_closure_value_and_grad, mlx_closure_value_and_grad_apply}; use smallvec::SmallVec; @@ -5,14 +7,15 @@ use crate::{ error::{ get_and_clear_last_mlx_error, is_mlx_error_handler_set, setup_mlx_error_handler, Exception, }, - utils::{Closure, VectorArray, VectorVectorArray}, + module::ModuleParamRef, + utils::{Closure, IntoOption, VectorArray, VectorVectorArray}, Array, }; pub mod compile; /// Evaluate an iterator of [`Array`]s. -pub fn eval<'a>(outputs: impl IntoIterator) -> Result<(), Exception> { +pub fn eval<'a>(outputs: impl IntoIterator) -> Result<(), Exception> { if !is_mlx_error_handler_set() { setup_mlx_error_handler(); } @@ -26,10 +29,17 @@ pub fn eval<'a>(outputs: impl IntoIterator) -> Result<(), get_and_clear_last_mlx_error().map_or(Ok(()), Err) } +/// Evaluate a module's parameters. +/// +/// This is a convenience function that flattens the parameters and evaluates them. +pub fn eval_params(params: ModuleParamRef<'_>) -> Result<(), Exception> { + eval(params.flatten().values().copied()) +} + /// Asynchronously evaluate an iterator of [`Array`]s. /// /// Please note that this is not a rust async function. -pub fn async_eval<'a>(outputs: impl IntoIterator) -> Result<(), Exception> { +pub fn async_eval<'a>(outputs: impl IntoIterator) -> Result<(), Exception> { if !is_mlx_error_handler_set() { setup_mlx_error_handler(); } @@ -43,6 +53,13 @@ pub fn async_eval<'a>(outputs: impl IntoIterator) -> Resul get_and_clear_last_mlx_error().map_or(Ok(()), Err) } +/// Asynchronously evaluate a module's parameters. +/// +/// This is a convenience function that flattens the parameters and evaluates them. +pub fn async_eval_params(params: ModuleParamRef<'_>) -> Result<(), Exception> { + async_eval(params.flatten().values().copied()) +} + #[inline] fn jvp_inner( closure: Closure<'_>, @@ -101,7 +118,7 @@ where } /// Similar to [`jvp`] but handles closures that can return an error. -pub fn jvp_fallible<'a, F>( +pub fn fallible_jvp<'a, F>( f: F, primals: &[Array], tangents: &[Array], @@ -170,7 +187,7 @@ where } /// Similar to [`vjp`] but handles closures that can return an error. -pub fn vjp_fallible<'a, F>( +pub fn fallible_vjp<'a, F>( f: F, primals: &[Array], cotangents: &[Array], @@ -231,11 +248,12 @@ fn build_gradient<'a, F>( where F: FnMut(&[Array]) -> Vec + 'a, { + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); let closure = Closure::new(f); build_gradient_inner(closure, argument_numbers) } -fn build_gradient_fallible<'a, F>( +fn build_fallible_gradient<'a, F>( f: F, argument_numbers: &'a [i32], ) -> impl FnMut(&[Array]) -> Result, Exception> + 'a @@ -276,7 +294,7 @@ where build_value_and_gradient_inner(closure, argument_numbers) } -fn build_value_and_gradient_fallible<'a, F>( +fn build_fallible_value_and_gradient<'a, F>( f: F, argument_numbers: &'a [i32], ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a @@ -287,72 +305,193 @@ where build_value_and_gradient_inner(closure, argument_numbers) } +pub trait IntoValueAndGrad<'a, Err> { + fn into_value_and_grad( + self, + argument_numbers: impl IntoOption<&'a [i32]>, + ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a; +} + +impl<'a, F> IntoValueAndGrad<'a, ()> for F +where + F: FnMut(&[Array]) -> Vec + 'a, +{ + // refining_impl_trait is fine here because we have restricted the Args and Output types + // in the generics. + #[allow(refining_impl_trait)] + fn into_value_and_grad( + self, + argument_numbers: impl IntoOption<&'a [i32]>, + ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a { + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + build_value_and_gradient(self, argument_numbers) + } +} + +impl<'a, F> IntoValueAndGrad<'a, Exception> for F +where + F: FnMut(&[Array]) -> Result, Exception> + 'a, +{ + #[allow(refining_impl_trait)] + fn into_value_and_grad( + self, + argument_numbers: impl IntoOption<&'a [i32]>, + ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a { + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + build_fallible_value_and_gradient(self, argument_numbers) + } +} + /// Returns a function which computes the value and gradient of `f`. -pub fn value_and_grad<'a, F>( +pub fn value_and_grad<'a, F, Err>( f: F, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a where - F: FnMut(&[Array]) -> Vec + 'a, + F: IntoValueAndGrad<'a, Err> + 'a, +{ + f.into_value_and_grad(argument_numbers) +} + +pub type HashMapGrad = HashMap, Array>; + +macro_rules! value_and_grad_with_hashmap { + ($inner_ret:ty, $cls_new:ident, $f:ident, $args_ty:ty) => { + move |parameters: HashMap, Arr>, + arrays: $args_ty| + -> Result<(Vec, HashMapGrad), Exception> { + let (flattened_keys, flattened_values): (Vec<_>, Vec<_>) = + parameters.into_iter().unzip(); + + let inner = |flattened_arrays: &[Array]| -> $inner_ret { + let parameters = flattened_keys + .iter() + .cloned() + .zip(flattened_arrays) + .collect(); + ($f)(parameters, arrays.clone()) + }; + + let argument_numbers = (0..flattened_values.len() as i32).collect::>(); + + let closure = Closure::$cls_new(inner); + let c_value_and_grad = unsafe { + try_catch_c_ptr_expr! { + mlx_sys::mlx_value_and_grad( + closure.as_ptr(), + argument_numbers.as_ptr(), + argument_numbers.len(), + ) + } + }; + + let (value, grads) = + value_and_gradient(c_value_and_grad, flattened_values.into_iter())?; + + let grads_map = flattened_keys.iter().cloned().zip(grads).collect(); + + Ok((value, grads_map)) + } + }; +} + +pub trait IntoValueAndGradWithHashMap<'a, Arr, Args, Err> +where + Arr: AsRef, + Args: Clone, +{ + fn into_value_and_grad_with_hashmap( + self, + ) -> impl FnMut(HashMap, Arr>, Args) -> Result<(Vec, HashMapGrad), Exception> + 'a; +} + +impl<'a, F, Arr, Args> IntoValueAndGradWithHashMap<'a, Arr, Args, ()> for F +where + F: FnMut(HashMap, &Array>, Args) -> Vec + 'a, + Arr: AsRef, + Args: Clone, +{ + fn into_value_and_grad_with_hashmap( + mut self, + ) -> impl FnMut(HashMap, Arr>, Args) -> Result<(Vec, HashMapGrad), Exception> + 'a + { + value_and_grad_with_hashmap!(Vec, new, self, Args) + } +} + +impl<'a, F, Arr, Args> IntoValueAndGradWithHashMap<'a, Arr, Args, Exception> for F +where + F: FnMut(HashMap, &Array>, Args) -> Result, Exception> + 'a, + Arr: AsRef, + Args: Clone, { - build_value_and_gradient(f, argument_numbers) + fn into_value_and_grad_with_hashmap( + mut self, + ) -> impl FnMut(HashMap, Arr>, Args) -> Result<(Vec, HashMapGrad), Exception> + 'a + { + value_and_grad_with_hashmap!(Result, Exception>, new_fallible, self, Args) + } } -pub fn value_and_grad_fallible<'a, F>( +pub fn value_and_grad_with_hashmap<'a, F, Arr, Args, Err>( f: F, - argument_numbers: &'a [i32], -) -> impl FnMut(&[Array]) -> Result<(Vec, Vec), Exception> + 'a +) -> impl FnMut(HashMap, Arr>, Args) -> Result<(Vec, HashMapGrad), Exception> + 'a where - F: FnMut(&[Array]) -> Result, Exception> + 'a, + F: IntoValueAndGradWithHashMap<'a, Arr, Args, Err> + 'a, + Arr: AsRef, + Args: Clone, { - build_value_and_gradient_fallible(f, argument_numbers) + f.into_value_and_grad_with_hashmap() } -pub trait Grad<'a, Args, Output, Err> { - fn grad( +pub trait IntoGrad<'a, Args, Output, Err> { + fn into_grad( self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(Args) -> Result + 'a; } -impl<'a, F> Grad<'a, &[Array], Vec, ()> for F +impl<'a, F> IntoGrad<'a, &[Array], Vec, ()> for F where F: FnMut(&[Array]) -> Vec + 'a, { // refining_impl_trait is fine here because we have restricted the Args and Output types // in the generics. #[allow(refining_impl_trait)] - fn grad( + fn into_grad( self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&[Array]) -> Result, Exception> + 'a { + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); build_gradient(self, argument_numbers) } } -impl<'a, F> Grad<'a, &[Array], Vec, Exception> for F +impl<'a, F> IntoGrad<'a, &[Array], Vec, Exception> for F where F: FnMut(&[Array]) -> Result, Exception> + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&[Array]) -> Result, Exception> + 'a { - build_gradient_fallible(self, argument_numbers) + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + build_fallible_gradient(self, argument_numbers) } } -impl<'a, F> Grad<'a, &Array, Array, ()> for F +impl<'a, F> IntoGrad<'a, &Array, Array, ()> for F where F: FnMut(&Array) -> Array + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&Array) -> Result + 'a { let f = move |args: &[Array]| -> Vec { vec![self(&args[0])] }; + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); let mut g = build_gradient(f, argument_numbers); move |args: &Array| -> Result { let args_clone = &[args.clone()]; @@ -362,19 +501,20 @@ where } } -impl<'a, F> Grad<'a, &Array, Array, Exception> for F +impl<'a, F> IntoGrad<'a, &Array, Array, Exception> for F where F: FnMut(&Array) -> Result + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&Array) -> Result + 'a { let f = move |args: &[Array]| -> Result, Exception> { self(&args[0]).map(|res| vec![res]) }; - let mut g = build_gradient_fallible(f, argument_numbers); + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + let mut g = build_fallible_gradient(f, argument_numbers); move |args: &Array| -> Result { let args_clone = &[args.clone()]; let result = g(args_clone)?; @@ -383,16 +523,17 @@ where } } -impl<'a, F> Grad<'a, &[Array], Array, ()> for F +impl<'a, F> IntoGrad<'a, &[Array], Array, ()> for F where F: FnMut(&[Array]) -> Array + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&[Array]) -> Result + 'a { let f = move |args: &[Array]| -> Vec { vec![self(args)] }; + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); let mut g = build_gradient(f, argument_numbers); move |args: &[Array]| -> Result { let result = g(args)?; @@ -401,19 +542,20 @@ where } } -impl<'a, F> Grad<'a, &[Array], Array, Exception> for F +impl<'a, F> IntoGrad<'a, &[Array], Array, Exception> for F where F: FnMut(&[Array]) -> Result + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&[Array]) -> Result + 'a { let f = move |args: &[Array]| -> Result, Exception> { self(args).map(|res| vec![res]) }; - let mut g = build_gradient_fallible(f, argument_numbers); + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + let mut g = build_fallible_gradient(f, argument_numbers); move |args: &[Array]| -> Result { let result = g(args)?; Ok(result.into_iter().next().unwrap()) @@ -421,16 +563,17 @@ where } } -impl<'a, F> Grad<'a, &Array, Vec, ()> for F +impl<'a, F> IntoGrad<'a, &Array, Vec, ()> for F where F: FnMut(&Array) -> Vec + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&Array) -> Result, Exception> + 'a { let f = move |args: &[Array]| -> Vec { self(&args[0]) }; + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); let mut g = build_gradient(f, argument_numbers); move |args: &Array| -> Result, Exception> { let args_clone = &[args.clone()]; @@ -440,17 +583,18 @@ where } } -impl<'a, F> Grad<'a, &Array, Vec, Exception> for F +impl<'a, F> IntoGrad<'a, &Array, Vec, Exception> for F where F: FnMut(&Array) -> Result, Exception> + 'a, { #[allow(refining_impl_trait)] - fn grad( + fn into_grad( mut self, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(&Array) -> Result, Exception> + 'a { let f = move |args: &[Array]| -> Result, Exception> { self(&args[0]) }; - let mut g = build_gradient_fallible(f, argument_numbers); + let argument_numbers = argument_numbers.into_option().unwrap_or(&[0]); + let mut g = build_fallible_gradient(f, argument_numbers); move |args: &Array| -> Result, Exception> { let args_clone = &[args.clone()]; let result = g(args_clone)?; @@ -462,20 +606,28 @@ where /// Returns a function which computes the gradient of `f`. pub fn grad<'a, F, Args, Output, Err>( f: F, - argument_numbers: &'a [i32], + argument_numbers: impl IntoOption<&'a [i32]>, ) -> impl FnMut(Args) -> Result + 'a where - F: Grad<'a, Args, Output, Err>, + F: IntoGrad<'a, Args, Output, Err>, { - f.grad(argument_numbers) + f.into_grad(argument_numbers) } #[cfg(test)] mod tests { - use crate::{array, error::Exception, Array}; + use std::{collections::HashMap, rc::Rc}; + + use crate::{ + array, + transforms::{grad, jvp, value_and_grad, vjp}, + Array, + }; use super::*; + use super::value_and_grad_with_hashmap; + // The unit tests below are adapted from the mlx c++ codebase #[test] @@ -497,7 +649,7 @@ mod tests { // Success case let x = array!(1.0f32); let y = array!(1.0f32); - let (out, dout) = jvp_fallible(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap(); + let (out, dout) = fallible_jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap(); assert_eq!(out[0].item::(), 2.0f32); assert_eq!(dout[0].item::(), 4.0f32); @@ -505,7 +657,7 @@ mod tests { // Use non-broadcastable shapes let a = array!([1.0, 2.0, 3.0]); let b = array!([4.0, 5.0]); - let result = jvp_fallible(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]); + let result = fallible_jvp(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]); assert!(result.is_err()); } @@ -532,7 +684,7 @@ mod tests { let y = array!(1.0f32); let primals = vec![x, y]; let cotangents = vec![array!(1.0f32)]; - let (out, dout) = vjp_fallible(f, &primals, &cotangents).unwrap(); + let (out, dout) = fallible_vjp(f, &primals, &cotangents).unwrap(); assert_eq!(out[0].item::(), 2.0f32); assert_eq!(dout[0].item::(), 1.0f32); @@ -540,7 +692,7 @@ mod tests { // Use non-broadcastable shapes let a = array!([1.0, 2.0, 3.0]); let b = array!([4.0, 5.0]); - let result = vjp_fallible(f, &[a, b], &[array!(1.0f32)]); + let result = fallible_vjp(f, &[a, b], &[array!(1.0f32)]); assert!(result.is_err()); } @@ -562,6 +714,28 @@ mod tests { assert_eq!(d2fdx2[0].item::(), 0.0); } + #[test] + fn test_value_and_grad_hash_map() { + let f = |parameters: HashMap, &Array>, _: i32| -> Vec { + vec![parameters["x"] * parameters["y"]] + }; + + let x = array!(1.5f32); + let y = array!(2.0f32); + let parameters = vec![("x", &x), ("y", &y)] + .into_iter() + .map(|(k, v)| (k.into(), v)) + .collect(); + + let mut vg = value_and_grad_with_hashmap(f); + + let (value, grad) = vg(parameters, 0).unwrap(); + + assert_eq!(value[0].item::(), 1.5 * 2.0); + assert_eq!(grad["x"].item::(), 2.0); + assert_eq!(grad["y"].item::(), 1.5); + } + #[test] fn test_value_and_grad_with_error() { let fun = |argin: &[Array]| -> Result, Exception> { @@ -572,14 +746,14 @@ mod tests { let argnums = &[0]; let x = array!(1.0f32); let y = array!(1.0f32); - let result = value_and_grad_fallible(fun, argnums)(&[x, y]); + let result = value_and_grad(fun, argnums)(&[x, y]); assert!(result.is_ok()); // Error case // Use non-broadcastable shapes let a = array!([1.0, 2.0, 3.0]); let b = array!([4.0, 5.0]); - let result = value_and_grad_fallible(fun, argnums)(&[a, b]); + let result = value_and_grad(fun, argnums)(&[a, b]); assert!(result.is_err()); } } diff --git a/mlx-rs/src/utils.rs b/mlx-rs/src/utils.rs index 27279b98..c4572510 100644 --- a/mlx-rs/src/utils.rs +++ b/mlx-rs/src/utils.rs @@ -144,6 +144,12 @@ impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] { } } +impl<'a, T> IntoOption<&'a [T]> for &'a Vec { + fn into_option(self) -> Option<&'a [T]> { + Some(self) + } +} + pub trait ScalarOrArray<'a> { type Array: AsRef + 'a;