Skip to content

Commit

Permalink
allow using array!() with shape
Browse files Browse the repository at this point in the history
  • Loading branch information
minghuaw committed Sep 18, 2024
1 parent f978569 commit 6d20b76
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlx-rs/src/macros/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,19 @@
/// [10, 11, 12]
/// ]
/// ]);
///
/// // Create a 2x2 array by specifying the shape
/// let a = array!([1, 2, 3, 4], shape=[2, 2]);
/// ```
#[macro_export]
macro_rules! array {
([$($x:expr),*], shape=[$($s:expr),*]) => {
{
let data = [$($x,)*];
let shape = [$($s,)*];
$crate::Array::from_slice(&data, &shape)
}
};
([$([$([$($x:expr),*]),*]),*]) => {
{
let arr = [$([$([$($x,)*],)*],)*];
Expand Down Expand Up @@ -133,4 +143,16 @@ mod tests {
assert_eq!(a.index((1, 1, 1)).item::<i32>(), 11);
assert_eq!(a.index((1, 1, 2)).item::<i32>(), 12);
}

#[test]
fn test_array_with_shape() {
let a = array!([1, 2, 3, 4], shape = [2, 2]);

assert!(a.ndim() == 2);
assert_eq!(a.shape(), &[2, 2]);
assert_eq!(a.index((0, 0)).item::<i32>(), 1);
assert_eq!(a.index((0, 1)).item::<i32>(), 2);
assert_eq!(a.index((1, 0)).item::<i32>(), 3);
assert_eq!(a.index((1, 1)).item::<i32>(), 4);
}
}

0 comments on commit 6d20b76

Please sign in to comment.