Skip to content

Commit

Permalink
Merge pull request #277 from IvanKirpichnikov/override-provides
Browse files Browse the repository at this point in the history
add override for alias and from context
  • Loading branch information
Tishka17 authored Oct 19, 2024
2 parents 169a4f3 + b425f6a commit 451660b
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 47 deletions.
17 changes: 16 additions & 1 deletion docs/provider/alias.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,19 @@ Provider object has also a ``.alias`` method with the same logic.
a_proto = alias(source=A, provides=AProtocol)
Additionally, alias has own setting for caching: it caches by default regardless if source is cached. You can disable it providing ``cache=False`` argument.
Additionally, alias has own setting for caching: it caches by default regardless if source is cached. You can disable it providing ``cache=False`` argument.

* Do you want to override the alias? To do this, specify the parameter ``override=True``. This can be checked when passing proper ``validation_settings`` when creating container.

.. code-block:: python
from dishka import WithParents, provide, Provider, Scope, alias
class MyProvider(Provider):
scope=Scope.APP
a = provide(lambda: 1, provides=int)
a_alias = alias(float, provides=int)
a_alias_override = alias(float, provides=int, override=True)
container = make_async_container(MyProvider())
a = await container.get(int)
# 2
15 changes: 15 additions & 0 deletions docs/provider/from_context.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,18 @@ You can put some data manually when entering scope and rely on it in your provid
container = make_container(MyProvider(), context={App: app})
with container(context={RequestClass: request_instance}) as request_container:
pass
* Do you want to override the from_context? To do this, specify the parameter ``override=True``. This can be checked when passing proper ``validation_settings`` when creating container.

.. code-block:: python
from dishka import WithParents, from_context, Provider, Scope
class MyProvider(Provider):
scope=Scope.APP
a = from_context(provides=int)
a_override = from_context(provides=int, override=True)
container = make_async_container(MyProvider())
a = await container.get(int)
# 2
6 changes: 4 additions & 2 deletions src/dishka/dependency_source/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ def _identity(x: Any) -> Any:


class Alias:
__slots__ = ("source", "provides", "cache", "component")
__slots__ = ("source", "provides", "cache", "component", "override")

def __init__(
self, *,
source: DependencyKey,
provides: DependencyKey,
cache: bool,
override: bool,
) -> None:
self.source = source
self.provides = provides
self.cache = cache
self.override = override

