-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from AllenInstitute/variable_args
Add support for a variable number of callback arguments
- Loading branch information
Showing
7 changed files
with
326 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import logging | ||
import inspect | ||
from copy import copy | ||
|
||
|
||
def setup_logging(logger_name: str, log_level: int = logging.INFO): | ||
logger = logging.getLogger(logger_name) | ||
handler = logging.StreamHandler() | ||
formatter = logging.Formatter( | ||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||
) | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
logger.setLevel(log_level) | ||
return logger | ||
|
||
|
||
def call_with_correct_args(func, *args, **kwargs): | ||
args = copy(args) | ||
kwargs = copy(kwargs) | ||
params = inspect.signature(func).parameters | ||
|
||
if True not in [ | ||
param.kind == inspect._ParameterKind.VAR_POSITIONAL for param in params.values() | ||
]: | ||
num_args = len( | ||
[ | ||
None | ||
for param in params.values() | ||
if param.default == param.empty and param.kind != param.VAR_KEYWORD | ||
] | ||
) | ||
if num_args > len(args): | ||
raise TypeError( | ||
f"Function '{func}' requires {num_args} arguments, but only {len(args)} are available." | ||
) | ||
args = args[:num_args] | ||
|
||
if True not in [ | ||
param.kind == inspect._ParameterKind.VAR_KEYWORD for param in params.values() | ||
]: | ||
allowed_keys = [key for key, val in params.items() if val.default != val.empty] | ||
for key in list(kwargs.keys()): | ||
if key not in allowed_keys: | ||
del kwargs[key] | ||
|
||
return func(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from pigeon.utils import call_with_correct_args | ||
import pytest | ||
|
||
|
||
def test_not_enough_args(): | ||
def test_func(a, b, c, d): | ||
return a, b, c, d | ||
|
||
with pytest.raises(TypeError): | ||
call_with_correct_args(test_func, 1, 2, 3) | ||
|
||
|
||
def test_equal_args(): | ||
def test_func(a, b, c, d): | ||
return a, b, c, d | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, 4) == (1, 2, 3, 4) | ||
|
||
|
||
def test_args(): | ||
def test_func(a, b, c, d): | ||
return a, b, c, d | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, 4, 5) == (1, 2, 3, 4) | ||
|
||
|
||
def test_not_enough_kwargs(): | ||
def test_func(a=1, b=2, c=3): | ||
return a, b, c | ||
|
||
assert call_with_correct_args(test_func, a=10, b=11) == (10, 11, 3) | ||
|
||
|
||
def test_no_args(): | ||
def test_func(): | ||
return True | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3) | ||
|
||
|
||
def test_both(): | ||
def test_func(a, b, c, d=1, e=2): | ||
return a, b, c, d, e | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, 4, 5, d=10, e=11, f=12) == ( | ||
1, | ||
2, | ||
3, | ||
10, | ||
11, | ||
) | ||
|
||
|
||
def test_var_args(): | ||
def test_func(a, b, *args): | ||
return a, b, args | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, 4) == (1, 2, (3, 4)) | ||
|
||
|
||
def test_var_kwargs(): | ||
def test_func(a=1, b=2, **kwargs): | ||
return a, b, kwargs | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, a=10, c=11, d=12) == ( | ||
10, | ||
2, | ||
{"c": 11, "d": 12}, | ||
) | ||
|
||
|
||
def test_both_var(): | ||
def test_func(a, b, *args, c=1, d=2, **kwargs): | ||
return a, b, c, d, args, kwargs | ||
|
||
assert call_with_correct_args(test_func, 1, 2, 3, 4, e=1, c=12, f=13) == ( | ||
1, | ||
2, | ||
12, | ||
2, | ||
(3, 4), | ||
{"e": 1, "f": 13}, | ||
) |
Oops, something went wrong.