diff --git a/packages/reflex-base/src/reflex_base/event/__init__.py b/packages/reflex-base/src/reflex_base/event/__init__.py index 97a8af96ba1..d4edf024d1d 100644 --- a/packages/reflex-base/src/reflex_base/event/__init__.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -8,6 +8,7 @@ import warnings from base64 import b64encode from collections.abc import Callable, Mapping, Sequence +from concurrent.futures import Executor from functools import lru_cache, partial from typing import ( TYPE_CHECKING, @@ -177,6 +178,7 @@ def from_event_type( BACKGROUND_TASK_MARKER = "_reflex_background_task" EVENT_ACTIONS_MARKER = "_rx_event_actions" +EXECUTOR_MARKER = "_reflex_event_executor" UPLOAD_FILES_CLIENT_HANDLER = "uploadFiles" @@ -426,6 +428,16 @@ def is_background(self) -> bool: """ return getattr(self.fn, BACKGROUND_TASK_MARKER, False) + @property + def executor(self) -> Executor | None: + """The executor to run a non-async handler in, if any. + + Returns: + The executor specified via ``@rx.event(executor=...)`` or + ``None`` if no executor was specified. + """ + return getattr(self.fn, EXECUTOR_MARKER, None) + def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": """Pass arguments to the handler to get an event spec. @@ -2689,6 +2701,7 @@ class EventNamespace: # Constants BACKGROUND_TASK_MARKER = BACKGROUND_TASK_MARKER EVENT_ACTIONS_MARKER = EVENT_ACTIONS_MARKER + EXECUTOR_MARKER = EXECUTOR_MARKER _EVENT_FIELDS = _EVENT_FIELDS FORM_DATA = FORM_DATA upload_files = upload_files @@ -2717,6 +2730,7 @@ def __new__( throttle: int | None = None, debounce: int | None = None, temporal: bool | None = None, + executor: Executor | None = None, ) -> ( "Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]]" ): ... @@ -2732,6 +2746,7 @@ def __new__( throttle: int | None = None, debounce: int | None = None, temporal: bool | None = None, + executor: Executor | None = None, ) -> EventCallback[Unpack[P]]: ... def __new__( @@ -2744,6 +2759,7 @@ def __new__( throttle: int | None = None, debounce: int | None = None, temporal: bool | None = None, + executor: Executor | None = None, ) -> ( EventCallback[Unpack[P]] | Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]] @@ -2758,6 +2774,9 @@ def __new__( throttle: Throttle the event handler to limit calls (in milliseconds). debounce: Debounce the event handler to delay calls (in milliseconds). temporal: Whether the event should be dropped when the backend is down. + executor: The executor to run a non-async handler in. If omitted, + the EventProcessor's default thread pool is used. Ignored for + async (coroutine or async-generator) handlers. Returns: The wrapped function. @@ -2804,6 +2823,8 @@ def wrapper( msg = "Background task must be async function or generator." raise TypeError(msg) setattr(func, BACKGROUND_TASK_MARKER, True) + if executor is not None: + setattr(func, EXECUTOR_MARKER, executor) if getattr(func, "__name__", "").startswith("_"): msg = "Event handlers cannot be private." raise ValueError(msg) diff --git a/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py index c8ad3e5d589..4321e527e62 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py @@ -2,11 +2,13 @@ from __future__ import annotations +import asyncio import dataclasses import functools import inspect import warnings from collections.abc import Mapping, Sequence +from concurrent.futures import Executor from enum import Enum from importlib.util import find_spec from typing import TYPE_CHECKING, Any @@ -202,11 +204,44 @@ async def chain_updates( await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) +@dataclasses.dataclass(frozen=True, slots=True) +class _GeneratorStep: + """One step of advancing a sync generator across the executor boundary. + + ``StopIteration`` cannot be propagated out of an executor — the executor + re-raises it as ``RuntimeError`` and ``StopIteration.value`` (the + generator's return value) is discarded. This wrapper captures both + "yielded a value" and "generator returned" outcomes in a plain result + that crosses the boundary cleanly. + """ + + value: Any = None + done: bool = False + + +def _next_or_done(generator: Any) -> _GeneratorStep: + """Advance a generator one step. + + Args: + generator: The sync generator to advance. + + Returns: + A ``_GeneratorStep`` carrying either the next yielded value, or the + generator's return value with ``done=True`` when the generator is + exhausted. + """ + try: + return _GeneratorStep(value=next(generator)) + except StopIteration as si: + return _GeneratorStep(value=si.value, done=True) + + async def process_event( handler: EventHandler, payload: dict, state: BaseState | StateProxy, root_state: BaseState, + executor: Executor | None = None, ): """Process event. @@ -215,6 +250,10 @@ async def process_event( payload: The event payload. state: State to process the handler. root_state: The root state of the app, used for emitting deltas. + executor: The executor to run a non-async handler in. Async handlers + run on the asyncio loop and ignore this argument. If None, sync + handlers run inline on the asyncio loop (matching pre-executor + behavior). Raises: ValueError: If a string value is received for an int or float type and cannot be converted. @@ -237,7 +276,16 @@ async def process_event( if inspect.iscoroutinefunction(fn.func): events = await fn(**payload) - # Handle regular functions. + # Handle async generators - the function itself returns synchronously + # (yielding an async generator object); only the body iteration is async. + elif inspect.isasyncgenfunction(fn.func): + events = fn(**payload) + + # Handle regular functions - run off the asyncio loop when an executor + # is configured so blocking calls in user code don't stall the loop. + elif executor is not None: + loop = asyncio.get_running_loop() + events = await loop.run_in_executor(executor, functools.partial(fn, **payload)) else: events = fn(**payload) # Handle async generators. @@ -248,18 +296,21 @@ async def process_event( # Handle regular generators. elif inspect.isgenerator(events): - try: - while True: - await chain_updates( - next(events), root_state=root_state, handler_name=handler_name - ) - except StopIteration as si: - # the "return" value of the generator is not available - # in the loop, we must catch StopIteration to access it - if si.value is not None: - await chain_updates( - si.value, root_state=root_state, handler_name=handler_name - ) + loop = asyncio.get_running_loop() if executor is not None else None + while True: + if loop is not None: + step = await loop.run_in_executor(executor, _next_or_done, events) + else: + step = _next_or_done(events) + if step.done: + if step.value is not None: + await chain_updates( + step.value, root_state=root_state, handler_name=handler_name + ) + break + await chain_updates( + step.value, root_state=root_state, handler_name=handler_name + ) await chain_updates(None, root_state=root_state, handler_name=handler_name) # Handle regular event chains. @@ -365,6 +416,7 @@ async def _execute_event( payload=event.payload, state=substate, root_state=root_state, + executor=self.get_executor_for(registered_handler), ) return # Otherwise drop the state lock and start processing the background task with a proxy state. @@ -373,6 +425,7 @@ async def _execute_event( state=StateProxy(substate), payload=event.payload, root_state=root_state, + executor=self.get_executor_for(registered_handler), ) async def _handle_backend_exception( diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py index 7d9296fe4dd..edf52f0b737 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -11,6 +11,7 @@ import time import traceback from collections.abc import AsyncGenerator, Callable, Coroutine, Mapping, Sequence +from concurrent.futures import Executor, ThreadPoolExecutor from contextvars import Token, copy_context from typing import TYPE_CHECKING, Any, TypeVar @@ -92,12 +93,14 @@ class EventProcessor: middleware: An optional middleware mixin to apply to all events processed by this processor. backend_exception_handler: An optional function to handle exceptions raised during event processing. The function should take an Exception as input and return an EventSpec or list of EventSpecs to be emitted in response, or None to not emit any events. graceful_shutdown_timeout: An optional amount of time in seconds to wait for the queue to drain before forcefully cancelling tasks when stopping the processor. If None, the processor will not wait and will cancel tasks immediately. + default_executor: The executor used to run non-async event handlers off the asyncio loop. If unset, a ``ThreadPoolExecutor`` is created lazily on first use and shut down when the processor stops. Per-handler executors set via ``@rx.event(executor=...)`` take precedence over this default. _queue: The asyncio queue for events to be processed. _queue_task: The task responsible for processing the event queue. _root_context: The root event context to use for events enqueued without an explicit context. _attached_root_context_token: The context variable token for the attached root context, used to reset the context variable on shutdown. _tasks: A mapping of active transaction ids to their corresponding event handler tasks, used for tracking and cancellation on shutdown. + _owns_default_executor: True when the processor lazily created its own ``default_executor`` and is responsible for shutting it down on stop. """ middleware: MiddlewareMixin | None = None @@ -105,6 +108,7 @@ class EventProcessor: Callable[[Exception], EventSpec | list[EventSpec] | None] | None ) = None graceful_shutdown_timeout: float | None = None + default_executor: Executor | None = None _queue: asyncio.Queue[EventQueueEntry] | None = dataclasses.field( default=None, init=False @@ -124,6 +128,30 @@ class EventProcessor: str, collections.deque[tuple[EventQueueEntry, RegisteredEventHandler]], ] = dataclasses.field(default_factory=dict, init=False) + _owns_default_executor: bool = dataclasses.field(default=False, init=False) + + def get_executor_for(self, handler: RegisteredEventHandler) -> Executor: + """Return the executor used to run a handler's non-async body. + + The handler's per-handler executor (from ``@rx.event(executor=...)``) + takes precedence over the processor's ``default_executor``. If neither + is set, a ``ThreadPoolExecutor`` is created lazily and owned by this + processor. + + Args: + handler: The registered event handler whose executor to resolve. + + Returns: + The executor for running this handler's non-async body. + """ + if (handler_executor := handler.handler.executor) is not None: + return handler_executor + if self.default_executor is None: + self.default_executor = ThreadPoolExecutor( + thread_name_prefix="reflex_event_handler" + ) + self._owns_default_executor = True + return self.default_executor def configure( self, @@ -240,13 +268,26 @@ async def _stop_tasks(self, timeout: float | None = None) -> None: """ finished_tasks = set() # Graceful drain time, wait for tasks to finish and handle any exceptions. - if timeout is not None and self._tasks: - with contextlib.suppress(asyncio.TimeoutError): - async for task in as_completed(self._tasks.values(), timeout=timeout): - # Exceptions are handled in _finish_task and ignored here. - with contextlib.suppress(Exception): - await task - finished_tasks.add(task) + # Re-snapshot self._tasks each iteration so tasks dispatched by + # ``_finish_task`` after a previous task completes (e.g. the next + # entry in a per-token queue) are also awaited within the budget. + if timeout is not None: + deadline = time.monotonic() + timeout + while True: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + pending = [ + task for task in self._tasks.values() if task not in finished_tasks + ] + if not pending: + break + with contextlib.suppress(asyncio.TimeoutError): + async for task in as_completed(pending, timeout=remaining): + # Exceptions are handled in _finish_task and ignored here. + with contextlib.suppress(Exception): + await task + finished_tasks.add(task) # Cancel all outstanding event handler tasks. outstanding_tasks = [ task for task in self._tasks.values() if task not in finished_tasks @@ -316,6 +357,11 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: if not future.done(): future.cancel() self._futures.clear() + # Shut down the lazily-created default executor (if owned). + if self._owns_default_executor and self.default_executor is not None: + self.default_executor.shutdown(wait=False) + self.default_executor = None + self._owns_default_executor = False async def join( self, timeout: float | None = None, queue: asyncio.Queue | None = None diff --git a/reflex/state.py b/reflex/state.py index e3e959a2d44..ac4eb0ff8ca 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -32,6 +32,7 @@ from reflex_base.event import ( BACKGROUND_TASK_MARKER, EVENT_ACTIONS_MARKER, + EXECUTOR_MARKER, Event, EventHandler, EventSpec, @@ -741,6 +742,8 @@ def _copy_fn(fn: Callable) -> Callable: # Preserve event_actions from @rx.event decorator if event_actions := getattr(fn, EVENT_ACTIONS_MARKER, None): object.__setattr__(newfn, EVENT_ACTIONS_MARKER, event_actions) + if executor := getattr(fn, EXECUTOR_MARKER, None): + object.__setattr__(newfn, EXECUTOR_MARKER, executor) return newfn @staticmethod diff --git a/tests/units/reflex_base/event/processor/test_base_state_processor.py b/tests/units/reflex_base/event/processor/test_base_state_processor.py index e414369a40e..91465d48d14 100644 --- a/tests/units/reflex_base/event/processor/test_base_state_processor.py +++ b/tests/units/reflex_base/event/processor/test_base_state_processor.py @@ -1,7 +1,10 @@ """Tests for BaseStateEventProcessor, specifically the _rehydrate path.""" +import asyncio +import threading import traceback from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor from typing import Any import pytest @@ -150,3 +153,102 @@ def noop(self): assert len(hydrated_deltas) >= 1, ( f"Expected at least one delta with is_hydrated=True, got deltas: {emitted_deltas}" ) + + +async def test_sync_handler_runs_off_event_loop( + app_module_mock, + real_base_state_processor: BaseStateEventProcessor, + token: str, +): + """A non-async @rx.event handler runs in a thread, not the asyncio loop. + + The handler captures the thread it ran on and asserts the asyncio loop is + still responsive while the handler is blocking (it blocks for a short + period using a real ``threading.Event``). + + Args: + app_module_mock: The mock app module fixture. + real_base_state_processor: The unmocked BaseStateEventProcessor. + token: The client token. + """ + from reflex.app import App + from reflex.event import Event + from reflex.state import OnLoadInternalState, State + + handler_threads: list[int] = [] + release = threading.Event() + + class MyState(State): + @event + def blocking(self): + handler_threads.append(threading.get_ident()) + # Block the worker thread until the loop signals us. + assert release.wait(timeout=5), ( + "asyncio loop never reached release.set() — sync handler" + " is blocking the event loop instead of running in a thread" + ) + + OnLoadInternalState._app_ref = None + app = app_module_mock.app = App() + assert real_base_state_processor._root_context is not None + app._state_manager = real_base_state_processor._root_context.state_manager + main_thread = threading.get_ident() + + async with real_base_state_processor as processor: + future = await processor.enqueue( + token, Event.from_event_type(MyState.blocking())[0] + ) + # While the handler is blocked, the asyncio loop must remain responsive. + await asyncio.sleep(0.05) + release.set() + await future + + assert len(handler_threads) == 1 + assert handler_threads[0] != main_thread + + +async def test_custom_executor_used_for_sync_handler( + app_module_mock, + real_base_state_processor: BaseStateEventProcessor, + token: str, +): + """Handlers decorated with ``executor=...`` run on that exact pool. + + Args: + app_module_mock: The mock app module fixture. + real_base_state_processor: The unmocked BaseStateEventProcessor. + token: The client token. + """ + from reflex.app import App + from reflex.event import Event + from reflex.state import OnLoadInternalState, State + + custom_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="rx_custom_pool" + ) + handler_thread_names: list[str] = [] + try: + + class MyState(State): + @event(executor=custom_executor) + def on_custom(self): + handler_thread_names.append(threading.current_thread().name) + + OnLoadInternalState._app_ref = None + app = app_module_mock.app = App() + assert real_base_state_processor._root_context is not None + app._state_manager = real_base_state_processor._root_context.state_manager + + async with real_base_state_processor as processor: + await processor.enqueue( + token, Event.from_event_type(MyState.on_custom())[0] + ) + await processor.join(1) + # When a handler brings its own executor, the processor never falls + # back to creating the lazy default. + assert processor.default_executor is None + + assert len(handler_thread_names) == 1 + assert handler_thread_names[0].startswith("rx_custom_pool") + finally: + custom_executor.shutdown(wait=False) diff --git a/tests/units/reflex_base/event/processor/test_event_processor.py b/tests/units/reflex_base/event/processor/test_event_processor.py index bcc1108be98..b0ea9e6a19e 100644 --- a/tests/units/reflex_base/event/processor/test_event_processor.py +++ b/tests/units/reflex_base/event/processor/test_event_processor.py @@ -2,6 +2,7 @@ import asyncio import contextlib +from concurrent.futures import ThreadPoolExecutor from typing import Any import pytest @@ -11,7 +12,7 @@ QueueShutDown, _stream_queue_until_done, ) -from reflex_base.registry import RegistrationContext +from reflex_base.registry import RegisteredEventHandler, RegistrationContext from reflex.event import Event, EventHandler @@ -735,3 +736,99 @@ async def _watcher(): # noqa: RUF029 collected = [v async for v in _stream_queue_until_done(queue, _watcher())] assert collected == [99] + + +def _make_registered_handler(handler: EventHandler) -> RegisteredEventHandler: + """Build a RegisteredEventHandler around a bare EventHandler for testing. + + Args: + handler: The EventHandler to wrap. + + Returns: + A RegisteredEventHandler with empty states. + """ + return RegisteredEventHandler(handler=handler, states=()) + + +def test_get_executor_for_uses_handler_executor(): + """A handler with its own executor wins over the processor default.""" + handler_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="rx_test") + default_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="rx_test") + try: + + def _sync_handler(): + pass + + _sync_handler._reflex_event_executor = handler_executor # type: ignore[attr-defined] + processor = EventProcessor(default_executor=default_executor) + chosen = processor.get_executor_for( + _make_registered_handler(EventHandler(fn=_sync_handler)) + ) + assert chosen is handler_executor + # Did not touch the default - still owned externally. + assert processor.default_executor is default_executor + assert processor._owns_default_executor is False + finally: + handler_executor.shutdown(wait=False) + default_executor.shutdown(wait=False) + + +def test_get_executor_for_lazily_creates_default(): + """When no executor is set, the processor lazily creates and owns one.""" + + def _sync_handler(): + pass + + processor = EventProcessor() + assert processor.default_executor is None + chosen = processor.get_executor_for( + _make_registered_handler(EventHandler(fn=_sync_handler)) + ) + assert chosen is processor.default_executor + assert processor._owns_default_executor is True + # Second call returns the same lazy instance. + assert ( + processor.get_executor_for( + _make_registered_handler(EventHandler(fn=_sync_handler)) + ) + is chosen + ) + # Shut it down ourselves since we never started the processor. + chosen.shutdown(wait=False) + + +async def test_stop_shuts_down_owned_default_executor(processor: EventProcessor): + """After stop(), a lazily-created default executor is shut down and cleared. + + Args: + processor: The event processor fixture. + """ + processor.configure() + async with processor: + executor = processor.get_executor_for( + _make_registered_handler(EventHandler(fn=_noop_handler)) + ) + assert processor._owns_default_executor is True + # The owned executor is released on stop(). + assert processor.default_executor is None + assert processor._owns_default_executor is False + # The reference we held is no longer accepting tasks. + with pytest.raises(RuntimeError): + executor.submit(lambda: None) + + +async def test_stop_does_not_shutdown_externally_provided_executor(): + """Externally provided executors are left alone on stop().""" + external = ThreadPoolExecutor(max_workers=1, thread_name_prefix="rx_test") + try: + processor = EventProcessor( + default_executor=external, graceful_shutdown_timeout=1 + ) + processor.configure() + async with processor: + pass + # Still set, still usable. + assert processor.default_executor is external + external.submit(lambda: None).result(timeout=1) + finally: + external.shutdown(wait=False) diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 500dd542097..02153e4cd1f 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -1,11 +1,13 @@ import json from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from typing import Any, cast import pytest from reflex_base.constants.compiler import Hooks, Imports from reflex_base.event import ( BACKGROUND_TASK_MARKER, + EXECUTOR_MARKER, Event, EventChain, EventChainVar, @@ -691,6 +693,8 @@ async def handle_old_background(self): assert isinstance(old_handler, EventHandler) assert old_handler.event_actions == {} assert not hasattr(old_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue] + assert old_handler.executor is None + assert not hasattr(old_handler.fn, EXECUTOR_MARKER) # pyright: ignore [reportAttributeAccessIssue] # Old background parameter should work unchanged bg_handler = MyTestState.handle_old_background @@ -698,6 +702,44 @@ async def handle_old_background(self): assert hasattr(bg_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue] +def test_event_decorator_with_executor(): + """Test that the @rx.event decorator accepts a custom executor.""" + custom_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="rx_test") + try: + + class MyTestState(BaseState): + @event(executor=custom_executor) + def handle_with_executor(self): + pass + + @event + def handle_default_executor(self): + pass + + @event(executor=custom_executor, throttle=300) + def handle_combined(self): + pass + + # Custom executor should be exposed via the handler. + with_executor = MyTestState.handle_with_executor + assert isinstance(with_executor, EventHandler) + assert with_executor.executor is custom_executor + assert getattr(with_executor.fn, EXECUTOR_MARKER) is custom_executor + + # No executor specified - handler exposes None. + default_handler = MyTestState.handle_default_executor + assert isinstance(default_handler, EventHandler) + assert default_handler.executor is None + + # Combined with other event actions. + combined = MyTestState.handle_combined + assert isinstance(combined, EventHandler) + assert combined.executor is custom_executor + assert combined.event_actions == {"throttle": 300} + finally: + custom_executor.shutdown(wait=False) + + def test_event_var_in_rx_cond(): """Test that EventVar and EventChainVar cannot be used in rx.cond().""" from reflex_components_core.core.cond import cond as rx_cond