Skip to content

Commit

Permalink
Hopefully fixing deserialization issues once and for all.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 25, 2024
1 parent a831de6 commit a1a9480
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 82 deletions.
31 changes: 14 additions & 17 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,23 @@ def _check_shape(
@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
def _get_props(cls):
props_tuple = (
cls.index_variadic,
cls.dims,
cls.array_type,
cls.dtypes,
cls.dim_str,
)
return props_tuple

# We have to use identity-based eq/hash behaviour. The reason for this is that
# when deserializing using cloudpickle (very common, it seems), that cloudpickle
# will actually attempt to put a partially constructed class in a dictionary.
# So if we start accessing `cls.index_variadic` and the like here, then that
# explodes.
# See
# https://github.com/patrick-kidger/jaxtyping/issues/198
# https://github.com/patrick-kidger/jaxtyping/issues/261
#
# This does mean that if you want to compare two array annotations for equality
# (e.g. this happens in jaxtyping's tests as part of checking correctness) then
# a custom equality function must be used -- we can't put it here.
def __eq__(cls, other):
if type(cls) is not type(other):
return False

return cls._get_props() == other._get_props()
return cls is other

def __hash__(cls):
# Does not use `_get_props` as these attributes don't necessarily exist
# during depickling. See #198.
return 0
return id(cls)

return MetaAbstractArray

Expand Down
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ cloudpickle
equinox
IPython
jax
numpy<2
pytest
pytest-asyncio
tensorflow
Expand Down
80 changes: 50 additions & 30 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
torch = None

from jaxtyping import (
AbstractArray,
AbstractDtype,
AnnotationError,
Array,
Expand Down Expand Up @@ -528,6 +529,15 @@ class A:
A(3, jnp.zeros(4))


def _to_set(x) -> set[tuple]:
return {
(xi.index_variadic, xi.dims, xi.array_type, xi.dtypes, xi.dim_str)
if issubclass(xi, AbstractArray)
else xi
for xi in x
}


def test_arraylike(typecheck, getkey):
floatlike1 = Float32[ArrayLike, ""]
floatlike2 = Float[ArrayLike, ""]
Expand All @@ -536,41 +546,51 @@ def test_arraylike(typecheck, getkey):
assert get_origin(floatlike1) is Union
assert get_origin(floatlike2) is Union
assert get_origin(floatlike3) is Union
assert set(get_args(floatlike1)) == {
Float32[Array, ""],
Float32[np.ndarray, ""],
Float32[np.number, ""],
float,
}
assert set(get_args(floatlike2)) == {
Float[Array, ""],
Float[np.ndarray, ""],
Float[np.number, ""],
float,
}
assert set(get_args(floatlike3)) == {
Float32[Array, "4"],
Float32[np.ndarray, "4"],
}
assert _to_set(get_args(floatlike1)) == _to_set(
[
Float32[Array, ""],
Float32[np.ndarray, ""],
Float32[np.number, ""],
float,
]
)
assert _to_set(get_args(floatlike2)) == _to_set(
[
Float[Array, ""],
Float[np.ndarray, ""],
Float[np.number, ""],
float,
]
)
assert _to_set(get_args(floatlike3)) == _to_set(
[
Float32[Array, "4"],
Float32[np.ndarray, "4"],
]
)

shaped1 = Shaped[ArrayLike, ""]
shaped2 = Shaped[ArrayLike, "4"]
assert get_origin(shaped1) is Union
assert get_origin(shaped2) is Union
assert set(get_args(shaped1)) == {
Shaped[Array, ""],
Shaped[np.ndarray, ""],
Shaped[np.bool_, ""],
Shaped[np.number, ""],
bool,
int,
float,
complex,
}
assert set(get_args(shaped2)) == {
Shaped[Array, "4"],
Shaped[np.ndarray, "4"],
}
assert _to_set(get_args(shaped1)) == _to_set(
[
Shaped[Array, ""],
Shaped[np.ndarray, ""],
Shaped[np.bool_, ""],
Shaped[np.number, ""],
bool,
int,
float,
complex,
]
)
assert _to_set(get_args(shaped2)) == _to_set(
[
Shaped[Array, "4"],
Shaped[np.ndarray, "4"],
]
)


def test_subclass():
Expand Down
35 changes: 0 additions & 35 deletions test/test_equals.py

This file was deleted.

0 comments on commit a1a9480

Please sign in to comment.