From a647e3028c0e2efdf23b875b788f0f185774c408 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:10:42 +0200 Subject: [PATCH] Now reporting the correct source code line numbers when using the import hook --- jaxtyping/_import_hook.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/jaxtyping/_import_hook.py b/jaxtyping/_import_hook.py index 43d99be..bc37e0e 100644 --- a/jaxtyping/_import_hook.py +++ b/jaxtyping/_import_hook.py @@ -80,8 +80,10 @@ def _optimized_cache_from_source(typechecker_hash, /, path, debug_override=None) # for importlib and decorator lookup. # Version 8: Now using new-style `jaxtyped(typechecker=...)` rather than old-style # double-decorators. + # Version 9: Now reporting the correct source code lines. (Important when used with + # a debugger.) return cache_from_source( - path, debug_override, optimization=f"jaxtyping8{typechecker_hash}" + path, debug_override, optimization=f"jaxtyping9{typechecker_hash}" ) @@ -89,8 +91,6 @@ class Typechecker: lookup = {} def __init__(self, typechecker): - self.ast = None - if isinstance(typechecker, str): # If the typechecker is a string, then we parse it string_to_eval = ( @@ -121,18 +121,17 @@ def get_hash(self): return self.hash def get_ast(self): - # we compile AST only if we missed importlib cache - if self.ast is None: - self.ast = ( - ast.parse( - f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n" - "def _():\n ..." - ) - .body[0] - .decorator_list[0] + # Note that we compile AST only if we missed importlib cache. + # No caching on this function! We modify the return type every time, with + # its appropriate source code location. + return ( + ast.parse( + f"@jaxtyping.jaxtyped(typechecker=jaxtyping._import_hook.Typechecker.lookup['{self.hash}'])\n" + "def _():\n ..." ) - - return self.ast + .body[0] + .decorator_list[0] + ) class JaxtypingTransformer(ast.NodeVisitor): @@ -159,7 +158,9 @@ def visit_Module(self, node: ast.Module): def visit_ClassDef(self, node: ast.ClassDef): # Place at the start of the decorator list, so that `@dataclass` decorators get # called first. - node.decorator_list.insert(0, self._typechecker.get_ast()) + decorator = self._typechecker.get_ast() + ast.copy_location(decorator, node) + node.decorator_list.insert(0, decorator) self._parents.append(node) self.generic_visit(node) self._parents.pop() @@ -173,6 +174,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef): # had type annotations in the body of the function (or # `assert isinstance(..., SomeType)`). + decorator = self._typechecker.get_ast() + ast.copy_location(decorator, node) # Place at the end of the decorator list, because: # - as otherwise we wrap e.g. `jax.custom_{jvp,vjp}` and lose the ability # to `defjvp` etc. @@ -187,7 +190,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): # case we're just going to have to need to ask the user to remove their # typechecking annotation (and let this decorator do it instead). # It's more important we be compatible with normal JAX code. - node.decorator_list.append(self._typechecker.get_ast()) + node.decorator_list.append(decorator) self._parents.append(node) self.generic_visit(node)