Skip to content

Commit

Permalink
Merge pull request #321 from ApostolFet/add_params
Browse files Browse the repository at this point in the history
Add parameters in the right order
  • Loading branch information
Tishka17 authored Dec 12, 2024
2 parents 2ed9851 + 7b600b7 commit 7f4b71d
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 1 deletion.
47 changes: 46 additions & 1 deletion src/dishka/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from inspect import (
Parameter,
Signature,
_ParameterKind,
isasyncgenfunction,
isgeneratorfunction,
signature,
Expand Down Expand Up @@ -127,7 +128,7 @@ def wrap_injection(

auto_injected_func: Callable[P, T | Awaitable[T]]
if additional_params:
new_params.extend(additional_params)
new_params = _add_params(new_params, additional_params)
for param in additional_params:
new_annotations[param.name] = param.annotation

Expand Down Expand Up @@ -230,3 +231,47 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs, **solved)

return auto_injected_func


def _add_params(
params: Sequence[Parameter],
additional_params: Sequence[Parameter],
):
params_kind_dict: dict[_ParameterKind, list[Parameter]] = {}

for param in params:
params_kind_dict.setdefault(param.kind, []).append(param)

for param in additional_params:
params_kind_dict.setdefault(param.kind, []).append(param)


var_positional = params_kind_dict.get(Parameter.VAR_POSITIONAL, [])
if len(var_positional) > 1:
param_names = (param.name for param in var_positional)
var_positional_names = ", *".join(param_names)
base_msg = "more than one variadic positional parameter: *"
msg = base_msg + var_positional_names
raise ValueError(msg)

var_keyword = params_kind_dict.get(Parameter.VAR_KEYWORD, [])
if len(var_keyword) > 1:
var_keyword_names = ", **".join(param.name for param in var_keyword)
msg = "more than one variadic keyword parameter: " + var_keyword_names
raise ValueError(msg)

positional_only = params_kind_dict.get(Parameter.POSITIONAL_ONLY, [])
positional_or_keyword = params_kind_dict.get(
Parameter.POSITIONAL_OR_KEYWORD,
[],
)
keyword_only = params_kind_dict.get(Parameter.KEYWORD_ONLY, [])

result_params = []
result_params.extend(positional_only)
result_params.extend(positional_or_keyword)
result_params.extend(var_positional)
result_params.extend(keyword_only)
result_params.extend(var_keyword)

return result_params
Empty file.
78 changes: 78 additions & 0 deletions tests/integrations/base/test_add_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from inspect import Parameter, Signature, signature

import pytest

from dishka.integrations.base import _add_params


def func(
pos_only,
/,
pos_keyword,
*,
keyword_only,
) -> None: ...


def func_expected(
pos_only,
add_pos_only,
/,
pos_keyword,
add_pos_keyword,
*add_var_pos,
keyword_only,
add_keyword_only,
**add_var_keyword,
) -> None: ...


def func_with_args_kwargs(*args, **kwargs): ...


def test_add_all_params():
additional_params = [
Parameter("add_pos_only", Parameter.POSITIONAL_ONLY),
Parameter("add_pos_keyword", Parameter.POSITIONAL_OR_KEYWORD),
Parameter("add_var_pos", Parameter.VAR_POSITIONAL),
Parameter("add_keyword_only", Parameter.KEYWORD_ONLY),
Parameter("add_var_keyword", Parameter.VAR_KEYWORD),
]
func_signature = signature(func)
func_params = list(func_signature.parameters.values())

result_params = _add_params(func_params, additional_params)
new_signature = Signature(
parameters=result_params,
return_annotation=func_signature.return_annotation,
)

assert new_signature == signature(func_expected)


def test_fail_add_second_args():
additional_params = [
Parameter("add_var_pos", Parameter.VAR_POSITIONAL),
]

func_signature = signature(func_with_args_kwargs)
func_params = list(func_signature.parameters.values())

with pytest.raises(
ValueError, match="more than one variadic positional parameter",
):
_add_params(func_params, additional_params)


def test_fail_add_second_kwargs():
additional_params = [
Parameter("add_var_keyword", Parameter.VAR_KEYWORD),
]

func_signature = signature(func_with_args_kwargs)
func_params = list(func_signature.parameters.values())

with pytest.raises(
ValueError, match="more than one variadic keyword parameter",
):
_add_params(func_params, additional_params)
20 changes: 20 additions & 0 deletions tests/integrations/taskiq/test_taskiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ async def return_int_task(data: FromDishka[int]) -> int:
return data


@inject
async def task_with_kwargs(
_: FromDishka[int],
**kwargs: str,
) -> dict[str, str]:
return kwargs


@asynccontextmanager
async def create_broker() -> AsyncIterator[AsyncBroker]:
in_memory_broker = InMemoryBroker().with_result_backend(
Expand All @@ -41,3 +49,15 @@ async def test_return_int_task() -> None:
kiq = await task.kiq()
result = await kiq.wait_result()
assert result.return_value == hash("dishka")


@pytest.mark.asyncio
async def test_task_with_kwargs() -> None:
async with create_broker() as broker:
task = broker.task(task_with_kwargs)
kwargs = {"key": "value"}

kiq = await task.kiq(**kwargs)
result = await kiq.wait_result()

assert result.return_value == kwargs
3 changes: 3 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ requires =
env_list =
unit,
real_world_example,
integrations-base,
fastapi-{0096,0109},
aiohttp-393,
flask-302,
Expand All @@ -25,6 +26,7 @@ use_develop = true
deps =
pytest
pytest-cov
integrations-base: -r requirements/test.txt
aiohttp-393: -r requirements/aiohttp-393.txt
aiohttp-latest: -r requirements/aiohttp-latest.txt
fastapi-latest: -r requirements/fastapi-latest.txt
Expand Down Expand Up @@ -57,6 +59,7 @@ deps =
click-latest: -r requirements/click-latest.txt

commands =
integrations-base: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/base
aiohttp: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/aiohttp
fastapi: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/fastapi
aiogram: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/aiogram
Expand Down

0 comments on commit 7f4b71d

Please sign in to comment.