Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/providers/factories.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class DIContainer(BaseContainer):
```

## AsyncFactory
- Async function is required.
- Allows both sync and async creators.
- Can only be resolved asynchronously, even if the creator is sync.
```python
import datetime

Expand All @@ -35,6 +36,7 @@ class DIContainer(BaseContainer):
async_factory = providers.AsyncFactory(async_factory)
```

> Note: If you have a class that has dependencies which need to be resolved asynchronously, you can use `AsyncFactory` to create instances of that class. The factory will handle the async resolution of dependencies.


## Retrieving provider as a Callable
Expand Down
41 changes: 41 additions & 0 deletions tests/providers/test_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import random

from that_depends import BaseContainer, Provide, inject, providers


async def test_async_factory_with_sync_creator() -> None:
return_value = random.random()
f = providers.AsyncFactory(lambda: return_value)

assert await f.resolve() == return_value


async def test_async_factory_with_sync_creator_multiple_parents() -> None:
"""Dependencies of async factory get resolved correctly with a sync creator."""
_return_value_1 = 32
_return_value_2 = 12

async def _async_creator() -> int:
return 32

def _sync_creator() -> int:
return _return_value_2

class _Adder:
def __init__(self, x: int, y: int) -> None:
self.x = x
self.y = y

def result(self) -> int:
return self.x + self.y

class _Container(BaseContainer):
p1 = providers.AsyncFactory(_async_creator)
p2 = providers.Factory(_sync_creator)
p3 = providers.AsyncFactory(_Adder, x=p1.cast, y=p2.cast)

@inject
async def _injected(adder: _Adder = Provide[_Container.p3]) -> int:
return adder.result()

assert await _injected() == _return_value_1 + _return_value_2
4 changes: 2 additions & 2 deletions that_depends/providers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ class Dict(AbstractProvider[dict[str, T_co]]):
dict_provider = Dict(key1=provider1, key2=provider2)

# Synchronous resolution
resolved_dict = dict_provider.sync_resolve()
resolved_dict = dict_provider.resolve_sync()
print(resolved_dict) # Output: {"key1": 1, "key2": 2}

# Asynchronous resolution
import asyncio
resolved_dict_async = asyncio.run(dict_provider.async_resolve())
resolved_dict_async = asyncio.run(dict_provider.resolve())
print(resolved_dict_async) # Output: {"key1": 1, "key2": 2}
```

Expand Down
2 changes: 1 addition & 1 deletion that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]:
```python
@container_context(MyContainer)
async def my_async_function():
result = await MyContainer.some_resource.async_resolve()
result = await MyContainer.some_resource.resolve()
return result
```

Expand Down
29 changes: 24 additions & 5 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
import inspect
import typing
from typing import overload

from typing_extensions import override

Expand Down Expand Up @@ -153,11 +155,21 @@ def _deregister_arguments(self) -> None:

__slots__ = "_args", "_factory", "_kwargs", "_override"

def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None:
@overload
def __init__(
self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs
) -> None: ...

@overload
def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: ...

