diff --git a/src/realtime/src/realtime/_async/channel.py b/src/realtime/src/realtime/_async/channel.py index 805299cf..34d067f4 100644 --- a/src/realtime/src/realtime/_async/channel.py +++ b/src/realtime/src/realtime/_async/channel.py @@ -189,6 +189,24 @@ async def subscribe( "Tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance" ) else: + subscribe_result: asyncio.Future[None] = ( + asyncio.get_running_loop().create_future() + ) + + def complete_subscription( + state: RealtimeSubscribeStates, error: Optional[Exception] + ) -> None: + if callback: + try: + callback(state, error) + except Exception as exc: + if not subscribe_result.done(): + subscribe_result.set_exception(exc) + return + + if not subscribe_result.done(): + subscribe_result.set_result(None) + config: RealtimeChannelConfig = self.params["config"] broadcast = config.get("broadcast") presence = config.get("presence") or RealtimeChannelPresenceConfig( @@ -245,7 +263,7 @@ def on_join_push_ok(payload: ReplyPostgresChanges): new_postgres_bindings.append(postgres_callback) else: asyncio.create_task(self.unsubscribe()) - callback and callback( + complete_subscription( RealtimeSubscribeStates.CHANNEL_ERROR, Exception( "mismatch between server and client bindings for postgres changes" @@ -254,16 +272,16 @@ def on_join_push_ok(payload: ReplyPostgresChanges): return self.postgres_changes_callbacks = new_postgres_bindings - callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None) + complete_subscription(RealtimeSubscribeStates.SUBSCRIBED, None) def on_join_push_error(payload: Dict[str, Any]): - callback and callback( + complete_subscription( RealtimeSubscribeStates.CHANNEL_ERROR, Exception(json.dumps(payload)), ) def on_join_push_timeout(*args): - callback and callback(RealtimeSubscribeStates.TIMED_OUT, None) + complete_subscription(RealtimeSubscribeStates.TIMED_OUT, None) self.join_push.receive( RealtimeAcknowledgementStatus.Ok, on_join_push_ok @@ -272,6 +290,7 @@ def on_join_push_timeout(*args): ) await self._rejoin() + await subscribe_result return self diff --git a/src/realtime/tests/test_channel.py b/src/realtime/tests/test_channel.py new file mode 100644 index 00000000..df3d38f7 --- /dev/null +++ b/src/realtime/tests/test_channel.py @@ -0,0 +1,71 @@ +import asyncio +from typing import Any, Optional, cast + +import pytest + +from realtime import AsyncRealtimeClient +from realtime.message import ReplyPostgresChanges +from realtime.types import RealtimeAcknowledgementStatus, RealtimeSubscribeStates + + +@pytest.mark.asyncio +async def test_subscribe_waits_for_ack_before_returning(): + socket = AsyncRealtimeClient("ws://localhost:54321/realtime/v1", "test-key") + socket._ws_connection = cast( + Any, object() + ) # mark socket as connected without network traffic + channel = socket.channel("test-channel") + + subscribed = asyncio.Event() + + def on_subscribe(state: RealtimeSubscribeStates, error: Optional[Exception]): + if state == RealtimeSubscribeStates.SUBSCRIBED and error is None: + subscribed.set() + + async def fake_rejoin(): + async def delayed_ack(): + await asyncio.sleep(0.05) + channel.join_push.trigger( + RealtimeAcknowledgementStatus.Ok, + ReplyPostgresChanges(postgres_changes=[]), + ) + + asyncio.create_task(delayed_ack()) + + channel._rejoin = fake_rejoin # type: ignore[method-assign] + + subscribe_task = asyncio.create_task(channel.subscribe(on_subscribe)) + + await asyncio.sleep(0.01) + assert not subscribe_task.done() + + await asyncio.wait_for(subscribe_task, 1) + assert subscribed.is_set() + + +@pytest.mark.asyncio +async def test_subscribe_propagates_callback_errors(): + socket = AsyncRealtimeClient("ws://localhost:54321/realtime/v1", "test-key") + socket._ws_connection = cast( + Any, object() + ) # mark socket as connected without network traffic + channel = socket.channel("test-channel") + + async def fake_rejoin(): + async def delayed_ack(): + await asyncio.sleep(0.01) + channel.join_push.trigger( + RealtimeAcknowledgementStatus.Ok, + ReplyPostgresChanges(postgres_changes=[]), + ) + + asyncio.create_task(delayed_ack()) + + channel._rejoin = fake_rejoin # type: ignore[method-assign] + + def raising_callback(state: RealtimeSubscribeStates, error: Optional[Exception]): + if state == RealtimeSubscribeStates.SUBSCRIBED and error is None: + raise RuntimeError("subscribe callback failed") + + with pytest.raises(RuntimeError, match="subscribe callback failed"): + await channel.subscribe(raising_callback)