Skip to content

Commit

Permalink
cargo clippy and fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
minghuaw committed Oct 20, 2024
1 parent 1fb76af commit f7e1355
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 147 deletions.
4 changes: 3 additions & 1 deletion examples/mnist/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let loader = load_training_data()?;
let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, output_dim)?;

let cross_entropy = CrossEntropy::builder().reduction(LossReduction::Mean).build()?;
let cross_entropy = CrossEntropy::builder()
.reduction(LossReduction::Mean)
.build()?;
let loss_fn = |model: &mlp::Mlp, (x, y): (&Array, &Array)| -> Result<Array, Exception> {
let y_pred = model.forward(x)?;
cross_entropy.apply(y_pred, y)
Expand Down
50 changes: 24 additions & 26 deletions mlx-macros/src/generate_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ pub(crate) fn expand_generate_builder(
);
optional_field_types.push(&field.ty);
if generate_build_fn {
optional_field_defaults.push(
field_attr
.default_value
);
optional_field_defaults.push(field_attr.default_value);
}
} else {
mandatory_field_idents.push(&field.ident);
Expand All @@ -67,12 +64,14 @@ pub(crate) fn expand_generate_builder(
let modified_optional_field_types = optional_field_types
.iter()
.zip(optional_field_skip.iter())
.map(|(field_type, skip)| if !skip {
quote! { Option<#field_type> }
} else {
quote! { #field_type }
.map(|(field_type, skip)| {
if !skip {
quote! { Option<#field_type> }
} else {
quote! { #field_type }
}
});

let builder_struct_doc = format!("Builder for [`{}`]", struct_ident);
let field_doc = format!("See [`{}`] for more details", struct_ident);
let builder_struct = quote! {
Expand Down Expand Up @@ -111,27 +110,21 @@ pub(crate) fn expand_generate_builder(
let builder_setter_docs = optional_field_idents
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_ident, skip)| if !skip {
Some(format!("Set the value of `{:?}`", field_ident))
} else {
None
.filter_map(|(field_ident, skip)| {
if !skip {
Some(format!("Set the value of `{:?}`", field_ident))
} else {
None
}
});
let filtered_optional_field_idents = optional_field_idents
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_ident, skip)| if !skip {
Some(field_ident)
} else {
None
});
.filter_map(|(field_ident, skip)| if !skip { Some(field_ident) } else { None });
let filtered_optional_field_types = optional_field_types
.iter()
.zip(optional_field_skip.iter())
.filter_map(|(field_type, skip)| if !skip {
Some(field_type)
} else {
None
});
.filter_map(|(field_type, skip)| if !skip { Some(field_type) } else { None });

let builder_setters = quote! {
impl #impl_generics #builder_ident #ty_generics #where_clause {
Expand All @@ -148,9 +141,14 @@ pub(crate) fn expand_generate_builder(
let builder_build = if generate_build_fn {
let builder_build_doc = format!("Build a new [`{}`]", struct_ident);
let struct_new_doc = format!("Create a new [`{}`] with default values", struct_ident);
let optional_field_defaults: Vec<_> = optional_field_defaults.iter().map(|default| {
default.clone().ok_or_else(|| "Default value must be supplied to generate build function")
}).collect::<Result<Vec<_>, _>>()?;
let optional_field_defaults: Vec<_> = optional_field_defaults
.iter()
.map(|default| {
default
.clone()
.ok_or("Default value must be supplied to generate build function")
})
.collect::<Result<Vec<_>, _>>()?;

quote! {
impl #impl_generics #builder_ident #ty_generics #where_clause {
Expand Down
8 changes: 4 additions & 4 deletions mlx-nn/src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,7 @@ impl PreluBuilder {
let count = self.count.unwrap_or(Prelu::DEFAULT_COUNT);
let value = self.value.unwrap_or(Prelu::DEFAULT_VALUE);
let weight = Param::new(mlx_rs::ops::full::<f32>(&[count], &array!(value))?);
Ok(Prelu {
weight,
})
Ok(Prelu { weight })
}
}

Expand All @@ -639,7 +637,9 @@ impl Prelu {

/// Creates a new Prelu module with the default values.
pub fn new() -> Prelu {
PreluBuilder::default().build().expect("Default value should be valid")
PreluBuilder::default()
.build()
.expect("Default value should be valid")
}
}

Expand Down
16 changes: 13 additions & 3 deletions mlx-nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ impl Conv2dBuilder {
let weight = uniform::<_, f32>(
-scale,
scale,
&[output_channels, kernel_size.0, kernel_size.1, input_channels],
&[
output_channels,
kernel_size.0,
kernel_size.1,
input_channels,
],
None,
)?;
let bias = if with_bias {
Expand All @@ -221,7 +226,6 @@ impl Conv2dBuilder {
}
}


/// Applies a 2-dimensional convolution over the multi-channel input image.
///
/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
Expand Down Expand Up @@ -355,7 +359,13 @@ impl Conv3dBuilder {
let weight = uniform::<_, f32>(
-scale,
scale,
&[output_channels, kernel_size.0, kernel_size.1, kernel_size.2, input_channels],
&[
output_channels,
kernel_size.0,
kernel_size.1,
kernel_size.2,
input_channels,
],
None,
)?;
let bias = if with_bias {
Expand Down
25 changes: 18 additions & 7 deletions mlx-nn/src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ impl LinearBuilder {
mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims, input_dims], None)?;

let bias = if with_bias {
Some(
mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims], None)?
)
Some(mlx_rs::random::uniform::<_, f32>(
-scale,
scale,
&[output_dims],
None,
)?)
} else {
None
};
Expand Down Expand Up @@ -115,7 +118,12 @@ impl BilinearBuilder {
}

/// Builds a new [`Bilinear`] layer.
pub fn build(self, input_dims_1: i32, input_dims_2: i32, output_dims: i32) -> Result<Bilinear, Exception> {
pub fn build(
self,
input_dims_1: i32,
input_dims_2: i32,
output_dims: i32,
) -> Result<Bilinear, Exception> {
let with_bias = self.with_bias.unwrap_or(Bilinear::DEFAULT_WITH_BIAS);

let scale = f32::sqrt(1.0 / (input_dims_1 as f32));
Expand All @@ -127,9 +135,12 @@ impl BilinearBuilder {
)?;

let bias = if with_bias {
Some(
mlx_rs::random::uniform::<_, f32>(-scale, scale, &[output_dims], None)?
)
Some(mlx_rs::random::uniform::<_, f32>(
-scale,
scale,
&[output_dims],
None,
)?)
} else {
None
};
Expand Down
Loading

0 comments on commit f7e1355

Please sign in to comment.