Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion packages/reflex-base/src/reflex_base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class BaseConfig:
react_strict_mode: Whether to use React strict mode.
frontend_packages: Additional frontend packages to install.
state_manager_mode: Indicate which type of state manager to use.
redis_cluster: Whether Redis connections target a Redis Cluster and need hash-tagged keys.
redis_lock_expiration: Maximum expiration lock time for redis state manager.
redis_lock_warning_threshold: Maximum lock time before warning for redis state manager.
redis_token_expiration: Token expiration time for redis state manager.
Expand Down Expand Up @@ -228,6 +229,8 @@ class BaseConfig:

state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK

redis_cluster: bool = False

redis_lock_expiration: int = constants.Expiration.LOCK

redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
Expand Down Expand Up @@ -309,7 +312,7 @@ class Config(BaseConfig):
- **Server**: `frontend_port`, `backend_port`, `api_url`, `cors_allowed_origins`
- **Database**: `db_url`, `async_db_url`, `redis_url`
- **Frontend**: `frontend_packages`, `react_strict_mode`
- **State Management**: `state_manager_mode`, `state_auto_setters`
- **State Management**: `state_manager_mode`, `redis_cluster`, `state_auto_setters`
- **Plugins**: `plugins`, `disable_plugins`

See the [configuration docs](https://reflex.dev/docs/advanced-onboarding/configuration) for complete details on all available options.
Expand Down
1 change: 1 addition & 0 deletions reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def create(cls):
return StateManagerRedis(
redis=redis,
token_expiration=config.redis_token_expiration,
redis_cluster=config.redis_cluster,
lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
)
Expand Down
64 changes: 55 additions & 9 deletions reflex/istate/manager/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken
from reflex.state import BaseState
from reflex.utils.redis_keys import format_redis_key, logical_redis_key
from reflex.utils.tasks import ensure_task


Expand All @@ -42,6 +43,15 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration


def _default_redis_cluster() -> bool:
"""Get whether Redis Cluster key tagging is enabled.

Returns:
Whether Redis Cluster key tagging is enabled.
"""
return get_config().redis_cluster


def _default_lock_warning_threshold() -> int:
"""Get the default lock warning threshold.

Expand Down Expand Up @@ -93,6 +103,9 @@ class StateManagerRedis(StateManager):
# The token expiration time (s).
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)

# Whether Redis key names should include Redis Cluster hash tags.
redis_cluster: bool = dataclasses.field(default_factory=_default_redis_cluster)

# The maximum time to hold a lock (ms).
lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration)

Expand Down Expand Up @@ -284,7 +297,7 @@ async def get_state(
token = self._coerce_token(token)
if not isinstance(token, BaseStateToken):
# Non-BaseState token: simple single-key fetch.
redis_data = await self.redis.get(str(token))
redis_data = await self.redis.get(self._state_key(token))
if redis_data is not None:
return token.deserialize(data=redis_data)
return token.cls()
Expand All @@ -305,7 +318,7 @@ async def get_state(

redis_pipeline = self.redis.pipeline()
for state_cls in required_state_classes:
redis_pipeline.get(str(token.with_cls(state_cls)))
redis_pipeline.get(self._state_key(token.with_cls(state_cls)))

for state_cls, redis_state in zip(
required_state_classes,
Expand Down Expand Up @@ -394,7 +407,9 @@ async def set_state(
# Non-BaseState token: simple single-key write.
pickle_state = token.serialize(state)
if pickle_state:
await self.redis.set(str(token), pickle_state, ex=self.token_expiration)
await self.redis.set(
self._state_key(token), pickle_state, ex=self.token_expiration
)
return

base_state = cast(BaseState, state)
Expand Down Expand Up @@ -435,7 +450,7 @@ async def set_state(
pickle_state = base_state._serialize()
if pickle_state:
await self.redis.set(
str(token.with_cls(type(base_state))),
self._state_key(token.with_cls(type(base_state))),
pickle_state,
ex=self.token_expiration,
)
Expand Down Expand Up @@ -730,8 +745,22 @@ async def lease_breaker():
return task
return None

@staticmethod
def _lock_key(token: StateToken[Any]) -> bytes:
def _state_key(self, token: StateToken[Any]) -> str:
"""Get the Redis key for a token's state.

Args:
token: The token to get the state key for.

Returns:
The Redis key for the token's state.
"""
return format_redis_key(
str(token),
cluster=self.redis_cluster,
slot_key=token.lock_key,
)

def _lock_key(self, token: StateToken[Any]) -> bytes:
"""Get the redis key for a token's lock.

Args:
Expand All @@ -740,7 +769,23 @@ def _lock_key(token: StateToken[Any]) -> bytes:
Returns:
The redis lock key for the token.
"""
return f"{token.lock_key}_lock".encode()
return format_redis_key(
f"{token.lock_key}_lock",
cluster=self.redis_cluster,
slot_key=token.lock_key,
).encode()

@staticmethod
def _logical_lock_key(lock_key: bytes) -> str:
"""Get the local logical lock key from a Redis lock key.

Args:
lock_key: The Redis lock key.

Returns:
The logical lock key without Redis Cluster hash tags or suffixes.
"""
return logical_redis_key(lock_key.decode()).rsplit("_lock", 1)[0]

async def _try_extend_lock(self, lock_key: bytes) -> bool | None:
"""Extends the current lock for another lock_expiration period.
Expand Down Expand Up @@ -790,7 +835,8 @@ async def _handle_lock_contention(self, message: RedisPubSubMessage) -> None:
message: The redis message.
"""
# Opportunistic lock contention notification.
token = message["channel"].rsplit(b":", 1)[1][: -len(b"_lock_waiters")].decode()
redis_key = message["channel"].split(b":", 1)[1].decode()
token = logical_redis_key(redis_key).removesuffix("_lock_waiters")
if (
message["data"] == b"sadd"
and (state_lock := self._cached_states_locks.get(token)) is not None
Expand Down Expand Up @@ -1001,7 +1047,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
lock_key: The redis key for the lock.
lock_id: The ID of the lock.
"""
token = lock_key.decode().rsplit("_lock", 1)[0]
token = self._logical_lock_key(lock_key)
if (
# If there's not a line, try to get the lock immediately.
not self._n_lock_waiters(lock_key)
Expand Down
47 changes: 47 additions & 0 deletions reflex/utils/redis_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Helpers for Redis key names."""

from __future__ import annotations

import hashlib


def stable_redis_hash_tag(slot_key: str) -> str:
"""Generate a stable Redis Cluster hash tag for a logical slot key.

Args:
slot_key: The logical key that should determine the Redis Cluster slot.

Returns:
A stable hash tag safe to place between Redis Cluster hash braces.
"""
return hashlib.sha256(slot_key.encode()).hexdigest()
Comment thread
riebecj marked this conversation as resolved.
Outdated


def format_redis_key(logical_key: str, *, cluster: bool, slot_key: str) -> str:
"""Format a Redis key, optionally with a Redis Cluster hash tag.

Args:
logical_key: The unmodified logical Redis key.
cluster: Whether Redis Cluster key tagging is enabled.
slot_key: The logical key used to choose the Redis Cluster slot.

Returns:
The Redis key to send to Redis.
"""
if not cluster:
return logical_key
return f"{{{stable_redis_hash_tag(slot_key)}}}:{logical_key}"


def logical_redis_key(redis_key: str) -> str:
"""Remove a Redis Cluster hash tag from a key if present.

Args:
redis_key: The Redis key as sent to or received from Redis.

Returns:
The logical key without any Redis Cluster hash tag prefix.
"""
if redis_key.startswith("{") and "}:" in redis_key:
return redis_key.split("}:", 1)[1]
return redis_key
Comment thread
riebecj marked this conversation as resolved.
Outdated
51 changes: 42 additions & 9 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from reflex.istate.manager.redis import StateManagerRedis
from reflex.state import StateUpdate
from reflex.utils import console, prerequisites
from reflex.utils.redis_keys import format_redis_key, logical_redis_key
from reflex.utils.tasks import ensure_task

if TYPE_CHECKING:
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(self, redis: Redis):

config = get_config()
self.token_expiration = config.redis_token_expiration
self.redis_cluster = config.redis_cluster

# Pub/sub tasks for handling sockets owned by other instances.
self._socket_record_task: asyncio.Task | None = None
Expand All @@ -211,7 +213,33 @@ def _get_redis_key(self, token: str) -> str:
Returns:
Redis key following Reflex conventions: token_manager_socket_record_{token}
"""
return f"{self._token_socket_record_prefix}{token}"
return format_redis_key(
f"{self._token_socket_record_prefix}{token}",
cluster=self.redis_cluster,
slot_key=token,
)

def _get_redis_key_pattern(self) -> str:
"""Get Redis key pattern for token mapping keys.

Returns:
Redis key pattern for token mapping keys.
"""
key_pattern = f"{self._token_socket_record_prefix}*"
return f"{{*}}:{key_pattern}" if self.redis_cluster else key_pattern

def _token_from_redis_key(self, redis_key: str) -> str:
"""Extract a client token from a Redis token mapping key.

Args:
redis_key: The Redis key for a token mapping.

Returns:
The client token.
"""
return logical_redis_key(redis_key).removeprefix(
self._token_socket_record_prefix
)

async def enumerate_tokens(self) -> AsyncIterator[str]:
"""Iterate over all tokens in the system.
Expand All @@ -221,11 +249,11 @@ async def enumerate_tokens(self) -> AsyncIterator[str]:
"""
cursor = 0
while scan_result := await self.redis.scan(
cursor=cursor, match=self._get_redis_key("*")
cursor=cursor, match=self._get_redis_key_pattern()
):
cursor = int(scan_result[0])
for key in scan_result[1]:
yield key.decode().replace(self._token_socket_record_prefix, "")
yield self._token_from_redis_key(key.decode())
if not cursor:
break

Expand All @@ -248,17 +276,19 @@ async def _handle_socket_record_del(

async def _subscribe_socket_record_updates(self) -> None:
"""Subscribe to Redis keyspace notifications for socket record updates."""
await StateManagerRedis(redis=self.redis)._enable_keyspace_notifications()
await StateManagerRedis(
redis=self.redis, redis_cluster=self.redis_cluster
)._enable_keyspace_notifications()
redis_db = self.redis.get_connection_kwargs().get("db", 0)

async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
f"__keyspace@{redis_db}__:{self._get_redis_key_pattern()}"
)
async for message in pubsub.listen():
if message["type"] == "pmessage":
key = message["channel"].split(b":", 1)[1].decode()
token = key.replace(self._token_socket_record_prefix, "")
token = self._token_from_redis_key(key)

if token not in self.token_to_socket:
# We don't know about this token, skip
Expand Down Expand Up @@ -357,8 +387,7 @@ async def disconnect_token(self, token: str, sid: str) -> None:
# Clean up local dicts (always do this)
await super().disconnect_token(token, sid)

@staticmethod
def _get_lost_and_found_key(instance_id: str) -> str:
def _get_lost_and_found_key(self, instance_id: str) -> str:
"""Get the Redis key for lost and found deltas for an instance.

Args:
Expand All @@ -367,7 +396,11 @@ def _get_lost_and_found_key(instance_id: str) -> str:
Returns:
The Redis key for lost and found deltas.
"""
return f"token_manager_lost_and_found_{instance_id}"
return format_redis_key(
f"token_manager_lost_and_found_{instance_id}",
cluster=self.redis_cluster,
slot_key=instance_id,
)

async def _subscribe_lost_and_found_updates(
self,
Expand Down
Loading
Loading