diff --git a/docs/app/reflex_docs/templates/docpage/docpage.py b/docs/app/reflex_docs/templates/docpage/docpage.py index 5ba080dcc8c..6393966a525 100644 --- a/docs/app/reflex_docs/templates/docpage/docpage.py +++ b/docs/app/reflex_docs/templates/docpage/docpage.py @@ -250,7 +250,7 @@ def feedback_button_toc() -> rx.Component: @rx.memo -def copy_to_markdown(text: str) -> rx.Component: +def copy_to_markdown(text: rx.Var[str]) -> rx.Component: copied = ClientStateVar.create("is_copied", default=False, global_ref=False) return marketing_button( rx.cond( @@ -297,7 +297,7 @@ def link_pill(text: str, href: str) -> rx.Component: @rx.memo -def docpage_footer(path: str): +def docpage_footer(path: rx.Var[str]) -> rx.Component: from reflex_site_shared.constants import FORUM_URL, ROADMAP_URL return rx.el.footer( diff --git a/docs/enterprise/drag-and-drop.md b/docs/enterprise/drag-and-drop.md index 58c59243c54..9edc26736b6 100644 --- a/docs/enterprise/drag-and-drop.md +++ b/docs/enterprise/drag-and-drop.md @@ -34,7 +34,7 @@ class BasicDndState(rx.State): @rx.memo -def draggable_card(): +def draggable_card() -> rxe.dnd.Draggable: return rxe.dnd.draggable( rx.card( rx.text("Drag me!", weight="bold"), @@ -95,7 +95,7 @@ class MultiPositionState(rx.State): @rx.memo -def movable_card(): +def movable_card() -> rxe.dnd.Draggable: return rxe.dnd.draggable( rx.card( rx.text("Movable Card", weight="bold"), @@ -166,7 +166,7 @@ class StateTrackingState(rx.State): @rx.memo -def tracked_draggable(): +def tracked_draggable() -> rxe.dnd.Draggable: drag_params = rxe.dnd.Draggable.collected_params return rxe.dnd.draggable( rx.card( @@ -274,7 +274,7 @@ class DynamicListState(rx.State): @rx.memo -def draggable_list_item(item: ListItem): +def draggable_list_item(item: rx.Var[ListItem]) -> rx.Component: return rxe.dnd.draggable( rx.card( rx.text(item.text, weight="bold"), diff --git a/docs/enterprise/react_flow/edges.md b/docs/enterprise/react_flow/edges.md index 31bcaf3da4d..dd48b047236 100644 --- a/docs/enterprise/react_flow/edges.md +++ b/docs/enterprise/react_flow/edges.md @@ -165,7 +165,7 @@ def button_edge( sourcePosition: rx.Var[Position], targetPosition: rx.Var[Position], markerEnd: rx.Var[str], -): +) -> rx.Fragment: bezier_path = rxe.components.flow.util.get_bezier_path( source_x=sourceX, source_y=sourceY, diff --git a/docs/enterprise/react_flow/examples.md b/docs/enterprise/react_flow/examples.md index e2c425dbf9e..2d40f84e6e1 100644 --- a/docs/enterprise/react_flow/examples.md +++ b/docs/enterprise/react_flow/examples.md @@ -179,7 +179,7 @@ class ConnectionLimitState(rx.State): @rx.memo def custom_handle( type: rx.Var[HandleType], position: rx.Var[Position], connection_count: rx.Var[int] -): +) -> rxe.components.flow.Handle: connections = rxe.flow.api.get_node_connections() return rxe.flow.handle( type=type, @@ -190,7 +190,7 @@ def custom_handle( @rx.memo -def custom_node(): +def custom_node() -> rx.el.Div: return rx.el.div( custom_handle(type="target", position="left", connection_count=1), rx.el.div("← Only one edge allowed"), diff --git a/docs/enterprise/react_flow/nodes.md b/docs/enterprise/react_flow/nodes.md index 7bb14a1b672..c9a97a96289 100644 --- a/docs/enterprise/react_flow/nodes.md +++ b/docs/enterprise/react_flow/nodes.md @@ -201,7 +201,9 @@ class CustomNodeState(rx.State): @rx.memo -def color_selector_node(data: rx.Var[dict], isConnectable: rx.Var[bool]): +def color_selector_node( + data: rx.Var[dict], isConnectable: rx.Var[bool] +) -> rx.Component: data = data.to(dict) return rx.el.div( rxe.flow.handle( diff --git a/docs/library/data-display/icon.md b/docs/library/data-display/icon.md index 9c8dd449a44..36a9bf78df7 100644 --- a/docs/library/data-display/icon.md +++ b/docs/library/data-display/icon.md @@ -13,7 +13,7 @@ icon_search_cs = ClientStateVar.create("icon_search", default="") @rx.memo -def lucide_icons(): +def lucide_icons() -> rx.Component: return rx.box( rx.box( rx.box( diff --git a/docs/library/other/memo.md b/docs/library/other/memo.md index f108d34a6a2..a3415b6cd00 100644 --- a/docs/library/other/memo.md +++ b/docs/library/other/memo.md @@ -4,21 +4,22 @@ import reflex as rx # Memo -The `memo` decorator is used to optimize component rendering by memoizing components that don't need to be re-rendered. This is particularly useful for expensive components that depend on specific props and don't need to be re-rendered when other state changes in your application. +The `@rx.memo` decorator turns a function into a memoized React component. The compiler emits the function as its own module, and React's `memo` only re-renders it when its declared props change. Reach for it when a subtree is expensive to render and depends on a narrow slice of state. ## Requirements -When using `rx.memo`, you must follow these requirements: +Every parameter must be annotated with `rx.Var[...]` or `rx.RestProp`. The compiler reads those annotations to generate prop names, prop forwarding, and the JS function signature. -1. **Type all arguments**: All arguments to a memoized component must have type annotations. -2. **Use keyword arguments**: When calling a memoized component, you must use keyword arguments (not positional arguments). +1. **`rx.Var[T]` for props** — annotate each prop as `rx.Var[T]` where `T` is the prop's runtime type (`str`, `int`, a TypedDict, etc.). Inside the function body, the parameter is a `Var` you compose into the rendered tree. +2. **`rx.RestProp` for spread props** — at most one parameter may be annotated as `rx.RestProp`, which forwards unrecognized kwargs through to the rendered root. +3. **`rx.Var[rx.Component]` for slot children** — a parameter named `children` annotated as `rx.Var[rx.Component]` accepts children rendered by the caller. +4. **Keyword arguments at the call site** — pass props by name, not by position. -## Basic Usage +Defaults need to be `rx.Var` values. For the common empty cases use the module-level constants `rx.EMPTY_VAR_STR` (an empty string) and `rx.EMPTY_VAR_INT` (zero): `class_name: rx.Var[str] = rx.EMPTY_VAR_STR` falls back to `""` when the caller omits the prop. -When you wrap a component function with `@rx.memo`, the component will only re-render when its props change. This helps improve performance by preventing unnecessary re-renders. +## Basic Usage ```python -# Define a state class to track count class DemoState(rx.State): count: int = 0 @@ -27,150 +28,148 @@ class DemoState(rx.State): self.count += 1 -# Define a memoized component @rx.memo -def expensive_component(label: str) -> rx.Component: +def expensive_component(label: rx.Var[str]) -> rx.Component: return rx.vstack( rx.heading(label), - rx.text("This component only re-renders when props change!"), + rx.text("This component only re-renders when props change."), rx.divider(), ) -# Use the memoized component in your app def index(): return rx.vstack( - rx.heading("Memo Example"), - rx.text("Count: 0"), # This will update with state.count + rx.text(f"Count: {DemoState.count}"), rx.button("Increment", on_click=DemoState.increment), - rx.divider(), - expensive_component(label="Memoized Component"), # Must use keyword arguments - spacing="4", - padding="4", - border_radius="md", - border="1px solid #eaeaea", + expensive_component(label="Memoized Component"), ) ``` -In this example, the `expensive_component` will only re-render when the `label` prop changes, not when the `count` state changes. +`expensive_component` re-renders only when `label` changes — bumping `DemoState.count` does not invalidate it. -## With Event Handlers +## With State Variables -You can also use `rx.memo` with components that have event handlers: +Props can be ordinary Vars. The memoized component re-renders when those Vars change: ```python -# Define a state class to track clicks -class ButtonState(rx.State): - clicks: int = 0 - - @rx.event - def increment(self): - self.clicks += 1 +class AppState(rx.State): + name: str = "World" -# Define a memoized button component @rx.memo -def my_button(text: str, on_click: rx.EventHandler) -> rx.Component: - return rx.button(text, on_click=on_click) +def greeting(name: rx.Var[str]) -> rx.Component: + return rx.heading("Hello, " + name) -# Use the memoized button in your app def index(): return rx.vstack( - rx.text("Clicks: 0"), # This will update with state.clicks - my_button(text="Click me", on_click=ButtonState.increment), - spacing="4", + greeting(name=AppState.name), + rx.input(value=AppState.name, on_change=AppState.set_name), ) ``` -## With State Variables +## Forwarding Props with `rx.RestProp` -When used with state variables, memoized components will only re-render when the specific state variables they depend on change: +Use `rx.RestProp` to accept and forward arbitrary props (think `...rest` in JSX). Useful for thin wrappers that re-style a primitive without redeclaring every prop. ```python -# Define a state class with multiple variables -class AppState(rx.State): - name: str = "World" - count: int = 0 +@rx.memo +def primary_button( + rest: rx.RestProp, + *, + label: rx.Var[str], +) -> rx.Component: + return rx.button(label, class_name="bg-primary-9 text-white", **rest) - @rx.event - def increment(self): - self.count += 1 - @rx.event - def set_name(self, name: str): - self.name = name +def index(): + return primary_button( + label="Save", + on_click=rx.console_log("clicked"), + id="save", + ) +``` +At most one `rx.RestProp` parameter is allowed per memo. -# Define a memoized greeting component +## Accepting Children + +Declare a parameter named `children` typed as `rx.Var[rx.Component]` to receive a child subtree. + +```python @rx.memo -def greeting(name: str) -> rx.Component: - return rx.heading("Hello, " + name) # Will display the name prop +def card( + children: rx.Var[rx.Component], + *, + title: rx.Var[str], +) -> rx.Component: + return rx.box( + rx.heading(title), + children, + class_name="border border-slate-5 rounded-lg p-4", + ) -# Use the memoized component with state variables def index(): - return rx.vstack( - greeting(name=AppState.name), # Must use keyword arguments - rx.text("Count: 0"), # Will display the count - rx.button("Increment Count", on_click=AppState.increment), - rx.input( - placeholder="Enter your name", - on_change=AppState.set_name, - value="World", # Will be bound to AppState.name - ), - spacing="4", + return card( + rx.text("Body copy goes here."), + title="Memoized card", ) ``` -## Advanced Event Handler Example +## Returning a `Var` Instead of a Component -You can also pass arguments to event handlers in memoized components: +A memo function can return `rx.Var[T]` instead of `rx.Component`. The compiler emits a plain JavaScript function and the call site is just a `Var` you can compose into the page. ```python -# Define a state class to track messages -class MessageState(rx.State): - message: str = "" - - @rx.event - def set_message(self, text: str): - self.message = text +class PriceState(rx.State): + amount: int = 100 + currency: str = "USD" -# Define a memoized component with event handlers that pass arguments @rx.memo -def action_buttons( - on_action: rx.EventHandler[rx.event.passthrough_event_spec(str)], -) -> rx.Component: - return rx.hstack( - rx.button("Save", on_click=on_action("Saved!")), - rx.button("Delete", on_click=on_action("Deleted!")), - rx.button("Cancel", on_click=on_action("Cancelled!")), - spacing="2", - ) +def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) -# Use the memoized component with event handlers def index(): + formatted = format_price(amount=PriceState.amount, currency=PriceState.currency) return rx.vstack( - rx.text("Status: "), # Will display the message - action_buttons(on_action=MessageState.set_message), - spacing="4", + rx.text(formatted), ) ``` +The body of a `Var`-returning memo runs at compile time and is restricted to Var operations — no hooks, no Python branching on the Vars. + ## Performance Considerations -Use `rx.memo` for: +Reach for `rx.memo` when: -- Components with expensive rendering logic -- Components that render the same result given the same props -- Components that re-render too often due to parent component updates +- The component is expensive to render. +- Its output is a stable function of a small set of props. +- A frequently-updating ancestor would otherwise force it to re-render. -Avoid using `rx.memo` for: +Skip it when: + +- The component is cheap and the bookkeeping is not worth it. +- The props change on every render anyway — memo never gets to short-circuit. + +## Migrating from the Old `rx.memo` + +The previous `rx.memo` accepted plain-typed arguments (`def card(title: str)`). The new one requires `rx.Var[...]`. To migrate: + +```python +# Before +@rx.memo +def card(title: str) -> rx.Component: ... + + +# After +@rx.memo +def card(title: rx.Var[str]) -> rx.Component: ... +``` -- Simple components where the memoization overhead might exceed the performance gain -- Components that almost always receive different props on re-render +The old `rx._x.memo` alias still resolves to the new memo and prints a one-time `was promoted to rx.memo` notice. ## API Reference @@ -180,8 +179,8 @@ Avoid using `rx.memo` for: rx.memo(component_fn) ``` -Decorates a function that returns a Reflex component so it can be reused as a memoized component. The function arguments must be type annotated, and memoized components should be called with keyword arguments. +Wraps a function whose parameters are all `rx.Var[...]` or `rx.RestProp`. Returns a callable that constructs the memoized component (or a `Var` if the function's return annotation is `rx.Var[T]`). | Argument | Type | Description | | --- | --- | --- | -| `component_fn` | `Callable[..., rx.Component]` | Function that returns the component to memoize. | +| `component_fn` | `Callable[..., rx.Component \| rx.Var]` | The function to memoize. All parameters must be `rx.Var[...]` or `rx.RestProp`. | diff --git a/packages/reflex-base/src/reflex_base/compiler/templates.py b/packages/reflex-base/src/reflex_base/compiler/templates.py index 08bb95a71d9..838d414b964 100644 --- a/packages/reflex-base/src/reflex_base/compiler/templates.py +++ b/packages/reflex-base/src/reflex_base/compiler/templates.py @@ -785,22 +785,6 @@ def memo_single_function_template( """ -def memo_index_template(reexports: Iterable[tuple[str, str]]) -> str: - """Template for the memo index module that re-exports every memo file. - - Args: - reexports: Iterable of ``(export_name, relative_module_specifier)``. - - Returns: - The rendered index module code. - """ - lines = [ - f'export {{ {export_name} }} from "{specifier}";' - for export_name, specifier in reexports - ] - return "\n".join(lines) + "\n" - - def styles_template(stylesheets: list[str]) -> str: """Template for styles.css. diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 14544841662..6af99f0c449 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -6,16 +6,14 @@ import dataclasses import enum import functools -import inspect import operator import typing from abc import ABC, ABCMeta, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import _MISSING_TYPE, MISSING -from functools import wraps from hashlib import md5 from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast from rich.markup import escape from typing_extensions import dataclass_transform @@ -27,18 +25,13 @@ from reflex_base.components.tags import Tag from reflex_base.constants import Dirs, EventTriggers, Hooks, Imports, MemoizationMode from reflex_base.constants.compiler import SpecialAttributes -from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.event import ( EventCallback, EventChain, - EventHandler, EventSpec, args_specs_from_fields, no_args_event_spec, - parse_args_spec, pointer_event_spec, - run_script, - unwrap_var_annotation, ) from reflex_base.style import Style, format_as_emotion from reflex_base.utils import console, format, imports, types @@ -51,11 +44,7 @@ Var, cached_property_no_lock, ) -from reflex_base.vars.function import ( - ArgsFunctionOperation, - FunctionStringVar, - FunctionVar, -) +from reflex_base.vars.function import ArgsFunctionOperation, FunctionStringVar from reflex_base.vars.number import ternary_operation from reflex_base.vars.object import ObjectVar from reflex_base.vars.sequence import LiteralArrayVar, LiteralStringVar, StringVar @@ -2179,315 +2168,6 @@ def _get_all_app_wrap_components( return components -class CustomComponent(Component): - """A custom user-defined component.""" - - # Use the components library. - library = f"$/{Dirs.COMPONENTS_PATH}" - - component_fn: Callable[..., Component] = field( - doc="The function that creates the component.", default=Component.create - ) - - props: dict[str, Any] = field( - doc="The props of the component.", default_factory=dict - ) - - def _post_init(self, **kwargs): - """Initialize the custom component. - - Args: - **kwargs: The kwargs to pass to the component. - """ - component_fn = kwargs.get("component_fn") - - # Set the props. - props_types = typing.get_type_hints(component_fn) if component_fn else {} - props = {key: value for key, value in kwargs.items() if key in props_types} - kwargs = {key: value for key, value in kwargs.items() if key not in props_types} - - event_types = { - key - for key in props - if ( - (get_origin((annotation := props_types.get(key))) or annotation) - == EventHandler - ) - } - - def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]: - type_ = props_types[key] - - return ( - args[0] - if (args := get_args(type_)) - else ( - annotation_args[1] - if get_origin( - annotation := inspect.getfullargspec(component_fn).annotations[ - key - ] - ) - is typing.Annotated - and (annotation_args := get_args(annotation)) - else no_args_event_spec - ) - ) - - super()._post_init( - event_triggers={ - key: EventChain.create( - value=props[key], - args_spec=get_args_spec(key), - key=key, - ) - for key in event_types - }, - **kwargs, - ) - - to_camel_cased_props = { - format.to_camel_case(key): None for key in props if key not in event_types - } - self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride] - - # Unset the style. - self.style = Style() - - # Set the tag to the name of the function. - self.tag = format.to_title_case(self.component_fn.__name__) - - for key, value in props.items(): - # Skip kwargs that are not props. - if key not in props_types: - continue - - camel_cased_key = format.to_camel_case(key) - - # Get the type based on the annotation. - type_ = props_types[key] - - # Handle event chains. - if type_ is EventHandler: - inspect.getfullargspec(component_fn).annotations[key] - self.props[camel_cased_key] = EventChain.create( - value=value, args_spec=get_args_spec(key), key=key - ) - continue - - value = LiteralVar.create(value) - self.props[camel_cased_key] = value - setattr(self, camel_cased_key, value) - - def __eq__(self, other: Any) -> bool: - """Check if the component is equal to another. - - Args: - other: The other component. - - Returns: - Whether the component is equal to the other. - """ - return isinstance(other, CustomComponent) and self.tag == other.tag - - def __hash__(self) -> int: - """Get the hash of the component. - - Returns: - The hash of the component. - """ - return hash(self.tag) - - @classmethod - def get_props(cls) -> Iterable[str]: - """Get the props for the component. - - Returns: - The set of component props. - """ - return () - - @staticmethod - def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable: - """Get the event spec from the args spec. - - Args: - name: The name of the event - event: The args spec. - - Returns: - The event spec. - """ - - def fn(*args): - return run_script(Var(name).to(FunctionVar).call(*args)) - - if event.args_spec: - arg_spec = ( - event.args_spec - if not isinstance(event.args_spec, Sequence) - else event.args_spec[0] - ) - names = inspect.getfullargspec(arg_spec).args - fn.__signature__ = inspect.Signature( # pyright: ignore[reportFunctionMemberAccess] - parameters=[ - inspect.Parameter( - name=name, - kind=inspect.Parameter.POSITIONAL_ONLY, - annotation=arg._var_type, - ) - for name, arg in zip( - names, parse_args_spec(event.args_spec)[0], strict=True - ) - ] - ) - - return fn - - def get_prop_vars(self) -> list[Var | Callable]: - """Get the prop vars. - - Returns: - The prop vars. - """ - return [ - Var( - _js_expr=name + CAMEL_CASE_MEMO_MARKER, - _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)), - ).guess_type() - if isinstance(prop, Var) or not isinstance(prop, EventChain) - else CustomComponent._get_event_spec_from_args_spec( - name + CAMEL_CASE_MEMO_MARKER, prop - ) - for name, prop in self.props.items() - ] - - @functools.cache # noqa: B019 - def get_component(self) -> Component: - """Render the component. - - Returns: - The code to render the component. - """ - component = self.component_fn(*self.get_prop_vars()) - - try: - from reflex.utils.prerequisites import get_and_validate_app - - style = get_and_validate_app().app.style - except Exception: - style = {} - - component._add_style_recursive(style) - return component - - def _get_all_app_wrap_components( - self, *, ignore_ids: set[int] | None = None - ) -> dict[tuple[int, str], Component]: - """Get the app wrap components for the custom component. - - Args: - ignore_ids: A set of IDs to ignore to avoid infinite recursion. - - Returns: - The app wrap components. - """ - ignore_ids = ignore_ids or set() - component = self.get_component() - if id(component) in ignore_ids: - return {} - ignore_ids.add(id(component)) - return self.get_component()._get_all_app_wrap_components(ignore_ids=ignore_ids) - - -CUSTOM_COMPONENTS: dict[str, CustomComponent] = {} - - -def _register_custom_component( - component_fn: Callable[..., Component], -): - """Register a custom component to be compiled. - - Args: - component_fn: The function that creates the component. - - Returns: - The custom component. - - Raises: - TypeError: If the tag name cannot be determined. - """ - dummy_props = { - prop: ( - Var( - "", - _var_type=unwrap_var_annotation(annotation), - ).guess_type() - if not types.safe_issubclass(annotation, EventHandler) - else EventSpec(handler=EventHandler(fn=no_args_event_spec)) - ) - for prop, annotation in typing.get_type_hints(component_fn).items() - if prop != "return" - } - dummy_component = CustomComponent._create( - children=[], - component_fn=component_fn, - **dummy_props, - ) - if dummy_component.tag is None: - msg = f"Could not determine the tag name for {component_fn!r}" - raise TypeError(msg) - CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component - return dummy_component - - -def custom_component( - component_fn: Callable[..., Component], -) -> Callable[..., CustomComponent]: - """Create a custom component from a function. - - Args: - component_fn: The function that creates the component. - - Returns: - The decorated function. - """ - - @wraps(component_fn) - def wrapper(*children, **props) -> CustomComponent: - # Remove the children from the props. - props.pop("children", None) - return CustomComponent._create( - children=list(children), component_fn=component_fn, **props - ) - - # Register this component so it can be compiled. - dummy_component = _register_custom_component(component_fn) - if tag := dummy_component.tag: - object.__setattr__( - wrapper, - "_as_var", - lambda: Var( - tag, - _var_type=type[Component], - _var_data=VarData( - imports={ - f"$/{constants.Dirs.UTILS}/components": [ImportVar(tag=tag)], - "@emotion/react": [ - ImportVar(tag="jsx"), - ], - } - ), - ), - ) - - return wrapper - - -# Alias memo to custom_component. -memo = custom_component - - class NoSSRComponent(Component): """A dynamic component that is not rendered on the server.""" diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index a668bd341fc..762dd93aed1 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -31,7 +31,6 @@ def get_cdn_url(lib: str) -> str: "@emotion/react", f"$/{constants.Dirs.UTILS}/context", f"$/{constants.Dirs.UTILS}/state", - f"$/{constants.Dirs.UTILS}/components", ] bundled_libraries = list(DEFAULT_BUNDLED_LIBRARIES) diff --git a/packages/reflex-base/src/reflex_base/components/memo.py b/packages/reflex-base/src/reflex_base/components/memo.py new file mode 100644 index 00000000000..0f8b0218023 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/components/memo.py @@ -0,0 +1,1769 @@ +"""Memo support for vars and components.""" + +from __future__ import annotations + +import contextlib +import dataclasses +import importlib +import inspect +from collections.abc import Callable, Iterator, Mapping, Sequence +from copy import copy +from enum import Enum +from functools import cache, update_wrapper +from typing import ( + Annotated, + Any, + ClassVar, + TypeVar, + get_args, + get_origin, + get_type_hints, + overload, +) + +from reflex_components_core.base.fragment import Fragment + +from reflex_base import constants +from reflex_base.components.component import Component +from reflex_base.components.dynamic import bundled_libraries +from reflex_base.components.memoize_helpers import ( + MemoizationStrategy, + get_memoization_strategy, +) +from reflex_base.constants.compiler import ( + MemoizationDisposition, + MemoizationMode, + SpecialAttributes, +) +from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER +from reflex_base.event import EventChain, EventHandler, no_args_event_spec, run_script +from reflex_base.utils import console, format +from reflex_base.utils.imports import ImportVar +from reflex_base.utils.types import safe_issubclass, typehint_issubclass +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_base.vars.function import ( + ArgsFunctionOperation, + DestructuredArg, + FunctionStringVar, + FunctionVar, + ReflexCallable, +) +from reflex_base.vars.object import RestProp + + +class MemoParamKind(str, Enum): + """The role a memo parameter plays in the compiled component. + + Each kind owns its full behavior — annotation classification, call-site + validation, placeholder construction, runtime binding, and JS signature + emission — via the per-kind :class:`_MemoParamSpec` instance in + :data:`_SPECS`. Adding a new kind means one new entry in :data:`_SPECS` + and one extra step in :data:`_CLASSIFICATION_ORDER`; the rest of the + module learns nothing else about the new kind. + """ + + VALUE = "value" + CHILDREN = "children" + REST = "rest" + EVENT_TRIGGER = "event_trigger" + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class MemoParam: + """Metadata about an analyzed memo parameter.""" + + name: str + kind: MemoParamKind + annotation: Any + parameter_kind: inspect._ParameterKind + js_prop_name: str + placeholder_name: str + kind_data: Any = None + default: Any = inspect.Parameter.empty + + @property + def spec(self) -> _MemoParamSpec: + """The per-kind behavior bundle for this parameter.""" + return _SPECS[self.kind] + + def make_placeholder(self) -> Any: + """Build the value passed to the memo function during analysis. + + Returns: + The placeholder value (a ``Var``, ``RestProp``, or plain callable). + """ + return self.spec.make_placeholder(self) + + def bind_call_value(self, binding: _MemoCallBinding) -> None: + """Route a user-provided value to props/event_triggers at instantiation. + + Args: + binding: The call-site routing accumulator. + """ + self.spec.bind_call_value(self, binding) + + def signature_field(self) -> str | None: + """The destructured JSX signature entry, or ``None`` if emitted elsewhere. + + Returns: + The destructured field (e.g. ``"event:eventRxMemo"``), or ``None`` + when this kind is emitted out-of-band by the compiler. + """ + return self.spec.signature_field(self) + + +@dataclasses.dataclass(frozen=True, slots=True) +class _MemoParamSpec: + """The role-owned behavior for one :class:`MemoParamKind`. + + Hooks (in classification + lifecycle order): + ``classify``: ``(annotation, param_name) -> (matches, kind_data)``. + Returns whether the annotation belongs to this kind, plus any + kind-specific payload (the args spec for ``EVENT_TRIGGER``). + ``validate``: ``(inspect.Parameter, fn_name, for_component) -> None``. + Raise ``TypeError`` for misuses (no defaults on EH, ``children`` + naming, rest-on-var-memo, etc.). + ``placeholder_name``: choose the destructured JS identifier (Var/EH + use ``camelCase + RxMemo``; children/rest keep the bare name). + ``make_placeholder``: build the analysis-time value passed to the memo + body function (a ``Var``, a ``RestProp``, or a plain callable). + ``bind_call_value``: at instantiation, pop the user value from kwargs + and route it via ``_MemoCallBinding`` to props or event_triggers. + ``signature_field``: the destructured JSX entry, or ``None`` for kinds + emitted out-of-band (REST -> spread; CHILDREN -> hardcoded prefix). + """ + + kind: MemoParamKind + classify: Callable[[Any, str], tuple[bool, Any]] + validate: Callable[[inspect.Parameter, str, bool], None] + placeholder_name: Callable[[str, str, bool], str] + make_placeholder: Callable[[MemoParam], Any] + bind_call_value: Callable[[MemoParam, _MemoCallBinding], None] + signature_field: Callable[[MemoParam], str | None] + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoDefinition: + """Base metadata for a memo.""" + + fn: Callable[..., Any] + python_name: str + params: tuple[MemoParam, ...] + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoFunctionDefinition(MemoDefinition): + """A memo that compiles to a JavaScript function.""" + + function: ArgsFunctionOperation + imported_var: FunctionVar + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoComponentDefinition(MemoDefinition): + """A memo that compiles to a React component.""" + + export_name: str + component: Component + # For passthrough wrappers built by the auto-memoize plugin: the + # ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo + # body. The ``component`` keeps its ORIGINAL children so compile-time + # walkers (``Form._get_form_refs`` etc.) can introspect the subtree; the + # compiler swaps to this placeholder only for the JSX render and for + # imports collection, so descendants emit their refs/imports/hooks in the + # page scope rather than being duplicated inside the memo body. + passthrough_hole_child: Component | None = None + + +class MemoComponent(Component): + """A rendered instance of a memo component.""" + + library = f"$/{constants.Dirs.COMPONENTS_PATH}" + _memoization_mode = MemoizationMode(disposition=MemoizationDisposition.NEVER) + + # The user-authored component class this wrapper stands in for. Populated + # on the dynamic subclass by ``_get_memo_component_class`` so + # introspection (e.g. compile telemetry) can recover the underlying type + # without parsing the wrapper's auto-generated class name. + _wrapped_component_type: ClassVar[type[Component] | None] = None + + def _validate_component_children(self, children: list[Component]) -> None: + """Skip direct parent/child validation for memo wrapper instances. + + Memos wrap an underlying compiled component definition. + The runtime wrapper should not interpose on `_valid_parents` checks for + the authored subtree because the wrapper itself is not the semantic + parent in the user-authored component tree. + + Args: + children: The children of the component (ignored). + """ + + def _post_init(self, **kwargs): + """Initialize the memo component. + + Args: + **kwargs: The kwargs to pass to the component. + """ + definition = kwargs.pop("memo_definition") + binding = _MemoCallBinding(kwargs) + + for param in definition.params: + param.bind_call_value(binding) + + has_rest = _get_rest_param(definition.params) is not None + rest_props = binding.take_rest(self.get_fields()) if has_rest else {} + + super()._post_init(**binding.build_super_kwargs()) + + prop_names = binding.finalize(self, rest_props) + object.__setattr__(self, "get_props", lambda: prop_names) + + +@cache +def _get_memo_component_class( + export_name: str, + wrapped_component_type: type[Component] = Component, +) -> type[MemoComponent]: + """Get the component subclass for a memo export. + + Class-level metadata that the compiler reads via ``type(comp)._get_*()`` + (notably ``_get_app_wrap_components``, which carries providers like + ``UploadFilesProvider`` that must reach the app root) is inherited from + ``wrapped_component_type`` so the wrapper is a transparent substitute for + the original in the compile tree. + + Args: + export_name: The exported React component name. + wrapped_component_type: The class of the component being memoized. + Defaults to ``Component`` for memos that don't wrap a user + component (e.g. function memos, raw passthroughs). + + Returns: + A cached component subclass with the tag set at class definition time. + """ + attrs: dict[str, Any] = { + "__module__": __name__, + "tag": export_name, + # Point each memo at its own per-file module so pages import directly + # from ``$/utils/components/`` rather than through the index. + # Per-file import paths give Vite distinct module boundaries per + # memo, enabling actual code-split by page. + "library": f"$/{constants.Dirs.COMPONENTS_PATH}/{export_name}", + "_wrapped_component_type": wrapped_component_type, + } + if ( + wrapped_component_type._get_app_wrap_components + is not Component._get_app_wrap_components + ): + attrs["_get_app_wrap_components"] = staticmethod( + wrapped_component_type._get_app_wrap_components + ) + return type( + f"MemoComponent_{export_name}", + (MemoComponent,), + attrs, + ) + + +MEMOS: dict[str, MemoDefinition] = {} + + +def _memo_registry_key(definition: MemoDefinition) -> str: + """Get the registry key for a memo. + + Args: + definition: The memo definition. + + Returns: + The registry key for the memo. + """ + if isinstance(definition, MemoComponentDefinition): + return definition.export_name + return definition.python_name + + +def _is_memo_reregistration( + existing: MemoDefinition, + definition: MemoDefinition, +) -> bool: + """Check whether a memo definition replaces the same memo during reload. + + Args: + existing: The currently registered memo definition. + definition: The new memo definition being registered. + + Returns: + Whether the new definition should replace the existing one. + """ + return ( + type(existing) is type(definition) + and existing.python_name == definition.python_name + and existing.fn.__module__ == definition.fn.__module__ + and existing.fn.__qualname__ == definition.fn.__qualname__ + ) + + +def _register_memo_definition(definition: MemoDefinition) -> None: + """Register a memo definition. + + Args: + definition: The memo definition to register. + + Raises: + ValueError: If another memo already compiles to the same exported name. + """ + key = _memo_registry_key(definition) + if (existing := MEMOS.get(key)) is not None and ( + not _is_memo_reregistration(existing, definition) + ): + msg = ( + f"Memo name collision for `{key}`: " + f"`{existing.fn.__module__}.{existing.python_name}` and " + f"`{definition.fn.__module__}.{definition.python_name}` both compile " + "to the same memo name." + ) + raise ValueError(msg) + + MEMOS[key] = definition + + +def _annotation_inner_type(annotation: Any) -> Any: + """Unwrap a Var-like annotation to its inner type. + + Args: + annotation: The annotation to unwrap. + + Returns: + The inner type for the annotation. + """ + if _is_rest_annotation(annotation): + return dict[str, Any] + + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) or annotation + if safe_issubclass(origin, Var) and (args := get_args(annotation)): + return args[0] + return Any + + +def _strip_annotated(annotation: Any) -> Any: + """Unwrap ``Annotated[X, ...]`` to ``X``; pass other annotations through. + + Args: + annotation: The annotation to unwrap. + + Returns: + The inner annotation, or the original if not ``Annotated``. + """ + if get_origin(annotation) is Annotated: + return get_args(annotation)[0] + return annotation + + +def _is_rest_annotation(annotation: Any) -> bool: + """Check whether an annotation is a RestProp. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is a RestProp. + """ + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, RestProp) + + +def _is_var_annotation(annotation: Any) -> bool: + """Check whether an annotation is a Var-like annotation. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is Var-like. + """ + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, Var) + + +def _is_event_handler_annotation(annotation: Any) -> tuple[bool, Any]: + """Detect ``EventHandler`` / ``EventHandler[spec]`` / ``EventHandler[s1, s2]``. + + ``EventHandler.__class_getitem__`` returns ``Annotated[EventHandler, spec]`` for a + single spec and ``Annotated[EventHandler, (s1, s2)]`` (a tuple in the single + metadata slot) for multiple specs. + + Args: + annotation: The annotation to inspect. + + Returns: + ``(is_event_handler, args_spec)`` — ``args_spec`` is ``no_args_event_spec`` for + bare ``EventHandler``, a single spec callable for ``EventHandler[spec]``, or + the tuple of specs for the multi-spec form. + """ + if get_origin(annotation) is Annotated: + inner, *metadata = get_args(annotation) + if isinstance(inner, type) and safe_issubclass(inner, EventHandler): + return True, metadata[0] + return False, None + if isinstance(annotation, type) and safe_issubclass(annotation, EventHandler): + return True, no_args_event_spec + return False, None + + +def _is_component_annotation(annotation: Any) -> bool: + """Check whether an annotation is component-like. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation resolves to Component. + """ + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and ( + safe_issubclass(origin, Component) + or bool( + safe_issubclass(origin, Var) + and (args := get_args(annotation)) + and safe_issubclass(args[0], Component) + ) + ) + + +def _children_annotation_is_valid(annotation: Any) -> bool: + """Check whether an annotation is valid for children. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is valid for children. + """ + return _is_var_annotation(annotation) and typehint_issubclass( + _annotation_inner_type(annotation), Component + ) + + +def _get_children_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((p for p in params if p.kind is MemoParamKind.CHILDREN), None) + + +def _get_rest_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((p for p in params if p.kind is MemoParamKind.REST), None) + + +def _imported_function_var(name: str, return_type: Any) -> FunctionVar: + """Create the imported FunctionVar for a memo. + + Args: + name: The exported function name. + return_type: The return type of the function. + + Returns: + The imported FunctionVar. + """ + return FunctionStringVar.create( + name, + _var_type=ReflexCallable[Any, return_type], + _var_data=VarData( + imports={ + f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)] + } + ), + ) + + +def _component_import_var(name: str) -> Var: + """Create the imported component var for a memo component. + + Args: + name: The exported component name. + + Returns: + The component var. + """ + return Var( + name, + _var_type=type[Component], + _var_data=VarData( + imports={ + f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)], + "@emotion/react": [ImportVar(tag="jsx")], + } + ), + ) + + +def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: + """Validate that a var-returning memo can compile safely. + + Args: + return_expr: The return expression. + func_name: The function name for error messages. + + Raises: + TypeError: If the return expression depends on unsupported features. + """ + var_data = VarData.merge(return_expr._get_all_var_data()) + if var_data is None: + return + + if var_data.hooks: + msg = ( + f"Var-returning `@rx.memo` `{func_name}` cannot depend on hooks. " + "Use a component-returning `@rx.memo` instead." + ) + raise TypeError(msg) + + if var_data.components: + msg = ( + f"Var-returning `@rx.memo` `{func_name}` cannot depend on embedded " + "components, custom code, or dynamic imports. Use a component-returning " + "`@rx.memo` instead." + ) + raise TypeError(msg) + + for lib in dict(var_data.imports): + if not lib: + continue + if lib.startswith((".", "/", "$/", "http")): + continue + if format.format_library_name(lib) in bundled_libraries: + continue + msg = ( + f"Var-returning `@rx.memo` `{func_name}` cannot import `{lib}` because " + "it is not bundled. Use a component-returning `@rx.memo` instead." + ) + raise TypeError(msg) + + +def _rest_placeholder(name: str) -> RestProp: + """Create the placeholder RestProp. + + Args: + name: The JavaScript identifier. + + Returns: + The placeholder rest prop. + """ + return RestProp(_js_expr=name, _var_type=dict[str, Any]) + + +def _var_placeholder(name: str, annotation: Any) -> Var: + """Create a placeholder Var for a memo parameter. + + Args: + name: The JavaScript identifier. + annotation: The parameter annotation. + + Returns: + The placeholder Var. + """ + return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type() + + +def _event_handler_placeholder(placeholder_name: str, args_spec: Any) -> Callable: + """Placeholder callable that compiles calls to the destructured JS prop. + + Returned as a plain callable (not an ``EventHandler``) so it flows through + ``EventChain.create`` -> ``call_event_fn``, which actually invokes it. + Wrapping in an ``EventHandler`` would skip the function body and bake the + Python function name into the rendered ``ReflexEvent(...)`` payload. + + Args: + placeholder_name: The destructured JS prop identifier (e.g. ``eventRxMemo``). + args_spec: The user-declared spec, or a tuple of specs from + ``EventHandler[s1, s2]``. Only the first spec shapes the placeholder's + signature; the inner-trigger boundary handles the rest. + + Returns: + A plain callable suitable as a memo-function placeholder. + """ + prop_callback = Var(_js_expr=placeholder_name).to(FunctionVar) + primary_spec = args_spec[0] if isinstance(args_spec, tuple) else args_spec + + def _placeholder(*args: Any) -> Any: + return run_script(prop_callback.call(*args)) + + _placeholder.__signature__ = inspect.signature(primary_spec) # pyright: ignore[reportFunctionMemberAccess] + return _placeholder + + +def _classify_value(annotation: Any, name: str) -> tuple[bool, Any]: + # ``RestProp`` is a ``Var`` subclass, so guard against it here even though + # ``_CLASSIFICATION_ORDER`` already tries REST first — keeping the classifier + # self-exclusive removes the implicit ordering dependency. + return ( + _is_var_annotation(annotation) and not _is_rest_annotation(annotation), + None, + ) + + +def _classify_children(annotation: Any, name: str) -> tuple[bool, Any]: + return ( + name == "children" and _children_annotation_is_valid(annotation), + None, + ) + + +def _classify_rest(annotation: Any, name: str) -> tuple[bool, Any]: + return _is_rest_annotation(annotation), None + + +def _classify_event_trigger(annotation: Any, name: str) -> tuple[bool, Any]: + return _is_event_handler_annotation(annotation) + + +def _validate_noop( + parameter: inspect.Parameter, fn_name: str, for_component: bool +) -> None: + pass + + +def _validate_children( + parameter: inspect.Parameter, fn_name: str, for_component: bool +) -> None: + if parameter.name != "children": + msg = ( + f"`rx.Var[rx.Component]` parameters in `{fn_name}` must be named " + "`children`." + ) + raise TypeError(msg) + + +def _validate_rest( + parameter: inspect.Parameter, fn_name: str, for_component: bool +) -> None: + if parameter.name == "children": + msg = f"`children` in `{fn_name}` cannot be `rx.RestProp`." + raise TypeError(msg) + + +def _validate_event_trigger( + parameter: inspect.Parameter, fn_name: str, for_component: bool +) -> None: + if not for_component: + msg = ( + f"`rx.EventHandler` parameters are only supported on component-" + f"returning memos. Got `{parameter.name}` in `{fn_name}`." + ) + raise TypeError(msg) + if parameter.name == "children": + msg = ( + f"`children` in `{fn_name}` cannot be an `rx.EventHandler`; " + "use `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + if parameter.default is not inspect.Parameter.empty: + msg = ( + f"`rx.EventHandler` parameter `{parameter.name}` in `{fn_name}` " + "must not have a default value." + ) + raise TypeError(msg) + + +def _placeholder_name_value(name: str, js_prop_name: str, for_component: bool) -> str: + return js_prop_name + CAMEL_CASE_MEMO_MARKER if for_component else name + + +def _placeholder_name_passthrough( + name: str, js_prop_name: str, for_component: bool +) -> str: + return name + + +def _make_value_placeholder(param: MemoParam) -> Var: + return _var_placeholder(param.placeholder_name, param.annotation) + + +def _make_rest_placeholder_spec(param: MemoParam) -> RestProp: + return _rest_placeholder(param.placeholder_name) + + +def _make_event_trigger_placeholder(param: MemoParam) -> Callable[..., Any]: + return _event_handler_placeholder(param.placeholder_name, param.kind_data) + + +def _bind_value(param: MemoParam, binding: _MemoCallBinding) -> None: + if param.name in binding.raw_kwargs: + binding.add_prop(param.js_prop_name, binding.take(param.name)) + + +def _bind_children(param: MemoParam, binding: _MemoCallBinding) -> None: + pass + + +def _bind_rest(param: MemoParam, binding: _MemoCallBinding) -> None: + pass + + +def _bind_event_trigger(param: MemoParam, binding: _MemoCallBinding) -> None: + if param.name in binding.raw_kwargs: + binding.add_event_trigger( + param.js_prop_name, binding.take(param.name), param.kind_data + ) + + +def _signature_destructured(param: MemoParam) -> str: + return f"{param.js_prop_name}:{param.placeholder_name}" + + +def _signature_none(param: MemoParam) -> None: + return None + + +_SPECS: dict[MemoParamKind, _MemoParamSpec] = { + MemoParamKind.VALUE: _MemoParamSpec( + kind=MemoParamKind.VALUE, + classify=_classify_value, + validate=_validate_noop, + placeholder_name=_placeholder_name_value, + make_placeholder=_make_value_placeholder, + bind_call_value=_bind_value, + signature_field=_signature_destructured, + ), + MemoParamKind.CHILDREN: _MemoParamSpec( + kind=MemoParamKind.CHILDREN, + classify=_classify_children, + validate=_validate_children, + placeholder_name=_placeholder_name_passthrough, + make_placeholder=_make_value_placeholder, + bind_call_value=_bind_children, + signature_field=_signature_none, + ), + MemoParamKind.REST: _MemoParamSpec( + kind=MemoParamKind.REST, + classify=_classify_rest, + validate=_validate_rest, + placeholder_name=_placeholder_name_passthrough, + make_placeholder=_make_rest_placeholder_spec, + bind_call_value=_bind_rest, + signature_field=_signature_none, + ), + MemoParamKind.EVENT_TRIGGER: _MemoParamSpec( + kind=MemoParamKind.EVENT_TRIGGER, + classify=_classify_event_trigger, + validate=_validate_event_trigger, + placeholder_name=_placeholder_name_value, + make_placeholder=_make_event_trigger_placeholder, + bind_call_value=_bind_event_trigger, + signature_field=_signature_destructured, + ), +} + +# Order matters: REST and CHILDREN before VALUE (``Var[Component]`` matches +# VALUE's classifier, so children must be tried first). EVENT_TRIGGER is +# independent (``Annotated[EventHandler, ...]`` is not a Var), but listing it +# before VALUE makes the precedence explicit. VALUE is the open fallback. +_CLASSIFICATION_ORDER: tuple[MemoParamKind, ...] = ( + MemoParamKind.REST, + MemoParamKind.CHILDREN, + MemoParamKind.EVENT_TRIGGER, + MemoParamKind.VALUE, +) + + +class _MemoCallBinding: + """Accumulates routing decisions for one memo component instantiation. + + Role specs call :meth:`take`, :meth:`add_prop`, and :meth:`add_event_trigger` + via ``param.bind_call_value(binding)``. The component then calls + :meth:`build_super_kwargs` (what :meth:`Component._post_init` should see) and + :meth:`finalize` (apply collected props as attributes after super returns). + """ + + __slots__ = ("_event_triggers", "_props", "raw_kwargs") + + def __init__(self, raw_kwargs: dict[str, Any]) -> None: + self.raw_kwargs = raw_kwargs + self._props: dict[str, Any] = {} + self._event_triggers: dict[str, EventChain | Var] = {} + + def take(self, key: str) -> Any: + return self.raw_kwargs.pop(key) + + def add_prop(self, js_prop_name: str, value: Any) -> None: + self._props[js_prop_name] = LiteralVar.create(value) + + def add_event_trigger(self, js_prop_name: str, value: Any, args_spec: Any) -> None: + self._event_triggers[js_prop_name] = EventChain.create( + value=value, args_spec=args_spec, key=js_prop_name + ) + + def take_rest(self, component_fields: Mapping[str, Any]) -> dict[str, Any]: + rest: dict[str, Any] = {} + for key in list(self.raw_kwargs): + if key in component_fields or SpecialAttributes.is_special(key): + continue + rest[format.to_camel_case(key)] = LiteralVar.create( + self.raw_kwargs.pop(key) + ) + return rest + + def build_super_kwargs(self) -> dict[str, Any]: + """Merge collected event triggers into raw kwargs for ``super()._post_init``. + + Mutates ``raw_kwargs`` in place. Call once per instantiation. + + Returns: + The kwargs to forward to ``Component._post_init``. + """ + if self._event_triggers: + self.raw_kwargs.setdefault("event_triggers", {}).update( + self._event_triggers + ) + return self.raw_kwargs + + def finalize( + self, component: Component, rest_props: dict[str, Any] + ) -> tuple[str, ...]: + all_props = {**self._props, **rest_props} + for key, value in all_props.items(): + setattr(component, key, value) + return tuple(all_props) + + +def _evaluate_memo_function( + fn: Callable[..., Any], + params: tuple[MemoParam, ...], +) -> Any: + """Evaluate a memo function with placeholder vars. + + Args: + fn: The function to evaluate. + params: The memo parameters. + + Returns: + The return value from the function. + """ + positional_args = [] + keyword_args = {} + + for param in params: + placeholder = param.make_placeholder() + if param.parameter_kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + positional_args.append(placeholder) + else: + keyword_args[param.name] = placeholder + + return fn(*positional_args, **keyword_args) + + +def _normalize_component_return(value: Any) -> Component | None: + """Normalize a component-like memo return value into a Component. + + Args: + value: The value returned from the memo function. + + Returns: + The normalized component, or ``None`` if the value is not component-like. + """ + if isinstance(value, Component): + return value + + if isinstance(value, Var) and typehint_issubclass(value._var_type, Component): + from reflex_components_core.base.bare import Bare + + return Bare.create(value) + + return None + + +def _lift_rest_props(component: Component) -> Component: + """Convert RestProp children into special props. + + Args: + component: The component tree to rewrite. + + Returns: + The rewritten component tree. + """ + from reflex_components_core.base.bare import Bare + + special_props = list(component.special_props) + rewritten_children = [] + + for child in component.children: + if isinstance(child, Bare) and isinstance(child.contents, RestProp): + special_props.append(child.contents) + continue + + if isinstance(child, Component): + child = _lift_rest_props(child) + + rewritten_children.append(child) + + component.children = rewritten_children + component.special_props = special_props + return component + + +def _analyze_params( + fn: Callable[..., Any], + *, + for_component: bool, + hints: dict[str, Any] | None = None, + defaulted_params: list[str] | None = None, +) -> tuple[MemoParam, ...]: + """Analyze and validate memo parameters. + + Args: + fn: The function to analyze. + for_component: Whether the memo returns a component. + hints: Pre-computed type hints with ``include_extras=True``; computed + from ``fn`` when omitted. + defaulted_params: When provided, parameters missing an annotation are + defaulted (``Var[Component]`` for ``children``, otherwise + ``Var[Any]``) and their names appended; when ``None``, a missing + annotation raises ``TypeError``. + + Returns: + The analyzed parameters. + + Raises: + TypeError: If the function signature is not supported. + """ + signature = inspect.signature(fn) + if hints is None: + hints = get_type_hints(fn, include_extras=True) + + params: list[MemoParam] = [] + rest_count = 0 + + for parameter in signature.parameters.values(): + _check_parameter_kind(parameter, fn.__name__) + + annotation = hints.get(parameter.name, parameter.annotation) + if annotation is inspect.Parameter.empty: + if defaulted_params is None: + msg = ( + f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`. Missing annotation for `{parameter.name}`." + ) + raise TypeError(msg) + annotation = Var[Component] if parameter.name == "children" else Var[Any] + defaulted_params.append(parameter.name) + + # Children parameters by name must match the children kind exactly — + # otherwise we accept a value-typed `children` and emit confusing JSX. + if ( + parameter.name == "children" + and not _children_annotation_is_valid(annotation) + and not _is_event_handler_annotation(annotation)[0] + ): + msg = ( + f"`children` in `{fn.__name__}` must be annotated as " + "`rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + kind, kind_data = _classify_parameter(annotation, parameter.name, fn.__name__) + spec = _SPECS[kind] + spec.validate(parameter, fn.__name__, for_component) + + if kind is MemoParamKind.REST: + rest_count += 1 + if rest_count > 1: + msg = f"`@rx.memo` only supports one `rx.RestProp` in `{fn.__name__}`." + raise TypeError(msg) + + js_prop_name = format.to_camel_case(parameter.name) + placeholder_name = spec.placeholder_name( + parameter.name, js_prop_name, for_component + ) + + params.append( + MemoParam( + name=parameter.name, + kind=kind, + kind_data=kind_data, + annotation=annotation, + parameter_kind=parameter.kind, + default=parameter.default, + js_prop_name=js_prop_name, + placeholder_name=placeholder_name, + ) + ) + + return tuple(params) + + +def _check_parameter_kind(parameter: inspect.Parameter, fn_name: str) -> None: + """Reject Python parameter kinds (``*args`` / ``**kwargs`` / positional-only) + that memo does not support. + + Args: + parameter: The parameter to check. + fn_name: The function name for error messages. + + Raises: + TypeError: If the parameter uses an unsupported kind. + """ + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + msg = f"`@rx.memo` does not support `*args` in `{fn_name}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + msg = f"`@rx.memo` does not support `**kwargs` in `{fn_name}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.POSITIONAL_ONLY: + msg = f"`@rx.memo` does not support positional-only parameters in `{fn_name}`." + raise TypeError(msg) + + +def _classify_parameter( + annotation: Any, param_name: str, fn_name: str +) -> tuple[MemoParamKind, Any]: + """Walk ``_CLASSIFICATION_ORDER`` and return the first matching kind. + + Args: + annotation: The parameter annotation. + param_name: The parameter name (some kinds care, e.g. ``children``). + fn_name: The function name for error messages. + + Returns: + The matched ``(kind, kind_data)``. + + Raises: + TypeError: If no kind matches. + """ + for kind in _CLASSIFICATION_ORDER: + matched, kind_data = _SPECS[kind].classify(annotation, param_name) + if matched: + return kind, kind_data + msg = ( + f"All parameters of `{fn_name}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`, got `{annotation}` for `{param_name}`." + ) + raise TypeError(msg) + + +def _build_args_function( + params: tuple[MemoParam, ...], return_expr: Var +) -> ArgsFunctionOperation: + """Build the JS ``ArgsFunctionOperation`` that wraps a memo's return expression. + + Args: + params: The memo parameters. + return_expr: The return expression of the memo body. + + Returns: + The compiled function operation. + """ + rest_param = _get_rest_param(params) + if _get_children_param(params) is None and rest_param is None: + return ArgsFunctionOperation.create( + args_names=tuple(param.placeholder_name for param in params), + return_expr=return_expr, + ) + return ArgsFunctionOperation.create( + args_names=( + DestructuredArg( + fields=tuple( + param.placeholder_name + for param in params + if param.kind is not MemoParamKind.REST + ), + rest=rest_param.placeholder_name if rest_param is not None else None, + ), + ), + return_expr=return_expr, + ) + + +def _create_component_definition( + fn: Callable[..., Any], + return_annotation: Any, +) -> MemoComponentDefinition: + """Create a definition for a component-returning memo. + + Args: + fn: The function to analyze. + return_annotation: The return annotation. + + Returns: + The component memo definition. + + Raises: + TypeError: If the function does not return a component. + """ + params = _analyze_params(fn, for_component=True) + component = _normalize_component_return(_evaluate_memo_function(fn, params)) + if component is None: + msg = ( + f"Component-returning `@rx.memo` `{fn.__name__}` must return an " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + return MemoComponentDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + export_name=format.to_title_case(fn.__name__), + component=_lift_rest_props(component), + ) + + +def _bind_function_runtime_args( + definition: MemoFunctionDefinition, + *args: Any, + **kwargs: Any, +) -> tuple[Any, ...]: + """Bind runtime args for a var-returning memo. + + Args: + definition: The function memo definition. + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The ordered arguments for the imported FunctionVar. + + Raises: + TypeError: If the provided arguments are invalid. + """ + children_param = _get_children_param(definition.params) + rest_param = _get_rest_param(definition.params) + + # Validate positional children usage and reserved keywords. + if "children" in kwargs: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + + if rest_param is not None and rest_param.name in kwargs: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + + if args and children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + + if any(not _is_component_child(child) for child in args): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + # Bind declared props before collecting any rest props. + explicit_params = [ + param + for param in definition.params + if param.kind not in (MemoParamKind.REST, MemoParamKind.CHILDREN) + ] + explicit_values = {} + remaining_props = kwargs.copy() + for param in explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + # Reject unknown props unless a rest prop is declared. + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + # Return ordered explicit args when no packed props object is needed. + if children_param is None and rest_param is None: + return tuple(explicit_values[param.name] for param in explicit_params) + + # Build the props object passed to the imported FunctionVar. + children_value: Any | None = None + if children_param is not None: + from reflex_components_core.base.fragment import Fragment + + children_value = args[0] if len(args) == 1 else Fragment.create(*args) + + # Convert rest-prop keys to camelCase to match component memo behavior. + camel_cased_remaining_props = { + format.to_camel_case(key): value for key, value in remaining_props.items() + } + + bound_props = {} + if children_param is not None: + bound_props[children_param.name] = children_value + bound_props.update(explicit_values) + bound_props.update(camel_cased_remaining_props) + return (bound_props,) + + +def _is_component_child(value: Any) -> bool: + """Check whether a value is valid as a memo child. + + Args: + value: The value to check. + + Returns: + Whether the value is a component child. + """ + return isinstance(value, Component) or ( + isinstance(value, Var) and typehint_issubclass(value._var_type, Component) + ) + + +class _MemoFunctionWrapper: + """Callable wrapper for a var-returning memo.""" + + def __init__(self, definition: MemoFunctionDefinition): + """Initialize the wrapper. + + Args: + definition: The function memo definition. + """ + self._definition = definition + self._imported_var = definition.imported_var + update_wrapper(self, definition.fn) + + def __call__(self, *args: Any, **kwargs: Any) -> Var: + """Call the wrapped memo and return a var. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The function call var. + """ + return self.call(*args, **kwargs) + + def call(self, *args: Any, **kwargs: Any) -> Var: + """Call the imported memo function. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The function call var. + """ + return self._imported_var.call( + *_bind_function_runtime_args(self._definition, *args, **kwargs) + ) + + def partial(self, *args: Any, **kwargs: Any) -> FunctionVar: + """Partially apply the imported memo function. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The partially applied function var. + """ + return self._imported_var.partial( + *_bind_function_runtime_args(self._definition, *args, **kwargs) + ) + + def _as_var(self) -> FunctionVar: + """Expose the imported function var. + + Returns: + The imported function var. + """ + return self._imported_var + + +class _MemoComponentWrapper: + """Callable wrapper for a component-returning memo.""" + + def __init__(self, definition: MemoComponentDefinition): + """Initialize the wrapper. + + Args: + definition: The component memo definition. + """ + self._definition = definition + self._children_param = _get_children_param(definition.params) + self._rest_param = _get_rest_param(definition.params) + self._explicit_params = [ + param + for param in definition.params + if param.kind not in (MemoParamKind.CHILDREN, MemoParamKind.REST) + ] + update_wrapper(self, definition.fn) + + def __call__(self, *children: Any, **props: Any) -> MemoComponent: + """Call the wrapped memo and return a component. + + Args: + *children: Positional children passed to the memo. + **props: Explicit props and rest props. + + Returns: + The rendered memo component. + """ + definition = self._definition + rest_param = self._rest_param + + # Validate positional children usage and reserved keywords. + if "children" in props: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + if rest_param is not None and rest_param.name in props: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + if children and self._children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + if any(not _is_component_child(child) for child in children): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + # Bind declared props before collecting any rest props. + explicit_values = {} + remaining_props = props.copy() + for param in self._explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + # Reject unknown props unless a rest prop is declared. + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + # Build the component props passed into the memo wrapper. + return _get_memo_component_class( + definition.export_name, type(definition.component) + )._create( + children=list(children), + memo_definition=definition, + **explicit_values, + **remaining_props, + ) + + def _as_var(self) -> Var: + """Expose the imported component var. + + Returns: + The imported component var. + """ + return _component_import_var(self._definition.export_name) + + +def _create_function_wrapper( + definition: MemoFunctionDefinition, +) -> _MemoFunctionWrapper: + """Create the Python wrapper for a var-returning memo. + + Args: + definition: The function memo definition. + + Returns: + The wrapper callable. + """ + return _MemoFunctionWrapper(definition) + + +def _create_component_wrapper( + definition: MemoComponentDefinition, +) -> _MemoComponentWrapper: + """Create the Python wrapper for a component-returning memo. + + Args: + definition: The component memo definition. + + Returns: + The wrapper callable. + """ + return _MemoComponentWrapper(definition) + + +def create_passthrough_component_memo( + component: Component, +) -> tuple[ + Callable[..., MemoComponent], + MemoComponentDefinition, +]: + """Create an unregistered ``@rx.memo``-style passthrough component memo. + + This is used by compiler auto-memoization so generated wrappers compile + through the memo pipeline instead of emitting ad-hoc page-local + ``React.memo`` declarations. + + The exported memo name is derived from ``component._compute_memo_tag()`` + after the ``{children}`` hole has been substituted into the wrapped + component's children (passthrough mode), so two call-sites differing only + in their children — whose generated memo bodies are identical — collapse + to one wrapper. + + Args: + component: The component to wrap. + + Returns: + The callable memo wrapper and its component definition. + """ + # Snapshot-boundary components (see ``is_snapshot_boundary``) own their + # subtree — the ``.children`` slot is internal machinery from the + # subclass's ``.create`` (e.g. the dropzone Div built inside + # ``Upload.create``), not a user content hole. The memoize plugin wraps + # the boundary with no structural children on the page side, so the memo + # body renders the full snapshot rather than a ``{children}``-holed + # template. + render_snapshot = ( + get_memoization_strategy(component) is MemoizationStrategy.SNAPSHOT + ) + + captured_hole_child: list[Component] = [] + + def passthrough(children: Var[Component]) -> Component: + new_component = copy(component) + if render_snapshot: + return new_component + # Components with no original structural children own their own JSX + # output (e.g. ``CodeBlock`` injects ``code`` as the ``children`` prop + # in ``_render``). Substituting a ``{children}`` hole here would emit + # ``jsx(Inner, {children: "..."}, hole)``, and an undefined hole at + # call time clobbers the prop. Skip the substitution so the wrapper's + # ``children`` parameter is present in the signature but unused. + if not component.children: + return new_component + from reflex_components_core.base.bare import Bare + + hole_bare = Bare.create(children) + captured_hole_child.append(hole_bare) + # Substitute the ``{children}`` hole for the original descendants so + # the memo body's hash and JSX both reflect the placeholder, not the + # specific children at any given call site. Original descendants stay + # reachable on the page-level wrapper via the plugin's + # ``_get_all_refs`` delegation back to the source component. + new_component.children = [hole_bare] + # Compile-time walkers that need the real subtree (notably + # ``Form._get_form_refs`` collecting id-based input refs into the + # generated ``handleSubmit`` JS) call ``self._get_all_refs()`` while + # the memo body's hooks are computed. With the hole substituted in, + # that walk would return nothing and the form handler would emit an + # empty ``field_ref_mapping``. Delegate ref collection back to the + # source component so descendants behind the hole remain visible. + object.__setattr__(new_component, "_get_all_refs", component._get_all_refs) + return new_component + + # Evaluate once to compute the tag from the rendered memo body shape. + # ``_create_component_definition`` will evaluate again internally; the + # second pass overwrites ``captured_hole_child`` but the captured value + # is identical. + params = _analyze_params(passthrough, for_component=True) + preview = _normalize_component_return(_evaluate_memo_function(passthrough, params)) + if preview is None: + msg = ( + "`create_passthrough_component_memo` requires a component that " + "normalizes to `rx.Component`." + ) + raise TypeError(msg) + tag = preview._compute_memo_tag() + + passthrough.__name__ = format.to_snake_case(tag) + passthrough.__qualname__ = passthrough.__name__ + passthrough.__module__ = __name__ + + definition = _create_component_definition(passthrough, Component) + replacements: dict[str, Any] = {} + if definition.export_name != tag: + replacements["export_name"] = tag + if captured_hole_child: + replacements["passthrough_hole_child"] = captured_hole_child[0] + if replacements: + definition = dataclasses.replace(definition, **replacements) + + return _create_component_wrapper(definition), definition + + +@contextlib.contextmanager +def _bind_self_reference(fn: Callable[..., Any], wrapper: Any) -> Iterator[None]: + """Bind ``wrapper`` to ``fn.__name__`` so the body can self-reference. + + Python only assigns the decorated name after the decorator returns, but + memo bodies are evaluated during decoration (and ``rx.foreach`` eagerly + invokes its render function once). The binding is installed at both the + module-global slot and the matching free-variable cell so recursion works + for module-level memos and for memos defined inside another function. + """ + fn_name = fn.__name__ + fn_globals = fn.__globals__ + sentinel = object() + previous_global = fn_globals.get(fn_name, sentinel) + fn_globals[fn_name] = wrapper + + cell = None + previous_cell_value: Any = sentinel + free_vars = fn.__code__.co_freevars + if fn_name in free_vars and fn.__closure__: + cell = fn.__closure__[free_vars.index(fn_name)] + # An unset cell stays in the ``sentinel`` state; the decorator's + # eventual return assigns the wrapper to the same cell anyway, so + # leaving our temporary write in place is a no-op. + with contextlib.suppress(ValueError): + previous_cell_value = cell.cell_contents + cell.cell_contents = wrapper + + try: + yield + finally: + if previous_global is sentinel: + fn_globals.pop(fn_name, None) + else: + fn_globals[fn_name] = previous_global + if cell is not None and previous_cell_value is not sentinel: + cell.cell_contents = previous_cell_value + + +_MemoVarT = TypeVar("_MemoVarT") + + +_PUBLIC_NAMESPACES: tuple[tuple[str, str], ...] = ( + # (display prefix, dotted attribute path to walk). Order matters — the + # shortest user-facing name wins. ``rxe`` only resolves when the optional + # ``reflex_enterprise`` package is installed. + ("rx.el", "reflex.el"), + ("rx", "reflex"), + ("rxe.dnd", "reflex_enterprise.dnd"), + ("rxe.flow", "reflex_enterprise.flow"), + ("rxe.components.dnd", "reflex_enterprise.components.dnd"), + ("rxe.components.flow", "reflex_enterprise.components.flow"), + ("rxe", "reflex_enterprise"), +) + + +def _resolve_namespace(dotted: str) -> Any: + """Walk a dotted path of attribute accesses rooted at an importable module. + + Args: + dotted: e.g. ``"reflex.el"`` or ``"reflex_enterprise.components.flow"``. + + Returns: + The resolved namespace object, or ``None`` if any step fails. + """ + head, *rest = dotted.split(".") + try: + ns: Any = importlib.import_module(head) + except ImportError: + return None + for attr in rest: + ns = getattr(ns, attr, None) + if ns is None: + return None + return ns + + +def _resolve_component_qualname(cls: type) -> str | None: + """Find the shortest public ``rx``/``rxe`` qualname under which ``cls`` lives. + + Args: + cls: The class to resolve. + + Returns: + The qualname (e.g. ``"rxe.dnd.Draggable"``), or ``None`` when no public + path is found. + """ + name = cls.__name__ + for display_prefix, dotted in _PUBLIC_NAMESPACES: + ns = _resolve_namespace(dotted) + if ns is not None and getattr(ns, name, None) is cls: + return f"{display_prefix}.{name}" + return None + + +def _suggest_return_annotation(result: Any, is_component: bool) -> str | None: + """Infer a copy-pasteable return annotation from a memo body's eval result. + + Args: + result: The value the body returned during memo eval. + is_component: Whether the memo was treated as component-returning. + + Returns: + A suggestion like ``"rxe.dnd.Draggable"`` or ``"rx.Var[str]"``, or + ``None`` when the result doesn't map cleanly to a public name. + """ + if is_component: + body = _normalize_component_return(result) + if body is None: + return None + return _resolve_component_qualname(type(body)) + if isinstance(result, Var): + inner = result._var_type + if isinstance(inner, type): + qual = _resolve_component_qualname(inner) + if qual is not None: + return f"rx.Var[{qual}]" + if inner.__module__ == "builtins": + return f"rx.Var[{inner.__name__}]" + return None + + +def _warn_missing_annotations( + fn_name: str, + missing_return: bool, + defaulted_params: Sequence[str], + suggested_return: str | None = None, +) -> None: + """Emit a deprecation warning for ``@rx.memo`` without explicit annotations. + + Args: + fn_name: Name of the decorated function (for the warning text). + missing_return: Whether the return annotation was missing. + defaulted_params: Names of parameters whose annotation was defaulted. + suggested_return: Inferred return type (e.g. ``"rxe.dnd.Draggable"``) + to surface in the message. When ``None``, the generic hint is used. + """ + parts: list[str] = [] + if missing_return: + if suggested_return is not None: + parts.append(f"a return annotation `-> {suggested_return}`") + else: + parts.append("a return annotation (`-> rx.Component` or `-> rx.Var[...]`)") + if defaulted_params: + joined = ", ".join(f"`{name}`" for name in defaulted_params) + parts.append(f"annotations on parameter(s) {joined} (`rx.Var[...]`)") + console.deprecate( + feature_name=f"`@rx.memo` on `{fn_name}` without explicit annotations", + reason=( + f"Add {' and '.join(parts)}. Missing annotations now default to " + "`rx.Component` / `rx.Var[Any]`" + ), + deprecation_version="0.9.3", + removal_version="1.0", + ) + + +@overload +def memo(fn: Callable[..., Component]) -> _MemoComponentWrapper: ... +@overload +def memo(fn: Callable[..., Var[_MemoVarT]]) -> _MemoFunctionWrapper: ... +def memo(fn: Callable[..., Any]) -> _MemoComponentWrapper | _MemoFunctionWrapper: + """Create a memo from a function. + + Args: + fn: The function to memoize. + + Returns: + The wrapped function or component factory. + + Raises: + TypeError: If the return type is not supported. + """ + hints = get_type_hints(fn, include_extras=True) + return_annotation = hints.get("return", inspect.Signature.empty) + missing_return = return_annotation is inspect.Signature.empty + if missing_return: + return_annotation = Component + hints["return"] = Component + + is_component = _is_component_annotation(return_annotation) + if not is_component and not _is_var_annotation(return_annotation): + msg = ( + f"`@rx.memo` on `{fn.__name__}` must return `rx.Component` or " + f"`rx.Var[...]`, got `{return_annotation}`." + ) + raise TypeError(msg) + + defaulted_params: list[str] = [] + params = _analyze_params( + fn, + for_component=is_component, + hints=hints, + defaulted_params=defaulted_params, + ) + + # Construct the wrapper against a placeholder body so the user's body can + # self-reference the memo during eager evaluation; the real body is patched + # in after eval completes (see `_bind_self_reference`). + definition: MemoComponentDefinition | MemoFunctionDefinition + if is_component: + definition = MemoComponentDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + export_name=format.to_title_case(fn.__name__), + component=Fragment.create(), + ) + wrapper = _create_component_wrapper(definition) + else: + definition = MemoFunctionDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + function=ArgsFunctionOperation.create( + args_names=(), return_expr=LiteralVar.create(None) + ), + imported_var=_imported_function_var( + fn.__name__, _annotation_inner_type(return_annotation) + ), + ) + wrapper = _create_function_wrapper(definition) + + with _bind_self_reference(fn, wrapper): + result = _evaluate_memo_function(fn, params) + + if missing_return or defaulted_params: + _warn_missing_annotations( + fn.__name__, + missing_return, + defaulted_params, + suggested_return=_suggest_return_annotation(result, is_component) + if missing_return + else None, + ) + + if is_component: + body = _normalize_component_return(result) + if body is None: + msg = ( + f"Component-returning `@rx.memo` `{fn.__name__}` must return an " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + object.__setattr__(definition, "component", _lift_rest_props(body)) + else: + return_expr = Var.create(result) + _validate_var_return_expr(return_expr, fn.__name__) + object.__setattr__( + definition, "function", _build_args_function(params, return_expr) + ) + + _register_memo_definition(definition) + return wrapper + + +__all__ = [ + "MEMOS", + "MemoComponent", + "MemoComponentDefinition", + "MemoDefinition", + "MemoFunctionDefinition", + "create_passthrough_component_memo", + "memo", +] diff --git a/packages/reflex-base/src/reflex_base/plugins/compiler.py b/packages/reflex-base/src/reflex_base/plugins/compiler.py index ecb55a03d92..5720d27aab7 100644 --- a/packages/reflex-base/src/reflex_base/plugins/compiler.py +++ b/packages/reflex-base/src/reflex_base/plugins/compiler.py @@ -765,9 +765,8 @@ class CompileContext(BaseContext): # Auto-memoize wrapper tags seen during the tree walk (populated by # ``MemoizeStatefulPlugin``). memoize_wrappers: dict[str, None] = dataclasses.field(default_factory=dict) - # Compiler-generated experimental memo definitions for auto-memoized - # stateful wrappers. Stored as ``Any`` to keep ``reflex_base`` decoupled - # from ``reflex.experimental.memo``. + # Compiler-generated memo definitions for auto-memoized stateful wrappers. + # Stored as ``Any`` to avoid an import cycle with ``reflex_base.components.memo``. auto_memo_components: dict[str, Any] = dataclasses.field(default_factory=dict) def compile( diff --git a/packages/reflex-base/src/reflex_base/vars/__init__.py b/packages/reflex-base/src/reflex_base/vars/__init__.py index 4c0ebe85c94..c986cf1fd36 100644 --- a/packages/reflex-base/src/reflex_base/vars/__init__.py +++ b/packages/reflex-base/src/reflex_base/vars/__init__.py @@ -2,6 +2,8 @@ from . import base, color, datetime, function, number, object, sequence from .base import ( + EMPTY_VAR_INT, + EMPTY_VAR_STR, BaseStateMeta, EvenMoreBasicBaseState, Field, @@ -28,6 +30,8 @@ ) __all__ = [ + "EMPTY_VAR_INT", + "EMPTY_VAR_STR", "ArrayVar", "BaseStateMeta", "BooleanVar", diff --git a/packages/reflex-base/src/reflex_base/vars/base.py b/packages/reflex-base/src/reflex_base/vars/base.py index 4034b69aa3c..56760594f84 100644 --- a/packages/reflex-base/src/reflex_base/vars/base.py +++ b/packages/reflex-base/src/reflex_base/vars/base.py @@ -362,7 +362,7 @@ def can_use_in_object_var(cls: GenericType) -> bool: Whether the class can be used in an ObjectVar. """ if types.is_union(cls): - return all(can_use_in_object_var(t) for t in types.get_args(cls)) + return all(can_use_in_object_var(t) for t in get_args(cls)) return ( isinstance(cls, type) and not safe_issubclass(cls, Var) @@ -3691,3 +3691,7 @@ def add_field(cls, name: str, var: Var, default_value: Any): annotated_type=var._var_type, ) cls.__fields__[name] = new_field + + +EMPTY_VAR_STR: Var[str] = LiteralVar.create("") +EMPTY_VAR_INT: Var[int] = LiteralVar.create(0) diff --git a/packages/reflex-components-internal/src/reflex_components_internal/blocks/lemcal.py b/packages/reflex-components-internal/src/reflex_components_internal/blocks/lemcal.py index 35a108fb911..75195eb04fb 100644 --- a/packages/reflex-components-internal/src/reflex_components_internal/blocks/lemcal.py +++ b/packages/reflex-components-internal/src/reflex_components_internal/blocks/lemcal.py @@ -9,8 +9,17 @@ LEMCAL_DEMO_URL = "https://app.lemcal.com/@alek/reflex-demo-call" +def lemcal_script(**props) -> rx.Component: + """Return the Lemcal integrations script tag.""" + return rx.script( + src="https://cdn.lemcal.com/lemcal-integrations.min.js", + defer=True, + **props, + ) + + @rx.memo -def lemcal_booking_calendar(): +def lemcal_booking_calendar() -> rx.Component: """Return the Lemcal booking calendar.""" return rx.fragment( rx.el.div( @@ -31,15 +40,6 @@ def lemcal_booking_calendar(): ) -def lemcal_script(**props) -> rx.Component: - """Return the Lemcal integrations script tag.""" - return rx.script( - src="https://cdn.lemcal.com/lemcal-integrations.min.js", - defer=True, - **props, - ) - - def lemcal_dialog(trigger: rx.Component, **props) -> rx.Component: """Return a Lemcal dialog container element.""" class_name = cn("w-auto", props.pop("class_name", "")) diff --git a/packages/reflex-components-internal/src/reflex_components_internal/components/base/skeleton.py b/packages/reflex-components-internal/src/reflex_components_internal/components/base/skeleton.py index 7baa29dd51d..3207225b0f5 100644 --- a/packages/reflex-components-internal/src/reflex_components_internal/components/base/skeleton.py +++ b/packages/reflex-components-internal/src/reflex_components_internal/components/base/skeleton.py @@ -2,8 +2,9 @@ from reflex_components_core.el.elements.typography import Div -from reflex.components.component import Component, memo -from reflex.vars.base import Var +from reflex.components.component import Component +from reflex.components.memo import memo +from reflex.vars.base import EMPTY_VAR_STR, Var from reflex_components_internal.utils.twmerge import cn @@ -15,7 +16,7 @@ class ClassNames: @memo def skeleton_component( - class_name: str | Var[str] = "", + class_name: Var[str] = EMPTY_VAR_STR, ) -> Component: """Skeleton component. diff --git a/packages/reflex-components-internal/src/reflex_components_internal/components/base/theme_switcher.py b/packages/reflex-components-internal/src/reflex_components_internal/components/base/theme_switcher.py index 351505bffc5..51826659d02 100644 --- a/packages/reflex-components-internal/src/reflex_components_internal/components/base/theme_switcher.py +++ b/packages/reflex-components-internal/src/reflex_components_internal/components/base/theme_switcher.py @@ -4,8 +4,10 @@ from reflex_components_core.el.elements.forms import Button from reflex_components_core.el.elements.typography import Div -from reflex.components.component import Component, memo +from reflex.components.component import Component +from reflex.components.memo import memo from reflex.style import LiteralColorMode, color_mode, set_color_mode +from reflex.vars.base import EMPTY_VAR_STR, Var from reflex_components_internal.components.icons.hugeicon import hi from reflex_components_internal.utils.twmerge import cn @@ -29,7 +31,7 @@ def theme_switcher_item(mode: LiteralColorMode, icon: str) -> Component: ) -def theme_switcher(class_name: str = "") -> Component: +def theme_switcher(class_name: str | Var[str] = "") -> Component: """Theme switcher component. Returns: @@ -47,7 +49,7 @@ def theme_switcher(class_name: str = "") -> Component: @memo -def memoized_theme_switcher(class_name: str = "") -> Component: +def memoized_theme_switcher(class_name: Var[str] = EMPTY_VAR_STR) -> Component: """Memoized theme switcher component. Returns: diff --git a/packages/reflex-components-internal/src/reflex_components_internal/components/icons/others.py b/packages/reflex-components-internal/src/reflex_components_internal/components/icons/others.py index 419e31e6575..b187b08914a 100644 --- a/packages/reflex-components-internal/src/reflex_components_internal/components/icons/others.py +++ b/packages/reflex-components-internal/src/reflex_components_internal/components/icons/others.py @@ -2,15 +2,16 @@ from reflex_components_core.el.elements.media import svg -from reflex.components.component import Component, memo -from reflex.vars.base import Var +from reflex.components.component import Component +from reflex.components.memo import memo +from reflex.vars.base import EMPTY_VAR_STR, Var from reflex_components_internal.components.icons.hugeicon import hi from reflex_components_internal.utils.twmerge import cn @memo def spinner_component( - class_name: str | Var[str] = "", + class_name: Var[str] = EMPTY_VAR_STR, ) -> Component: """Create a spinner SVG icon. @@ -44,7 +45,7 @@ def spinner_component( @memo def select_arrow_icon( - class_name: str | Var[str] = "", + class_name: Var[str] = EMPTY_VAR_STR, ) -> Component: """A select arrow SVG icon. @@ -58,7 +59,7 @@ def select_arrow_icon( @memo -def arrow_svg_component(class_name: str | Var[str] = "") -> Component: +def arrow_svg_component(class_name: Var[str] = EMPTY_VAR_STR) -> Component: """Create a tooltip arrow SVG icon. The arrow SVG icon. diff --git a/packages/reflex-components-markdown/src/reflex_components_markdown/markdown.py b/packages/reflex-components-markdown/src/reflex_components_markdown/markdown.py index aea8542fd17..f7ae600d5d6 100644 --- a/packages/reflex-components-markdown/src/reflex_components_markdown/markdown.py +++ b/packages/reflex-components-markdown/src/reflex_components_markdown/markdown.py @@ -13,7 +13,6 @@ BaseComponent, Component, ComponentNamespace, - CustomComponent, MemoizationLeaf, field, ) @@ -413,15 +412,7 @@ def _get_map_fn_custom_code_from_children( if isinstance(component, MarkdownComponentMap): custom_code_list.append(component.get_component_map_custom_code()) - # If the component is a custom component(rx.memo), obtain the underlining - # component and get the custom code from the children. - if isinstance(component, CustomComponent): - custom_code_list.extend( - self._get_map_fn_custom_code_from_children( - component.component_fn(*component.get_prop_vars()) - ) - ) - elif isinstance(component, Component): + if isinstance(component, Component): for child in component.children: custom_code_list.extend( self._get_map_fn_custom_code_from_children(child) diff --git a/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/code.py b/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/code.py index a2557b46da7..2da45e16140 100644 --- a/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/code.py +++ b/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/code.py @@ -11,7 +11,7 @@ @rx.memo -def _plain_code_block(code: str, language: str): +def _plain_code_block(code: rx.Var[str], language: rx.Var[str]) -> rx.Component: """Shared plain code block implementation. Returns: @@ -83,7 +83,7 @@ def code_block(code: str, language: str): @rx.memo -def code_block_dark(code: str, language: str): +def code_block_dark(code: rx.Var[str], language: rx.Var[str]) -> rx.Component: """Code block dark. Returns: diff --git a/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/headings.py b/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/headings.py index 4a0808f6613..c07c6118c43 100644 --- a/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/headings.py +++ b/packages/reflex-site-shared/src/reflex_site_shared/components/blocks/headings.py @@ -144,7 +144,7 @@ def create( @rx.memo -def h1_comp(text: str) -> rx.Component: +def h1_comp(text: rx.Var[str]) -> rx.Component: """H1 comp. Returns: @@ -158,7 +158,7 @@ def h1_comp(text: str) -> rx.Component: @rx.memo -def h1_comp_xd(text: str) -> rx.Component: +def h1_comp_xd(text: rx.Var[str]) -> rx.Component: """H1 comp xd. Returns: @@ -172,7 +172,7 @@ def h1_comp_xd(text: str) -> rx.Component: @rx.memo -def h2_comp(text: str) -> rx.Component: +def h2_comp(text: rx.Var[str]) -> rx.Component: """H2 comp. Returns: @@ -187,7 +187,7 @@ def h2_comp(text: str) -> rx.Component: @rx.memo -def h2_comp_xd(text: str) -> rx.Component: +def h2_comp_xd(text: rx.Var[str]) -> rx.Component: """H2 comp xd. Returns: @@ -202,7 +202,7 @@ def h2_comp_xd(text: str) -> rx.Component: @rx.memo -def h3_comp(text: str) -> rx.Component: +def h3_comp(text: rx.Var[str]) -> rx.Component: """H3 comp. Returns: @@ -217,7 +217,7 @@ def h3_comp(text: str) -> rx.Component: @rx.memo -def h3_comp_xd(text: str) -> rx.Component: +def h3_comp_xd(text: rx.Var[str]) -> rx.Component: """H3 comp xd. Returns: @@ -232,7 +232,7 @@ def h3_comp_xd(text: str) -> rx.Component: @rx.memo -def h4_comp(text: str) -> rx.Component: +def h4_comp(text: rx.Var[str]) -> rx.Component: """H4 comp. Returns: @@ -247,7 +247,7 @@ def h4_comp(text: str) -> rx.Component: @rx.memo -def h4_comp_xd(text: str) -> rx.Component: +def h4_comp_xd(text: rx.Var[str]) -> rx.Component: """H4 comp xd. Returns: @@ -262,7 +262,7 @@ def h4_comp_xd(text: str) -> rx.Component: @rx.memo -def img_comp_xd(src: str) -> rx.Component: +def img_comp_xd(src: rx.Var[str]) -> rx.Component: """Img comp xd. Returns: diff --git a/packages/reflex-site-shared/src/reflex_site_shared/components/code_card.py b/packages/reflex-site-shared/src/reflex_site_shared/components/code_card.py index f8a23424058..a2e1fa0e413 100644 --- a/packages/reflex-site-shared/src/reflex_site_shared/components/code_card.py +++ b/packages/reflex-site-shared/src/reflex_site_shared/components/code_card.py @@ -11,8 +11,8 @@ @rx.memo def install_command( - command: str, - show_dollar_sign: bool = True, + command: rx.Var[str], + show_dollar_sign: rx.Var[bool] = True, ) -> rx.Component: """Install command. diff --git a/packages/reflex-site-shared/src/reflex_site_shared/views/cta_card.py b/packages/reflex-site-shared/src/reflex_site_shared/views/cta_card.py index b47c4a9ba97..24f2e0b7bfc 100644 --- a/packages/reflex-site-shared/src/reflex_site_shared/views/cta_card.py +++ b/packages/reflex-site-shared/src/reflex_site_shared/views/cta_card.py @@ -9,7 +9,7 @@ @rx.memo -def cta_card(): +def cta_card() -> rx.Component: """Cta card. Returns: diff --git a/packages/reflex-site-shared/src/reflex_site_shared/views/footer.py b/packages/reflex-site-shared/src/reflex_site_shared/views/footer.py index f703dc9320c..69d80a90d4b 100644 --- a/packages/reflex-site-shared/src/reflex_site_shared/views/footer.py +++ b/packages/reflex-site-shared/src/reflex_site_shared/views/footer.py @@ -231,7 +231,10 @@ def footer_legal(class_name: str = "") -> rx.Component: @rx.memo -def footer_index(class_name: str = "", grid_class_name: str = "") -> rx.Component: +def footer_index( + class_name: rx.Var[str] = rx.EMPTY_VAR_STR, + grid_class_name: rx.Var[str] = rx.EMPTY_VAR_STR, +) -> rx.Component: """Full marketing footer: logo, newsletter, links, and legal. Returns: diff --git a/pyi_hashes.json b/pyi_hashes.json index 00ea0719dfd..b4498b3581b 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -40,7 +40,7 @@ "packages/reflex-components-dataeditor/src/reflex_components_dataeditor/dataeditor.pyi": "8e379fa038c7c6c0672639eb5902934d", "packages/reflex-components-gridjs/src/reflex_components_gridjs/datatable.pyi": "d2dc211d707c402eb24678a4cba945f7", "packages/reflex-components-lucide/src/reflex_components_lucide/icon.pyi": "b692058e40b15da293fbf463ad300a83", - "packages/reflex-components-markdown/src/reflex_components_markdown/markdown.pyi": "27661fcc57f3aa6b22ebefbc1082350c", + "packages/reflex-components-markdown/src/reflex_components_markdown/markdown.pyi": "e04f22f5d3d2b5dfd99f9fbedb2b4f3d", "packages/reflex-components-moment/src/reflex_components_moment/moment.pyi": "d6a02e447dfd3c91bba84bcd02722aed", "packages/reflex-components-plotly/src/reflex_components_plotly/plotly.pyi": "91e956633778c6992f04940c69ff7140", "packages/reflex-components-radix/src/reflex_components_radix/__init__.pyi": "19216eb3618f68c8a76e5e43801cf4af", @@ -118,7 +118,7 @@ "packages/reflex-components-recharts/src/reflex_components_recharts/polar.pyi": "1979bb6c22bb7a0d3342b2d63fb19d74", "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "c5288f311fe37b23539518ba2a3d4482", "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", - "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", + "reflex/__init__.pyi": "12a863ddbcac050c702a3ec6092ae17c", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "d09629b81bf0df6153b131ac0ee10bd7" + "reflex/experimental/memo.pyi": "5bfbbd60585132d7a76840a0dbacbdd2" } diff --git a/reflex/__init__.py b/reflex/__init__.py index 3a81b098db6..6e711166871 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -140,9 +140,9 @@ "reflex_base.components.component": [ "Component", "NoSSRComponent", - "memo", "ComponentNamespace", ], + "reflex_base.components.memo": ["memo"], "reflex_components_core.el.elements.media": ["image"], "reflex_components_lucide": ["icon"], **_COMPONENTS_BASE_MAPPING, @@ -235,7 +235,7 @@ "utils.imports": ["ImportDict", "ImportVar"], "utils.misc": ["run_in_thread"], "utils.serializers": ["serializer"], - "vars": ["Var", "field", "Field", "RestProp"], + "vars": ["Var", "field", "Field", "RestProp", "EMPTY_VAR_STR", "EMPTY_VAR_INT"], } _SUBMODULES: set[str] = { diff --git a/reflex/app.py b/reflex/app.py index 91a8180e262..267b02c3ec1 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -194,9 +194,9 @@ def default_overlay_component() -> Component: Returns: The default overlay component, which is a connection banner/toaster set. """ - from reflex_base.components.component import memo + from reflex_base.components.memo import memo - def default_overlay_components(): + def default_overlay_components() -> Component: return Fragment.create( connection_pulser(), connection_toaster(), @@ -1162,10 +1162,10 @@ def _should_compile(self) -> bool: def _setup_sticky_badge(self): """Add the sticky badge to the app.""" - from reflex_base.components.component import memo + from reflex_base.components.memo import memo @memo - def memoized_badge(): + def memoized_badge() -> Component: sticky_badge = sticky() sticky_badge._add_style_recursive({}) return sticky_badge diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index a0a934c1b6d..b4b562e1121 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -11,13 +11,17 @@ from reflex_base import constants from reflex_base.components.component import ( - CUSTOM_COMPONENTS, BaseComponent, Component, ComponentStyle, - CustomComponent, evaluate_style_namespaces, ) +from reflex_base.components.memo import ( + MEMOS, + MemoComponentDefinition, + MemoDefinition, + MemoFunctionDefinition, +) from reflex_base.config import get_config from reflex_base.constants.compiler import PageNames, ResetStylesheet from reflex_base.constants.state import FIELD_MARKER @@ -35,12 +39,6 @@ from reflex.compiler import templates, utils from reflex.compiler.plugins import default_page_plugins -from reflex.experimental.memo import ( - EXPERIMENTAL_MEMOS, - ExperimentalMemoComponentDefinition, - ExperimentalMemoDefinition, - ExperimentalMemoFunctionDefinition, -) from reflex.state import BaseState, code_uses_state_contexts from reflex.utils import console, frontend_skeleton, path_ops, prerequisites from reflex.utils.exec import get_compile_context, is_prod_mode @@ -394,56 +392,29 @@ def _compile_component(component: Component) -> str: def _compile_memo_components( - components: Iterable[CustomComponent], - experimental_memos: Iterable[ExperimentalMemoDefinition] = (), + memos: Iterable[MemoDefinition] = (), ) -> tuple[list[tuple[str, str]], dict[str, list[ImportVar]]]: - """Compile each memo/custom-component as its own module plus an index. + """Compile each memo as its own module. Each memo lands in ``.web//.jsx`` with only the imports - it actually uses. Experimental memo wrappers declare their ``library`` as - that per-memo file path so page-side imports resolve directly to the + it actually uses. Memo wrappers declare their ``library`` as that + per-memo file path so page-side imports resolve directly to the individual module. - The ``$/utils/components`` index only re-exports the legacy - ``@rx.memo`` custom components, which are the ones app-level code - (``root.jsx``) imports by name. Keeping experimental memos out of the - index is what lets root's ``import * as utils_components`` avoid - transitively dragging every page-specific memo into the always-loaded - chunk — the tree-shaking win of per-memo files relies on that. - Args: - components: The components to compile. - experimental_memos: The experimental memos to compile. + memos: The memos to compile. Returns: - A list of ``(path, code)`` pairs to write — one per memo plus one - index — and the aggregated imports across all memo modules. + A list of ``(path, code)`` pairs to write — one per memo — and the + aggregated imports across all memo modules. """ per_memo_files: list[tuple[str, str]] = [] - # Only legacy custom components go through the index: they are the ones - # root.jsx/custom code imports by name from ``$/utils/components``. - # Experimental memos declare their library per-file (see - # ``_get_experimental_memo_component_class``) so pages import them - # directly and the index stays small. - index_entries: list[tuple[str, str]] = [] aggregate_imports: dict[str, list[ImportVar]] = {} base_dir = utils.get_memo_components_dir() - for component in components: - component_render, component_imports = utils.compile_custom_component(component) - name = component_render["name"] - code, file_imports = _compile_single_memo_component( - component_render, component_imports - ) - path = _memo_component_file_path(base_dir, name) - specifier = _memo_component_index_specifier(name) - per_memo_files.append((path, code)) - index_entries.append((name, specifier)) - _extend_imports_in_place(aggregate_imports, file_imports) - - for memo in experimental_memos: - if isinstance(memo, ExperimentalMemoComponentDefinition): + for memo in memos: + if isinstance(memo, MemoComponentDefinition): memo_render, memo_imports = utils.compile_experimental_component_memo(memo) name = memo_render["name"] code, file_imports = _compile_single_memo_component( @@ -452,7 +423,7 @@ def _compile_memo_components( path = _memo_component_file_path(base_dir, name) per_memo_files.append((path, code)) _extend_imports_in_place(aggregate_imports, file_imports) - elif isinstance(memo, ExperimentalMemoFunctionDefinition): + elif isinstance(memo, MemoFunctionDefinition): memo_render, memo_imports = utils.compile_experimental_function_memo(memo) name = memo_render["name"] code, file_imports = _compile_single_memo_function( @@ -462,9 +433,7 @@ def _compile_memo_components( per_memo_files.append((path, code)) _extend_imports_in_place(aggregate_imports, file_imports) - index_path = utils.get_components_path() - index_code = templates.memo_index_template(index_entries) - return [(index_path, index_code), *per_memo_files], aggregate_imports + return per_memo_files, aggregate_imports def _compile_single_memo_component( @@ -678,20 +647,18 @@ def compile_page_from_context(page_ctx: PageContext) -> tuple[str, str]: def compile_memo_components( - components: Iterable[CustomComponent], - experimental_memos: Iterable[ExperimentalMemoDefinition] = (), + memos: Iterable[MemoDefinition] = (), ) -> tuple[list[tuple[str, str]], dict[str, list[ImportVar]]]: - """Compile the custom components into one module per memo plus an index. + """Compile the memos into one module per memo. Args: - components: The custom components to compile. - experimental_memos: The experimental memos to compile. + memos: The memos to compile. Returns: - A list of ``(path, code)`` pairs (one per memo module and one index) - alongside the aggregated imports across all memo modules. + A list of ``(path, code)`` pairs (one per memo module) alongside the + aggregated imports across all memo modules. """ - return _compile_memo_components(components, experimental_memos) + return _compile_memo_components(memos) def purge_web_pages_dir(): @@ -944,10 +911,10 @@ def _resolve_app_wrap_components( app_wrappers[200, "StrictMode"] = StrictMode.create() if (toaster := app.toaster) is not None: - from reflex_base.components.component import memo + from reflex_base.components.memo import memo @memo - def memoized_toast_provider(): + def memoized_toast_provider() -> Component: return toaster app_wrappers[44, "ToasterProvider"] = Fragment.create(memoized_toast_provider()) @@ -1129,10 +1096,9 @@ def compile_app( all_imports = utils.merge_imports(all_imports, app_root._get_all_imports()) memo_component_files, memo_components_imports = compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), ( - *tuple(EXPERIMENTAL_MEMOS.values()), - *tuple(compile_ctx.auto_memo_components.values()), + *MEMOS.values(), + *compile_ctx.auto_memo_components.values(), ), ) compile_results.extend(memo_component_files) diff --git a/reflex/compiler/plugins/memoize.py b/reflex/compiler/plugins/memoize.py index cdd7c3e7c49..62878ef222a 100644 --- a/reflex/compiler/plugins/memoize.py +++ b/reflex/compiler/plugins/memoize.py @@ -9,8 +9,8 @@ Each unique subtree shape contributes: -- One generated experimental memo component definition, compiled into the - shared ``$/utils/components`` module. +- One generated experimental memo component definition, compiled into its own + per-memo module at ``$/utils/components/``. - ``useCallback`` hook lines for each non-lifecycle event trigger, emitted into the generated memo body so handler hooks stay inside that rendering domain. @@ -23,6 +23,7 @@ from typing import Any from reflex_base.components.component import BaseComponent, Component +from reflex_base.components.memo import create_passthrough_component_memo from reflex_base.components.memoize_helpers import ( MemoizationStrategy, _is_structural_memoization_child, @@ -34,8 +35,6 @@ from reflex_base.plugins import ComponentAndChildren, PageContext from reflex_base.plugins.base import Plugin -from reflex.experimental.memo import create_passthrough_component_memo - def _subtree_has_reactive_data( component: Component, _cache: dict[int, bool] | None = None diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index ca1cf146d11..bafb2290fb0 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -14,8 +14,13 @@ from urllib.parse import urlparse from reflex_base import constants -from reflex_base.components.component import Component, ComponentStyle, CustomComponent -from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER, FIELD_MARKER +from reflex_base.components.component import Component, ComponentStyle +from reflex_base.components.memo import ( + MemoComponentDefinition, + MemoFunctionDefinition, + MemoParamKind, +) +from reflex_base.constants.state import FIELD_MARKER from reflex_base.style import Style from reflex_base.utils import format, imports from reflex_base.utils.imports import ImportVar, ParsedImportDict @@ -28,10 +33,6 @@ from reflex_components_core.el.elements.other import Html from reflex_components_core.el.elements.sectioning import Body -from reflex.experimental.memo import ( - ExperimentalMemoComponentDefinition, - ExperimentalMemoFunctionDefinition, -) from reflex.istate.storage import Cookie, LocalStorage, SessionStorage from reflex.state import BaseState, _resolve_delta from reflex.utils import path_ops @@ -321,53 +322,6 @@ def compile_client_storage( } -def compile_custom_component( - component: CustomComponent, -) -> tuple[dict, ParsedImportDict]: - """Compile a custom component. - - Args: - component: The custom component to compile. - - Returns: - A tuple of the compiled component and the imports required by the component. - """ - # Render the component. - render = component.get_component() - - # Get the imports. - imports: ParsedImportDict = {} - for lib, fields in render._get_all_imports().items(): - if lib != component.library: - imports[lib] = fields - continue - - filtered_fields = [field for field in fields if field.tag != component.tag] - if filtered_fields: - imports[lib] = filtered_fields - - imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) - - # Concatenate the props. - props = list(component.props) - - # Compile the component. - return ( - { - "name": component.tag, - "props": props, - "signature": DestructuredArg( - fields=tuple(f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in props) - ).to_javascript(), - "render": render.render(), - "hooks": render._get_all_hooks(), - "custom_code": render._get_all_custom_code(), - "dynamic_imports": render._get_all_dynamic_imports(), - }, - imports, - ) - - def _apply_component_style_for_compile(component: Component) -> Component: """Apply the app style to a compiled component tree. @@ -421,9 +375,9 @@ def _app_style() -> ComponentStyle | Style: def compile_experimental_component_memo( - definition: ExperimentalMemoComponentDefinition, + definition: MemoComponentDefinition, ) -> tuple[dict, ParsedImportDict]: - """Compile an experimental memo component. + """Compile a memo component. Args: definition: The component memo definition. @@ -466,11 +420,9 @@ def compile_experimental_component_memo( dynamic_imports = render._get_all_dynamic_imports() all_imports = render._get_all_imports() - # Each experimental memo now lives in ``web/utils/components/.jsx``, - # so importing the ``$/utils/components`` index from this file is only - # circular when ```` itself appears in that index — i.e. a legacy - # ``@rx.memo`` wrapper file. For auto-memo wrappers around legacy custom - # components, the index import is legitimate and must be preserved. + # Each memo lives in ``web/utils/components/.jsx`` and is imported + # from ``$/utils/components/``. Strip a self-import so a memo body + # that references the wrapper's own module specifier doesn't recurse. self_module = f"$/{constants.Dirs.COMPONENTS_PATH}/{definition.export_name}" imports: ParsedImportDict = { lib: fields for lib, fields in all_imports.items() if lib != self_module @@ -479,15 +431,17 @@ def compile_experimental_component_memo( imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) signature_fields = [ - f"{param.js_prop_name}:{param.placeholder_name}" + field for param in definition.params - if not param.is_children and not param.is_rest + if (field := param.signature_field()) is not None ] - if any(param.is_children for param in definition.params): + if any(p.kind is MemoParamKind.CHILDREN for p in definition.params): signature_fields.insert(0, "children") - rest_param = next((param for param in definition.params if param.is_rest), None) + rest_param = next( + (p for p in definition.params if p.kind is MemoParamKind.REST), None + ) return ( { @@ -561,9 +515,9 @@ def _root_only_dynamic_imports(component: Component) -> set[str]: def compile_experimental_function_memo( - definition: ExperimentalMemoFunctionDefinition, + definition: MemoFunctionDefinition, ) -> tuple[dict, ParsedImportDict]: - """Compile an experimental memo function. + """Compile a memo function. Args: definition: The function memo definition. @@ -787,25 +741,12 @@ def get_context_path() -> str: return str(get_web_dir() / (constants.Dirs.CONTEXTS_PATH + constants.Ext.JS)) -def get_components_path() -> str: - """Get the path of the compiled components. - - Returns: - The path of the compiled components. - """ - return str( - get_web_dir() - / constants.Dirs.UTILS - / (constants.PageNames.COMPONENTS + constants.Ext.JSX), - ) - - def get_memo_components_dir() -> str: """Get the directory that holds per-memo module files. Returns: - The directory used for per-memo ``.jsx`` modules re-exported by the - top-level components index. + The directory used for per-memo ``.jsx`` modules. Pages import each + wrapper directly from ``$/utils/components/``. """ return str( get_web_dir() / constants.Dirs.UTILS / constants.PageNames.COMPONENTS, diff --git a/reflex/components/memo.py b/reflex/components/memo.py new file mode 100644 index 00000000000..4de9661b2e8 --- /dev/null +++ b/reflex/components/memo.py @@ -0,0 +1,4 @@ +# pyright: reportWildcardImportFromLibrary=false +"""Re-export from reflex_base.""" + +from reflex_base.components.memo import * # pragma: no cover diff --git a/reflex/experimental/__init__.py b/reflex/experimental/__init__.py index 5854243bea7..0b3ca628a8d 100644 --- a/reflex/experimental/__init__.py +++ b/reflex/experimental/__init__.py @@ -1,7 +1,9 @@ """Namespace for experimental features.""" from types import SimpleNamespace +from typing import Any +from reflex_base.components.memo import memo as _memo from reflex_base.utils.console import warn from reflex_components_code.shiki_code_block import code_block as code_block @@ -9,7 +11,6 @@ from . import hooks as hooks from .client_state import ClientStateVar as ClientStateVar -from .memo import memo as memo class ExperimentalNamespace(SimpleNamespace): @@ -42,6 +43,16 @@ def run_in_thread(self): self.register_component_warning("run_in_thread") return run_in_thread + @property + def memo(self) -> Any: + """Deprecated alias for :func:`rx.memo`. + + Returns: + The promoted memo decorator from ``reflex_base.components.memo``. + """ + self.register_component_warning("memo") + return _memo + @staticmethod def register_component_warning(component_name: str): """Add component to emitted warnings and throw a warning if it @@ -60,5 +71,4 @@ def register_component_warning(component_name: str): client_state=ClientStateVar.create, hooks=hooks, code_block=code_block, - memo=memo, ) diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index b5a18e3f26c..ea01542314d 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -1,1160 +1,11 @@ -"""Experimental memo support for vars and components.""" +"""Deprecated alias for :mod:`reflex_base.components.memo`.""" -from __future__ import annotations +import sys -import dataclasses -import inspect -from collections.abc import Callable -from copy import copy -from functools import cache, update_wrapper -from typing import Any, ClassVar, get_args, get_origin, get_type_hints +from reflex_base.components import memo -from reflex_base import constants -from reflex_base.components.component import Component -from reflex_base.components.dynamic import bundled_libraries -from reflex_base.components.memoize_helpers import ( - MemoizationStrategy, - get_memoization_strategy, -) -from reflex_base.constants.compiler import ( - MemoizationDisposition, - MemoizationMode, - SpecialAttributes, -) -from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER -from reflex_base.utils import format -from reflex_base.utils.imports import ImportVar -from reflex_base.utils.types import safe_issubclass -from reflex_base.vars import VarData -from reflex_base.vars.base import LiteralVar, Var -from reflex_base.vars.function import ( - ArgsFunctionOperation, - DestructuredArg, - FunctionStringVar, - FunctionVar, - ReflexCallable, -) -from reflex_base.vars.object import RestProp -from reflex_components_core.base.bare import Bare -from reflex_components_core.base.fragment import Fragment +from reflex.experimental import ExperimentalNamespace -from reflex.utils import types as type_utils +ExperimentalNamespace.register_component_warning("memo") - -@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) -class MemoParam: - """Metadata about a memo parameter.""" - - name: str - annotation: Any - kind: inspect._ParameterKind - default: Any = inspect.Parameter.empty - js_prop_name: str | None = None - placeholder_name: str = "" - is_children: bool = False - is_rest: bool = False - - -@dataclasses.dataclass(frozen=True, slots=True) -class ExperimentalMemoDefinition: - """Base metadata for an experimental memo.""" - - fn: Callable[..., Any] - python_name: str - params: tuple[MemoParam, ...] - - -@dataclasses.dataclass(frozen=True, slots=True) -class ExperimentalMemoFunctionDefinition(ExperimentalMemoDefinition): - """A memo that compiles to a JavaScript function.""" - - function: ArgsFunctionOperation - imported_var: FunctionVar - - -@dataclasses.dataclass(frozen=True, slots=True) -class ExperimentalMemoComponentDefinition(ExperimentalMemoDefinition): - """A memo that compiles to a React component.""" - - export_name: str - component: Component - # For passthrough wrappers built by the auto-memoize plugin: the - # ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo - # body. The ``component`` keeps its ORIGINAL children so compile-time - # walkers (``Form._get_form_refs`` etc.) can introspect the subtree; the - # compiler swaps to this placeholder only for the JSX render and for - # imports collection, so descendants emit their refs/imports/hooks in the - # page scope rather than being duplicated inside the memo body. - passthrough_hole_child: Component | None = None - - -class ExperimentalMemoComponent(Component): - """A rendered instance of an experimental memo component.""" - - library = f"$/{constants.Dirs.COMPONENTS_PATH}" - _memoization_mode = MemoizationMode(disposition=MemoizationDisposition.NEVER) - - # The user-authored component class this wrapper stands in for. Populated - # on the dynamic subclass by ``_get_experimental_memo_component_class`` so - # introspection (e.g. compile telemetry) can recover the underlying type - # without parsing the wrapper's auto-generated class name. - _wrapped_component_type: ClassVar[type[Component] | None] = None - - def _validate_component_children(self, children: list[Component]) -> None: - """Skip direct parent/child validation for memo wrapper instances. - - Experimental memos wrap an underlying compiled component definition. - The runtime wrapper should not interpose on `_valid_parents` checks for - the authored subtree because the wrapper itself is not the semantic - parent in the user-authored component tree. - - Args: - children: The children of the component (ignored). - """ - - def _post_init(self, **kwargs): - """Initialize the experimental memo component. - - Args: - **kwargs: The kwargs to pass to the component. - """ - definition = kwargs.pop("memo_definition") - - explicit_props = { - param.name - for param in definition.params - if not param.is_children and not param.is_rest - } - component_fields = self.get_fields() - - declared_props = { - key: kwargs.pop(key) for key in list(kwargs) if key in explicit_props - } - - rest_props = {} - if _get_rest_param(definition.params) is not None: - rest_props = { - key: kwargs.pop(key) - for key in list(kwargs) - if key not in component_fields and not SpecialAttributes.is_special(key) - } - - super()._post_init(**kwargs) - - props: dict[str, Any] = {} - for key, value in {**declared_props, **rest_props}.items(): - camel_cased_key = format.to_camel_case(key) - literal_value = LiteralVar.create(value) - props[camel_cased_key] = literal_value - setattr(self, camel_cased_key, literal_value) - - prop_names = tuple(props) - object.__setattr__(self, "get_props", lambda: prop_names) - - -@cache -def _get_experimental_memo_component_class( - export_name: str, - wrapped_component_type: type[Component] = Component, -) -> type[ExperimentalMemoComponent]: - """Get the component subclass for an experimental memo export. - - Class-level metadata that the compiler reads via ``type(comp)._get_*()`` - (notably ``_get_app_wrap_components``, which carries providers like - ``UploadFilesProvider`` that must reach the app root) is inherited from - ``wrapped_component_type`` so the wrapper is a transparent substitute for - the original in the compile tree. - - Args: - export_name: The exported React component name. - wrapped_component_type: The class of the component being memoized. - Defaults to ``Component`` for memos that don't wrap a user - component (e.g. function memos, raw passthroughs). - - Returns: - A cached component subclass with the tag set at class definition time. - """ - attrs: dict[str, Any] = { - "__module__": __name__, - "tag": export_name, - # Point each memo at its own per-file module so pages import directly - # from ``$/utils/components/`` rather than through the index. - # Per-file import paths give Vite distinct module boundaries per - # memo, enabling actual code-split by page. - "library": f"$/{constants.Dirs.COMPONENTS_PATH}/{export_name}", - "_wrapped_component_type": wrapped_component_type, - } - if ( - wrapped_component_type._get_app_wrap_components - is not Component._get_app_wrap_components - ): - attrs["_get_app_wrap_components"] = staticmethod( - wrapped_component_type._get_app_wrap_components - ) - return type( - f"ExperimentalMemoComponent_{export_name}", - (ExperimentalMemoComponent,), - attrs, - ) - - -EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} - - -def _memo_registry_key(definition: ExperimentalMemoDefinition) -> str: - """Get the registry key for an experimental memo. - - Args: - definition: The memo definition. - - Returns: - The registry key for the memo. - """ - if isinstance(definition, ExperimentalMemoComponentDefinition): - return definition.export_name - return definition.python_name - - -def _is_memo_reregistration( - existing: ExperimentalMemoDefinition, - definition: ExperimentalMemoDefinition, -) -> bool: - """Check whether a memo definition replaces the same memo during reload. - - Args: - existing: The currently registered memo definition. - definition: The new memo definition being registered. - - Returns: - Whether the new definition should replace the existing one. - """ - return ( - type(existing) is type(definition) - and existing.python_name == definition.python_name - and existing.fn.__module__ == definition.fn.__module__ - and existing.fn.__qualname__ == definition.fn.__qualname__ - ) - - -def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: - """Register an experimental memo definition. - - Args: - definition: The memo definition to register. - - Raises: - ValueError: If another memo already compiles to the same exported name. - """ - key = _memo_registry_key(definition) - if (existing := EXPERIMENTAL_MEMOS.get(key)) is not None and ( - not _is_memo_reregistration(existing, definition) - ): - msg = ( - f"Experimental memo name collision for `{key}`: " - f"`{existing.fn.__module__}.{existing.python_name}` and " - f"`{definition.fn.__module__}.{definition.python_name}` both compile " - "to the same memo name." - ) - raise ValueError(msg) - - EXPERIMENTAL_MEMOS[key] = definition - - -def _annotation_inner_type(annotation: Any) -> Any: - """Unwrap a Var-like annotation to its inner type. - - Args: - annotation: The annotation to unwrap. - - Returns: - The inner type for the annotation. - """ - if _is_rest_annotation(annotation): - return dict[str, Any] - - origin = get_origin(annotation) or annotation - if type_utils.safe_issubclass(origin, Var) and (args := get_args(annotation)): - return args[0] - return Any - - -def _is_rest_annotation(annotation: Any) -> bool: - """Check whether an annotation is a RestProp. - - Args: - annotation: The annotation to check. - - Returns: - Whether the annotation is a RestProp. - """ - origin = get_origin(annotation) or annotation - return isinstance(origin, type) and issubclass(origin, RestProp) - - -def _is_var_annotation(annotation: Any) -> bool: - """Check whether an annotation is a Var-like annotation. - - Args: - annotation: The annotation to check. - - Returns: - Whether the annotation is Var-like. - """ - origin = get_origin(annotation) or annotation - return isinstance(origin, type) and issubclass(origin, Var) - - -def _is_component_annotation(annotation: Any) -> bool: - """Check whether an annotation is component-like. - - Args: - annotation: The annotation to check. - - Returns: - Whether the annotation resolves to Component. - """ - origin = get_origin(annotation) or annotation - return isinstance(origin, type) and ( - safe_issubclass(origin, Component) - or bool( - safe_issubclass(origin, Var) - and (args := get_args(annotation)) - and safe_issubclass(args[0], Component) - ) - ) - - -def _children_annotation_is_valid(annotation: Any) -> bool: - """Check whether an annotation is valid for children. - - Args: - annotation: The annotation to check. - - Returns: - Whether the annotation is valid for children. - """ - return _is_var_annotation(annotation) and type_utils.typehint_issubclass( - _annotation_inner_type(annotation), Component - ) - - -def _get_children_param(params: tuple[MemoParam, ...]) -> MemoParam | None: - return next((param for param in params if param.is_children), None) - - -def _get_rest_param(params: tuple[MemoParam, ...]) -> MemoParam | None: - return next((param for param in params if param.is_rest), None) - - -def _imported_function_var(name: str, return_type: Any) -> FunctionVar: - """Create the imported FunctionVar for an experimental memo. - - Args: - name: The exported function name. - return_type: The return type of the function. - - Returns: - The imported FunctionVar. - """ - return FunctionStringVar.create( - name, - _var_type=ReflexCallable[Any, return_type], - _var_data=VarData( - imports={ - f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)] - } - ), - ) - - -def _component_import_var(name: str) -> Var: - """Create the imported component var for an experimental memo component. - - Args: - name: The exported component name. - - Returns: - The component var. - """ - return Var( - name, - _var_type=type[Component], - _var_data=VarData( - imports={ - f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)], - "@emotion/react": [ImportVar(tag="jsx")], - } - ), - ) - - -def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: - """Validate that a var-returning memo can compile safely. - - Args: - return_expr: The return expression. - func_name: The function name for error messages. - - Raises: - TypeError: If the return expression depends on unsupported features. - """ - var_data = VarData.merge(return_expr._get_all_var_data()) - if var_data is None: - return - - if var_data.hooks: - msg = ( - f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on hooks. " - "Use a component-returning `@rx._x.memo` instead." - ) - raise TypeError(msg) - - if var_data.components: - msg = ( - f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on embedded " - "components, custom code, or dynamic imports. Use a component-returning " - "`@rx._x.memo` instead." - ) - raise TypeError(msg) - - for lib in dict(var_data.imports): - if not lib: - continue - if lib.startswith((".", "/", "$/", "http")): - continue - if format.format_library_name(lib) in bundled_libraries: - continue - msg = ( - f"Var-returning `@rx._x.memo` `{func_name}` cannot import `{lib}` because " - "it is not bundled. Use a component-returning `@rx._x.memo` instead." - ) - raise TypeError(msg) - - -def _rest_placeholder(name: str) -> RestProp: - """Create the placeholder RestProp. - - Args: - name: The JavaScript identifier. - - Returns: - The placeholder rest prop. - """ - return RestProp(_js_expr=name, _var_type=dict[str, Any]) - - -def _var_placeholder(name: str, annotation: Any) -> Var: - """Create a placeholder Var for a memo parameter. - - Args: - name: The JavaScript identifier. - annotation: The parameter annotation. - - Returns: - The placeholder Var. - """ - return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type() - - -def _placeholder_for_param(param: MemoParam) -> Var: - """Create a placeholder var for a parameter. - - Args: - param: The parameter metadata. - - Returns: - The placeholder var. - """ - if param.is_rest: - return _rest_placeholder(param.placeholder_name) - return _var_placeholder(param.placeholder_name, param.annotation) - - -def _evaluate_memo_function( - fn: Callable[..., Any], - params: tuple[MemoParam, ...], -) -> Any: - """Evaluate a memo function with placeholder vars. - - Args: - fn: The function to evaluate. - params: The memo parameters. - - Returns: - The return value from the function. - """ - positional_args = [] - keyword_args = {} - - for param in params: - placeholder = _placeholder_for_param(param) - if param.kind in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): - positional_args.append(placeholder) - else: - keyword_args[param.name] = placeholder - - return fn(*positional_args, **keyword_args) - - -def _normalize_component_return(value: Any) -> Component | None: - """Normalize a component-like memo return value into a Component. - - Args: - value: The value returned from the memo function. - - Returns: - The normalized component, or ``None`` if the value is not component-like. - """ - if isinstance(value, Component): - return value - - if isinstance(value, Var) and type_utils.typehint_issubclass( - value._var_type, Component - ): - return Bare.create(value) - - return None - - -def _lift_rest_props(component: Component) -> Component: - """Convert RestProp children into special props. - - Args: - component: The component tree to rewrite. - - Returns: - The rewritten component tree. - """ - special_props = list(component.special_props) - rewritten_children = [] - - for child in component.children: - if isinstance(child, Bare) and isinstance(child.contents, RestProp): - special_props.append(child.contents) - continue - - if isinstance(child, Component): - child = _lift_rest_props(child) - - rewritten_children.append(child) - - component.children = rewritten_children - component.special_props = special_props - return component - - -def _analyze_params( - fn: Callable[..., Any], - *, - for_component: bool, -) -> tuple[MemoParam, ...]: - """Analyze and validate memo parameters. - - Args: - fn: The function to analyze. - for_component: Whether the memo returns a component. - - Returns: - The analyzed parameters. - - Raises: - TypeError: If the function signature is not supported. - """ - signature = inspect.signature(fn) - hints = get_type_hints(fn) - - params: list[MemoParam] = [] - rest_count = 0 - - for parameter in signature.parameters.values(): - if parameter.kind is inspect.Parameter.VAR_POSITIONAL: - msg = f"`@rx._x.memo` does not support `*args` in `{fn.__name__}`." - raise TypeError(msg) - if parameter.kind is inspect.Parameter.VAR_KEYWORD: - msg = f"`@rx._x.memo` does not support `**kwargs` in `{fn.__name__}`." - raise TypeError(msg) - if parameter.kind is inspect.Parameter.POSITIONAL_ONLY: - msg = ( - f"`@rx._x.memo` does not support positional-only parameters in " - f"`{fn.__name__}`." - ) - raise TypeError(msg) - - annotation = hints.get(parameter.name, parameter.annotation) - if annotation is inspect.Parameter.empty: - msg = ( - f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " - f"or `rx.RestProp`. Missing annotation for `{parameter.name}`." - ) - raise TypeError(msg) - - is_rest = _is_rest_annotation(annotation) - is_children = parameter.name == "children" and _children_annotation_is_valid( - annotation - ) - - if parameter.name == "children" and not is_children: - msg = ( - f"`children` in `{fn.__name__}` must be annotated as " - "`rx.Var[rx.Component]`." - ) - raise TypeError(msg) - - if not is_rest and not _is_var_annotation(annotation): - msg = ( - f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " - f"or `rx.RestProp`, got `{annotation}` for `{parameter.name}`." - ) - raise TypeError(msg) - - if is_rest: - rest_count += 1 - if rest_count > 1: - msg = ( - f"`@rx._x.memo` only supports one `rx.RestProp` in `{fn.__name__}`." - ) - raise TypeError(msg) - - js_prop_name = format.to_camel_case(parameter.name) - placeholder_name = ( - parameter.name - if is_children or is_rest or not for_component - else js_prop_name + CAMEL_CASE_MEMO_MARKER - ) - - params.append( - MemoParam( - name=parameter.name, - annotation=annotation, - kind=parameter.kind, - default=parameter.default, - js_prop_name=js_prop_name, - placeholder_name=placeholder_name, - is_children=is_children, - is_rest=is_rest, - ) - ) - - return tuple(params) - - -def _create_function_definition( - fn: Callable[..., Any], - return_annotation: Any, -) -> ExperimentalMemoFunctionDefinition: - """Create a definition for a var-returning memo. - - Args: - fn: The function to analyze. - return_annotation: The return annotation. - - Returns: - The function memo definition. - """ - params = _analyze_params(fn, for_component=False) - return_expr = Var.create(_evaluate_memo_function(fn, params)) - _validate_var_return_expr(return_expr, fn.__name__) - - children_param = _get_children_param(params) - rest_param = _get_rest_param(params) - if children_param is None and rest_param is None: - function = ArgsFunctionOperation.create( - args_names=tuple(param.placeholder_name for param in params), - return_expr=return_expr, - ) - else: - function = ArgsFunctionOperation.create( - args_names=( - DestructuredArg( - fields=tuple( - param.placeholder_name for param in params if not param.is_rest - ), - rest=( - rest_param.placeholder_name if rest_param is not None else None - ), - ), - ), - return_expr=return_expr, - ) - - return ExperimentalMemoFunctionDefinition( - fn=fn, - python_name=fn.__name__, - params=params, - function=function, - imported_var=_imported_function_var( - fn.__name__, _annotation_inner_type(return_annotation) - ), - ) - - -def _create_component_definition( - fn: Callable[..., Any], - return_annotation: Any, -) -> ExperimentalMemoComponentDefinition: - """Create a definition for a component-returning memo. - - Args: - fn: The function to analyze. - return_annotation: The return annotation. - - Returns: - The component memo definition. - - Raises: - TypeError: If the function does not return a component. - """ - params = _analyze_params(fn, for_component=True) - component = _normalize_component_return(_evaluate_memo_function(fn, params)) - if component is None: - msg = ( - f"Component-returning `@rx._x.memo` `{fn.__name__}` must return an " - "`rx.Component` or `rx.Var[rx.Component]`." - ) - raise TypeError(msg) - - return ExperimentalMemoComponentDefinition( - fn=fn, - python_name=fn.__name__, - params=params, - export_name=format.to_title_case(fn.__name__), - component=_lift_rest_props(component), - ) - - -def _bind_function_runtime_args( - definition: ExperimentalMemoFunctionDefinition, - *args: Any, - **kwargs: Any, -) -> tuple[Any, ...]: - """Bind runtime args for a var-returning memo. - - Args: - definition: The function memo definition. - *args: Positional arguments. - **kwargs: Keyword arguments. - - Returns: - The ordered arguments for the imported FunctionVar. - - Raises: - TypeError: If the provided arguments are invalid. - """ - children_param = _get_children_param(definition.params) - rest_param = _get_rest_param(definition.params) - - # Validate positional children usage and reserved keywords. - if "children" in kwargs: - msg = f"`{definition.python_name}` only accepts children positionally." - raise TypeError(msg) - - if rest_param is not None and rest_param.name in kwargs: - msg = ( - f"`{definition.python_name}` captures rest props from extra keyword " - f"arguments. Do not pass `{rest_param.name}=` directly." - ) - raise TypeError(msg) - - if args and children_param is None: - msg = f"`{definition.python_name}` only accepts keyword props." - raise TypeError(msg) - - if any(not _is_component_child(child) for child in args): - msg = ( - f"`{definition.python_name}` only accepts positional children that are " - "`rx.Component` or `rx.Var[rx.Component]`." - ) - raise TypeError(msg) - - # Bind declared props before collecting any rest props. - explicit_params = [ - param - for param in definition.params - if not param.is_rest and not param.is_children - ] - explicit_values = {} - remaining_props = kwargs.copy() - for param in explicit_params: - if param.name in remaining_props: - explicit_values[param.name] = remaining_props.pop(param.name) - elif param.default is not inspect.Parameter.empty: - explicit_values[param.name] = param.default - else: - msg = f"`{definition.python_name}` is missing required prop `{param.name}`." - raise TypeError(msg) - - # Reject unknown props unless a rest prop is declared. - if remaining_props and rest_param is None: - unexpected_prop = next(iter(remaining_props)) - msg = ( - f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " - "Only declared props may be passed when no `rx.RestProp` is present." - ) - raise TypeError(msg) - - # Return ordered explicit args when no packed props object is needed. - if children_param is None and rest_param is None: - return tuple(explicit_values[param.name] for param in explicit_params) - - # Build the props object passed to the imported FunctionVar. - children_value: Any | None = None - if children_param is not None: - children_value = args[0] if len(args) == 1 else Fragment.create(*args) - - # Convert rest-prop keys to camelCase to match component memo behavior. - camel_cased_remaining_props = { - format.to_camel_case(key): value for key, value in remaining_props.items() - } - - bound_props = {} - if children_param is not None: - bound_props[children_param.name] = children_value - bound_props.update(explicit_values) - bound_props.update(camel_cased_remaining_props) - return (bound_props,) - - -def _is_component_child(value: Any) -> bool: - """Check whether a value is valid as an experimental memo child. - - Args: - value: The value to check. - - Returns: - Whether the value is a component child. - """ - return isinstance(value, Component) or ( - isinstance(value, Var) - and type_utils.typehint_issubclass(value._var_type, Component) - ) - - -class _ExperimentalMemoFunctionWrapper: - """Callable wrapper for a var-returning experimental memo.""" - - def __init__(self, definition: ExperimentalMemoFunctionDefinition): - """Initialize the wrapper. - - Args: - definition: The function memo definition. - """ - self._definition = definition - self._imported_var = definition.imported_var - update_wrapper(self, definition.fn) - - def __call__(self, *args: Any, **kwargs: Any) -> Var: - """Call the wrapped memo and return a var. - - Args: - *args: Positional children, if supported. - **kwargs: Explicit props and rest props. - - Returns: - The function call var. - """ - return self.call(*args, **kwargs) - - def call(self, *args: Any, **kwargs: Any) -> Var: - """Call the imported memo function. - - Args: - *args: Positional children, if supported. - **kwargs: Explicit props and rest props. - - Returns: - The function call var. - """ - return self._imported_var.call( - *_bind_function_runtime_args(self._definition, *args, **kwargs) - ) - - def partial(self, *args: Any, **kwargs: Any) -> FunctionVar: - """Partially apply the imported memo function. - - Args: - *args: Positional children, if supported. - **kwargs: Explicit props and rest props. - - Returns: - The partially applied function var. - """ - return self._imported_var.partial( - *_bind_function_runtime_args(self._definition, *args, **kwargs) - ) - - def _as_var(self) -> FunctionVar: - """Expose the imported function var. - - Returns: - The imported function var. - """ - return self._imported_var - - -class _ExperimentalMemoComponentWrapper: - """Callable wrapper for a component-returning experimental memo.""" - - def __init__(self, definition: ExperimentalMemoComponentDefinition): - """Initialize the wrapper. - - Args: - definition: The component memo definition. - """ - self._definition = definition - self._children_param = _get_children_param(definition.params) - self._rest_param = _get_rest_param(definition.params) - self._explicit_params = [ - param - for param in definition.params - if not param.is_children and not param.is_rest - ] - update_wrapper(self, definition.fn) - - def __call__(self, *children: Any, **props: Any) -> ExperimentalMemoComponent: - """Call the wrapped memo and return a component. - - Args: - *children: Positional children passed to the memo. - **props: Explicit props and rest props. - - Returns: - The rendered memo component. - """ - definition = self._definition - rest_param = self._rest_param - - # Validate positional children usage and reserved keywords. - if "children" in props: - msg = f"`{definition.python_name}` only accepts children positionally." - raise TypeError(msg) - if rest_param is not None and rest_param.name in props: - msg = ( - f"`{definition.python_name}` captures rest props from extra keyword " - f"arguments. Do not pass `{rest_param.name}=` directly." - ) - raise TypeError(msg) - if children and self._children_param is None: - msg = f"`{definition.python_name}` only accepts keyword props." - raise TypeError(msg) - if any(not _is_component_child(child) for child in children): - msg = ( - f"`{definition.python_name}` only accepts positional children that are " - "`rx.Component` or `rx.Var[rx.Component]`." - ) - raise TypeError(msg) - - # Bind declared props before collecting any rest props. - explicit_values = {} - remaining_props = props.copy() - for param in self._explicit_params: - if param.name in remaining_props: - explicit_values[param.name] = remaining_props.pop(param.name) - elif param.default is not inspect.Parameter.empty: - explicit_values[param.name] = param.default - else: - msg = f"`{definition.python_name}` is missing required prop `{param.name}`." - raise TypeError(msg) - - # Reject unknown props unless a rest prop is declared. - if remaining_props and rest_param is None: - unexpected_prop = next(iter(remaining_props)) - msg = ( - f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " - "Only declared props may be passed when no `rx.RestProp` is present." - ) - raise TypeError(msg) - - # Build the component props passed into the memo wrapper. - return _get_experimental_memo_component_class( - definition.export_name, type(definition.component) - )._create( - children=list(children), - memo_definition=definition, - **explicit_values, - **remaining_props, - ) - - def _as_var(self) -> Var: - """Expose the imported component var. - - Returns: - The imported component var. - """ - return _component_import_var(self._definition.export_name) - - -def _create_function_wrapper( - definition: ExperimentalMemoFunctionDefinition, -) -> _ExperimentalMemoFunctionWrapper: - """Create the Python wrapper for a var-returning memo. - - Args: - definition: The function memo definition. - - Returns: - The wrapper callable. - """ - return _ExperimentalMemoFunctionWrapper(definition) - - -def _create_component_wrapper( - definition: ExperimentalMemoComponentDefinition, -) -> _ExperimentalMemoComponentWrapper: - """Create the Python wrapper for a component-returning memo. - - Args: - definition: The component memo definition. - - Returns: - The wrapper callable. - """ - return _ExperimentalMemoComponentWrapper(definition) - - -def create_passthrough_component_memo( - component: Component, -) -> tuple[ - Callable[..., ExperimentalMemoComponent], - ExperimentalMemoComponentDefinition, -]: - """Create an unregistered ``@rx._x.memo``-style passthrough component memo. - - This is used by compiler auto-memoization so generated wrappers compile - through the experimental memo pipeline instead of emitting ad-hoc page-local - ``React.memo`` declarations. - - The exported memo name is derived from ``component._compute_memo_tag()`` - after the ``{children}`` hole has been substituted into the wrapped - component's children (passthrough mode), so two call-sites differing only - in their children — whose generated memo bodies are identical — collapse - to one wrapper. - - Args: - component: The component to wrap. - - Returns: - The callable memo wrapper and its component definition. - """ - # Snapshot-boundary components (see ``is_snapshot_boundary``) own their - # subtree — the ``.children`` slot is internal machinery from the - # subclass's ``.create`` (e.g. the dropzone Div built inside - # ``Upload.create``), not a user content hole. The memoize plugin wraps - # the boundary with no structural children on the page side, so the memo - # body renders the full snapshot rather than a ``{children}``-holed - # template. - render_snapshot = ( - get_memoization_strategy(component) is MemoizationStrategy.SNAPSHOT - ) - - captured_hole_child: list[Component] = [] - - def passthrough(children: Var[Component]) -> Component: - new_component = copy(component) - if render_snapshot: - return new_component - # Components with no original structural children own their own JSX - # output (e.g. ``CodeBlock`` injects ``code`` as the ``children`` prop - # in ``_render``). Substituting a ``{children}`` hole here would emit - # ``jsx(Inner, {children: "..."}, hole)``, and an undefined hole at - # call time clobbers the prop. Skip the substitution so the wrapper's - # ``children`` parameter is present in the signature but unused. - if not component.children: - return new_component - hole_bare = Bare.create(children) - captured_hole_child.append(hole_bare) - # Substitute the ``{children}`` hole for the original descendants so - # the memo body's hash and JSX both reflect the placeholder, not the - # specific children at any given call site. Original descendants stay - # reachable on the page-level wrapper via the plugin's - # ``_get_all_refs`` delegation back to the source component. - new_component.children = [hole_bare] - # Compile-time walkers that need the real subtree (notably - # ``Form._get_form_refs`` collecting id-based input refs into the - # generated ``handleSubmit`` JS) call ``self._get_all_refs()`` while - # the memo body's hooks are computed. With the hole substituted in, - # that walk would return nothing and the form handler would emit an - # empty ``field_ref_mapping``. Delegate ref collection back to the - # source component so descendants behind the hole remain visible. - object.__setattr__(new_component, "_get_all_refs", component._get_all_refs) - return new_component - - # Evaluate once to compute the tag from the rendered memo body shape. - # ``_create_component_definition`` will evaluate again internally; the - # second pass overwrites ``captured_hole_child`` but the captured value - # is identical. - params = _analyze_params(passthrough, for_component=True) - preview = _normalize_component_return(_evaluate_memo_function(passthrough, params)) - if preview is None: - msg = ( - "`create_passthrough_component_memo` requires a component that " - "normalizes to `rx.Component`." - ) - raise TypeError(msg) - tag = preview._compute_memo_tag() - - passthrough.__name__ = format.to_snake_case(tag) - passthrough.__qualname__ = passthrough.__name__ - passthrough.__module__ = __name__ - - definition = _create_component_definition(passthrough, Component) - replacements: dict[str, Any] = {} - if definition.export_name != tag: - replacements["export_name"] = tag - if captured_hole_child: - replacements["passthrough_hole_child"] = captured_hole_child[0] - if replacements: - definition = dataclasses.replace(definition, **replacements) - - return _create_component_wrapper(definition), definition - - -def memo(fn: Callable[..., Any]) -> Callable[..., Any]: - """Create an experimental memo from a function. - - Args: - fn: The function to memoize. - - Returns: - The wrapped function or component factory. - - Raises: - TypeError: If the return type is not supported. - """ - hints = get_type_hints(fn) - return_annotation = hints.get("return", inspect.Signature.empty) - if return_annotation is inspect.Signature.empty: - msg = ( - f"`@rx._x.memo` requires a return annotation on `{fn.__name__}`. " - "Use `-> rx.Component` or `-> rx.Var[...]`." - ) - raise TypeError(msg) - - if _is_component_annotation(return_annotation): - definition = _create_component_definition(fn, return_annotation) - _register_memo_definition(definition) - return _create_component_wrapper(definition) - - if _is_var_annotation(return_annotation): - definition = _create_function_definition(fn, return_annotation) - _register_memo_definition(definition) - return _create_function_wrapper(definition) - - msg = ( - f"`@rx._x.memo` on `{fn.__name__}` must return `rx.Component` or `rx.Var[...]`, " - f"got `{return_annotation}`." - ) - raise TypeError(msg) - - -__all__ = [ - "EXPERIMENTAL_MEMOS", - "ExperimentalMemoComponent", - "ExperimentalMemoComponentDefinition", - "ExperimentalMemoDefinition", - "ExperimentalMemoFunctionDefinition", - "create_passthrough_component_memo", - "memo", -] +sys.modules[__name__] = memo # pyright: ignore[reportArgumentType] diff --git a/reflex/testing.py b/reflex/testing.py index 3e23afdaae8..2838c7ef746 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn -from reflex_base.components.component import CUSTOM_COMPONENTS, CustomComponent +from reflex_base.components.memo import MEMOS from reflex_base.config import get_config from reflex_base.environment import environment from reflex_base.registry import RegistrationContext @@ -40,7 +40,6 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.shared import SharedState as SharedState # To register it. from reflex.state import reload_state_module from reflex.utils import console, js_runtimes @@ -239,11 +238,9 @@ def _get_source_from_app_source(self, app_source: Any) -> str: def _initialize_app(self): # disable telemetry reporting for tests os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" - # Reset global memo registries so previous AppHarness apps do not + # Reset the global memo registry so previous AppHarness apps do not # leak compiled component definitions into the next test app. - CUSTOM_COMPONENTS.clear() - EXPERIMENTAL_MEMOS.clear() - CustomComponent.create().get_component.cache_clear() + MEMOS.clear() self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: app_globals = self._get_globals_from_signature(self.app_source) diff --git a/reflex/utils/telemetry_accounting.py b/reflex/utils/telemetry_accounting.py index bda3f068b10..7dbffb58107 100644 --- a/reflex/utils/telemetry_accounting.py +++ b/reflex/utils/telemetry_accounting.py @@ -113,7 +113,7 @@ def _count_components(pages: Iterable[BaseComponent]) -> dict[str, int]: """Count component types across one or more component trees. Auto-memoized components live in the tree as dynamic - ``ExperimentalMemoComponent___`` subclasses. Bucketing by + ``MemoComponent___`` subclasses. Bucketing by the raw class name would explode telemetry cardinality (each handler hash produces a new key), so wrappers are counted under the user-authored component they stand in for, exposed via ``_wrapped_component_type``. diff --git a/tests/integration/test_experimental_memo.py b/tests/integration/test_experimental_memo.py deleted file mode 100644 index 935f8c75ce6..00000000000 --- a/tests/integration/test_experimental_memo.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Integration tests for rx._x.memo.""" - -from collections.abc import Generator - -import pytest -from selenium.webdriver.common.by import By - -from reflex.testing import AppHarness - - -def ExperimentalMemoApp(): - """Reflex app that exercises experimental memo functions and components.""" - import reflex as rx - - class FooComponent(rx.Fragment): - def add_custom_code(self) -> list[str]: - return [ - "const foo = 'bar'", - ] - - @rx._x.memo - def foo_component(label: rx.Var[str]) -> rx.Component: - return FooComponent.create(label, rx.Var("foo")) - - @rx._x.memo - def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: - return currency.to(str) + ": $" + amount.to(str) - - @rx._x.memo - def summary_card( - children: rx.Var[rx.Component], - rest: rx.RestProp, - *, - title: rx.Var[str], - value: rx.Var[str], - ) -> rx.Component: - return rx.box( - rx.heading(title, id="summary-title"), - rx.text(value, id="summary-value"), - children, - rest, - ) - - class ExperimentalMemoState(rx.State): - amount: int = 125 - currency: str = "USD" - title: str = "Current Price" - - @rx.event - def increment_amount(self): - self.amount += 5 - - def index() -> rx.Component: - formatted_price = format_price( - amount=ExperimentalMemoState.amount, - currency=ExperimentalMemoState.currency, - ) - return rx.vstack( - rx.vstack( - foo_component(label="foo"), - foo_component(label="bar"), - id="experimental-memo-custom-code", - ), - rx.text(formatted_price, id="formatted-price"), - rx.button( - "Increment", - id="increment-price", - on_click=ExperimentalMemoState.increment_amount, - ), - summary_card( - rx.text("Children are passed positionally.", id="summary-child"), - title=ExperimentalMemoState.title, - value=formatted_price, - id="summary-card", - class_name="forwarded-summary-card", - ), - ) - - app = rx.App() - app.add_page(index) - - -@pytest.fixture -def experimental_memo_app(tmp_path) -> Generator[AppHarness, None, None]: - """Start ExperimentalMemoApp app at tmp_path via AppHarness. - - Args: - tmp_path: pytest tmp_path fixture. - - Yields: - Running AppHarness instance. - """ - with AppHarness.create( - root=tmp_path, - app_source=ExperimentalMemoApp, - ) as harness: - yield harness - - -def test_experimental_memo_app(experimental_memo_app: AppHarness): - """Render experimental memos and assert on their frontend behavior. - - Args: - experimental_memo_app: Harness for ExperimentalMemoApp. - """ - assert experimental_memo_app.app_instance is not None, "app is not running" - driver = experimental_memo_app.frontend() - - memo_custom_code_stack = AppHarness.poll_for_or_raise_timeout( - lambda: driver.find_element(By.ID, "experimental-memo-custom-code") - ) - assert ( - experimental_memo_app.poll_for_content(memo_custom_code_stack, exp_not_equal="") - == "foobarbarbar" - ) - assert memo_custom_code_stack.text == "foobarbarbar" - - formatted_price = driver.find_element(By.ID, "formatted-price") - assert ( - experimental_memo_app.poll_for_content(formatted_price, exp_not_equal="") - == "USD: $125" - ) - - summary_card = driver.find_element(By.ID, "summary-card") - assert "forwarded-summary-card" in (summary_card.get_attribute("class") or "") - assert driver.find_element(By.ID, "summary-title").text == "Current Price" - assert ( - driver.find_element(By.ID, "summary-child").text - == "Children are passed positionally." - ) - - summary_value = driver.find_element(By.ID, "summary-value") - assert ( - experimental_memo_app.poll_for_content(summary_value, exp_not_equal="") - == "USD: $125" - ) - - driver.find_element(By.ID, "increment-price").click() - assert experimental_memo_app.poll_for_content(formatted_price) == "USD: $130" - assert experimental_memo_app.poll_for_content(summary_value) == "USD: $130" diff --git a/tests/integration/test_memo.py b/tests/integration/test_memo.py index 07442fab6ce..4a8524d6e2f 100644 --- a/tests/integration/test_memo.py +++ b/tests/integration/test_memo.py @@ -1,4 +1,4 @@ -"""Integration tests for rx.memo components.""" +"""Integration tests for the ``rx._x.memo`` deprecation shim.""" from collections.abc import Generator @@ -9,7 +9,7 @@ def MemoApp(): - """Reflex app with memo components.""" + """Reflex app that exercises memo functions and components via ``rx._x.memo``.""" import reflex as rx class FooComponent(rx.Fragment): @@ -18,39 +18,61 @@ def add_custom_code(self) -> list[str]: "const foo = 'bar'", ] - @rx.memo - def foo_component(t: str): - return FooComponent.create(t, rx.Var("foo")) - - @rx.memo - def foo_component2(t: str): - return FooComponent.create(t, rx.Var("foo")) + @rx._x.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx._x.memo + def summary_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + value: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title, id="summary-title"), + rx.text(value, id="summary-value"), + children, + rest, + ) class MemoState(rx.State): - last_value: str = "" + amount: int = 125 + currency: str = "USD" + title: str = "Current Price" @rx.event - def set_last_value(self, value: str): - self.last_value = value - - @rx.memo - def my_memoed_component( - some_value: str, - event: rx.EventHandler[rx.event.passthrough_event_spec(str)], - ) -> rx.Component: - return rx.vstack( - rx.button(some_value, id="memo-button", on_click=event(some_value)), - rx.input(id="memo-input", on_change=event), - ) + def increment_amount(self): + self.amount += 5 def index() -> rx.Component: + formatted_price = format_price( + amount=MemoState.amount, + currency=MemoState.currency, + ) return rx.vstack( rx.vstack( - foo_component(t="foo"), foo_component2(t="bar"), id="memo-custom-code" + foo_component(label="foo"), + foo_component(label="bar"), + id="experimental-memo-custom-code", + ), + rx.text(formatted_price, id="formatted-price"), + rx.button( + "Increment", + id="increment-price", + on_click=MemoState.increment_amount, ), - rx.text(MemoState.last_value, id="memo-last-value"), - my_memoed_component( - some_value="memod_some_value", event=MemoState.set_last_value + summary_card( + rx.text("Children are passed positionally.", id="summary-child"), + title=MemoState.title, + value=formatted_price, + id="summary-card", + class_name="forwarded-summary-card", ), ) @@ -63,10 +85,10 @@ def memo_app(tmp_path) -> Generator[AppHarness, None, None]: """Start MemoApp app at tmp_path via AppHarness. Args: - tmp_path: pytest tmp_path fixture + tmp_path: pytest tmp_path fixture. Yields: - running AppHarness instance + Running AppHarness instance. """ with AppHarness.create( root=tmp_path, @@ -76,17 +98,16 @@ def memo_app(tmp_path) -> Generator[AppHarness, None, None]: def test_memo_app(memo_app: AppHarness): - """Render various memo'd components and assert on the output. + """Render experimental memos and assert on their frontend behavior. Args: - memo_app: harness for MemoApp app + memo_app: Harness for MemoApp. """ assert memo_app.app_instance is not None, "app is not running" driver = memo_app.frontend() - # check that the output matches memo_custom_code_stack = AppHarness.poll_for_or_raise_timeout( - lambda: driver.find_element(By.ID, "memo-custom-code") + lambda: driver.find_element(By.ID, "experimental-memo-custom-code") ) assert ( memo_app.poll_for_content(memo_custom_code_stack, exp_not_equal="") @@ -94,13 +115,20 @@ def test_memo_app(memo_app: AppHarness): ) assert memo_custom_code_stack.text == "foobarbarbar" - # click the button to trigger partial event application - button = driver.find_element(By.ID, "memo-button") - button.click() - last_value = driver.find_element(By.ID, "memo-last-value") - assert memo_app.poll_for_content(last_value, exp_not_equal="") == "memod_some_value" + formatted_price = driver.find_element(By.ID, "formatted-price") + assert memo_app.poll_for_content(formatted_price, exp_not_equal="") == "USD: $125" + + summary_card = driver.find_element(By.ID, "summary-card") + assert "forwarded-summary-card" in (summary_card.get_attribute("class") or "") + assert driver.find_element(By.ID, "summary-title").text == "Current Price" + assert ( + driver.find_element(By.ID, "summary-child").text + == "Children are passed positionally." + ) + + summary_value = driver.find_element(By.ID, "summary-value") + assert memo_app.poll_for_content(summary_value, exp_not_equal="") == "USD: $125" - # enter text to trigger passed argument to event handler - textbox = driver.find_element(By.ID, "memo-input") - textbox.send_keys("new_value") - AppHarness.expect(lambda: memo_app.poll_for_content(last_value) == "new_value") + driver.find_element(By.ID, "increment-price").click() + assert memo_app.poll_for_content(formatted_price) == "USD: $130" + assert memo_app.poll_for_content(summary_value) == "USD: $130" diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index 409a0838b2e..5b5949188ef 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -59,12 +59,14 @@ class VarOperationState(rx.State): app = rx.App() @rx.memo - def memo_comp(list1: list[int], int_var1: int, id: str): + def memo_comp( + list1: rx.Var[list[int]], int_var1: rx.Var[int], id: rx.Var[str] + ) -> rx.Component: return rx.text(list1, int_var1, id=id) @rx.memo - def memo_comp_nested(int_var2: int, id: str): - return memo_comp(list1=[3, 4], int_var1=int_var2, id=id) + def memo_comp_nested(int_var2: rx.Var[int], id: rx.Var[str]) -> rx.Component: + return memo_comp(list1=rx.Var.create([3, 4]), int_var1=int_var2, id=id) @app.add_page def index(): @@ -634,13 +636,13 @@ def index(): id="foreach_list_arg2", ), memo_comp( - list1=VarOperationState.list1, + list1=VarOperationState.list1.to(list[int]), int_var1=VarOperationState.int_var1, - id="memo_comp", + id=rx.Var.create("memo_comp"), ), memo_comp_nested( int_var2=VarOperationState.int_var2, - id="memo_comp_nested", + id=rx.Var.create("memo_comp_nested"), ), # length rx.box( diff --git a/tests/integration/tests_playwright/test_memo.py b/tests/integration/tests_playwright/test_memo.py new file mode 100644 index 00000000000..67f06fa4243 --- /dev/null +++ b/tests/integration/tests_playwright/test_memo.py @@ -0,0 +1,169 @@ +"""Integration tests for ``rx.memo`` runtime behavior. + +Covers behaviors previously exercised by the deleted +``tests/integration/test_memo.py`` (Selenium): partial-application of an +``EventHandler`` prop (``event(some_value)``) and raw pass-through to an +inner event trigger (``on_change=event``). Also covers recursion through a +self-referencing component memo rendering a tree via ``rx.foreach``. +""" + +from collections.abc import Generator + +import pytest +from playwright.sync_api import Page, expect + +from reflex.testing import AppHarness + + +def MemoApp(): + """App exercising ``rx.memo`` with ``EventHandler`` props and recursion.""" + from collections.abc import Sequence + from typing import TypedDict + + import reflex as rx + + class TreeNode(TypedDict): + name: str + children: Sequence["TreeNode"] + + class MemoState(rx.State): + last_value: str = "" + tree: TreeNode = TreeNode( + name="root", + children=[ + TreeNode(name="child1", children=[]), + TreeNode( + name="child2", + children=[TreeNode(name="grandchild1", children=[])], + ), + ], + ) + + @rx.event + def set_last_value(self, value: str): + self.last_value = value + + @rx.event + def replace_tree(self): + self.tree = TreeNode( + name="root2", + children=[TreeNode(name="only-child", children=[])], + ) + + @rx.memo + def my_memoed_component( + some_value: rx.Var[str], + event: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Component: + return rx.vstack( + rx.button(some_value, id="memo-button", on_click=event(some_value)), + rx.input(id="memo-input", on_change=event), + ) + + @rx.memo + def tree_node(data: rx.vars.ObjectVar[TreeNode]) -> rx.Component: + return rx.vstack( + rx.text(data.name, class_name="tree-node-name"), + rx.foreach(data.children, lambda child: tree_node(data=child)), + class_name="pl-4 border-l", + ) + + def index() -> rx.Component: + return rx.vstack( + rx.text(MemoState.last_value, id="memo-last-value"), + my_memoed_component( + some_value="memod_some_value", event=MemoState.set_last_value + ), + rx.button( + "replace-tree", id="replace-tree", on_click=MemoState.replace_tree + ), + rx.box(tree_node(data=MemoState.tree), id="tree-root"), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture(scope="module") +def memo_app( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[AppHarness, None, None]: + """Run the memo app under an AppHarness. + + Args: + tmp_path_factory: Pytest fixture for creating temporary directories. + + Yields: + The running harness. + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("memo_app"), + app_source=MemoApp, + ) as harness: + yield harness + + +def test_memo_event_handler_partial_application( + memo_app: AppHarness, page: Page +) -> None: + """Clicking a button whose ``on_click`` is ``event(some_value)`` dispatches it. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + + expect(page.locator("#memo-last-value")).to_have_text("") + page.click("#memo-button") + expect(page.locator("#memo-last-value")).to_have_text("memod_some_value") + + +def test_memo_event_handler_raw_pass_through(memo_app: AppHarness, page: Page) -> None: + """Typing into an input whose ``on_change`` is the raw handler dispatches it. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + + page.locator("#memo-input").fill("typed_value") + expect(page.locator("#memo-last-value")).to_have_text("typed_value") + + +def test_memo_recursive_tree_render(memo_app: AppHarness, page: Page) -> None: + """A self-referencing component memo renders nested children via ``rx.foreach``. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + + tree_root = page.locator("#tree-root") + node_names = tree_root.locator(".tree-node-name") + expect(node_names).to_have_count(4) + expect(node_names).to_have_text(["root", "child1", "child2", "grandchild1"]) + + +def test_memo_recursive_tree_reacts_to_state(memo_app: AppHarness, page: Page) -> None: + """Replacing the tree in state re-renders the recursive memo with new data. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + + node_names = page.locator("#tree-root .tree-node-name") + expect(node_names).to_have_count(4) + + page.click("#replace-tree") + + expect(node_names).to_have_count(2) + expect(node_names).to_have_text(["root2", "only-child"]) diff --git a/tests/units/compiler/test_memoize_plugin.py b/tests/units/compiler/test_memoize_plugin.py index 755ef8914e7..4061d0470f1 100644 --- a/tests/units/compiler/test_memoize_plugin.py +++ b/tests/units/compiler/test_memoize_plugin.py @@ -10,6 +10,11 @@ import pytest from reflex_base.components.component import Component from reflex_base.components.component import field as component_field +from reflex_base.components.memo import ( + MemoComponent, + MemoComponentDefinition, + create_passthrough_component_memo, +) from reflex_base.components.memoize_helpers import ( MemoizationStrategy, get_memoization_strategy, @@ -45,11 +50,6 @@ import reflex.compiler.plugins.memoize as memoize_plugin from reflex.compiler.plugins import DefaultCollectorPlugin, default_page_plugins from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin, _should_memoize -from reflex.experimental.memo import ( - ExperimentalMemoComponent, - ExperimentalMemoComponentDefinition, - create_passthrough_component_memo, -) from reflex.state import BaseState STATE_VAR = LiteralVar.create("value")._replace( @@ -198,8 +198,8 @@ def test_should_not_memoize_when_disposition_never() -> None: assert not _should_memoize(comp) -def test_memoize_wrapper_uses_experimental_memo_component_and_call_site() -> None: - """Memoizable component imports a generated ``rx._x.memo`` wrapper.""" +def test_memoize_wrapper_uses_memo_component_and_call_site() -> None: + """Memoizable component imports a generated ``rx.memo`` wrapper.""" ctx, page_ctx = _compile_single_page(lambda: Plain.create(STATE_VAR)) assert len(ctx.memoize_wrappers) == 1 @@ -327,8 +327,7 @@ def special_child() -> Component: ctx, page_ctx = _compile_single_page(lambda: rx.box(special_child())) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) memo_code = "\n".join(code for _, code in memo_files) @@ -369,8 +368,7 @@ def accordion() -> Component: ) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) foreach_code = next( code for path, code in memo_files if "/Foreach" in path or "\\Foreach" in path @@ -403,7 +401,7 @@ def test_foreach_parent_does_not_absorb_sibling_into_snapshot() -> None: wrapped_definitions = [ definition for definition in ctx.auto_memo_components.values() - if isinstance(definition, ExperimentalMemoComponentDefinition) + if isinstance(definition, MemoComponentDefinition) ] wrapped_types = {type(definition.component) for definition in wrapped_definitions} @@ -495,7 +493,7 @@ def test_generated_memo_component_is_not_itself_memoized() -> None: """The generated memo component instance itself is skipped by the heuristic.""" wrapper_factory, _definition = create_passthrough_component_memo(Fragment.create()) wrapper = wrapper_factory(Plain.create()) - assert isinstance(wrapper, ExperimentalMemoComponent) + assert isinstance(wrapper, MemoComponent) assert not _should_memoize(wrapper) @@ -543,7 +541,7 @@ def test_generated_memo_component_renders_as_its_exported_tag() -> None: """The generated experimental memo component renders as its exported tag.""" wrapper_factory, definition = create_passthrough_component_memo(Fragment.create()) wrapper = wrapper_factory(Plain.create()) - assert isinstance(wrapper, ExperimentalMemoComponent) + assert isinstance(wrapper, MemoComponent) tag = definition.export_name assert tag.startswith("Fragment_"), ( f"Expected the wrapped class qualname to be encoded in the tag prefix; " @@ -775,8 +773,7 @@ def create(cls, *children, **props): "expected an auto-memo wrapper to be generated for the leaf" ) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) memo_code = "\n".join(code for _, code in memo_files) assert "useLeafProbe" in memo_code, ( @@ -1019,8 +1016,7 @@ def page() -> Component: ) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) memo_code = "\n".join(code for _, code in memo_files) @@ -1117,8 +1113,7 @@ def page() -> Component: ) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) match_memo_code = next( code @@ -1230,8 +1225,7 @@ def page() -> Component: wrapper_tag = next(iter(ctx.memoize_wrappers)) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) memo_code = next( code for path, code in memo_files if Path(path).name == f"{wrapper_tag}.jsx" @@ -1314,8 +1308,7 @@ def page() -> Component: ) memo_files, _memo_imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) memo_code = next( code for path, code in memo_files if Path(path).name == f"{wrapper_tag}.jsx" @@ -1663,8 +1656,7 @@ def _compile_memo_module_text(ctx: CompileContext) -> str: from reflex.compiler.compiler import compile_memo_components memo_files, _imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) return "\n".join(code for _, code in memo_files) @@ -2184,8 +2176,7 @@ def test_each_memo_wrapper_emits_one_component_module_file() -> None: ) ) memo_files, _imports = compile_memo_components( - components=(), - experimental_memos=tuple(ctx.auto_memo_components.values()), + memos=tuple(ctx.auto_memo_components.values()), ) component_module_names = { Path(path).name diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py index 8874b126bc9..d51d3666f4d 100644 --- a/tests/units/components/markdown/test_markdown.py +++ b/tests/units/components/markdown/test_markdown.py @@ -1,5 +1,6 @@ import pytest -from reflex_base.components.component import Component, memo +from reflex_base.components.component import Component +from reflex_base.components.memo import memo from reflex_base.plugins import CompileContext, CompilerHooks, PageContext from reflex_base.vars.base import Var from reflex_components_code.code import CodeBlock @@ -42,7 +43,7 @@ def get_fn_body(cls) -> Var: def syntax_highlighter_memoized_component(codeblock: type[Component]): @memo - def code_block(code: str, language: str): + def code_block(code: Var[str], language: Var[str]) -> Component: return Box.create( codeblock.create( code, @@ -222,7 +223,7 @@ def _compile_page_output(root: Component) -> str: hooks.compile_page(page_ctx, compile_context=compile_ctx) _, page_code = compiler.compile_page_from_context(page_ctx) memo_files, _ = compiler.compile_memo_components( - (), compile_ctx.auto_memo_components.values() + compile_ctx.auto_memo_components.values() ) return "\n".join([page_code, *(code for _, code in memo_files)]) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 293132190c2..e22aa1434bc 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -3,13 +3,7 @@ from typing import Any, ClassVar import pytest -from reflex_base.components.component import ( - CUSTOM_COMPONENTS, - Component, - CustomComponent, - custom_component, - field, -) +from reflex_base.components.component import Component, field from reflex_base.constants import EventTriggers from reflex_base.constants.state import FIELD_MARKER from reflex_base.event import ( @@ -46,7 +40,6 @@ _COMPONENTS_BASE_MAPPING, # pyright: ignore[reportAttributeAccessIssue] _COMPONENTS_CORE_MAPPING, # pyright: ignore[reportAttributeAccessIssue] ) -from reflex.compiler.utils import compile_custom_component from reflex.state import BaseState from reflex.utils import imports @@ -272,20 +265,6 @@ def on_click2(): return EventHandler(fn=on_click2) -@pytest.fixture -def my_component(): - """A test component function. - - Returns: - A test component function. - """ - - def my_component(prop1: Var[str], prop2: Var[int]): - return Box.create(prop1, prop2) - - return my_component - - def test_set_style_attrs(component1): """Test that style attributes are set in the dict. @@ -860,52 +839,6 @@ def get_event_triggers(cls) -> dict[str, Any]: C1.create(on_foo=C1State.mock_handler) -def test_create_custom_component(my_component): - """Test that we can create a custom component. - - Args: - my_component: A test custom component. - """ - component = rx.memo(my_component)(prop1="test", prop2=1) - assert component.tag == "MyComponent" - assert set(component.get_props()) == {"prop1", "prop2"} - assert component.tag in CUSTOM_COMPONENTS - - -def test_custom_component_hash(my_component): - """Test that the hash of a custom component is correct. - - Args: - my_component: A test custom component. - """ - component1 = rx.memo(my_component)(prop1="test", prop2=1) - component2 = rx.memo(my_component)(prop1="test", prop2=2) - assert {component1, component2} == {component1} - - -def test_custom_component_wrapper(): - """Test that the wrapper of a custom component is correct.""" - - @custom_component - def my_component(width: Var[int], color: Var[str]): - return rx.box( - width=width, - color=color, - ) - - from reflex_components_radix.themes.typography.text import Text - - ccomponent = my_component( - rx.text("child"), width=LiteralVar.create(1), color=LiteralVar.create("red") - ) - assert isinstance(ccomponent, CustomComponent) - assert len(ccomponent.children) == 1 - assert isinstance(ccomponent.children[0], Text) - - component = ccomponent.get_component() - assert isinstance(component, Box) - - def test_invalid_event_handler_args(component2, test_state: type[TestState]): """Test that an invalid event handler raises an error. @@ -1758,43 +1691,6 @@ class C2(C1): assert 'renamed_prop3:"prop3_2"' in rendered_c2["props"] -def test_custom_component_get_imports(): - class Inner(Component): - tag = "Inner" - library = "inner" - - @rx.memo - def wrapper(): - return Inner.create() - - @rx.memo - def outer(): - return wrapper() - - custom_comp = wrapper() - - # Inner is not imported directly, but it is imported by the custom component. - assert "inner" not in custom_comp._get_all_imports() - assert "outer" not in custom_comp._get_all_imports() - - # The imports are only resolved during compilation. - custom_comp.get_component() - _, imports_inner = compile_custom_component(custom_comp) - assert "inner" in imports_inner - assert "outer" not in imports_inner - - outer_comp = outer() - - # Nested custom components are only imported during compilation. - assert "inner" not in outer_comp._get_all_imports() - - # The imports are only resolved during compilation. - _, imports_outer = compile_custom_component(outer_comp) - assert "inner" not in imports_outer - assert "$/utils/components" in imports_outer - assert imports_outer["$/utils/components"] == [ImportVar(tag="Wrapper")] - - def test_custom_component_declare_event_handlers_in_fields(): class ReferenceComponent(Component): @classmethod diff --git a/tests/units/components/test_memo.py b/tests/units/components/test_memo.py new file mode 100644 index 00000000000..0881e2ad075 --- /dev/null +++ b/tests/units/components/test_memo.py @@ -0,0 +1,1015 @@ +"""Tests for rx.memo support.""" + +from __future__ import annotations + +import inspect +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import patch + +import pytest +from reflex_base.components.component import Component +from reflex_base.components.memo import ( + _SPECS, + MEMOS, + MemoComponent, + MemoComponentDefinition, + MemoFunctionDefinition, + MemoParam, + MemoParamKind, + _analyze_params, + _MemoCallBinding, +) +from reflex_base.event import EventChain, EventHandler, no_args_event_spec +from reflex_base.style import Style +from reflex_base.utils import console +from reflex_base.utils import format as format_utils +from reflex_base.utils.imports import ImportVar +from reflex_base.vars import VarData +from reflex_base.vars.base import Var +from reflex_base.vars.function import FunctionVar + +import reflex as rx +from reflex.compiler import compiler +from reflex.compiler import utils as compiler_utils + + +@pytest.fixture(autouse=True) +def _restore_memo_registries(preserve_memo_registries): + """Autouse wrapper around the shared preserve_memo_registries fixture.""" + + +def test_var_returning_memo(): + """Var-returning memos should behave like imported function vars.""" + + @rx.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + assert ( + str(format_price(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) + assert ( + str(format_price.call(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) + assert isinstance(format_price._as_var(), FunctionVar) + + definition = MEMOS["format_price"] + assert isinstance(definition, MemoFunctionDefinition) + assert ( + str(definition.function) == '((amount, currency) => ((currency+": $")+amount))' + ) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + +def test_component_returning_memo_with_children_and_rest(): + """Component-returning memos should accept positional children and forwarded props.""" + + @rx.memo + def my_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title), + children, + rest, + ) + + component = my_card( + rx.text("child 1"), + rx.text("child 2"), + title="Hello", + foo="extra", + class_name="extra", + ) + component_again = my_card(title="World") + + assert isinstance(component, MemoComponent) + assert len(component.children) == 2 + assert component.get_props() == ("title", "foo") + assert type(component) is type(component_again) + assert type(component).tag == "MyCard" + assert type(component).get_fields()["tag"].default == "MyCard" + + rendered = component.render() + assert rendered["name"] == "MyCard" + assert 'title:"Hello"' in rendered["props"] + assert 'foo:"extra"' in rendered["props"] + assert 'className:"extra"' in rendered["props"] + + definition = MEMOS["MyCard"] + assert isinstance(definition, MemoComponentDefinition) + assert any(str(prop) == "rest" for prop in definition.component.special_props) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert "export const MyCard = memo(({children, title:title" in code + assert "...rest" in code + assert "jsx(RadixThemesBox,{...rest}" in code + + +def test_component_returning_memo_accepts_component_var_result(): + """Component-returning memos should accept component-typed var results.""" + + @rx.memo + def conditional_slot( + show: rx.Var[bool], + first: rx.Var[rx.Component], + second: rx.Var[rx.Component], + ) -> rx.Var[rx.Component]: + return rx.cond(show, first, second) + + definition = MEMOS["ConditionalSlot"] + assert isinstance(definition, MemoComponentDefinition) + assert definition.component.render() == { + "contents": "(showRxMemo ? firstRxMemo : secondRxMemo)" + } + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert "export const ConditionalSlot = memo(({show:showRxMemo" in code + assert "(showRxMemo ? firstRxMemo : secondRxMemo)" in code + + +def test_var_returning_memo_with_rest_props(): + """Var-returning memos should capture extra keyword args into RestProp.""" + + @rx.memo + def merge_styles( + base: rx.Var[dict[str, str]], + overrides: rx.RestProp, + ) -> rx.Var[Any]: + return base.to(dict).merge(overrides) + + base = Var(_js_expr="base", _var_type=dict[str, str]) + merged = merge_styles(base=base, color="red", class_name="primary") + + assert "merge_styles" in str(merged) + assert '["base"] : base' in str(merged) + assert '["color"] : "red"' in str(merged) + assert '["className"] : "primary"' in str(merged) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert ( + "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" + in code + ) + + with pytest.raises(TypeError, match="Do not pass `overrides=` directly"): + merge_styles(base=base, overrides={"color": "red"}) + + +def test_component_returning_memo_with_only_rest(): + """Component-returning memos with only RestProp should emit valid JSX (#6443).""" + + @rx.memo + def hover_trigger(rest: rx.RestProp) -> rx.Component: + return rx.text("hover me", rest) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert "memo(({...rest})" in code + assert "({," not in code + + +def test_var_returning_memo_with_only_rest(): + """Var-returning memos with only RestProp should emit valid JS (#6443).""" + + @rx.memo + def merge_only(overrides: rx.RestProp) -> rx.Var[Any]: + return overrides + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert "(({...overrides}) => overrides)" in code + assert "({," not in code + + +def test_var_returning_memo_with_children_and_rest(): + """Var-returning memos should accept positional children plus keyword props.""" + + @rx.memo + def label_slot( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + label: rx.Var[str], + ) -> rx.Var[str]: + return label + + rendered = label_slot( + rx.text("child"), + label="Hello", + class_name="slot", + ) + + assert "label_slot" in str(rendered) + assert '["children"]' in str(rendered) + assert '["className"] : "slot"' in str(rendered) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + assert "export const label_slot = (({children, label, ...rest}) => label);" in code + + +def test_memo_requires_var_annotations(): + """Memos should reject non-Var annotations on parameters.""" + with pytest.raises(TypeError, match="must be annotated"): + + @rx.memo + def bad_annotation(value: int) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_warns_on_missing_param_annotation(): + """Unannotated parameters should fall back to ``rx.Var[Any]`` with a warning.""" + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def soft_missing(value) -> rx.Component: + return rx.text(value.to(str)) + + mock_deprecate.assert_called_once() + kwargs = mock_deprecate.call_args.kwargs + assert "soft_missing" in kwargs["feature_name"] + assert "`value`" in kwargs["reason"] + + +def test_memo_warns_on_missing_return_annotation(): + """A missing return annotation should default to ``rx.Component`` with a warning.""" + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def soft_return(): + return rx.box() + + mock_deprecate.assert_called_once() + kwargs = mock_deprecate.call_args.kwargs + assert "soft_return" in kwargs["feature_name"] + assert "return annotation" in kwargs["reason"] + + +def test_memo_warning_suggests_inferred_return_type(): + """The warning should surface the inferred public qualname of the body's return.""" + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def fragment_memo(): + return rx.fragment(rx.text("x")) + + reason = mock_deprecate.call_args.kwargs["reason"] + assert "-> rx.Fragment" in reason + + +def test_memo_warns_once_when_return_and_param_both_missing(): + """A function missing both should emit a single combined warning.""" + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def soft_both(value): + return rx.text(value.to(str)) + + mock_deprecate.assert_called_once() + reason = mock_deprecate.call_args.kwargs["reason"] + assert "return annotation" in reason + assert "`value`" in reason + + +def test_memo_defaults_children_to_var_component(): + """An unannotated ``children`` parameter must default to ``Var[Component]``. + + ``Var[Any]`` would fail the children-name validation in ``_analyze_params``; + this guards the name-based special case. + """ + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def soft_children(children) -> rx.Component: + return rx.box(children) + + mock_deprecate.assert_called_once() + + definition = MEMOS["SoftChildren"] + assert isinstance(definition, MemoComponentDefinition) + (children_param,) = definition.params + assert children_param.name == "children" + assert children_param.kind is MemoParamKind.CHILDREN + + +def test_memo_does_not_warn_when_fully_annotated(): + """Fully-annotated memos must not trigger the deprecation fallback.""" + with patch.object(console, "deprecate") as mock_deprecate: + + @rx.memo + def fully_typed(value: rx.Var[str]) -> rx.Component: + return rx.text(value) + + mock_deprecate.assert_not_called() + + +def test_analyze_params_strict_mode_still_raises(): + """Internal callers (``defaulted_params=None``) must keep the strict contract.""" + + def missing_annotation(value) -> rx.Component: + return rx.text("x") + + with pytest.raises(TypeError, match="Missing annotation"): + _analyze_params(missing_annotation, for_component=True) + + +def test_memo_rejects_invalid_children_annotation(): + """Component memos should validate the special children annotation.""" + with pytest.raises(TypeError, match="children"): + + @rx.memo + def bad_children(children: rx.Var[str]) -> rx.Component: + return rx.text(children) + + +def test_memo_rejects_multiple_rest_props(): + """Experimental memos should only allow a single RestProp.""" + with pytest.raises(TypeError, match="only supports one"): + + @rx.memo + def too_many_rest( + first: rx.RestProp, + second: rx.RestProp, + ) -> rx.Var[Any]: + return first + + +def test_memo_rejects_component_and_function_name_collision(): + """Experimental memos should reject same exported name across kinds.""" + + @rx.memo + def foo_bar() -> rx.Component: + return rx.box() + + assert "FooBar" in MEMOS + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx.memo + def FooBar() -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_rejects_component_export_name_collision(): + """Experimental memos should reject duplicate component export names.""" + + @rx.memo + def foo_bar() -> rx.Component: + return rx.box() + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx.memo + def foo__bar() -> rx.Component: + return rx.box() + + +def test_memo_rejects_varargs(): + """Experimental memos should reject *args and **kwargs.""" + with pytest.raises(TypeError, match=r"\*args"): + + @rx.memo + def bad_args(*values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + with pytest.raises(TypeError, match=r"\*\*kwargs"): + + @rx.memo + def bad_kwargs(**values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_component_memo_rejects_invalid_positional_usage(): + """Component memos should only accept positional children.""" + + @rx.memo + def title_card(*, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title)) + + with pytest.raises(TypeError, match="only accepts keyword props"): + title_card(rx.text("child")) + + @rx.memo + def child_card( + children: rx.Var[rx.Component], *, title: rx.Var[str] + ) -> rx.Component: + return rx.box(rx.heading(title), children) + + with pytest.raises(TypeError, match="only accepts positional children"): + child_card("not a component", title="Hello") + + +def test_var_memo_rejects_invalid_positional_usage(): + """Var memos should also reserve positional arguments for children only.""" + + @rx.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + @rx.memo + def child_label( + children: rx.Var[rx.Component], *, label: rx.Var[str] + ) -> rx.Var[str]: + return label + + with pytest.raises(TypeError, match="only accepts positional children"): + child_label("not a component", label="Hello") + + +def test_var_returning_memo_rejects_hooks(): + """Var-returning memos should reject hook-bearing expressions.""" + with pytest.raises(TypeError, match="cannot depend on hooks"): + + @rx.memo + def bad_hook(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(hooks={"const badHook = 1": None}), + ) + + +def test_var_returning_memo_rejects_non_bundled_imports(): + """Var-returning memos should reject non-bundled imports.""" + with pytest.raises(TypeError, match="not bundled"): + + @rx.memo + def bad_import(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(imports={"some-lib": [ImportVar(tag="x")]}), + ) + + +def test_compile_memo_components_includes_functions_and_components(): + """The shared memo output should include both function and component memos.""" + + @rx.memo + def text_wrapper(title: rx.Var[str]) -> rx.Component: + return rx.text(title) + + @rx.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx.memo + def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title), children) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + + assert "export const TextWrapper = memo(" in code + assert "export const format_price =" in code + assert "export const MyCard = memo(" in code + + +def test_compile_memo_components_extends_imports_without_remerging( + monkeypatch: pytest.MonkeyPatch, +): + """Memo import aggregation should not repeatedly reprocess prior imports.""" + + def noop() -> None: + pass + + memos = tuple( + MemoComponentDefinition( + fn=noop, + python_name=f"memo_{idx}", + params=(), + export_name=f"Memo{idx}", + component=rx.fragment(), + passthrough_hole_child=None, + ) + for idx in range(5) + ) + + def fake_compile_experimental_component_memo( + definition: MemoComponentDefinition, + ) -> tuple[dict[str, str], dict[str, list[ImportVar]]]: + return {"name": definition.export_name}, {} + + def fake_compile_single_memo_component( + component_render: dict[str, str], + component_imports: dict[str, list[ImportVar]], + ) -> tuple[str, dict[str, list[ImportVar]]]: + return ( + f"export const {component_render['name']} = null", + {"shared-lib": [ImportVar(tag=component_render["name"])]}, + ) + + real_merge_imports = compiler_utils.merge_imports + + def reject_growing_merge(*imports): + if len(imports) == 2 and imports[0]: + msg = "aggregate imports should be extended, not remerged" + raise AssertionError(msg) + return real_merge_imports(*imports) + + monkeypatch.setattr( + compiler_utils, + "compile_experimental_component_memo", + fake_compile_experimental_component_memo, + ) + monkeypatch.setattr( + compiler, + "_compile_single_memo_component", + fake_compile_single_memo_component, + ) + monkeypatch.setattr(compiler_utils, "merge_imports", reject_growing_merge) + + files, aggregate_imports = compiler.compile_memo_components(memos) + + assert len(files) == len(memos) + assert [import_var.tag for import_var in aggregate_imports["shared-lib"]] == [ + f"Memo{idx}" for idx in range(5) + ] + + +def test_experimental_component_memo_get_imports(): + """Experimental component memos should resolve imports during compilation.""" + + class Inner(Component): + tag = "Inner" + library = "inner" + + @rx.memo + def wrapper() -> rx.Component: + return Inner.create() + + experimental_component = wrapper() + + assert "inner" not in experimental_component._get_all_imports() + + definition = MEMOS["Wrapper"] + assert isinstance(definition, MemoComponentDefinition) + _, imports = compiler_utils.compile_experimental_component_memo(definition) + assert "inner" in imports + + +def test_compile_experimental_component_memo_does_not_mutate_definition( + monkeypatch: pytest.MonkeyPatch, +): + """Experimental component memo compilation should not mutate stored components.""" + + @rx.memo + def wrapper() -> rx.Component: + return rx.box("hi") + + definition = MEMOS["Wrapper"] + assert isinstance(definition, MemoComponentDefinition) + assert definition.component.style == Style() + + monkeypatch.setattr( + "reflex.utils.prerequisites.get_and_validate_app", + lambda: SimpleNamespace( + app=SimpleNamespace( + style={type(definition.component): Style({"color": "red"})} + ) + ), + ) + + render, _ = compiler_utils.compile_experimental_component_memo(definition) + + assert render["render"]["props"] == ['css:({ ["color"] : "red" })'] + assert definition.component.style == Style() + + +def test_component_returning_memo_is_transparent_for_child_validation(): + """Experimental memo wrappers should not break `_valid_parents` checks.""" + + class ValidParent(Component): + tag = "ValidParent" + library = "valid-parent" + + class RestrictedChild(Component): + tag = "RestrictedChild" + library = "restricted-child" + _valid_parents = ["ValidParent"] + + @rx.memo + def transparent(children: rx.Var[rx.Component]) -> rx.Component: + return children # type: ignore[return-value] + + wrapped_child = transparent(RestrictedChild.create()) + parent = ValidParent.create(wrapped_child) + + assert isinstance(wrapped_child, MemoComponent) + assert parent.children == [wrapped_child] + + +def test_compile_memo_components_includes_experimental_custom_code(): + """Experimental component memos should include custom code in compiled output.""" + + class FooComponent(rx.Fragment): + def add_custom_code(self) -> list[str]: + return [ + "const foo = 'bar'", + ] + + @rx.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + + assert "const foo = 'bar'" in code + + +def test_component_memo_accepts_event_handler(): + """Component memos should accept EventHandler params with passthrough specs.""" + + @rx.memo + def eh_memo( + some_value: rx.Var[str], + event: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Component: + return rx.vstack( + rx.button(some_value, on_click=event(some_value)), + rx.input(on_change=event), + ) + + definition = MEMOS["EhMemo"] + assert isinstance(definition, MemoComponentDefinition) + event_param = next(p for p in definition.params if p.name == "event") + assert event_param.kind is MemoParamKind.EVENT_TRIGGER + assert event_param.kind_data is not None + assert event_param.kind_data is not no_args_event_spec + + +def test_component_memo_accepts_bare_event_handler(): + """Component memos should accept bare EventHandler (no spec) params.""" + + @rx.memo + def bare_eh_memo(event: rx.EventHandler) -> rx.Component: + return rx.button("click", on_click=event()) + + definition = MEMOS["BareEhMemo"] + assert isinstance(definition, MemoComponentDefinition) + event_param = next(p for p in definition.params if p.name == "event") + assert event_param.kind is MemoParamKind.EVENT_TRIGGER + assert event_param.kind_data is no_args_event_spec + + +def test_component_memo_event_handler_compiles_to_prop_callback(): + """`event(value)` and `on_change=event` should compile to the destructured JS prop.""" + + @rx.memo + def eh_compile_memo( + some_value: rx.Var[str], + event: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Component: + return rx.vstack( + rx.button(some_value, on_click=event(some_value)), + rx.input(on_change=event), + ) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + code = "\n".join(c for _, c in files) + + # Signature destructures the EH prop with the RxMemo suffix. + assert "event:eventRxMemo" in code + # Partial application: event(some_value) -> eventRxMemo(someValueRxMemo). + assert "eventRxMemo(someValueRxMemo)" in code + # Raw pass-through: on_change=event -> eventRxMemo(...input event arg...). + assert ( + "eventRxMemo(_ev_0)" in code or "eventRxMemo(" in code.split("onChange:", 1)[1] + ) + + +def test_component_memo_event_handler_wires_event_chain_at_call_site(): + """Instantiating an EH memo should wrap the handler in an EventChain trigger.""" + + def _handler_fn(value: str): # pyright: ignore[reportUnusedFunction] + pass + + raw_handler = EventHandler(fn=_handler_fn) + + @rx.memo + def eh_wired_memo( + some_value: rx.Var[str], + event: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Component: + return rx.button(some_value, on_click=event(some_value)) + + component = eh_wired_memo(some_value="hello", event=raw_handler) + assert isinstance(component, MemoComponent) + # EH props live on event_triggers, not in get_props(). + assert "event" not in component.get_props() + assert "event" in component.event_triggers + assert isinstance(component.event_triggers["event"], EventChain) + + +def test_var_returning_memo_rejects_event_handler(): + """Var-returning memos should reject EventHandler params.""" + with pytest.raises(TypeError, match="component-returning"): + + @rx.memo + def bad_var_eh( + event: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_component_memo_rejects_event_handler_with_default(): + """EH params should not allow defaults (matches old CustomComponent behavior).""" + with pytest.raises(TypeError, match="default"): + + @rx.memo + def bad_eh_default( + event: rx.EventHandler[rx.event.passthrough_event_spec(str)] = None, # pyright: ignore[reportArgumentType] + ) -> rx.Component: + return rx.button("hi") + + +def test_component_memo_rejects_event_handler_named_children(): + """A `children` parameter must not be an EventHandler.""" + with pytest.raises(TypeError, match="children"): + + @rx.memo + def bad_eh_children( + children: rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) -> rx.Component: + return rx.box() + + +# --------------------------------------------------------------------------- +# Interface-level tests: target the _MemoParamSpec Seam directly. +# These exercise per-kind behavior without going through the @rx.memo decorator, +# giving a tight feedback loop for adding new kinds in the future. +# --------------------------------------------------------------------------- + + +def _make_param( + *, + name: str = "x", + kind: MemoParamKind, + annotation: Any = None, + kind_data: Any = None, + placeholder_name: str | None = None, + js_prop_name: str | None = None, +) -> MemoParam: + """Build a MemoParam directly, bypassing _analyze_params. + + Returns: + A populated ``MemoParam`` with sensible defaults for tests. + """ + js = js_prop_name if js_prop_name is not None else format_utils.to_camel_case(name) + return MemoParam( + name=name, + kind=kind, + kind_data=kind_data, + annotation=annotation if annotation is not None else rx.Var[int], + parameter_kind=inspect.Parameter.KEYWORD_ONLY, + js_prop_name=js, + placeholder_name=placeholder_name if placeholder_name is not None else name, + ) + + +def test_classify_routes_each_annotation_to_the_expected_kind(): + """Ordered classification routes each supported annotation to one kind.""" + from reflex_base.components.memo import _classify_parameter + + cases = [ + ("var", rx.Var[int], "x", MemoParamKind.VALUE), + ("rest", rx.RestProp, "rest", MemoParamKind.REST), + ( + "event_with_spec", + rx.EventHandler[rx.event.passthrough_event_spec(str)], + "event", + MemoParamKind.EVENT_TRIGGER, + ), + ("bare_event", rx.EventHandler, "event", MemoParamKind.EVENT_TRIGGER), + ("children_var", rx.Var[rx.Component], "children", MemoParamKind.CHILDREN), + # Var[Component] *not* named children classifies as VALUE — that's the + # path conditional_slot/component-typed slots take in the existing suite. + ("named_x_var_component", rx.Var[rx.Component], "x", MemoParamKind.VALUE), + ] + for case_name, annotation, param_name, expected in cases: + kind, _ = _classify_parameter(annotation, param_name, "test_fn") + assert kind is expected, f"{case_name}: got {kind}, expected {expected}" + + +def test_classify_value_excludes_rest_independent_of_order(): + """The VALUE classifier must reject RestProp even called in isolation. + + ``_CLASSIFICATION_ORDER`` puts REST before VALUE, but the classifier itself + is also self-exclusive so a reordering wouldn't silently regress. + """ + assert _SPECS[MemoParamKind.VALUE].classify(rx.RestProp, "x") == (False, None) + assert _SPECS[MemoParamKind.VALUE].classify(rx.Var[int], "x") == (True, None) + + +def test_children_classifier_requires_named_children(): + """CHILDREN is the only name-sensitive kind; verify it gates on the name.""" + spec = _SPECS[MemoParamKind.CHILDREN] + component_var_annotation = rx.Var[rx.Component] + assert spec.classify(component_var_annotation, "children")[0] is True + assert spec.classify(component_var_annotation, "x")[0] is False + + +def test_value_make_placeholder_returns_typed_var(): + """VALUE kind builds a Var placeholder whose _var_type unwraps the annotation.""" + param = _make_param( + kind=MemoParamKind.VALUE, + annotation=rx.Var[int], + placeholder_name="xRxMemo", + ) + placeholder = param.make_placeholder() + assert isinstance(placeholder, Var) + assert placeholder._js_expr == "xRxMemo" + + +def test_event_trigger_make_placeholder_returns_plain_callable(): + """EVENT_TRIGGER kind builds a plain callable, not an EventHandler. + + The body's `event(value)` call site must compile to the destructured JS + prop name, which requires call_event_fn to actually execute the placeholder. + A synthetic EventHandler(fn=_stub) would bake the Python identifier into + the rendered ReflexEvent instead. + """ + spec = rx.event.passthrough_event_spec(str) + param = _make_param( + name="event", + kind=MemoParamKind.EVENT_TRIGGER, + annotation=rx.EventHandler[spec], + kind_data=spec, + placeholder_name="eventRxMemo", + js_prop_name="event", + ) + placeholder = param.make_placeholder() + assert callable(placeholder) + assert not isinstance(placeholder, EventHandler) + + arg = Var(_js_expr="someValueRxMemo", _var_type=str) + rendered = str(placeholder(arg)) + assert "eventRxMemo" in rendered + assert "someValueRxMemo" in rendered + + +def test_bind_value_routes_to_props(): + """VALUE binding pops the kwarg into binding._props (camelCased).""" + binding = _MemoCallBinding({"my_value": 42, "other": "x"}) + param = _make_param(name="my_value", kind=MemoParamKind.VALUE) + param.bind_call_value(binding) + + assert "my_value" not in binding.raw_kwargs + assert "other" in binding.raw_kwargs # untouched + assert binding._props["myValue"]._js_expr == "42" + assert binding._event_triggers == {} + + +def test_bind_event_trigger_routes_to_event_triggers(): + """EVENT_TRIGGER binding wraps the value in an EventChain on event_triggers.""" + + def _handler(value: str): + pass + + handler = EventHandler(fn=_handler) + spec = rx.event.passthrough_event_spec(str) + binding = _MemoCallBinding({"event": handler}) + param = _make_param( + name="event", + kind=MemoParamKind.EVENT_TRIGGER, + kind_data=spec, + ) + + param.bind_call_value(binding) + assert "event" not in binding.raw_kwargs + assert binding._props == {} + assert isinstance(binding._event_triggers["event"], EventChain) + + +def test_bind_children_and_rest_are_noops_at_the_param_level(): + """CHILDREN comes in positionally; REST is swept by binding.take_rest.""" + binding = _MemoCallBinding({"children": object(), "extra": 1}) + children_param = _make_param(name="children", kind=MemoParamKind.CHILDREN) + rest_param = _make_param(name="rest", kind=MemoParamKind.REST) + + children_param.bind_call_value(binding) + rest_param.bind_call_value(binding) + + # Neither method consumed any kwarg. + assert binding.raw_kwargs == { + "children": binding.raw_kwargs["children"], + "extra": 1, + } + assert binding._props == {} + assert binding._event_triggers == {} + + +def test_take_rest_sweeps_unconsumed_keys_into_camel_cased_dict(): + """binding.take_rest collects every leftover kwarg not on the Component.""" + binding = _MemoCallBinding({"foo_bar": "x", "class_name": "y"}) + rest = binding.take_rest(component_fields={}) + assert set(rest) == {"fooBar", "className"} + assert binding.raw_kwargs == {} + + +@pytest.mark.parametrize( + ("kind", "expected"), + [ + (MemoParamKind.VALUE, "amount:amountRxMemo"), + (MemoParamKind.EVENT_TRIGGER, "amount:amountRxMemo"), + (MemoParamKind.CHILDREN, None), + (MemoParamKind.REST, None), + ], +) +def test_signature_field_for_each_kind(kind: MemoParamKind, expected: str | None): + """VALUE/EVENT_TRIGGER destructure; CHILDREN/REST emit out-of-band.""" + param = _make_param( + name="amount", + kind=kind, + js_prop_name="amount", + placeholder_name="amountRxMemo", + ) + assert param.signature_field() == expected + + +def test_event_trigger_validate_rejects_default_directly(): + """The validate hook on _SPECS[EVENT_TRIGGER] rejects defaults without + going through the decorator. This pins per-kind invariants at the Seam. + """ + parameter = inspect.Parameter( + name="event", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=rx.EventHandler[rx.event.passthrough_event_spec(str)], + ) + with pytest.raises(TypeError, match="default"): + _SPECS[MemoParamKind.EVENT_TRIGGER].validate(parameter, "fn", True) + + +def test_event_trigger_validate_rejects_in_var_returning_memo(): + """EVENT_TRIGGER is only valid on component-returning memos.""" + parameter = inspect.Parameter( + name="event", + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=rx.EventHandler, + ) + with pytest.raises(TypeError, match="component-returning"): + _SPECS[MemoParamKind.EVENT_TRIGGER].validate(parameter, "fn", False) + + +def test_self_referencing_component_memo(): + """Component memos whose body recursively calls themselves should decorate.""" + + @rx.memo + def recursive_box(items: rx.Var[list[int]]) -> rx.Component: + return rx.box( + rx.foreach(items, lambda item: recursive_box(items=items)), + ) + + assert "RecursiveBox" in MEMOS + definition = MEMOS["RecursiveBox"] + assert isinstance(definition, MemoComponentDefinition) + + files, _ = compiler.compile_memo_components(tuple(MEMOS.values())) + body_source = next( + code for path, code in files if path.endswith("RecursiveBox.jsx") + ) + # ``>= 2``: once for the export, once for the recursive foreach call site. + assert body_source.count("RecursiveBox") >= 2 + + instance = recursive_box(items=Var(_js_expr="items", _var_type=list[int])) + assert isinstance(instance, MemoComponent) + assert type(instance).tag == "RecursiveBox" + + +def test_self_referencing_var_memo(): + """Var-returning memos whose body recursively calls themselves should decorate.""" + + @rx.memo + def recursive_count(n: rx.vars.NumberVar[int]) -> rx.Var[int]: + recurse = cast("rx.vars.NumberVar[int]", recursive_count(n=n - 1)) + return cast("rx.Var[int]", rx.cond(n.bool(), n + recurse, 0)) + + definition = MEMOS["recursive_count"] + assert isinstance(definition, MemoFunctionDefinition) + assert "recursive_count" in str(definition.function) + + invoked = recursive_count(n=Var(_js_expr="three", _var_type=int)) + assert "recursive_count" in str(invoked) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 36baee0ec8e..acc05045bfe 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -10,14 +10,13 @@ import pytest import pytest_asyncio -from reflex_base.components.component import CUSTOM_COMPONENTS +from reflex_base.components.memo import MEMOS from reflex_base.event import Event, EventSpec from reflex_base.event.context import EventContext from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor from reflex_base.registry import RegistrationContext from reflex.app import App -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory @@ -491,17 +490,14 @@ def clean_registration_context() -> Generator[RegistrationContext, None, None]: @pytest.fixture def preserve_memo_registries(): - """Save and restore global memo registries around a test. + """Save and restore the global memo registry around a test. Yields: None """ - custom_components = dict(CUSTOM_COMPONENTS) - experimental_memos = dict(EXPERIMENTAL_MEMOS) + memos = dict(MEMOS) try: yield finally: - CUSTOM_COMPONENTS.clear() - CUSTOM_COMPONENTS.update(custom_components) - EXPERIMENTAL_MEMOS.clear() - EXPERIMENTAL_MEMOS.update(experimental_memos) + MEMOS.clear() + MEMOS.update(memos) diff --git a/tests/units/experimental/__init__.py b/tests/units/experimental/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py deleted file mode 100644 index efb006d545d..00000000000 --- a/tests/units/experimental/test_memo.py +++ /dev/null @@ -1,542 +0,0 @@ -"""Tests for experimental memo support.""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any - -import pytest -from reflex_base.components.component import CUSTOM_COMPONENTS, Component -from reflex_base.style import Style -from reflex_base.utils.imports import ImportVar -from reflex_base.vars import VarData -from reflex_base.vars.base import Var -from reflex_base.vars.function import FunctionVar - -import reflex as rx -from reflex.compiler import compiler -from reflex.compiler import utils as compiler_utils -from reflex.experimental.memo import ( - EXPERIMENTAL_MEMOS, - ExperimentalMemoComponent, - ExperimentalMemoComponentDefinition, - ExperimentalMemoFunctionDefinition, -) - - -@pytest.fixture(autouse=True) -def _restore_memo_registries(preserve_memo_registries): - """Autouse wrapper around the shared preserve_memo_registries fixture.""" - - -def test_var_returning_memo(): - """Var-returning memos should behave like imported function vars.""" - - @rx._x.memo - def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: - return currency.to(str) + ": $" + amount.to(str) - - price = Var(_js_expr="price", _var_type=int) - currency = Var(_js_expr="currency", _var_type=str) - - assert ( - str(format_price(amount=price, currency=currency)) - == "(format_price(price, currency))" - ) - assert ( - str(format_price.call(amount=price, currency=currency)) - == "(format_price(price, currency))" - ) - assert isinstance(format_price._as_var(), FunctionVar) - - definition = EXPERIMENTAL_MEMOS["format_price"] - assert isinstance(definition, ExperimentalMemoFunctionDefinition) - assert ( - str(definition.function) == '((amount, currency) => ((currency+": $")+amount))' - ) - - with pytest.raises(TypeError, match="only accepts keyword props"): - format_price(price, currency) - - -def test_component_returning_memo_with_children_and_rest(): - """Component-returning memos should accept positional children and forwarded props.""" - - @rx._x.memo - def my_card( - children: rx.Var[rx.Component], - rest: rx.RestProp, - *, - title: rx.Var[str], - ) -> rx.Component: - return rx.box( - rx.heading(title), - children, - rest, - ) - - component = my_card( - rx.text("child 1"), - rx.text("child 2"), - title="Hello", - foo="extra", - class_name="extra", - ) - component_again = my_card(title="World") - - assert isinstance(component, ExperimentalMemoComponent) - assert len(component.children) == 2 - assert component.get_props() == ("title", "foo") - assert type(component) is type(component_again) - assert type(component).tag == "MyCard" - assert type(component).get_fields()["tag"].default == "MyCard" - - rendered = component.render() - assert rendered["name"] == "MyCard" - assert 'title:"Hello"' in rendered["props"] - assert 'foo:"extra"' in rendered["props"] - assert 'className:"extra"' in rendered["props"] - - definition = EXPERIMENTAL_MEMOS["MyCard"] - assert isinstance(definition, ExperimentalMemoComponentDefinition) - assert any(str(prop) == "rest" for prop in definition.component.special_props) - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert "export const MyCard = memo(({children, title:title" in code - assert "...rest" in code - assert "jsx(RadixThemesBox,{...rest}" in code - - -def test_component_returning_memo_accepts_component_var_result(): - """Component-returning memos should accept component-typed var results.""" - - @rx._x.memo - def conditional_slot( - show: rx.Var[bool], - first: rx.Var[rx.Component], - second: rx.Var[rx.Component], - ) -> rx.Var[rx.Component]: - return rx.cond(show, first, second) - - definition = EXPERIMENTAL_MEMOS["ConditionalSlot"] - assert isinstance(definition, ExperimentalMemoComponentDefinition) - assert definition.component.render() == { - "contents": "(showRxMemo ? firstRxMemo : secondRxMemo)" - } - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert "export const ConditionalSlot = memo(({show:showRxMemo" in code - assert "(showRxMemo ? firstRxMemo : secondRxMemo)" in code - - -def test_var_returning_memo_with_rest_props(): - """Var-returning memos should capture extra keyword args into RestProp.""" - - @rx._x.memo - def merge_styles( - base: rx.Var[dict[str, str]], - overrides: rx.RestProp, - ) -> rx.Var[Any]: - return base.to(dict).merge(overrides) - - base = Var(_js_expr="base", _var_type=dict[str, str]) - merged = merge_styles(base=base, color="red", class_name="primary") - - assert "merge_styles" in str(merged) - assert '["base"] : base' in str(merged) - assert '["color"] : "red"' in str(merged) - assert '["className"] : "primary"' in str(merged) - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert ( - "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" - in code - ) - - with pytest.raises(TypeError, match="Do not pass `overrides=` directly"): - merge_styles(base=base, overrides={"color": "red"}) - - -def test_component_returning_memo_with_only_rest(): - """Component-returning memos with only RestProp should emit valid JSX (#6443).""" - - @rx._x.memo - def hover_trigger(rest: rx.RestProp) -> rx.Component: - return rx.text("hover me", rest) - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert "memo(({...rest})" in code - assert "({," not in code - - -def test_var_returning_memo_with_only_rest(): - """Var-returning memos with only RestProp should emit valid JS (#6443).""" - - @rx._x.memo - def merge_only(overrides: rx.RestProp) -> rx.Var[Any]: - return overrides - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert "(({...overrides}) => overrides)" in code - assert "({," not in code - - -def test_var_returning_memo_with_children_and_rest(): - """Var-returning memos should accept positional children plus keyword props.""" - - @rx._x.memo - def label_slot( - children: rx.Var[rx.Component], - rest: rx.RestProp, - *, - label: rx.Var[str], - ) -> rx.Var[str]: - return label - - rendered = label_slot( - rx.text("child"), - label="Hello", - class_name="slot", - ) - - assert "label_slot" in str(rendered) - assert '["children"]' in str(rendered) - assert '["className"] : "slot"' in str(rendered) - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - assert "export const label_slot = (({children, label, ...rest}) => label);" in code - - -def test_memo_requires_var_annotations(): - """Experimental memos should require Var annotations on parameters.""" - with pytest.raises(TypeError, match="must be annotated"): - - @rx._x.memo - def bad_annotation(value: int) -> rx.Var[str]: - return rx.Var.create("x") - - with pytest.raises(TypeError, match="Missing annotation"): - - @rx._x.memo - def missing_annotation(value) -> rx.Var[str]: - return rx.Var.create("x") - - -def test_memo_rejects_invalid_children_annotation(): - """Component memos should validate the special children annotation.""" - with pytest.raises(TypeError, match="children"): - - @rx._x.memo - def bad_children(children: rx.Var[str]) -> rx.Component: - return rx.text(children) - - -def test_memo_rejects_multiple_rest_props(): - """Experimental memos should only allow a single RestProp.""" - with pytest.raises(TypeError, match="only supports one"): - - @rx._x.memo - def too_many_rest( - first: rx.RestProp, - second: rx.RestProp, - ) -> rx.Var[Any]: - return first - - -def test_memo_rejects_component_and_function_name_collision(): - """Experimental memos should reject same exported name across kinds.""" - - @rx._x.memo - def foo_bar() -> rx.Component: - return rx.box() - - assert "FooBar" in EXPERIMENTAL_MEMOS - - with pytest.raises(ValueError, match=r"name collision.*FooBar"): - - @rx._x.memo - def FooBar() -> rx.Var[str]: - return rx.Var.create("x") - - -def test_memo_rejects_component_export_name_collision(): - """Experimental memos should reject duplicate component export names.""" - - @rx._x.memo - def foo_bar() -> rx.Component: - return rx.box() - - with pytest.raises(ValueError, match=r"name collision.*FooBar"): - - @rx._x.memo - def foo__bar() -> rx.Component: - return rx.box() - - -def test_memo_rejects_varargs(): - """Experimental memos should reject *args and **kwargs.""" - with pytest.raises(TypeError, match=r"\*args"): - - @rx._x.memo - def bad_args(*values: rx.Var[str]) -> rx.Var[str]: - return rx.Var.create("x") - - with pytest.raises(TypeError, match=r"\*\*kwargs"): - - @rx._x.memo - def bad_kwargs(**values: rx.Var[str]) -> rx.Var[str]: - return rx.Var.create("x") - - -def test_component_memo_rejects_invalid_positional_usage(): - """Component memos should only accept positional children.""" - - @rx._x.memo - def title_card(*, title: rx.Var[str]) -> rx.Component: - return rx.box(rx.heading(title)) - - with pytest.raises(TypeError, match="only accepts keyword props"): - title_card(rx.text("child")) - - @rx._x.memo - def child_card( - children: rx.Var[rx.Component], *, title: rx.Var[str] - ) -> rx.Component: - return rx.box(rx.heading(title), children) - - with pytest.raises(TypeError, match="only accepts positional children"): - child_card("not a component", title="Hello") - - -def test_var_memo_rejects_invalid_positional_usage(): - """Var memos should also reserve positional arguments for children only.""" - - @rx._x.memo - def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: - return currency.to(str) + ": $" + amount.to(str) - - price = Var(_js_expr="price", _var_type=int) - currency = Var(_js_expr="currency", _var_type=str) - - with pytest.raises(TypeError, match="only accepts keyword props"): - format_price(price, currency) - - @rx._x.memo - def child_label( - children: rx.Var[rx.Component], *, label: rx.Var[str] - ) -> rx.Var[str]: - return label - - with pytest.raises(TypeError, match="only accepts positional children"): - child_label("not a component", label="Hello") - - -def test_var_returning_memo_rejects_hooks(): - """Var-returning memos should reject hook-bearing expressions.""" - with pytest.raises(TypeError, match="cannot depend on hooks"): - - @rx._x.memo - def bad_hook(value: rx.Var[str]) -> rx.Var[str]: - return Var( - _js_expr="value", - _var_type=str, - _var_data=VarData(hooks={"const badHook = 1": None}), - ) - - -def test_var_returning_memo_rejects_non_bundled_imports(): - """Var-returning memos should reject non-bundled imports.""" - with pytest.raises(TypeError, match="not bundled"): - - @rx._x.memo - def bad_import(value: rx.Var[str]) -> rx.Var[str]: - return Var( - _js_expr="value", - _var_type=str, - _var_data=VarData(imports={"some-lib": [ImportVar(tag="x")]}), - ) - - -def test_compile_memo_components_includes_experimental_functions_and_components(): - """The shared memo output should include both experimental functions and components.""" - - @rx.memo - def old_wrapper(title: rx.Var[str]) -> rx.Component: - return rx.text(title) - - @rx._x.memo - def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: - return currency.to(str) + ": $" + amount.to(str) - - @rx._x.memo - def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Component: - return rx.box(rx.heading(title), children) - - files, _ = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), - ) - code = "\n".join(c for _, c in files) - - assert "export const OldWrapper = memo(" in code - assert "export const format_price =" in code - assert "export const MyCard = memo(" in code - - -def test_compile_memo_components_extends_imports_without_remerging( - monkeypatch: pytest.MonkeyPatch, -): - """Memo import aggregation should not repeatedly reprocess prior imports.""" - - def noop() -> None: - pass - - memos = tuple( - ExperimentalMemoComponentDefinition( - fn=noop, - python_name=f"memo_{idx}", - params=(), - export_name=f"Memo{idx}", - component=rx.fragment(), - passthrough_hole_child=None, - ) - for idx in range(5) - ) - - def fake_compile_experimental_component_memo( - definition: ExperimentalMemoComponentDefinition, - ) -> tuple[dict[str, str], dict[str, list[ImportVar]]]: - return {"name": definition.export_name}, {} - - def fake_compile_single_memo_component( - component_render: dict[str, str], - component_imports: dict[str, list[ImportVar]], - ) -> tuple[str, dict[str, list[ImportVar]]]: - return ( - f"export const {component_render['name']} = null", - {"shared-lib": [ImportVar(tag=component_render["name"])]}, - ) - - real_merge_imports = compiler_utils.merge_imports - - def reject_growing_merge(*imports): - if len(imports) == 2 and imports[0]: - msg = "aggregate imports should be extended, not remerged" - raise AssertionError(msg) - return real_merge_imports(*imports) - - monkeypatch.setattr( - compiler_utils, - "compile_experimental_component_memo", - fake_compile_experimental_component_memo, - ) - monkeypatch.setattr( - compiler, - "_compile_single_memo_component", - fake_compile_single_memo_component, - ) - monkeypatch.setattr(compiler_utils, "merge_imports", reject_growing_merge) - - files, aggregate_imports = compiler.compile_memo_components((), memos) - - assert len(files) == len(memos) + 1 - assert [import_var.tag for import_var in aggregate_imports["shared-lib"]] == [ - f"Memo{idx}" for idx in range(5) - ] - - -def test_experimental_component_memo_get_imports(): - """Experimental component memos should resolve imports during compilation.""" - - class Inner(Component): - tag = "Inner" - library = "inner" - - @rx._x.memo - def wrapper() -> rx.Component: - return Inner.create() - - experimental_component = wrapper() - - assert "inner" not in experimental_component._get_all_imports() - - definition = EXPERIMENTAL_MEMOS["Wrapper"] - assert isinstance(definition, ExperimentalMemoComponentDefinition) - _, imports = compiler_utils.compile_experimental_component_memo(definition) - assert "inner" in imports - - -def test_compile_experimental_component_memo_does_not_mutate_definition( - monkeypatch: pytest.MonkeyPatch, -): - """Experimental component memo compilation should not mutate stored components.""" - - @rx._x.memo - def wrapper() -> rx.Component: - return rx.box("hi") - - definition = EXPERIMENTAL_MEMOS["Wrapper"] - assert isinstance(definition, ExperimentalMemoComponentDefinition) - assert definition.component.style == Style() - - monkeypatch.setattr( - "reflex.utils.prerequisites.get_and_validate_app", - lambda: SimpleNamespace( - app=SimpleNamespace( - style={type(definition.component): Style({"color": "red"})} - ) - ), - ) - - render, _ = compiler_utils.compile_experimental_component_memo(definition) - - assert render["render"]["props"] == ['css:({ ["color"] : "red" })'] - assert definition.component.style == Style() - - -def test_component_returning_memo_is_transparent_for_child_validation(): - """Experimental memo wrappers should not break `_valid_parents` checks.""" - - class ValidParent(Component): - tag = "ValidParent" - library = "valid-parent" - - class RestrictedChild(Component): - tag = "RestrictedChild" - library = "restricted-child" - _valid_parents = ["ValidParent"] - - @rx._x.memo - def transparent(children: rx.Var[rx.Component]) -> rx.Component: - return children # type: ignore[return-value] - - wrapped_child = transparent(RestrictedChild.create()) - parent = ValidParent.create(wrapped_child) - - assert isinstance(wrapped_child, ExperimentalMemoComponent) - assert parent.children == [wrapped_child] - - -def test_compile_memo_components_includes_experimental_custom_code(): - """Experimental component memos should include custom code in compiled output.""" - - class FooComponent(rx.Fragment): - def add_custom_code(self) -> list[str]: - return [ - "const foo = 'bar'", - ] - - @rx._x.memo - def foo_component(label: rx.Var[str]) -> rx.Component: - return FooComponent.create(label, rx.Var("foo")) - - files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) - code = "\n".join(c for _, c in files) - - assert "const foo = 'bar'" in code diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 887f6a3d8b4..2461ce18afc 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -38,7 +38,8 @@ import reflex as rx from reflex import AdminDash, constants -from reflex.app import App, ComponentCallable, upload +from reflex._upload import upload +from reflex.app import App, ComponentCallable from reflex.environment import environment from reflex.istate.data import RouterData from reflex.istate.manager.disk import StateManagerDisk @@ -2232,7 +2233,7 @@ def test_compile_writes_app_wrap_memo_components( compilable_app: tuple[App, Path], mocker, ) -> None: - """App-wrap memo components are emitted to the shared components module.""" + """App-wrap memo components are emitted as per-memo modules.""" conf = rx.Config(app_name="testing") mocker.patch("reflex_base.config._get_config", return_value=conf) app, web_dir = compilable_app @@ -2240,19 +2241,9 @@ def test_compile_writes_app_wrap_memo_components( app.add_page(rx.box("Index"), route="/") app._compile() - components_index = ( - web_dir - / constants.Dirs.UTILS - / f"{constants.PageNames.COMPONENTS}{constants.Ext.JSX}" - ).read_text() - - # Per-memo modules live under .web/utils/components/; the index re-exports - # each one so page-side ``$/utils/components`` resolves the same tags. - assert "DefaultOverlayComponents" in components_index - assert "MemoizedToastProvider" in components_index - assert 'from "./components/DefaultOverlayComponents"' in components_index - assert 'from "./components/MemoizedToastProvider"' in components_index - + # Per-memo modules live under .web/utils/components/; each memo wrapper + # declares its ``library`` as the per-memo file path, so pages import it + # directly. memo_dir = web_dir / constants.Dirs.UTILS / constants.PageNames.COMPONENTS assert (memo_dir / f"DefaultOverlayComponents{constants.Ext.JSX}").exists() assert (memo_dir / f"MemoizedToastProvider{constants.Ext.JSX}").exists() diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py index e1f6c20576b..82a40a2dd26 100644 --- a/tests/units/test_testing.py +++ b/tests/units/test_testing.py @@ -6,13 +6,12 @@ import pytest import reflex_base.config -from reflex_base.components.component import CUSTOM_COMPONENTS +from reflex_base.components.memo import MEMOS from reflex_base.constants import IS_WINDOWS import reflex.reflex as reflex_cli import reflex.testing as reflex_testing import reflex.utils.prerequisites -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.testing import AppHarness @@ -88,7 +87,7 @@ def harness_mocks(monkeypatch): def test_app_harness_initialize_clears_memo_registries( tmp_path, preserve_memo_registries, harness_mocks, monkeypatch ): - """Ensure app initialization clears leaked memo registries. + """Ensure app initialization clears the leaked memo registry. Args: tmp_path: pytest tmp_path fixture @@ -98,8 +97,7 @@ def test_app_harness_initialize_clears_memo_registries( """ monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) - CUSTOM_COMPONENTS["FooComponent"] = mock.sentinel.component - EXPERIMENTAL_MEMOS["format_value"] = mock.sentinel.memo + MEMOS["format_value"] = mock.sentinel.memo harness = AppHarness.create( root=tmp_path / "memo_app", @@ -109,8 +107,7 @@ def test_app_harness_initialize_clears_memo_registries( harness.app_module_path.parent.mkdir(parents=True, exist_ok=True) harness._initialize_app() - assert "FooComponent" not in CUSTOM_COMPONENTS - assert "format_value" not in EXPERIMENTAL_MEMOS + assert "format_value" not in MEMOS harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) diff --git a/tests/units/utils/test_telemetry_accounting.py b/tests/units/utils/test_telemetry_accounting.py index 83103d7f7fb..7fcd996d055 100644 --- a/tests/units/utils/test_telemetry_accounting.py +++ b/tests/units/utils/test_telemetry_accounting.py @@ -221,7 +221,7 @@ def test_memo_wrapper_class_records_wrapped_component_type(): memo_module = importlib.import_module("reflex.experimental.memo") - wrapper_cls = memo_module._get_experimental_memo_component_class( + wrapper_cls = memo_module._get_memo_component_class( "Button_button_deadbeefcafebabe", Button, )