From 0ef7b34cb9c8e6eeacb6c489fb7a6be90a71386b Mon Sep 17 00:00:00 2001 From: anzzyspeaksgit Date: Sun, 22 Mar 2026 12:06:18 +0000 Subject: [PATCH] feat: allow overriding CustomField instances via dependency_provider Closes #176 --- fast_depends/core/model.py | 40 ++++++++++++---- fast_depends/dependencies/provider.py | 11 +++-- tests/test_overrides.py | 69 ++++++++++++++++++++++++++- 3 files changed, 105 insertions(+), 15 deletions(-) diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index 928c9ccc..76519a3c 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -241,10 +241,22 @@ def solve( if self.custom_fields: for custom in self.custom_fields.values(): - if custom.field: - custom.use_field(kwargs) + custom_override = provider.overrides.get( + custom + ) or provider.overrides.get(type(custom)) + if custom_override: + kwargs[custom.param_name] = custom_override.solve( + *args, + stack=stack, + cache_dependencies=cache_dependencies, + nested=True, + **kwargs, + ) else: - kwargs = custom.use(**kwargs) + if custom.field: + custom.use_field(kwargs) + else: + kwargs = custom.use(**kwargs) final_args, final_kwargs = cast_gen.send(kwargs) @@ -310,9 +322,7 @@ async def asolve( for dep_arg, dep_key in self.dependencies.items(): if dep_arg not in kwargs: - kwargs[dep_arg] = await provider.get_dependant( - dep_key - ).asolve( + kwargs[dep_arg] = await provider.get_dependant(dep_key).asolve( *args, stack=stack, cache_dependencies=cache_dependencies, @@ -326,10 +336,22 @@ async def asolve( try: async with anyio.create_task_group() as tg: for custom in self.custom_fields.values(): - if custom.field: - tg.start_soon(run_async, custom.use_field, kwargs) + custom_override = provider.overrides.get( + custom + ) or provider.overrides.get(type(custom)) + if custom_override: + kwargs[custom.param_name] = await custom_override.asolve( + *args, + stack=stack, + cache_dependencies=cache_dependencies, + nested=True, + **kwargs, + ) else: - custom_to_solve.append(custom) + if custom.field: + tg.start_soon(run_async, custom.use_field, kwargs) + else: + custom_to_solve.append(custom) except ExceptionGroup as exgr: for ex in exgr.exceptions: diff --git a/fast_depends/dependencies/provider.py b/fast_depends/dependencies/provider.py index e3c6cfc0..19ac7113 100644 --- a/fast_depends/dependencies/provider.py +++ b/fast_depends/dependencies/provider.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, TypeAlias from fast_depends.core import build_call_model +from fast_depends.library import CustomField if TYPE_CHECKING: from fast_depends.core import CallModel @@ -41,7 +42,7 @@ def get_dependant(self, key: Key) -> "CallModel": def override( self, - original: Callable[..., Any], + original: "Callable[..., Any] | CustomField", override: Callable[..., Any], ) -> None: key = self.__get_original_key(original) @@ -51,7 +52,7 @@ def override( if original_dependant := self.dependencies.get(key): serializer_cls = original_dependant.serializer_cls - else: + elif not isinstance(original, CustomField): self.dependencies[key] = build_call_model( original, dependency_provider=self, @@ -67,7 +68,7 @@ def override( def __setitem__( self, - key: Callable[..., Any], + key: "Callable[..., Any] | CustomField", value: Callable[..., Any], ) -> None: """Alias for `provider[key] = value` syntax""" @@ -76,12 +77,12 @@ def __setitem__( @contextmanager def scope( self, - original: Callable[..., Any], + original: "Callable[..., Any] | CustomField", override: Callable[..., Any], ) -> Iterator[None]: self.override(original, override) yield self.overrides.pop(self.__get_original_key(original), None) - def __get_original_key(self, original: Callable[..., Any]) -> Key: + def __get_original_key(self, original: "Callable[..., Any] | CustomField") -> Key: return original diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 94486d14..8434b0a9 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -1,10 +1,11 @@ from collections.abc import AsyncGenerator, Generator -from typing import Annotated +from typing import Annotated, Any from unittest.mock import Mock import pytest from fast_depends import Depends, Provider, inject +from fast_depends.library import CustomField def test_not_override(provider: Provider) -> None: @@ -305,3 +306,69 @@ def func(d: Annotated[int, Depends(base_dep)]) -> int: assert len(provider.overrides) == 0 assert len(provider.dependencies) == 1 assert func() == 1 # original dep called + + +class Header(CustomField): + def use(self, /, **kwargs: dict[str, Any]) -> dict[str, Any]: + kwargs = super().use(**kwargs) + kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) + return kwargs + + +def test_override_custom_field_class(provider: Provider) -> None: + @inject(dependency_provider=provider) + def func(h: int = Header()) -> int: + return h + + provider.override(Header, lambda: 1) + + assert func() == 1 + + +def test_override_custom_field_instance(provider: Provider) -> None: + h = Header() + + @inject(dependency_provider=provider) + def func(header_field: int = h) -> int: + return header_field + + provider.override(h, lambda: 2) + + assert func() == 2 + + +@pytest.mark.anyio +async def test_async_override_custom_field(provider: Provider) -> None: + h = Header() + + @inject(dependency_provider=provider) + async def func(header_field: int = h) -> int: + return header_field + + async def override_dep() -> int: + return 3 + + provider.override(h, override_dep) + + assert await func() == 3 + + +class HeaderField(CustomField): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.field = True + + def use_field(self, kwargs: dict[str, Any]) -> None: + kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) + + +def test_override_custom_field_field(provider: Provider) -> None: + h = HeaderField() + + @inject(dependency_provider=provider) + def func(header_field: int = h) -> int: + return header_field + + provider.override(h, lambda: 2) + + assert func() == 2