Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Half compatibility with typeguard v4. #273

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/runtime-type-checking.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Runtime type checking **synergises beautifully with `jax.jit`!** All shape check

There are two approaches: either use [`jaxtyping.jaxtyped`][] to typecheck a single function, or [`jaxtyping.install_import_hook`][] to typecheck a whole codebase.

In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard). (If using typeguard, then specifically the version `2.*` series should be used. Later versions -- `3` and `4` -- have some known issues.)
In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are [beartype](https://github.com/beartype/beartype) and [typeguard](https://github.com/agronholm/typeguard).

!!! warning

Expand Down
27 changes: 27 additions & 0 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import enum
import functools as ft
import importlib.metadata
import importlib.util
import re
import sys
Expand Down Expand Up @@ -738,6 +739,28 @@ def __init_subclass__(cls, **kwargs):
_complex128 = "complex128"


# Workaround a longstanding bug in typeguard v4, by monkeypatching their internals.
# https://stackoverflow.com/questions/79201839/hello-world-for-jaxtyping/79205145#79205145
# https://github.com/patrick-kidger/jaxtyping/issues/80
# https://github.com/agronholm/typeguard/issues/353
# This is as robust as I can make it to future changes in typeguard, I think.
typeguard_v4_compat = False
try:
typeguard_distribution = importlib.metadata.distribution("typeguard")
except importlib.metadata.PackageNotFoundError:
pass
else:
if typeguard_distribution.version.split(".", 1)[0] == "4":
if importlib.util.find_spec("typeguard._transformer") is not None:
import typeguard._transformer

if hasattr(typeguard._transformer, "annotated_names"):
annotated_names = typeguard._transformer.annotated_names
if type(annotated_names) is tuple:
if all(type(x) is str for x in annotated_names):
typeguard_v4_compat = True


def _make_dtype(_dtypes, name):
class _Cls(AbstractDtype):
dtypes = _dtypes
Expand All @@ -748,6 +771,10 @@ class _Cls(AbstractDtype):
_Cls.__module__ = "builtins"
else:
_Cls.__module__ = "jaxtyping"
if typeguard_v4_compat:
typeguard._transformer.annotated_names = (
typeguard._transformer.annotated_names + (f"jaxtyping.{name}",)
)
return _Cls


Expand Down
24 changes: 19 additions & 5 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,13 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
module = getattr(fn, "__module__", "<generated_by_jaxtyping>")

# Use the same name so that typeguard warnings look correct.
# Set the line number so that typeguard v4 finds us.
lineno = getattr(getattr(fn, "__code__", None), "co_firstlineno", 1)
full_fn, output_name = _make_fn_with_signature(
name, qualname, module, full_signature, output=True
name, qualname, module, full_signature, output=True, lineno=lineno
)
param_fn = _make_fn_with_signature(
name, qualname, module, param_signature, output=False
name, qualname, module, param_signature, output=False, lineno=lineno
)
full_fn = _apply_typechecker(typechecker, full_fn)
param_fn = _apply_typechecker(typechecker, param_fn)
Expand Down Expand Up @@ -616,13 +618,19 @@ def _check_dataclass_annotations(self, typechecker):
self.__class__.__module__,
signature,
output=False,
lineno=1,
)
f = jaxtyped(f, typechecker=typechecker)
f(self, **values)


def _make_fn_with_signature(
name: str, qualname: str, module: str, signature: inspect.Signature, output: bool
name: str,
qualname: str,
module: str,
signature: inspect.Signature,
output: bool,
lineno: int,
):
"""Dynamically creates a function `fn` with name `name` and signature `signature`.

Expand Down Expand Up @@ -740,7 +748,8 @@ def _make_fn_with_signature(
else:
retstr = f"-> {name_to_annotation['return']}"

fnstr = f"def {name}({argstr}){retstr}:\n {outstr}"
newlines = "\n" * (lineno - 1)
fnstr = f"{newlines}def {name}({argstr}){retstr}:\n {outstr}"
exec(fnstr, scope)
fn = scope[name]
del scope[name] # Avoids introducing a reference cycle.
Expand Down Expand Up @@ -802,7 +811,12 @@ def _get_problem_arg(
assert keep_annotation is not sentinel
new_signature = inspect.Signature(new_parameters)
fn = _make_fn_with_signature(
"check_single_arg", "check_single_arg", module, new_signature, output=False
"check_single_arg",
"check_single_arg",
module,
new_signature,
output=False,
lineno=1,
)
fn = _apply_typechecker(
typechecker, fn
Expand Down
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ numpy<2
pytest
pytest-asyncio
tensorflow
typeguard<3
typeguard
20 changes: 0 additions & 20 deletions test/test_import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,6 @@
_here = pathlib.Path(__file__).parent


try:
typeguard_version = importlib.metadata.version("typeguard")
except Exception as e:
raise ImportError("Could not find typeguard version") from e
else:
try:
major, _, _ = typeguard_version.split(".")
major = int(major)
except Exception as e:
raise ImportError(
f"Unexpected typeguard version {typeguard_version}; not formatted as "
"`major.minor.patch`"
) from e
if major != 2:
raise ImportError(
"jaxtyping's tests required typeguard version 2. (Versions 3 and 4 are both "
"known to have bugs.)"
)


assert not hasattr(jaxtyping, "_test_import_hook_counter")
jaxtyping._test_import_hook_counter = 0

Expand Down
Loading