Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions packages/reflex-base/src/reflex_base/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from base64 import b64encode
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import Executor
from functools import lru_cache, partial
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -177,6 +178,7 @@ def from_event_type(

BACKGROUND_TASK_MARKER = "_reflex_background_task"
EVENT_ACTIONS_MARKER = "_rx_event_actions"
EXECUTOR_MARKER = "_reflex_event_executor"
UPLOAD_FILES_CLIENT_HANDLER = "uploadFiles"


Expand Down Expand Up @@ -426,6 +428,16 @@ def is_background(self) -> bool:
"""
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)

@property
def executor(self) -> Executor | None:
"""The executor to run a non-async handler in, if any.

Returns:
The executor specified via ``@rx.event(executor=...)`` or
``None`` if no executor was specified.
"""
return getattr(self.fn, EXECUTOR_MARKER, None)

def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec":
"""Pass arguments to the handler to get an event spec.

Expand Down Expand Up @@ -2689,6 +2701,7 @@ class EventNamespace:
# Constants
BACKGROUND_TASK_MARKER = BACKGROUND_TASK_MARKER
EVENT_ACTIONS_MARKER = EVENT_ACTIONS_MARKER
EXECUTOR_MARKER = EXECUTOR_MARKER
_EVENT_FIELDS = _EVENT_FIELDS
FORM_DATA = FORM_DATA
upload_files = upload_files
Expand Down Expand Up @@ -2717,6 +2730,7 @@ def __new__(
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
executor: Executor | None = None,
) -> (
"Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]]"
): ...
Expand All @@ -2732,6 +2746,7 @@ def __new__(
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
executor: Executor | None = None,
) -> EventCallback[Unpack[P]]: ...

def __new__(
Expand All @@ -2744,6 +2759,7 @@ def __new__(
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
executor: Executor | None = None,
) -> (
EventCallback[Unpack[P]]
| Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]]
Expand All @@ -2758,6 +2774,9 @@ def __new__(
throttle: Throttle the event handler to limit calls (in milliseconds).
debounce: Debounce the event handler to delay calls (in milliseconds).
temporal: Whether the event should be dropped when the backend is down.
executor: The executor to run a non-async handler in. If omitted,
the EventProcessor's default thread pool is used. Ignored for
async (coroutine or async-generator) handlers.

Returns:
The wrapped function.
Expand Down Expand Up @@ -2804,6 +2823,8 @@ def wrapper(
msg = "Background task must be async function or generator."
raise TypeError(msg)
setattr(func, BACKGROUND_TASK_MARKER, True)
if executor is not None:
setattr(func, EXECUTOR_MARKER, executor)
if getattr(func, "__name__", "").startswith("_"):
msg = "Event handlers cannot be private."
raise ValueError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

import asyncio
import dataclasses
import functools
import inspect
import warnings
from collections.abc import Mapping, Sequence
from concurrent.futures import Executor
from enum import Enum
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -202,11 +204,44 @@ async def chain_updates(
await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_")))


@dataclasses.dataclass(frozen=True, slots=True)
class _GeneratorStep:
"""One step of advancing a sync generator across the executor boundary.

``StopIteration`` cannot be propagated out of an executor — the executor
re-raises it as ``RuntimeError`` and ``StopIteration.value`` (the
generator's return value) is discarded. This wrapper captures both
"yielded a value" and "generator returned" outcomes in a plain result
that crosses the boundary cleanly.
"""

value: Any = None
done: bool = False


def _next_or_done(generator: Any) -> _GeneratorStep:
"""Advance a generator one step.

Args:
generator: The sync generator to advance.

Returns:
A ``_GeneratorStep`` carrying either the next yielded value, or the
generator's return value with ``done=True`` when the generator is
exhausted.
"""
try:
return _GeneratorStep(value=next(generator))
except StopIteration as si:
return _GeneratorStep(value=si.value, done=True)


async def process_event(
handler: EventHandler,
payload: dict,
state: BaseState | StateProxy,
root_state: BaseState,
executor: Executor | None = None,
):
"""Process event.

Expand All @@ -215,6 +250,10 @@ async def process_event(
payload: The event payload.
state: State to process the handler.
root_state: The root state of the app, used for emitting deltas.
executor: The executor to run a non-async handler in. Async handlers
run on the asyncio loop and ignore this argument. If None, sync
handlers run inline on the asyncio loop (matching pre-executor
behavior).

Raises:
ValueError: If a string value is received for an int or float type and cannot be converted.
Expand All @@ -237,7 +276,16 @@ async def process_event(
if inspect.iscoroutinefunction(fn.func):
events = await fn(**payload)

# Handle regular functions.
# Handle async generators - the function itself returns synchronously
# (yielding an async generator object); only the body iteration is async.
elif inspect.isasyncgenfunction(fn.func):
events = fn(**payload)

# Handle regular functions - run off the asyncio loop when an executor
# is configured so blocking calls in user code don't stall the loop.
elif executor is not None:
loop = asyncio.get_running_loop()
events = await loop.run_in_executor(executor, functools.partial(fn, **payload))
else:
events = fn(**payload)
# Handle async generators.
Expand All @@ -248,18 +296,21 @@ async def process_event(

# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
await chain_updates(
next(events), root_state=root_state, handler_name=handler_name
)
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
await chain_updates(
si.value, root_state=root_state, handler_name=handler_name
)
loop = asyncio.get_running_loop() if executor is not None else None
while True:
if loop is not None:
step = await loop.run_in_executor(executor, _next_or_done, events)
else:
step = _next_or_done(events)
if step.done:
if step.value is not None:
await chain_updates(
step.value, root_state=root_state, handler_name=handler_name
)
break
await chain_updates(
step.value, root_state=root_state, handler_name=handler_name
)
await chain_updates(None, root_state=root_state, handler_name=handler_name)

# Handle regular event chains.
Expand Down Expand Up @@ -365,6 +416,7 @@ async def _execute_event(
payload=event.payload,
state=substate,
root_state=root_state,
executor=self.get_executor_for(registered_handler),
)
return
# Otherwise drop the state lock and start processing the background task with a proxy state.
Expand All @@ -373,6 +425,7 @@ async def _execute_event(
state=StateProxy(substate),
payload=event.payload,
root_state=root_state,
executor=self.get_executor_for(registered_handler),
)

async def _handle_backend_exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Coroutine, Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from contextvars import Token, copy_context
from typing import TYPE_CHECKING, Any, TypeVar

Expand Down Expand Up @@ -92,19 +93,22 @@ class EventProcessor:
middleware: An optional middleware mixin to apply to all events processed by this processor.
backend_exception_handler: An optional function to handle exceptions raised during event processing. The function should take an Exception as input and return an EventSpec or list of EventSpecs to be emitted in response, or None to not emit any events.
graceful_shutdown_timeout: An optional amount of time in seconds to wait for the queue to drain before forcefully cancelling tasks when stopping the processor. If None, the processor will not wait and will cancel tasks immediately.
default_executor: The executor used to run non-async event handlers off the asyncio loop. If unset, a ``ThreadPoolExecutor`` is created lazily on first use and shut down when the processor stops. Per-handler executors set via ``@rx.event(executor=...)`` take precedence over this default.

_queue: The asyncio queue for events to be processed.
_queue_task: The task responsible for processing the event queue.
_root_context: The root event context to use for events enqueued without an explicit context.
_attached_root_context_token: The context variable token for the attached root context, used to reset the context variable on shutdown.
_tasks: A mapping of active transaction ids to their corresponding event handler tasks, used for tracking and cancellation on shutdown.
_owns_default_executor: True when the processor lazily created its own ``default_executor`` and is responsible for shutting it down on stop.
"""

middleware: MiddlewareMixin | None = None
backend_exception_handler: (
Callable[[Exception], EventSpec | list[EventSpec] | None] | None
) = None
graceful_shutdown_timeout: float | None = None
default_executor: Executor | None = None

_queue: asyncio.Queue[EventQueueEntry] | None = dataclasses.field(
default=None, init=False
Expand All @@ -124,6 +128,30 @@ class EventProcessor:
str,
collections.deque[tuple[EventQueueEntry, RegisteredEventHandler]],
] = dataclasses.field(default_factory=dict, init=False)
_owns_default_executor: bool = dataclasses.field(default=False, init=False)

def get_executor_for(self, handler: RegisteredEventHandler) -> Executor:
"""Return the executor used to run a handler's non-async body.

The handler's per-handler executor (from ``@rx.event(executor=...)``)
takes precedence over the processor's ``default_executor``. If neither
is set, a ``ThreadPoolExecutor`` is created lazily and owned by this
processor.

Args:
handler: The registered event handler whose executor to resolve.

Returns:
The executor for running this handler's non-async body.
"""
if (handler_executor := handler.handler.executor) is not None:
return handler_executor
if self.default_executor is None:
self.default_executor = ThreadPoolExecutor(
thread_name_prefix="reflex_event_handler"
)
self._owns_default_executor = True
return self.default_executor

def configure(
self,
Expand Down Expand Up @@ -240,13 +268,26 @@ async def _stop_tasks(self, timeout: float | None = None) -> None:
"""
finished_tasks = set()
# Graceful drain time, wait for tasks to finish and handle any exceptions.
if timeout is not None and self._tasks:
with contextlib.suppress(asyncio.TimeoutError):
async for task in as_completed(self._tasks.values(), timeout=timeout):
# Exceptions are handled in _finish_task and ignored here.
with contextlib.suppress(Exception):
await task
finished_tasks.add(task)
# Re-snapshot self._tasks each iteration so tasks dispatched by
# ``_finish_task`` after a previous task completes (e.g. the next
# entry in a per-token queue) are also awaited within the budget.
if timeout is not None:
deadline = time.monotonic() + timeout
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
pending = [
task for task in self._tasks.values() if task not in finished_tasks
]
if not pending:
break
with contextlib.suppress(asyncio.TimeoutError):
async for task in as_completed(pending, timeout=remaining):
# Exceptions are handled in _finish_task and ignored here.
with contextlib.suppress(Exception):
await task
finished_tasks.add(task)
# Cancel all outstanding event handler tasks.
outstanding_tasks = [
task for task in self._tasks.values() if task not in finished_tasks
Expand Down Expand Up @@ -316,6 +357,11 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None:
if not future.done():
future.cancel()
self._futures.clear()
# Shut down the lazily-created default executor (if owned).
if self._owns_default_executor and self.default_executor is not None:
self.default_executor.shutdown(wait=False)
self.default_executor = None
self._owns_default_executor = False

async def join(
self, timeout: float | None = None, queue: asyncio.Queue | None = None
Expand Down
3 changes: 3 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from reflex_base.event import (
BACKGROUND_TASK_MARKER,
EVENT_ACTIONS_MARKER,
EXECUTOR_MARKER,
Event,
EventHandler,
EventSpec,
Expand Down Expand Up @@ -741,6 +742,8 @@ def _copy_fn(fn: Callable) -> Callable:
# Preserve event_actions from @rx.event decorator
if event_actions := getattr(fn, EVENT_ACTIONS_MARKER, None):
object.__setattr__(newfn, EVENT_ACTIONS_MARKER, event_actions)
if executor := getattr(fn, EXECUTOR_MARKER, None):
object.__setattr__(newfn, EXECUTOR_MARKER, executor)
return newfn

@staticmethod
Expand Down
Loading
Loading