diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index d8d0a19..d77c496 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -224,12 +224,12 @@ def _check_shape( arg_memo: dict[str, Any], ) -> str: if cls.index_variadic is None: - if obj.ndim != len(cls.dims): - return f"this array has {obj.ndim} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501 + if len(obj.shape) != len(cls.dims): + return f"this array has {len(obj.shape)} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501 return _check_dims(cls.dims, obj.shape, single_memo, arg_memo) else: - if obj.ndim < len(cls.dims) - 1: - return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims) - 1} that is the minimum expected by the type hint" # noqa: E501 + if len(obj.shape) < len(cls.dims) - 1: + return f"this array has {len(obj.shape)} dimensions, which is fewer than {len(cls.dims) - 1} that is the minimum expected by the type hint" # noqa: E501 i = cls.index_variadic j = -(len(cls.dims) - i - 1) if j == 0: