Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 17, 2024
1 parent 72becb4 commit 8823419
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 16 deletions.
6 changes: 2 additions & 4 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ def _apply_typechecker(typechecker, fn):
def jaxtyped(
*,
typechecker=_sentinel,
) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]:
...
) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]: ...


@overload
def jaxtyped(
fn: Callable[_Params, _Return], *, typechecker=_sentinel
) -> Callable[_Params, _Return]:
...
) -> Callable[_Params, _Return]: ...


def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
Expand Down
6 changes: 2 additions & 4 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,7 @@ def test_extension(jaxtyp, typecheck, getkey):
Y = Float[X, "b"]

@jaxtyp(typecheck)
def f(a: X, b: Y):
...
def f(a: X, b: Y): ...

a = jr.normal(getkey(), (3, 4))
b = jr.normal(getkey(), (4,))
Expand All @@ -708,8 +707,7 @@ def f(a: X, b: Y):
f(a, a)

@typecheck
def g(a: Shaped[PRNGKeyArray, "2"]):
...
def g(a: Shaped[PRNGKeyArray, "2"]): ...

with pytest.raises(ParamError):
g(jr.PRNGKey(0))
Expand Down
9 changes: 3 additions & 6 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

class M(metaclass=abc.ABCMeta):
@jaxtyped(typechecker=None)
def f(self):
...
def f(self): ...

@jaxtyped(typechecker=None)
@classmethod
Expand All @@ -39,13 +38,11 @@ def h2():

@jaxtyped(typechecker=None)
@abc.abstractmethod
def i1(self):
...
def i1(self): ...

@abc.abstractmethod
@jaxtyped(typechecker=None)
def i2(self):
...
def i2(self): ...


class N:
Expand Down
3 changes: 1 addition & 2 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ class OtherCustomNamedTuple(NamedTuple):
y: Float[jnp.ndarray, "b c"]

@typecheck
def g(x: PyTree[CustomNamedTuple]):
...
def g(x: PyTree[CustomNamedTuple]): ...

g(
CustomNamedTuple(
Expand Down

0 comments on commit 8823419

Please sign in to comment.