Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/realtime/src/realtime/_async/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,24 @@ async def subscribe(
"Tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance"
)
else:
subscribe_result: asyncio.Future[None] = (
asyncio.get_running_loop().create_future()
)

def complete_subscription(
state: RealtimeSubscribeStates, error: Optional[Exception]
) -> None:
if callback:
try:
callback(state, error)
except Exception as exc:
if not subscribe_result.done():
subscribe_result.set_exception(exc)
return

if not subscribe_result.done():
subscribe_result.set_result(None)

config: RealtimeChannelConfig = self.params["config"]
broadcast = config.get("broadcast")
presence = config.get("presence") or RealtimeChannelPresenceConfig(
Expand Down Expand Up @@ -245,7 +263,7 @@ def on_join_push_ok(payload: ReplyPostgresChanges):
new_postgres_bindings.append(postgres_callback)
else:
asyncio.create_task(self.unsubscribe())
callback and callback(
complete_subscription(
RealtimeSubscribeStates.CHANNEL_ERROR,
Exception(
"mismatch between server and client bindings for postgres changes"
Expand All @@ -254,16 +272,16 @@ def on_join_push_ok(payload: ReplyPostgresChanges):
return

self.postgres_changes_callbacks = new_postgres_bindings
callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
complete_subscription(RealtimeSubscribeStates.SUBSCRIBED, None)

def on_join_push_error(payload: Dict[str, Any]):
callback and callback(
complete_subscription(
RealtimeSubscribeStates.CHANNEL_ERROR,
Exception(json.dumps(payload)),
)

def on_join_push_timeout(*args):
callback and callback(RealtimeSubscribeStates.TIMED_OUT, None)
complete_subscription(RealtimeSubscribeStates.TIMED_OUT, None)

self.join_push.receive(
RealtimeAcknowledgementStatus.Ok, on_join_push_ok
Expand All @@ -272,6 +290,7 @@ def on_join_push_timeout(*args):
)

await self._rejoin()
await subscribe_result

return self

Expand Down
71 changes: 71 additions & 0 deletions src/realtime/tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
from typing import Any, Optional, cast

import pytest

from realtime import AsyncRealtimeClient
from realtime.message import ReplyPostgresChanges
from realtime.types import RealtimeAcknowledgementStatus, RealtimeSubscribeStates


@pytest.mark.asyncio
async def test_subscribe_waits_for_ack_before_returning():
socket = AsyncRealtimeClient("ws://localhost:54321/realtime/v1", "test-key")
socket._ws_connection = cast(
Any, object()
) # mark socket as connected without network traffic
channel = socket.channel("test-channel")

subscribed = asyncio.Event()

def on_subscribe(state: RealtimeSubscribeStates, error: Optional[Exception]):
if state == RealtimeSubscribeStates.SUBSCRIBED and error is None:
subscribed.set()

async def fake_rejoin():
async def delayed_ack():
await asyncio.sleep(0.05)
channel.join_push.trigger(
RealtimeAcknowledgementStatus.Ok,
ReplyPostgresChanges(postgres_changes=[]),
)

asyncio.create_task(delayed_ack())

channel._rejoin = fake_rejoin # type: ignore[method-assign]

subscribe_task = asyncio.create_task(channel.subscribe(on_subscribe))

await asyncio.sleep(0.01)
assert not subscribe_task.done()

await asyncio.wait_for(subscribe_task, 1)
assert subscribed.is_set()


@pytest.mark.asyncio
async def test_subscribe_propagates_callback_errors():
socket = AsyncRealtimeClient("ws://localhost:54321/realtime/v1", "test-key")
socket._ws_connection = cast(
Any, object()
) # mark socket as connected without network traffic
channel = socket.channel("test-channel")

async def fake_rejoin():
async def delayed_ack():
await asyncio.sleep(0.01)
channel.join_push.trigger(
RealtimeAcknowledgementStatus.Ok,
ReplyPostgresChanges(postgres_changes=[]),
)

asyncio.create_task(delayed_ack())

channel._rejoin = fake_rejoin # type: ignore[method-assign]

def raising_callback(state: RealtimeSubscribeStates, error: Optional[Exception]):
if state == RealtimeSubscribeStates.SUBSCRIBED and error is None:
raise RuntimeError("subscribe callback failed")

with pytest.raises(RuntimeError, match="subscribe callback failed"):
await channel.subscribe(raising_callback)