From b12ee6ca115b39962ca4b5baac6e1e1a8af4167e Mon Sep 17 00:00:00 2001 From: IvanKirpichnikov Date: Fri, 18 Oct 2024 18:54:52 +0300 Subject: [PATCH 1/5] add override for alias and from context --- docs/provider/alias.rst | 17 +++- docs/provider/from_context.rst | 15 ++++ src/dishka/dependency_source/alias.py | 6 +- src/dishka/dependency_source/context_var.py | 18 ++-- src/dishka/dependency_source/make_alias.py | 2 + .../dependency_source/make_context_var.py | 8 +- src/dishka/dependency_source/make_factory.py | 6 +- .../dependency_source/unpack_provides.py | 2 + src/dishka/registry.py | 2 +- src/dishka/registry_builder.py | 45 +++++++++- tests/unit/container/override/__init__.py | 0 tests/unit/container/override/test_alias.py | 88 +++++++++++++++++++ .../container/override/test_context_var.py | 80 +++++++++++++++++ .../test_provide.py} | 38 +------- .../container/override/test_provide_all.py | 34 +++++++ 15 files changed, 311 insertions(+), 50 deletions(-) create mode 100644 tests/unit/container/override/__init__.py create mode 100644 tests/unit/container/override/test_alias.py create mode 100644 tests/unit/container/override/test_context_var.py rename tests/unit/container/{test_override.py => override/test_provide.py} (66%) create mode 100644 tests/unit/container/override/test_provide_all.py diff --git a/docs/provider/alias.rst b/docs/provider/alias.rst index 8f1777b1..945de81d 100644 --- a/docs/provider/alias.rst +++ b/docs/provider/alias.rst @@ -18,4 +18,19 @@ Provider object has also a ``.alias`` method with the same logic. a_proto = alias(source=A, provides=AProtocol) -Additionally, alias has own setting for caching: it caches by default regardless if source is cached. You can disable it providing ``cache=False`` argument. \ No newline at end of file +Additionally, alias has own setting for caching: it caches by default regardless if source is cached. You can disable it providing ``cache=False`` argument. + +* Do you want to override the alias? To do this, specify the parameter ``override=True``. This can be checked when passing proper ``validation_settings`` when creating container. + +.. code-block:: python + + from dishka import WithParents, provide, Provider, Scope, alias + class MyProvider(Provider): + scope=Scope.APP + a = provide(lambda: 1, provides=int) + a_alias = alias(float, provides=int) + a_alias_override = alias(float, provides=int, override=True) + + container = make_async_container(MyProvider()) + a = await container.get(int) + # 2 diff --git a/docs/provider/from_context.rst b/docs/provider/from_context.rst index ccc6c36c..f129d050 100644 --- a/docs/provider/from_context.rst +++ b/docs/provider/from_context.rst @@ -23,3 +23,18 @@ You can put some data manually when entering scope and rely on it in your provid container = make_container(MyProvider(), context={App: app}) with container(context={RequestClass: request_instance}) as request_container: pass + + +* Do you want to override the from_context? To do this, specify the parameter ``override=True``. This can be checked when passing proper ``validation_settings`` when creating container. + +.. code-block:: python + + from dishka import WithParents, from_context, Provider, Scope + class MyProvider(Provider): + scope=Scope.APP + a = from_context(provides=int) + a_override = from_context(provides=int, override=True) + + container = make_async_container(MyProvider()) + a = await container.get(int) + # 2 diff --git a/src/dishka/dependency_source/alias.py b/src/dishka/dependency_source/alias.py index db0e7d5b..daa11388 100644 --- a/src/dishka/dependency_source/alias.py +++ b/src/dishka/dependency_source/alias.py @@ -14,17 +14,19 @@ def _identity(x: Any) -> Any: class Alias: - __slots__ = ("source", "provides", "cache", "component") + __slots__ = ("source", "provides", "cache", "component", "override") def __init__( self, *, source: DependencyKey, provides: DependencyKey, cache: bool, + override: bool, ) -> None: self.source = source self.provides = provides self.cache = cache + self.override = override def as_factory( self, scope: BaseScope | None, component: Component | None, @@ -38,7 +40,7 @@ def as_factory( kw_dependencies={}, type_=FactoryType.ALIAS, cache=self.cache, - override=False, + override=self.override, ) def __get__(self, instance: Any, owner: Any) -> Alias: diff --git a/src/dishka/dependency_source/context_var.py b/src/dishka/dependency_source/context_var.py index 8700e443..87034155 100644 --- a/src/dishka/dependency_source/context_var.py +++ b/src/dishka/dependency_source/context_var.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, NoReturn from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.factory_type import FactoryType @@ -10,20 +10,22 @@ from .factory import Factory -def context_stub() -> Any: +def context_stub() -> NoReturn: raise NotImplementedError class ContextVariable: - __slots__ = ("provides", "scope") + __slots__ = ("provides", "scope", "override") def __init__( self, *, provides: DependencyKey, scope: BaseScope | None, + override: bool, ) -> None: self.provides = provides self.scope = scope + self.override = override def as_factory( self, component: Component, @@ -38,14 +40,17 @@ def as_factory( kw_dependencies={}, type_=FactoryType.CONTEXT, cache=False, - override=False, + override=self.override, ) else: aliased = Alias( source=self.provides.with_component(DEFAULT_COMPONENT), - provides=DependencyKey(self.provides.type_hint, - component=component), cache=False, + override=self.override, + provides=DependencyKey( + component=component, + type_hint=self.provides.type_hint, + ), ) return aliased.as_factory(scope=self.scope, component=component) @@ -54,4 +59,5 @@ def __get__(self, instance: Any, owner: Any) -> ContextVariable: return ContextVariable( scope=scope, provides=self.provides, + override=self.override, ) diff --git a/src/dishka/dependency_source/make_alias.py b/src/dishka/dependency_source/make_alias.py index 68315b0f..c9a5791f 100644 --- a/src/dishka/dependency_source/make_alias.py +++ b/src/dishka/dependency_source/make_alias.py @@ -13,6 +13,7 @@ def alias( provides: Any | None = None, cache: bool = True, component: Component | None = None, + override: bool = False, ) -> CompositeDependencySource: if component is provides is None: raise ValueError("Either component or provides must be set in alias") @@ -27,6 +28,7 @@ def alias( ), provides=DependencyKey(provides, None), cache=cache, + override=override, ) composite.dependency_sources.extend(unpack_alias(alias_instance)) return composite diff --git a/src/dishka/dependency_source/make_context_var.py b/src/dishka/dependency_source/make_context_var.py index c2b1d995..13da74dc 100644 --- a/src/dishka/dependency_source/make_context_var.py +++ b/src/dishka/dependency_source/make_context_var.py @@ -8,16 +8,20 @@ def from_context( - provides: Any, *, scope: BaseScope | None = None, + provides: Any, + *, + scope: BaseScope | None = None, + override: bool = False, ) -> CompositeDependencySource: composite = CompositeDependencySource(origin=context_stub) composite.dependency_sources.append( ContextVariable( + scope=scope, + override=override, provides=DependencyKey( type_hint=provides, component=DEFAULT_COMPONENT, ), - scope=scope, ), ) return composite diff --git a/src/dishka/dependency_source/make_factory.py b/src/dishka/dependency_source/make_factory.py index e3fa0aa4..2faf4aa0 100644 --- a/src/dishka/dependency_source/make_factory.py +++ b/src/dishka/dependency_source/make_factory.py @@ -227,7 +227,7 @@ def _make_factory_by_class( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, + name=e.name, # type: ignore[call-arg] ) from e hints.pop("return", _empty) @@ -285,7 +285,7 @@ def _make_factory_by_function( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, + name=e.name, # type: ignore[call-arg] ) from e if is_in_class: self = next(iter(params.values()), None) @@ -351,7 +351,7 @@ def _make_factory_by_static_method( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, + name=e.name, # type: ignore[call-arg] ) from e possible_dependency = hints.pop("return", _empty) diff --git a/src/dishka/dependency_source/unpack_provides.py b/src/dishka/dependency_source/unpack_provides.py index c11df772..173f51e3 100644 --- a/src/dishka/dependency_source/unpack_provides.py +++ b/src/dishka/dependency_source/unpack_provides.py @@ -19,6 +19,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: provides=DependencyKey(provides_other, factory.provides.component), source=DependencyKey(provides_first, factory.provides.component), cache=factory.cache, + override=factory.override, ) for provides_other in provides_others ] @@ -63,6 +64,7 @@ def unpack_alias(alias: Alias) -> Sequence[DependencySource]: provides=DependencyKey(provides, alias.provides.component), source=alias.source, cache=alias.cache, + override=alias.override, ) for provides in alias.provides.type_hint.items ] diff --git a/src/dishka/registry.py b/src/dishka/registry.py index c7d67b5c..0c2e5e54 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -90,7 +90,7 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: scope=self.scope, dependencies=[], kw_dependencies={}, - provides=DependencyKey(type[typevar], dependency.component), + provides=DependencyKey(type[typevar], dependency.component), # type: ignore[misc] type_=FactoryType.FACTORY, is_to_bind=False, cache=False, diff --git a/src/dishka/registry_builder.py b/src/dishka/registry_builder.py index eaa57e0b..e7246850 100644 --- a/src/dishka/registry_builder.py +++ b/src/dishka/registry_builder.py @@ -136,6 +136,8 @@ def __init__( self.skip_validation = skip_validation self.validation_settings = validation_settings self.processed_factories: dict[DependencyKey, Factory] = {} + self.processed_aliases: dict[DependencyKey, Alias] = {} + self.processed_contex_vars: dict[DependencyKey, ContextVariable] = {} def _collect_components(self) -> None: for provider in self.providers: @@ -175,6 +177,7 @@ def _init_registries(self) -> None: context_var = ContextVariable( provides=DependencyKey(self.container_type, DEFAULT_COMPONENT), scope=scope, + override=False, ) for component in self.components: registry.add_factory(context_var.as_factory(component)) @@ -243,7 +246,27 @@ def _process_alias( registry = self.registries[scope] factory = alias.as_factory(scope, component) + if ( + self.validation_settings.nothing_overridden + and not self.skip_validation + and factory.override + and factory.provides not in self.processed_factories + ): + raise NothingOverriddenError(factory) + + if ( + self.validation_settings.implicit_override + and not self.skip_validation + and not factory.override + and factory.provides in self.processed_factories + ): + raise ImplicitOverrideDetectedError( + factory, + self.processed_factories[factory.provides], + ) + self.dependency_scopes[factory.provides] = scope + self.processed_factories[factory.provides] = factory registry.add_factory(factory) def _process_generic_decorator( @@ -375,7 +398,27 @@ def _process_context_var( ) registry = self.registries[context_var.scope] for component in self.components: - registry.add_factory(context_var.as_factory(component)) + factory = context_var.as_factory(component) + if ( + self.validation_settings.nothing_overridden + and not self.skip_validation + and factory.override + and factory.provides not in self.processed_factories + ): + raise NothingOverriddenError(factory) + + if ( + self.validation_settings.implicit_override + and not self.skip_validation + and not factory.override + and factory.provides in self.processed_factories + ): + raise ImplicitOverrideDetectedError( + factory, + self.processed_factories[factory.provides], + ) + self.processed_factories[context_var.provides] = factory + registry.add_factory(factory) def build(self) -> tuple[Registry, ...]: self._collect_components() diff --git a/tests/unit/container/override/__init__.py b/tests/unit/container/override/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/container/override/test_alias.py b/tests/unit/container/override/test_alias.py new file mode 100644 index 00000000..a1deddc5 --- /dev/null +++ b/tests/unit/container/override/test_alias.py @@ -0,0 +1,88 @@ +import pytest + +from dishka import Provider, Scope, alias, make_container, provide +from dishka.entities.validation_settigs import ( + STRICT_VALIDATION, + ValidationSettings, +) +from dishka.exceptions import ( + ImplicitOverrideDetectedError, + NothingOverriddenError, +) + + +def test_no_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide(int, provides=int) + + alias(source=int, provides=float) + + alias(source=int, provides=float) + ) + + with pytest.raises(ImplicitOverrideDetectedError) as e: + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + assert str(e.value) + + +def test_skip_no_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide(int, provides=int) + + alias(source=int, provides=float) + + alias(source=int, provides=float) + ) + + make_container( + TestProvider(), + validation_settings=ValidationSettings(implicit_override=False), + ) + + +def test_override_ok() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide(int, provides=int) + + alias(source=int, provides=float) + + alias(source=int, provides=float, override=True) + ) + + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + + +def test_cant_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide(int, provides=int) + + alias(source=int, provides=float, override=True) + ) + + with pytest.raises(NothingOverriddenError) as e: + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + assert str(e.value) + + +def test_skip_cant_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide(int, provides=int) + + alias(source=int, provides=float, override=True) + ) + + make_container( + TestProvider(), + validation_settings=ValidationSettings(nothing_overridden=False), + ) diff --git a/tests/unit/container/override/test_context_var.py b/tests/unit/container/override/test_context_var.py new file mode 100644 index 00000000..efe6fe5c --- /dev/null +++ b/tests/unit/container/override/test_context_var.py @@ -0,0 +1,80 @@ +import pytest + +from dishka import Provider, Scope, from_context, make_container +from dishka.entities.validation_settigs import ( + STRICT_VALIDATION, + ValidationSettings, +) +from dishka.exceptions import ( + ImplicitOverrideDetectedError, + NothingOverriddenError, +) + + +def test_no_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + from_context(int) + + from_context(int) + ) + + with pytest.raises(ImplicitOverrideDetectedError) as e: + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + assert str(e.value) + + +def test_skip_no_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + from_context(int) + + from_context(int) + ) + + make_container( + TestProvider(), + validation_settings=ValidationSettings(implicit_override=False), + ) + + +def test_override_ok() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + from_context(int) + + from_context(int, override=True) + ) + + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + + +def test_cant_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = from_context(int, override=True) + + with pytest.raises(NothingOverriddenError) as e: + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + assert str(e.value) + + +def test_skip_cant_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = from_context(int, override=True) + + make_container( + TestProvider(), + validation_settings=ValidationSettings(nothing_overridden=False), + ) + diff --git a/tests/unit/container/test_override.py b/tests/unit/container/override/test_provide.py similarity index 66% rename from tests/unit/container/test_override.py rename to tests/unit/container/override/test_provide.py index 57f98c0f..156f6a28 100644 --- a/tests/unit/container/test_override.py +++ b/tests/unit/container/override/test_provide.py @@ -1,6 +1,6 @@ import pytest -from dishka import Provider, Scope, make_container, provide, provide_all +from dishka import Provider, Scope, make_container, provide from dishka.entities.validation_settigs import ( STRICT_VALIDATION, ValidationSettings, @@ -11,7 +11,7 @@ ) -def test_no_override_provide() -> None: +def test_no_override() -> None: class TestProvider(Provider): scope = Scope.APP provides = ( @@ -27,7 +27,7 @@ class TestProvider(Provider): assert str(e.value) -def test_skip_no_override_provide() -> None: +def test_skip_no_override() -> None: class TestProvider(Provider): scope = Scope.APP provides = ( @@ -41,7 +41,7 @@ class TestProvider(Provider): ) -def test_override_provide_ok() -> None: +def test_override_ok() -> None: class TestProvider(Provider): scope = Scope.APP provides = ( @@ -55,35 +55,6 @@ class TestProvider(Provider): ) -def test_not_override_provide_all() -> None: - class TestProvider(Provider): - scope = Scope.APP - provides = ( - provide_all(int, str) - + provide_all(int, str) - ) - - with pytest.raises(ImplicitOverrideDetectedError): - make_container( - TestProvider(), - validation_settings=STRICT_VALIDATION, - ) - - -def test_override_provide_all() -> None: - class TestProvider(Provider): - scope = Scope.APP - provides = ( - provide_all(int, str) - + provide_all(int, str, override=True) - ) - - make_container( - TestProvider(), - validation_settings=STRICT_VALIDATION, - ) - - def test_cant_override() -> None: class TestProvider(Provider): scope = Scope.APP @@ -106,4 +77,3 @@ class TestProvider(Provider): TestProvider(), validation_settings=ValidationSettings(nothing_overridden=False), ) - diff --git a/tests/unit/container/override/test_provide_all.py b/tests/unit/container/override/test_provide_all.py new file mode 100644 index 00000000..f27ea297 --- /dev/null +++ b/tests/unit/container/override/test_provide_all.py @@ -0,0 +1,34 @@ +import pytest + +from dishka import Provider, Scope, make_container, provide_all +from dishka.entities.validation_settigs import STRICT_VALIDATION +from dishka.exceptions import ImplicitOverrideDetectedError + + +def test_not_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide_all(int, str) + + provide_all(int, str) + ) + + with pytest.raises(ImplicitOverrideDetectedError): + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) + + +def test_override() -> None: + class TestProvider(Provider): + scope = Scope.APP + provides = ( + provide_all(int, str) + + provide_all(int, str, override=True) + ) + + make_container( + TestProvider(), + validation_settings=STRICT_VALIDATION, + ) From e3bf8981f38bb3d21ceb96c3322bbceab85ecfd5 Mon Sep 17 00:00:00 2001 From: IvanKirpichnikov Date: Fri, 18 Oct 2024 19:14:01 +0300 Subject: [PATCH 2/5] add parameter 'override' for provider methods --- src/dishka/provider.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/dishka/provider.py b/src/dishka/provider.py index 2b9618f3..ff256c4b 100644 --- a/src/dishka/provider.py +++ b/src/dishka/provider.py @@ -163,7 +163,10 @@ def provide_all( if scope is None: scope = self.scope composite = provide_all_on_instance( - *provides, scope=scope, cache=cache, recursive=recursive, + *provides, + scope=scope, + cache=cache, + recursive=recursive, override=override, ) self._add_dependency_sources("?", composite.dependency_sources) @@ -176,12 +179,14 @@ def alias( provides: Any = None, cache: bool = True, component: Component | None = None, + override: bool = False, ) -> CompositeDependencySource: composite = alias( source=source, provides=provides, cache=cache, component=component, + override=override, ) self._add_dependency_sources(str(source), composite.dependency_sources) return composite @@ -203,9 +208,17 @@ def to_component(self, component: Component) -> ProviderWrapper: return ProviderWrapper(component, self) def from_context( - self, *, provides: Any, scope: BaseScope, + self, + *, + provides: Any, + scope: BaseScope, + override: bool = False, ) -> CompositeDependencySource: - composite = from_context(provides, scope=scope) + composite = from_context( + provides=provides, + scope=scope, + override=override, + ) self._add_dependency_sources( name=str(provides), sources=composite.dependency_sources, From 9bd2f31825995c480cb236ad43bc511bd3125048 Mon Sep 17 00:00:00 2001 From: IvanKirpichnikov Date: Fri, 18 Oct 2024 21:00:59 +0300 Subject: [PATCH 3/5] remove unused attrs --- src/dishka/registry_builder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dishka/registry_builder.py b/src/dishka/registry_builder.py index e7246850..ccc75262 100644 --- a/src/dishka/registry_builder.py +++ b/src/dishka/registry_builder.py @@ -136,8 +136,6 @@ def __init__( self.skip_validation = skip_validation self.validation_settings = validation_settings self.processed_factories: dict[DependencyKey, Factory] = {} - self.processed_aliases: dict[DependencyKey, Alias] = {} - self.processed_contex_vars: dict[DependencyKey, ContextVariable] = {} def _collect_components(self) -> None: for provider in self.providers: From 772368a7683c3c24583c333ee9fe79fd910b0945 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Sat, 19 Oct 2024 23:04:07 +0200 Subject: [PATCH 4/5] Revert some time ignore --- src/dishka/dependency_source/make_factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dishka/dependency_source/make_factory.py b/src/dishka/dependency_source/make_factory.py index 2faf4aa0..e3fa0aa4 100644 --- a/src/dishka/dependency_source/make_factory.py +++ b/src/dishka/dependency_source/make_factory.py @@ -227,7 +227,7 @@ def _make_factory_by_class( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, # type: ignore[call-arg] + name=e.name, ) from e hints.pop("return", _empty) @@ -285,7 +285,7 @@ def _make_factory_by_function( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, # type: ignore[call-arg] + name=e.name, ) from e if is_in_class: self = next(iter(params.values()), None) @@ -351,7 +351,7 @@ def _make_factory_by_static_method( f"If your are using `if TYPE_CHECKING` to import '{e.name}' " f"then try removing it. \n" f"Or, create a separate factory with all types imported.", - name=e.name, # type: ignore[call-arg] + name=e.name, ) from e possible_dependency = hints.pop("return", _empty) From b425f6a6d9be42d3d60f470d237e49ef447429b3 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Sat, 19 Oct 2024 23:08:54 +0200 Subject: [PATCH 5/5] revert type ignore from registry generics --- src/dishka/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 0c2e5e54..c7d67b5c 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -90,7 +90,7 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: scope=self.scope, dependencies=[], kw_dependencies={}, - provides=DependencyKey(type[typevar], dependency.component), # type: ignore[misc] + provides=DependencyKey(type[typevar], dependency.component), type_=FactoryType.FACTORY, is_to_bind=False, cache=False,