Skip to content

Commit

Permalink
Add support for numpy struct dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfanqi authored and patrick-kidger committed Jun 13, 2024
1 parent 59aeeec commit ad95c10
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AbstractArray as AbstractArray,
AbstractDtype as AbstractDtype,
get_array_name_format as get_array_name_format,
make_numpy_struct_dtype as make_numpy_struct_dtype,
set_array_name_format as set_array_name_format,
)
from ._config import config as config
Expand Down
38 changes: 38 additions & 0 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def _check_dims(
return ""


def _dtype_is_numpy_struct_array(dtype):
return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void)


class _MetaAbstractArray(type):
_skip_instancecheck: bool = False

Expand All @@ -177,6 +181,9 @@ def __instancecheck_str__(cls, obj: Any) -> str:
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
# numpy structured array is strictly a subtype of np.void
if _dtype_is_numpy_struct_array(obj.dtype):
dtype = str(obj.dtype)
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
Expand Down Expand Up @@ -755,3 +762,34 @@ class _Cls(AbstractDtype):
Shaped = _make_dtype(_any_dtype, "Shaped")

Key = _make_dtype(_prng_key, "Key")


def make_numpy_struct_dtype(dtype: np.dtype, name: str):
"""Creates a type annotation for [numpy structured array](https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays)
It does exact match on the name, order, and dtype of all its fields.
!!! Example
```python
label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')
```
after that, you can use it just like any AbstractDtype
```python
a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
```
**Arguments:**
- `dtype`: The numpy dtype that the returned annotation matches
- `name`: The python class name for the returned dtype annotation
**Returns:**
A type annotation with classname `name` and matching exactly `dtype`.
It can be used like any usual subclasses of AbstractDtypes.
"""
if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)):
raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}")
return _make_dtype(str(dtype), name)
18 changes: 18 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ def test_dtypes():
assert key == val.__name__


def test_numpy_struct_dtype():
from jaxtyping import make_numpy_struct_dtype

dtype1 = np.dtype([("first", np.uint8), ("second", bool)])
Dtype1 = make_numpy_struct_dtype(dtype1, "Dtype1")
arr = np.array([0, False], dtype=dtype1)

assert isinstance(arr, Dtype1[np.ndarray, "_"])

dtype2 = np.dtype([("third", np.uint8), ("second", bool)])
Dtype2 = make_numpy_struct_dtype(dtype2, "Dtype2")
assert not isinstance(arr, Dtype2[np.ndarray, "_"])

dtype3 = np.dtype([("second", bool), ("first", np.uint8)])
Dtype3 = make_numpy_struct_dtype(dtype3, "Dtype3")
assert not isinstance(arr, Dtype3[np.ndarray, "_"])


def test_return(jaxtyp, typecheck, getkey):
@jaxtyp(typecheck)
def g(x: Float[Array, "b c"]) -> Float[Array, "c b"]:
Expand Down

0 comments on commit ad95c10

Please sign in to comment.