diff --git a/docs/providers/factories.md b/docs/providers/factories.md index 079fde9..e4ade97 100644 --- a/docs/providers/factories.md +++ b/docs/providers/factories.md @@ -20,7 +20,8 @@ class DIContainer(BaseContainer): ``` ## AsyncFactory -- Async function is required. +- Allows both sync and async creators. +- Can only be resolved asynchronously, even if the creator is sync. ```python import datetime @@ -35,6 +36,7 @@ class DIContainer(BaseContainer): async_factory = providers.AsyncFactory(async_factory) ``` +> Note: If you have a class that has dependencies which need to be resolved asynchronously, you can use `AsyncFactory` to create instances of that class. The factory will handle the async resolution of dependencies. ## Retrieving provider as a Callable diff --git a/tests/providers/test_factories.py b/tests/providers/test_factories.py new file mode 100644 index 0000000..228f63d --- /dev/null +++ b/tests/providers/test_factories.py @@ -0,0 +1,41 @@ +import random + +from that_depends import BaseContainer, Provide, inject, providers + + +async def test_async_factory_with_sync_creator() -> None: + return_value = random.random() + f = providers.AsyncFactory(lambda: return_value) + + assert await f.resolve() == return_value + + +async def test_async_factory_with_sync_creator_multiple_parents() -> None: + """Dependencies of async factory get resolved correctly with a sync creator.""" + _return_value_1 = 32 + _return_value_2 = 12 + + async def _async_creator() -> int: + return 32 + + def _sync_creator() -> int: + return _return_value_2 + + class _Adder: + def __init__(self, x: int, y: int) -> None: + self.x = x + self.y = y + + def result(self) -> int: + return self.x + self.y + + class _Container(BaseContainer): + p1 = providers.AsyncFactory(_async_creator) + p2 = providers.Factory(_sync_creator) + p3 = providers.AsyncFactory(_Adder, x=p1.cast, y=p2.cast) + + @inject + async def _injected(adder: _Adder = Provide[_Container.p3]) -> int: + return adder.result() + + assert await _injected() == _return_value_1 + _return_value_2 diff --git a/that_depends/providers/collection.py b/that_depends/providers/collection.py index 33afaec..28533d1 100644 --- a/that_depends/providers/collection.py +++ b/that_depends/providers/collection.py @@ -80,12 +80,12 @@ class Dict(AbstractProvider[dict[str, T_co]]): dict_provider = Dict(key1=provider1, key2=provider2) # Synchronous resolution - resolved_dict = dict_provider.sync_resolve() + resolved_dict = dict_provider.resolve_sync() print(resolved_dict) # Output: {"key1": 1, "key2": 2} # Asynchronous resolution import asyncio - resolved_dict_async = asyncio.run(dict_provider.async_resolve()) + resolved_dict_async = asyncio.run(dict_provider.resolve()) print(resolved_dict_async) # Output: {"key1": 1, "key2": 2} ``` diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 80eb56e..b995b87 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -607,7 +607,7 @@ def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: ```python @container_context(MyContainer) async def my_async_function(): - result = await MyContainer.some_resource.async_resolve() + result = await MyContainer.some_resource.resolve() return result ``` diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index fc54297..6242ea5 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -1,5 +1,7 @@ import abc +import inspect import typing +from typing import overload from typing_extensions import override @@ -153,11 +155,21 @@ def _deregister_arguments(self) -> None: __slots__ = "_args", "_factory", "_kwargs", "_override" - def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None: + @overload + def __init__( + self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs + ) -> None: ... + + @overload + def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: ... + + def __init__( + self, factory: typing.Callable[P, T_co | typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs + ) -> None: """Initialize an AsyncFactory instance. Args: - factory (Callable[P, Awaitable[T_co]]): Async function that returns the resource. + factory (Callable[P, T_co | Awaitable[T_co]]): Function that returns the resource (sync or async). *args: Arguments to pass to the factory function. **kwargs: Keyword arguments to pass to the factory @@ -174,11 +186,18 @@ async def resolve(self) -> T_co: if self._override: return typing.cast(T_co, self._override) - return await self._factory( - *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args] + kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()} + + result = self._factory( + *args, # type:ignore[arg-type] + **kwargs, # type:ignore[arg-type] ) + if inspect.isawaitable(result): + return await result + return result + @override def resolve_sync(self) -> typing.NoReturn: msg = "AsyncFactory cannot be resolved synchronously" diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index 2cd2a5b..dd8f1ef 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -27,11 +27,11 @@ def factory(): singleton = ThreadLocalSingleton(factory) # Same thread, same instance - instance1 = singleton.sync_resolve() - instance2 = singleton.sync_resolve() + instance1 = singleton.resolve_sync() + instance2 = singleton.resolve_sync() def thread_task(): - return singleton.sync_resolve() + return singleton.resolve_sync() threads = [threading.Thread(target=thread_task) for i in range(10)] for thread in threads: @@ -119,7 +119,7 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> async def tear_down(self, propagate: bool = True) -> None: """Reset the thread-local instance. - After calling this method, subsequent calls to `sync_resolve` on the + After calling this method, subsequent calls to `resolve_sync()` on the same thread will produce a new instance. """ if self._instance is not None: diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py index 11d13e0..b29ed7c 100644 --- a/that_depends/providers/object.py +++ b/that_depends/providers/object.py @@ -18,7 +18,7 @@ class Object(AbstractProvider[T_co]): Example: ```python provider = Object(1) - result = provider.sync_resolve() + result = provider.resolve_sync() print(result) # 1 ``` diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index 1430e99..e5a84c6 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -33,7 +33,7 @@ class MyContainer: async_resource = Resource(create_async_resource) async def main(): - async_resource_instance = await MyContainer.async_resource.async_resolve() + async_resource_instance = await MyContainer.async_resource.resolve() await MyContainer.async_resource.tear_down() ``` diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py index 711326f..5bbe023 100644 --- a/that_depends/providers/selector.py +++ b/that_depends/providers/selector.py @@ -29,7 +29,7 @@ def environment_selector(): ) # Synchronously resolve the selected provider - service = selector_instance.sync_resolve() + service = selector_instance.resolve_sync() ``` """ @@ -59,7 +59,7 @@ def my_selector(): ) # The "remote" provider will be selected - selected_service = my_selector_instance.sync_resolve() + selected_service = my_selector_instance.resolve_sync() ``` """ diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 49465ba..daa470d 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -18,7 +18,7 @@ class Singleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T_co]) """A provider that creates an instance once and caches it for subsequent injections. This provider is safe to use concurrently in both threading and asyncio contexts. - On the first call to either ``sync_resolve()`` or ``async_resolve()``, the instance + On the first call to either ``resolve()`` or ``resolve_sync()``, the instance is created by calling the provided factory. All future calls return the cached instance. Example: @@ -27,8 +27,8 @@ def my_factory() -> float: return 0.5 singleton = Singleton(my_factory) - value1 = singleton.sync_resolve() - value2 = singleton.sync_resolve() + value1 = singleton.resolve_sync() + value2 = singleton.resolve_sync() assert value1 == value2 ``` @@ -101,7 +101,7 @@ def resolve_sync(self) -> T_co: async def tear_down(self, propagate: bool = True) -> None: """Reset the cached instance. - After calling this method, the next async_resolve call will recreate the instance. + After calling this method, the next resolve() call will recreate the instance. """ if self._instance is not None: self._instance = None @@ -126,7 +126,7 @@ class AsyncSingleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T """A provider that creates an instance asynchronously and caches it for subsequent injections. This provider is safe to use concurrently in asyncio contexts. On the first call - to ``async_resolve()``, the instance is created by awaiting the provided factory. + to ``resolve()``, the instance is created by awaiting the provided factory. All subsequent calls return the cached instance. Example: @@ -135,8 +135,8 @@ async def my_async_factory() -> float: return 0.5 async_singleton = AsyncSingleton(my_async_factory) - value1 = await async_singleton.async_resolve() - value2 = await async_singleton.async_resolve() + value1 = await async_singleton.resolve() + value2 = await async_singleton.resolve() assert value1 == value2 ``` @@ -202,7 +202,7 @@ def resolve_sync(self) -> typing.NoReturn: async def tear_down(self, propagate: bool = True) -> None: """Reset the cached instance. - After calling this method, the next call to ``async_resolve()`` will recreate the instance. + After calling this method, the next call to ``resolve()`` will recreate the instance. """ if self._instance is not None: self._instance = None @@ -214,7 +214,7 @@ async def tear_down(self, propagate: bool = True) -> None: def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: """Reset the cached instance. - After calling this method, the next call to ``sync_resolve()`` will recreate the instance. + After calling this method, the next call to ``resolve_sync()`` will recreate the instance. """ if self._instance is not None: self._instance = None