Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into test-generators
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Feb 17, 2024
2 parents a73e3e0 + c876407 commit 01d11fc
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 76 deletions.
4 changes: 4 additions & 0 deletions docs/api/advanced-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
members:
false

## Printing axis bindings

::: jaxtyping.print_bindings

## Introspection

If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
Expand Down
4 changes: 2 additions & 2 deletions docs/api/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ The dtype should be any one of (all imported from `jaxtyping`):
- Of particular precision: `Complex64`, `Complex128`
- Any integer or unsigned intger: `Integer`
- Any unsigned integer: `UInt`
- Of particular precision: `UInt8`, `UInt16`, `UInt32`, `UInt64`
- Of particular precision: `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
- Any signed integer: `Int`
- Of particular precision: `Int8`, `Int16`, `Int32`, `Int64`
- Of particular precision: `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
- Any floating, integer, or unsigned integer: `Real`.

Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use
Expand Down
5 changes: 5 additions & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from ._import_hook import install_import_hook as install_import_hook
from ._ipython_extension import load_ipython_extension as load_ipython_extension
from ._storage import print_bindings as print_bindings


# Now import Array and ArrayLike
Expand Down Expand Up @@ -84,6 +85,7 @@ class ArrayLike:
Float64 as Float64,
Inexact as Inexact,
Int as Int,
Int4 as Int4,
Int8 as Int8,
Int16 as Int16,
Int32 as Int32,
Expand All @@ -94,6 +96,7 @@ class ArrayLike:
Real as Real,
Shaped as Shaped,
UInt as UInt,
Uint4 as Uint4,
UInt8 as UInt8,
UInt16 as UInt16,
UInt32 as UInt32,
Expand All @@ -112,6 +115,7 @@ class ArrayLike:
Float64 as Float64,
Inexact as Inexact,
Int as Int,
Int4 as Int4,
Int8 as Int8,
Int16 as Int16,
Int32 as Int32,
Expand All @@ -121,6 +125,7 @@ class ArrayLike:
Real as Real,
Shaped as Shaped,
UInt as UInt,
UInt4 as UInt4,
UInt8 as UInt8,
UInt16 as UInt16,
UInt32 as UInt32,
Expand Down
80 changes: 48 additions & 32 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _check_dims(
obj_shape: tuple[int, ...],
single_memo: dict[str, int],
arg_memo: dict[str, Any],
) -> bool:
) -> str:
assert len(cls_dims) == len(obj_shape)
for cls_dim, obj_size in zip(cls_dims, obj_shape):
if cls_dim is _anonymous_dim:
Expand All @@ -124,7 +124,7 @@ def _check_dims(
pass
elif type(cls_dim) is _FixedDim:
if cls_dim.size != obj_size:
return False
return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501
elif type(cls_dim) is _SymbolicDim:
try:
# Support f-string syntax.
Expand All @@ -141,7 +141,7 @@ def _check_dims(
"arguments."
) from e
if eval_size != obj_size:
return False
return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501
else:
assert type(cls_dim) is _NamedDim
if cls_dim.treepath:
Expand All @@ -154,20 +154,24 @@ def _check_dims(
single_memo[name] = obj_size
else:
if cls_size != obj_size:
return False
return True
return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501
return ""


class _MetaAbstractArray(type):
_skip_instancecheck: bool = False

def __instancecheck__(cls, obj):
def __instancecheck__(cls, obj: Any) -> bool:
return cls.__instancecheck_str__(obj) == ""

def __instancecheck_str__(cls, obj: Any) -> str:
if cls._skip_instancecheck:
return True
return ""

if not isinstance(obj, cls.array_type):
return False
return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501
if get_treeflatten_memo():
return True
return ""

if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
Expand Down Expand Up @@ -197,7 +201,10 @@ def __instancecheck__(cls, obj):
if in_dtypes:
break
if not in_dtypes:
return False
if len(cls.dtypes) == 1:
return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501
else:
return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501

single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo()
single_memo_bak = single_memo.copy()
Expand All @@ -211,41 +218,46 @@ def __instancecheck__(cls, obj):
single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
)
raise
if check:
return True
if check == "":
return check
else:
set_shape_memo(
single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
)
return False
return check

def _check_shape(
cls,
obj,
single_memo: dict[str, int],
variadic_memo: dict[str, tuple[bool, tuple[int, ...]]],
arg_memo: dict[str, Any],
):
) -> str:
if cls.index_variadic is None:
if obj.ndim != len(cls.dims):
return False
return f"this array has {obj.ndim} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501
return _check_dims(cls.dims, obj.shape, single_memo, arg_memo)
else:
if obj.ndim < len(cls.dims) - 1:
return False
return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims - 1)} that is the minimum expected by the type hint" # noqa: E501
i = cls.index_variadic
j = -(len(cls.dims) - i - 1)
if j == 0:
j = None
if not _check_dims(cls.dims[:i], obj.shape[:i], single_memo, arg_memo):
return False
if j is not None and not _check_dims(
cls.dims[j:], obj.shape[j:], single_memo, arg_memo
):
return False
prefix_check = _check_dims(
cls.dims[:i], obj.shape[:i], single_memo, arg_memo
)
if prefix_check != "":
return prefix_check
if j is not None:
suffix_check = _check_dims(
cls.dims[j:], obj.shape[j:], single_memo, arg_memo
)
if suffix_check != "":
return suffix_check
variadic_dim = cls.dims[i]
if variadic_dim is _anonymous_variadic_dim:
return True
return ""
else:
assert type(variadic_dim) is _NamedVariadicDim
if variadic_dim.treepath:
Expand All @@ -257,16 +269,16 @@ def _check_shape(
prev_broadcastable, prev_shape = variadic_memo[name]
except KeyError:
variadic_memo[name] = (broadcastable, obj.shape[i:j])
return True
return ""
else:
new_shape = obj.shape[i:j]
if prev_broadcastable:
try:
broadcast_shape = np.broadcast_shapes(new_shape, prev_shape)
except ValueError: # not broadcastable e.g. (3, 4) and (5,)
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
if not broadcastable and broadcast_shape != new_shape:
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501
variadic_memo[name] = (broadcastable, broadcast_shape)
else:
if broadcastable:
Expand All @@ -275,13 +287,13 @@ def _check_shape(
new_shape, prev_shape
)
except ValueError: # not broadcastable e.g. (3, 4) and (5,)
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
if broadcast_shape != prev_shape:
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501
else:
if new_shape != prev_shape:
return False
return True
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501
return ""
assert False


Expand Down Expand Up @@ -645,10 +657,12 @@ def __init_subclass__(cls, **kwargs):
_prng_key = "prng_key"
_bool = "bool"
_bool_ = "bool_"
_uint4 = "uint4"
_uint8 = "uint8"
_uint16 = "uint16"
_uint32 = "uint32"
_uint64 = "uint64"
_int4 = "int4"
_int8 = "int8"
_int16 = "int16"
_int32 = "int32"
Expand All @@ -674,10 +688,12 @@ class _Cls(AbstractDtype):
return _Cls


UInt4 = _make_dtype(_uint4, "UInt4")
UInt8 = _make_dtype(_uint8, "UInt8")
UInt16 = _make_dtype(_uint16, "UInt16")
UInt32 = _make_dtype(_uint32, "UInt32")
UInt64 = _make_dtype(_uint64, "UInt64")
Int4 = _make_dtype(_int4, "Int4")
Int8 = _make_dtype(_int8, "Int8")
Int16 = _make_dtype(_int16, "Int16")
Int32 = _make_dtype(_int32, "Int32")
Expand All @@ -690,8 +706,8 @@ class _Cls(AbstractDtype):
Complex128 = _make_dtype(_complex128, "Complex128")

bools = [_bool, _bool_]
uints = [_uint8, _uint16, _uint32, _uint64]
ints = [_int8, _int16, _int32, _int64]
uints = [_uint4, _uint8, _uint16, _uint32, _uint64]
ints = [_int4, _int8, _int16, _int32, _int64]
floats = [_bfloat16, _float16, _float32, _float64]
complexes = [_complex64, _complex128]

Expand Down
42 changes: 4 additions & 38 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ._array_types import _MetaAbstractArray
from ._config import config
from ._errors import AnnotationError, TypeCheckError
from ._storage import pop_shape_memo, push_shape_memo
from ._storage import pop_shape_memo, push_shape_memo, shape_str


class _Sentinel:
Expand Down Expand Up @@ -343,7 +343,7 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
except Exception as e:
# add_note api is support from python 3.11+
if sys.version_info >= (3, 11) and _no_jaxtyping_note(e):
shape_info = _exc_shape_info(memos)
shape_info = shape_str(memos)
if shape_info != "":
msg = (
"The preceding error occurred within the scope of a "
Expand Down Expand Up @@ -435,7 +435,7 @@ def wrapped_fn(*args, **kwargs):
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
Expand Down Expand Up @@ -488,7 +488,7 @@ def wrapped_fn(*args, **kwargs):
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
Expand Down Expand Up @@ -780,40 +780,6 @@ def _pformat(x, short_self: bool):
return pformat(x)


def _exc_shape_info(memos) -> str:
"""Gives debug information on the current state of jaxtyping's internal memos.
Used in type-checking error messages.
"""
single_memo, variadic_memo, pytree_memo, _ = memos
single_memo = {
name: size
for name, size in single_memo.items()
if not name.startswith("~~delete~~")
}
variadic_memo = {
name: shape
for name, (_, shape) in variadic_memo.items()
if not name.startswith("~~delete~~")
}
pieces = []
if len(single_memo) > 0 or len(variadic_memo) > 0:
pieces.append(
"The current values for each jaxtyping axis annotation are as follows."
)
for name, size in single_memo.items():
pieces.append(f"{name}={size}")
for name, shape in variadic_memo.items():
pieces.append(f"{name}={shape}")
if len(pytree_memo) > 0:
pieces.append(
"The current values for each jaxtyping PyTree structure annotation are as "
"follows."
)
for name, structure in pytree_memo.items():
pieces.append(f"{name}={structure}")
return "\n".join(pieces)


class _jaxtyping_note_str(str):
"""Used with `_no_jaxtyping_note` to flag that a note came from jaxtyping."""

Expand Down
2 changes: 2 additions & 0 deletions jaxtyping/_indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Annotated as Float64, # noqa: F401
Annotated as Inexact, # noqa: F401
Annotated as Int, # noqa: F401
Annotated as Int4, # noqa: F401
Annotated as Int8, # noqa: F401
Annotated as Int16, # noqa: F401
Annotated as Int32, # noqa: F401
Expand All @@ -42,6 +43,7 @@
Annotated as Real, # noqa: F401
Annotated as Shaped, # noqa: F401
Annotated as UInt, # noqa: F401
Annotated as UInt4, # noqa: F401
Annotated as UInt8, # noqa: F401
Annotated as UInt16, # noqa: F401
Annotated as UInt32, # noqa: F401
Expand Down
Loading

0 comments on commit 01d11fc

Please sign in to comment.