def as_factory(
self, scope: BaseScope | None, component: Component | None,
Expand All @@ -38,7 +40,7 @@ def as_factory(
kw_dependencies={},
type_=FactoryType.ALIAS,
cache=self.cache,
override=False,
override=self.override,
)

def __get__(self, instance: Any, owner: Any) -> Alias:
Expand Down
18 changes: 12 additions & 6 deletions src/dishka/dependency_source/context_var.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, NoReturn

from dishka.entities.component import DEFAULT_COMPONENT, Component
from dishka.entities.factory_type import FactoryType
Expand All @@ -10,20 +10,22 @@
from .factory import Factory


def context_stub() -> Any:
def context_stub() -> NoReturn:
raise NotImplementedError


class ContextVariable:
__slots__ = ("provides", "scope")
__slots__ = ("provides", "scope", "override")

def __init__(
self, *,
provides: DependencyKey,
scope: BaseScope | None,
override: bool,
) -> None:
self.provides = provides
self.scope = scope
self.override = override

def as_factory(
self, component: Component,
Expand All @@ -38,14 +40,17 @@ def as_factory(
kw_dependencies={},
type_=FactoryType.CONTEXT,
cache=False,
override=False,
override=self.override,
)
else:
aliased = Alias(
source=self.provides.with_component(DEFAULT_COMPONENT),
provides=DependencyKey(self.provides.type_hint,
component=component),
cache=False,
override=self.override,
provides=DependencyKey(
component=component,
type_hint=self.provides.type_hint,
),
)
return aliased.as_factory(scope=self.scope, component=component)

Expand All @@ -54,4 +59,5 @@ def __get__(self, instance: Any, owner: Any) -> ContextVariable:
return ContextVariable(
scope=scope,
provides=self.provides,
override=self.override,
)
2 changes: 2 additions & 0 deletions src/dishka/dependency_source/make_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def alias(
provides: Any | None = None,
cache: bool = True,
component: Component | None = None,
override: bool = False,
) -> CompositeDependencySource:
if component is provides is None:
raise ValueError("Either component or provides must be set in alias")
Expand All @@ -27,6 +28,7 @@ def alias(
),
provides=DependencyKey(provides, None),
cache=cache,
override=override,
)
composite.dependency_sources.extend(unpack_alias(alias_instance))
return composite
8 changes: 6 additions & 2 deletions src/dishka/dependency_source/make_context_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@


def from_context(
provides: Any, *, scope: BaseScope | None = None,
provides: Any,
*,
scope: BaseScope | None = None,
override: bool = False,
) -> CompositeDependencySource:
composite = CompositeDependencySource(origin=context_stub)
composite.dependency_sources.append(
ContextVariable(
scope=scope,
override=override,
provides=DependencyKey(
type_hint=provides,
component=DEFAULT_COMPONENT,
),
scope=scope,
),
)
return composite
2 changes: 2 additions & 0 deletions src/dishka/dependency_source/unpack_provides.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]:
provides=DependencyKey(provides_other, factory.provides.component),
source=DependencyKey(provides_first, factory.provides.component),
cache=factory.cache,
override=factory.override,
)
for provides_other in provides_others
]
Expand Down Expand Up @@ -63,6 +64,7 @@ def unpack_alias(alias: Alias) -> Sequence[DependencySource]:
provides=DependencyKey(provides, alias.provides.component),
source=alias.source,
cache=alias.cache,
override=alias.override,
)
for provides in alias.provides.type_hint.items
]
9 changes: 8 additions & 1 deletion src/dishka/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def provide_all(
if scope is None:
scope = self.scope
composite = provide_all_on_instance(
*provides, scope=scope, cache=cache, recursive=recursive,
*provides,
scope=scope,
cache=cache,
recursive=recursive,
override=override,
)
self._add_dependency_sources("?", composite.dependency_sources)
Expand All @@ -176,12 +179,14 @@ def alias(
provides: Any = None,
cache: bool = True,
component: Component | None = None,
override: bool = False,
) -> CompositeDependencySource:
composite = alias(
source=source,
provides=provides,
cache=cache,
component=component,
override=override,
)
self._add_dependency_sources(str(source), composite.dependency_sources)
return composite
Expand All @@ -207,10 +212,12 @@ def from_context(
provides: Any,
*,
scope: BaseScope | None = None,
override: bool = False,
) -> CompositeDependencySource:
composite = from_context(
provides=provides,
scope=scope or self.scope,
override=override,
)
self._add_dependency_sources(
name=str(provides),
Expand Down
43 changes: 42 additions & 1 deletion src/dishka/registry_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _init_registries(self) -> None:
context_var = ContextVariable(
provides=DependencyKey(self.container_type, DEFAULT_COMPONENT),
scope=scope,
override=False,
)
for component in self.components:
registry.add_factory(context_var.as_factory(component))
Expand Down Expand Up @@ -243,7 +244,27 @@ def _process_alias(
registry = self.registries[scope]

factory = alias.as_factory(scope, component)
if (
self.validation_settings.nothing_overridden
and not self.skip_validation
and factory.override
and factory.provides not in self.processed_factories
):
raise NothingOverriddenError(factory)

if (
self.validation_settings.implicit_override
and not self.skip_validation
and not factory.override
and factory.provides in self.processed_factories
):
raise ImplicitOverrideDetectedError(
factory,
self.processed_factories[factory.provides],
)

self.dependency_scopes[factory.provides] = scope
self.processed_factories[factory.provides] = factory
registry.add_factory(factory)

def _process_generic_decorator(
Expand Down Expand Up @@ -375,7 +396,27 @@ def _process_context_var(
)
registry = self.registries[context_var.scope]
for component in self.components:
registry.add_factory(context_var.as_factory(component))
factory = context_var.as_factory(component)
if (
self.validation_settings.nothing_overridden
and not self.skip_validation
and factory.override
and factory.provides not in self.processed_factories
):
raise NothingOverriddenError(factory)

if (
self.validation_settings.implicit_override
and not self.skip_validation
and not factory.override
and factory.provides in self.processed_factories
):
raise ImplicitOverrideDetectedError(
factory,
self.processed_factories[factory.provides],
)
self.processed_factories[context_var.provides] = factory
registry.add_factory(factory)

def build(self) -> tuple[Registry, ...]:
self._collect_components()
Expand Down
Empty file.
88 changes: 88 additions & 0 deletions tests/unit/container/override/test_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest

from dishka import Provider, Scope, alias, make_container, provide
from dishka.entities.validation_settigs import (
STRICT_VALIDATION,
ValidationSettings,
)
from dishka.exceptions import (
ImplicitOverrideDetectedError,
NothingOverriddenError,
)


def test_no_override() -> None:
class TestProvider(Provider):
scope = Scope.APP
provides = (
provide(int, provides=int)
+ alias(source=int, provides=float)
+ alias(source=int, provides=float)
)

with pytest.raises(ImplicitOverrideDetectedError) as e:
make_container(
TestProvider(),
validation_settings=STRICT_VALIDATION,
)
assert str(e.value)


def test_skip_no_override() -> None:
class TestProvider(Provider):
scope = Scope.APP
provides = (
provide(int, provides=int)
+ alias(source=int, provides=float)
+ alias(source=int, provides=float)
)

make_container(
TestProvider(),
validation_settings=ValidationSettings(implicit_override=False),
)


def test_override_ok() -> None:
class TestProvider(Provider):
scope = Scope.APP
provides = (
provide(int, provides=int)
+ alias(source=int, provides=float)
+ alias(source=int, provides=float, override=True)
)

make_container(
TestProvider(),
validation_settings=STRICT_VALIDATION,
)


def test_cant_override() -> None:
class TestProvider(Provider):
scope = Scope.APP
provides = (
provide(int, provides=int)
+ alias(source=int, provides=float, override=True)
)

with pytest.raises(NothingOverriddenError) as e:
make_container(
TestProvider(),
validation_settings=STRICT_VALIDATION,
)
assert str(e.value)


def test_skip_cant_override() -> None:
class TestProvider(Provider):
scope = Scope.APP
provides = (
provide(int, provides=int)
+ alias(source=int, provides=float, override=True)
)

make_container(
TestProvider(),
validation_settings=ValidationSettings(nothing_overridden=False),
)
Loading

0 comments on commit 451660b

Please sign in to comment.