Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,22 @@ def solve(

if self.custom_fields:
for custom in self.custom_fields.values():
if custom.field:
custom.use_field(kwargs)
custom_override = provider.overrides.get(
custom
) or provider.overrides.get(type(custom))
if custom_override:
kwargs[custom.param_name] = custom_override.solve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
nested=True,
**kwargs,
)
else:
kwargs = custom.use(**kwargs)
if custom.field:
custom.use_field(kwargs)
else:
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

Expand Down Expand Up @@ -310,9 +322,7 @@ async def asolve(

for dep_arg, dep_key in self.dependencies.items():
if dep_arg not in kwargs:
kwargs[dep_arg] = await provider.get_dependant(
dep_key
).asolve(
kwargs[dep_arg] = await provider.get_dependant(dep_key).asolve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -326,10 +336,22 @@ async def asolve(
try:
async with anyio.create_task_group() as tg:
for custom in self.custom_fields.values():
if custom.field:
tg.start_soon(run_async, custom.use_field, kwargs)
custom_override = provider.overrides.get(
custom
) or provider.overrides.get(type(custom))
if custom_override:
kwargs[custom.param_name] = await custom_override.asolve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
nested=True,
**kwargs,
)
else:
custom_to_solve.append(custom)
if custom.field:
tg.start_soon(run_async, custom.use_field, kwargs)
else:
custom_to_solve.append(custom)

except ExceptionGroup as exgr:
for ex in exgr.exceptions:
Expand Down
11 changes: 6 additions & 5 deletions fast_depends/dependencies/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, TypeAlias

from fast_depends.core import build_call_model
from fast_depends.library import CustomField

if TYPE_CHECKING:
from fast_depends.core import CallModel
Expand Down Expand Up @@ -41,7 +42,7 @@ def get_dependant(self, key: Key) -> "CallModel":

def override(
self,
original: Callable[..., Any],
original: "Callable[..., Any] | CustomField",
override: Callable[..., Any],
) -> None:
key = self.__get_original_key(original)
Expand All @@ -51,7 +52,7 @@ def override(
if original_dependant := self.dependencies.get(key):
serializer_cls = original_dependant.serializer_cls

else:
elif not isinstance(original, CustomField):
self.dependencies[key] = build_call_model(
original,
dependency_provider=self,
Expand All @@ -67,7 +68,7 @@ def override(

def __setitem__(
self,
key: Callable[..., Any],
key: "Callable[..., Any] | CustomField",
value: Callable[..., Any],
) -> None:
"""Alias for `provider[key] = value` syntax"""
Expand All @@ -76,12 +77,12 @@ def __setitem__(
@contextmanager
def scope(
self,
original: Callable[..., Any],
original: "Callable[..., Any] | CustomField",
override: Callable[..., Any],
) -> Iterator[None]:
self.override(original, override)
yield
self.overrides.pop(self.__get_original_key(original), None)

def __get_original_key(self, original: Callable[..., Any]) -> Key:
def __get_original_key(self, original: "Callable[..., Any] | CustomField") -> Key:
return original
69 changes: 68 additions & 1 deletion tests/test_overrides.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import AsyncGenerator, Generator
from typing import Annotated
from typing import Annotated, Any
from unittest.mock import Mock

import pytest

from fast_depends import Depends, Provider, inject
from fast_depends.library import CustomField


def test_not_override(provider: Provider) -> None:
Expand Down Expand Up @@ -305,3 +306,69 @@ def func(d: Annotated[int, Depends(base_dep)]) -> int:
assert len(provider.overrides) == 0
assert len(provider.dependencies) == 1
assert func() == 1 # original dep called


class Header(CustomField):
def use(self, /, **kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs = super().use(**kwargs)
kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name)
return kwargs


def test_override_custom_field_class(provider: Provider) -> None:
@inject(dependency_provider=provider)
def func(h: int = Header()) -> int:
return h

provider.override(Header, lambda: 1)

assert func() == 1


def test_override_custom_field_instance(provider: Provider) -> None:
h = Header()

@inject(dependency_provider=provider)
def func(header_field: int = h) -> int:
return header_field

provider.override(h, lambda: 2)

assert func() == 2


@pytest.mark.anyio
async def test_async_override_custom_field(provider: Provider) -> None:
h = Header()

@inject(dependency_provider=provider)
async def func(header_field: int = h) -> int:
return header_field

async def override_dep() -> int:
return 3

provider.override(h, override_dep)

assert await func() == 3


class HeaderField(CustomField):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.field = True

def use_field(self, kwargs: dict[str, Any]) -> None:
kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name)


def test_override_custom_field_field(provider: Provider) -> None:
h = HeaderField()

@inject(dependency_provider=provider)
def func(header_field: int = h) -> int:
return header_field

provider.override(h, lambda: 2)

assert func() == 2
Loading