def __init__(
self, factory: typing.Callable[P, T_co | typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Initialize an AsyncFactory instance.

Args:
factory (Callable[P, Awaitable[T_co]]): Async function that returns the resource.
factory (Callable[P, T_co | Awaitable[T_co]]): Function that returns the resource (sync or async).
*args: Arguments to pass to the factory function.
**kwargs: Keyword arguments to pass to the factory

Expand All @@ -174,11 +186,18 @@ async def resolve(self) -> T_co:
if self._override:
return typing.cast(T_co, self._override)

return await self._factory(
*[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
**{k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type]
args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args]
kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}

result = self._factory(
*args, # type:ignore[arg-type]
**kwargs, # type:ignore[arg-type]
)

if inspect.isawaitable(result):
return await result
return result

@override
def resolve_sync(self) -> typing.NoReturn:
msg = "AsyncFactory cannot be resolved synchronously"
Expand Down
8 changes: 4 additions & 4 deletions that_depends/providers/local_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def factory():
singleton = ThreadLocalSingleton(factory)

# Same thread, same instance
instance1 = singleton.sync_resolve()
instance2 = singleton.sync_resolve()
instance1 = singleton.resolve_sync()
instance2 = singleton.resolve_sync()

def thread_task():
return singleton.sync_resolve()
return singleton.resolve_sync()

threads = [threading.Thread(target=thread_task) for i in range(10)]
for thread in threads:
Expand Down Expand Up @@ -119,7 +119,7 @@ def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) ->
async def tear_down(self, propagate: bool = True) -> None:
"""Reset the thread-local instance.

After calling this method, subsequent calls to `sync_resolve` on the
After calling this method, subsequent calls to `resolve_sync()` on the
same thread will produce a new instance.
"""
if self._instance is not None:
Expand Down
2 changes: 1 addition & 1 deletion that_depends/providers/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Object(AbstractProvider[T_co]):
Example:
```python
provider = Object(1)
result = provider.sync_resolve()
result = provider.resolve_sync()
print(result) # 1
```

Expand Down
2 changes: 1 addition & 1 deletion that_depends/providers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MyContainer:
async_resource = Resource(create_async_resource)

async def main():
async_resource_instance = await MyContainer.async_resource.async_resolve()
async_resource_instance = await MyContainer.async_resource.resolve()
await MyContainer.async_resource.tear_down()
```

Expand Down
4 changes: 2 additions & 2 deletions that_depends/providers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def environment_selector():
)

# Synchronously resolve the selected provider
service = selector_instance.sync_resolve()
service = selector_instance.resolve_sync()
```

"""
Expand Down Expand Up @@ -59,7 +59,7 @@ def my_selector():
)

# The "remote" provider will be selected
selected_service = my_selector_instance.sync_resolve()
selected_service = my_selector_instance.resolve_sync()
```

"""
Expand Down
18 changes: 9 additions & 9 deletions that_depends/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Singleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T_co])
"""A provider that creates an instance once and caches it for subsequent injections.

This provider is safe to use concurrently in both threading and asyncio contexts.
On the first call to either ``sync_resolve()`` or ``async_resolve()``, the instance
On the first call to either ``resolve()`` or ``resolve_sync()``, the instance
is created by calling the provided factory. All future calls return the cached instance.

Example:
Expand All @@ -27,8 +27,8 @@ def my_factory() -> float:
return 0.5

singleton = Singleton(my_factory)
value1 = singleton.sync_resolve()
value2 = singleton.sync_resolve()
value1 = singleton.resolve_sync()
value2 = singleton.resolve_sync()
assert value1 == value2
```

Expand Down Expand Up @@ -101,7 +101,7 @@ def resolve_sync(self) -> T_co:
async def tear_down(self, propagate: bool = True) -> None:
"""Reset the cached instance.

After calling this method, the next async_resolve call will recreate the instance.
After calling this method, the next resolve() call will recreate the instance.
"""
if self._instance is not None:
self._instance = None
Expand All @@ -126,7 +126,7 @@ class AsyncSingleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T
"""A provider that creates an instance asynchronously and caches it for subsequent injections.

This provider is safe to use concurrently in asyncio contexts. On the first call
to ``async_resolve()``, the instance is created by awaiting the provided factory.
to ``resolve()``, the instance is created by awaiting the provided factory.
All subsequent calls return the cached instance.

Example:
Expand All @@ -135,8 +135,8 @@ async def my_async_factory() -> float:
return 0.5

async_singleton = AsyncSingleton(my_async_factory)
value1 = await async_singleton.async_resolve()
value2 = await async_singleton.async_resolve()
value1 = await async_singleton.resolve()
value2 = await async_singleton.resolve()
assert value1 == value2
```

Expand Down Expand Up @@ -202,7 +202,7 @@ def resolve_sync(self) -> typing.NoReturn:
async def tear_down(self, propagate: bool = True) -> None:
"""Reset the cached instance.

After calling this method, the next call to ``async_resolve()`` will recreate the instance.
After calling this method, the next call to ``resolve()`` will recreate the instance.
"""
if self._instance is not None:
self._instance = None
Expand All @@ -214,7 +214,7 @@ async def tear_down(self, propagate: bool = True) -> None:
def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
"""Reset the cached instance.

After calling this method, the next call to ``sync_resolve()`` will recreate the instance.
After calling this method, the next call to ``resolve_sync()`` will recreate the instance.
"""
if self._instance is not None:
self._instance = None
Expand Down