diff --git a/README.md b/README.md index 1da50a3..b3c84d8 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,8 @@ Use the `@poltergeist` decorator on any function: ```python from poltergeist import poltergeist -# Handle an exception type potentially raised -# within the function (Exception by default) -@poltergeist(error=OSError) +# Handle an exception type potentially raised within the function +@poltergeist(OSError) def read_text(path: str) -> str: with open(path) as f: return f.read() @@ -43,17 +42,6 @@ def read_text(path: str) -> Result[str, OSError]: return Err(e) ``` -It's also possible to wrap multiple exception types: - -```python -def read_text(path: str) -> Result[str, OSError | UnicodeDecodeError]: - try: - with open(path) as f: - return Ok(f.read()) - except (OSError, UnicodeDecodeError) as e: - return Err(e) -``` - Then handle the result in a type-safe way: ```python @@ -78,6 +66,26 @@ match result: print("File not found:", e.filename) ``` +It's also possible to wrap multiple exception types with the decorator: + +```python +@poltergeist(OSError, UnicodeDecodeError) +def read_text(path: str) -> str: + with open(path) as f: + return f.read() +``` + +Or manually: + +```python +def read_text(path: str) -> Result[str, OSError | UnicodeDecodeError]: + try: + with open(path) as f: + return Ok(f.read()) + except (OSError, UnicodeDecodeError) as e: + return Err(e) +``` + ## Contributing Set up the project using [Poetry](https://python-poetry.org/): diff --git a/poltergeist/decorator.py b/poltergeist/decorator.py index dab1f7e..909498d 100644 --- a/poltergeist/decorator.py +++ b/poltergeist/decorator.py @@ -1,48 +1,23 @@ import functools -from typing import Any, Callable, ParamSpec, Type, overload +from typing import Callable, ParamSpec, Type from poltergeist.result import E, Err, Ok, Result, T P = ParamSpec("P") -@overload -def poltergeist(func: Callable[P, T], /) -> Callable[P, Result[T, Exception]]: - # Called as @poltergeist - ... - - -@overload -def poltergeist() -> Callable[[Callable[P, T]], Callable[P, Result[T, Exception]]]: - # Called as @poltergeist() - ... - - -@overload def poltergeist( - *, - error: Type[E], + *errors: Type[E], ) -> Callable[[Callable[P, T]], Callable[P, Result[T, E]]]: - # Called as @poltergeist(error=SomeError) - ... - - -def poltergeist(func: Any = None, /, *, error: Any = Exception) -> Any: - """ - Decorator that wraps the result of a function into an Ok object if it - executes without raising an exception. Otherwise, returns an Err object with - the exception raised by the function. - """ - - if func is None: - return functools.partial(poltergeist, error=error) - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - result = func(*args, **kwargs) - except error as e: - return Err(e) - return Ok(result) - - return wrapper + def decorator(func: Callable[P, T]) -> Callable[P, Result[T, E]]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, E]: + try: + result = func(*args, **kwargs) + except errors as e: + return Err(e) + return Ok(result) + + return wrapper + + return decorator diff --git a/tests/mypy/test_decorator.yml b/tests/mypy/test_decorator.yml index be74660..44afc38 100644 --- a/tests/mypy/test_decorator.yml +++ b/tests/mypy/test_decorator.yml @@ -1,39 +1,27 @@ -- case: decorator_no_args +- case: decorator_single_error main: | from poltergeist import poltergeist, Result - @poltergeist + @poltergeist(ValueError) def test(a: int, b: str) -> float | None: ... - reveal_type(test) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> Union[poltergeist.result.Ok[Union[builtins.float, None]], poltergeist.result.Err[builtins.Exception]]" - -- case: decorator_default - main: | - from poltergeist import poltergeist, Result - - @poltergeist() - def test(a: int, b: str) -> float | None: ... - - reveal_type(test) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> Union[poltergeist.result.Ok[Union[builtins.float, None]], poltergeist.result.Err[builtins.Exception]]" + reveal_type(test) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> Union[poltergeist.result.Ok[Union[builtins.float, None]], poltergeist.result.Err[builtins.ValueError]]" -- case: decorator_with_args +- case: decorator_multiple_errors + skip: True # TODO: Enable this test once MyPy properly detects the return type main: | from poltergeist import poltergeist, Result - @poltergeist(error=ValueError) + @poltergeist(ValueError, TypeError) def test(a: int, b: str) -> float | None: ... - reveal_type(test) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> Union[poltergeist.result.Ok[Union[builtins.float, None]], poltergeist.result.Err[builtins.ValueError]]" + reveal_type(test) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> Union[poltergeist.result.Ok[Union[builtins.float, None]], poltergeist.result.Err[Union[builtins.ValueError, builtins.TypeError]]]" - case: decorator_invalid_error_type main: | from poltergeist import poltergeist, Result - @poltergeist(error=123) + @poltergeist(123) def test(a: int, b: str) -> float | None: ... out: | - main:3: error: No overload variant of "poltergeist" matches argument type "int" [call-overload] - main:3: note: Possible overload variants: - main:3: note: def [P`-1, T] poltergeist(Callable[P, T], /) -> Callable[P, Union[Ok[T], Err[Exception]]] - main:3: note: def poltergeist() -> Callable[[Callable[P, T]], Callable[P, Union[Ok[T], Err[Exception]]]] - main:3: note: def [E <: BaseException] poltergeist(*, error: Type[E]) -> Callable[[Callable[P, T]], Callable[P, Union[Ok[T], Err[E]]]] + main:3: error: Argument 1 to "poltergeist" has incompatible type "int"; expected "Type[]" [arg-type] diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 827e7b0..ae91e37 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -6,7 +6,7 @@ def test_decorator() -> None: - decorated = poltergeist(error=ZeroDivisionError)(operator.truediv) + decorated = poltergeist(ZeroDivisionError)(operator.truediv) assert decorated(4, 2) == Ok(2) @@ -20,7 +20,7 @@ def test_decorator() -> None: def test_decorator_other_error() -> None: # Only catching instances of ValueError - decorated = poltergeist(error=ValueError)(operator.truediv) + decorated = poltergeist(ValueError)(operator.truediv) assert decorated(4, 2) == Ok(2) @@ -29,8 +29,8 @@ def test_decorator_other_error() -> None: decorated(4, 0) -def test_decorator_default_error() -> None: - decorated = poltergeist()(operator.truediv) +def test_decorator_multiple_errors() -> None: + decorated = poltergeist(ZeroDivisionError, TypeError)(operator.truediv) assert decorated(4, 2) == Ok(2) @@ -41,15 +41,9 @@ def test_decorator_default_error() -> None: case _: pytest.fail("Should have been Err") - -def test_decorator_default_error_no_args() -> None: - decorated = poltergeist(operator.truediv) - - assert decorated(4, 2) == Ok(2) - - match decorated(4, 0): + match decorated("4", 0): case Err(e): - assert type(e) == ZeroDivisionError - assert e.args == ("division by zero",) + assert type(e) == TypeError + assert e.args == ("unsupported operand type(s) for /: 'str' and 'int'",) case _: pytest.fail("Should have been Err")