Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions marimo/_server/api/endpoints/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
inject_script,
notebook_page_template,
)
from marimo._server.workspace import (
FileKey,
parse_file_key,
serialize_file_key,
)
from marimo._session.model import SessionMode
from marimo._utils.async_path import AsyncPath
from marimo._utils.paths import (
Expand Down Expand Up @@ -192,11 +197,14 @@ def og_thumbnail(*, request: Request) -> Response:
from marimo._utils.paths import normalize_path

app_state = AppState(request)
file_key = (
app_state.query_params(FILE_QUERY_PARAM_KEY)
or app_state.session_manager.workspace.get_unique_file_key()
raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY)
# Empty ``?file=`` falls back to the workspace key — same as missing.
file_key: FileKey | None = (
parse_file_key(raw_file_key)
if raw_file_key
else app_state.session_manager.workspace.get_unique_file_key()
)
if not file_key:
if file_key is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="File not found"
)
Expand All @@ -216,7 +224,7 @@ def og_thumbnail(*, request: Request) -> Response:
notebook_path,
context=OpenGraphContext(
filepath=notebook_path,
file_key=file_key,
file_key=serialize_file_key(file_key),
base_url=app_state.base_url,
mode=app_state.mode.value,
),
Expand Down Expand Up @@ -309,9 +317,12 @@ async def index(request: Request) -> Response:
index_html = root / "index.html"

file_key_from_query = app_state.query_params(FILE_QUERY_PARAM_KEY)
file_key = (
file_key_from_query
or app_state.session_manager.workspace.get_unique_file_key()
# Empty ``?file=`` falls back to the workspace key — same as missing —
# which preserves the homepage rendering when no file is selected.
file_key: FileKey | None = (
parse_file_key(file_key_from_query)
if file_key_from_query
else app_state.session_manager.workspace.get_unique_file_key()
)

# Try local index.html first, fallback to asset_url if local file doesn't exist
Expand All @@ -329,7 +340,7 @@ async def index(request: Request) -> Response:
detail=_missing_index_html_detail(),
)

if not file_key:
if file_key is None:
# We don't know which file to use, so we need to render a homepage
LOGGER.debug("No file key provided, serving homepage")
html = home_page_template(
Expand All @@ -342,10 +353,11 @@ async def index(request: Request) -> Response:
asset_url=app_state.asset_url,
)
else:
config_manager = app_state.config_manager_at_file(file_key)
serialized_file_key = serialize_file_key(file_key)
config_manager = app_state.config_manager_at_file(serialized_file_key)

# We have a file key, so we can render the app with the file
LOGGER.debug(f"File key provided: {file_key}")
LOGGER.debug(f"File key provided: {serialized_file_key}")
app_manager = app_state.session_manager.app_manager(file_key)
app_config = app_manager.app.config
absolute_filepath = app_manager.filename
Expand Down Expand Up @@ -402,7 +414,7 @@ async def index(request: Request) -> Response:
)

# Inject service worker registration with the notebook ID
html = _inject_service_worker(html, file_key)
html = _inject_service_worker(html, serialized_file_key)

return HTMLResponse(html, headers=_HTML_SECURITY_HEADERS)

Expand Down Expand Up @@ -614,7 +626,9 @@ async def serve_public_file(request: Request) -> Response:
if notebook_id:
# Decode notebook ID
notebook_id = uri_decode_component(notebook_id)
app_manager = app_state.session_manager.app_manager(notebook_id)
app_manager = app_state.session_manager.app_manager(
parse_file_key(notebook_id)
)
if app_manager.filename:
notebook_dir = Path(app_manager.filename).parent
else:
Expand Down
31 changes: 21 additions & 10 deletions marimo/_server/api/endpoints/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
)
from marimo._server.router import APIRouter
from marimo._server.uvicorn_utils import close_uvicorn
from marimo._server.workspace import MarimoFileKey
from marimo._server.workspace import (
FileKey,
PathFileKey,
parse_file_key,
)
from marimo._types.ids import ConsumerId

if TYPE_CHECKING:
Expand Down Expand Up @@ -423,12 +427,16 @@ async def restart_session(
session = app_state.require_current_session()
session_manager.close_session(session_id)

# Close RTC doc if it exists
file_key: MarimoFileKey | None = (
app_state.query_params(FILE_QUERY_PARAM_KEY)
or session_manager.workspace.get_unique_file_key()
or session.app_file_manager.path
)
# Close RTC doc if it exists. Empty ``?file=`` falls back to the workspace
# key — same as a missing query param — to preserve the prior or-chain.
raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY)
file_key: FileKey | None
if raw_file_key:
file_key = parse_file_key(raw_file_key)
else:
file_key = session_manager.workspace.get_unique_file_key()
if file_key is None and session.app_file_manager.path is not None:
file_key = PathFileKey(session.app_file_manager.path)
if file_key is not None:
Comment thread
mscolnick marked this conversation as resolved.
await DOC_MANAGER.remove_doc(file_key)
else:
Expand Down Expand Up @@ -514,9 +522,12 @@ async def takeover_endpoint(
"""
app_state = AppState(request)

file_key: MarimoFileKey | None = (
app_state.query_params(FILE_QUERY_PARAM_KEY)
or app_state.session_manager.workspace.get_unique_file_key()
raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY)
# Empty ``?file=`` falls back to the workspace key — same as missing.
file_key: FileKey | None = (
parse_file_key(raw_file_key)
if raw_file_key
else app_state.session_manager.workspace.get_unique_file_key()
)
if file_key is None:
LOGGER.error("No file key provided")
Expand Down
3 changes: 2 additions & 1 deletion marimo/_server/api/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from marimo._dependencies.dependencies import DependencyManager
from marimo._server.api.deps import AppState
from marimo._server.router import APIRouter
from marimo._server.workspace import NEW_FILE_WIRE
from marimo._utils.health import (
get_cgroup_cpu_percent,
get_cgroup_mem_stats,
Expand Down Expand Up @@ -74,7 +75,7 @@ async def status(request: Request) -> JSONResponse:
"""
app_state = AppState(request)
files = [
session.app_file_manager.filename or "__new__"
session.app_file_manager.filename or NEW_FILE_WIRE
for session in app_state.session_manager.sessions.values()
]
return JSONResponse(
Expand Down
3 changes: 2 additions & 1 deletion marimo/_server/api/endpoints/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from marimo._server.router import APIRouter
from marimo._server.workspace import (
PathFileKey,
count_files,
flatten_files,
)
Expand Down Expand Up @@ -116,7 +117,7 @@ def get_files_with_metadata() -> list[FileInfo]:
for file in marimo_files:
try:
resolved_path = session_manager.workspace.resolve(
file.path
PathFileKey(file.path)
)
except HTTPException as e:
if e.status_code == HTTPStatus.NOT_FOUND:
Expand Down
37 changes: 24 additions & 13 deletions marimo/_server/api/endpoints/ws/ws_connection_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from marimo._server.api.auth import validate_auth
from marimo._server.api.deps import AppState
from marimo._server.codes import WebSocketCodes
from marimo._server.workspace import MarimoFileKey
from marimo._server.workspace import (
FileKey,
parse_file_key,
serialize_file_key,
)
from marimo._types.ids import SessionId

LOGGER = _loggers.marimo_logger()
Expand All @@ -26,7 +30,7 @@ class ConnectionParams:
"""Parameters extracted from WebSocket connection request."""

session_id: SessionId
file_key: MarimoFileKey
file_key: FileKey
kiosk: bool
auto_instantiate: bool
rtc_enabled: bool
Expand Down Expand Up @@ -71,10 +75,7 @@ async def extract_connection_params(
session_id = SessionId(raw_session_id)

# Extract file_key
file_key: MarimoFileKey | None = (
self.app_state.query_params(FILE_QUERY_PARAM_KEY)
or self.app_state.session_manager.workspace.get_unique_file_key()
)
file_key = self._extract_file_key()

if file_key is None:
await self.websocket.close(
Expand All @@ -86,7 +87,9 @@ async def extract_connection_params(
kiosk = self.app_state.query_params(KIOSK_QUERY_PARAM_KEY) == "true"

# Extract config-based parameters
config = self.app_state.config_manager_at_file(file_key).get_config()
config = self.app_state.config_manager_at_file(
serialize_file_key(file_key)
).get_config()
rtc_enabled = config.get("experimental", {}).get("rtc_v2", False)
auto_instantiate = config["runtime"]["auto_instantiate"]

Expand All @@ -98,16 +101,13 @@ async def extract_connection_params(
rtc_enabled=rtc_enabled,
)

async def extract_file_key_only(self) -> MarimoFileKey | None:
async def extract_file_key_only(self) -> FileKey | None:
"""Extract only the file_key parameter (for RTC endpoint).

Returns:
MarimoFileKey if valid, None otherwise.
FileKey if valid, None otherwise.
"""
file_key: MarimoFileKey | None = (
self.app_state.query_params(FILE_QUERY_PARAM_KEY)
or self.app_state.session_manager.workspace.get_unique_file_key()
)
file_key = self._extract_file_key()

if file_key is None:
LOGGER.warning("RTC: Closing websocket - no file key")
Expand All @@ -117,3 +117,14 @@ async def extract_file_key_only(self) -> MarimoFileKey | None:
return None

return file_key

def _extract_file_key(self) -> FileKey | None:
"""Extract a FileKey from query params or fall back to the workspace.

An empty ``?file=`` value falls back to the workspace key — same as a
missing query param — to preserve the prior ``or``-chain semantics.
"""
raw = self.app_state.query_params(FILE_QUERY_PARAM_KEY)
if raw:
return parse_file_key(raw)
return self.app_state.session_manager.workspace.get_unique_file_key()
6 changes: 3 additions & 3 deletions marimo/_server/api/endpoints/ws/ws_kernel_ready.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if TYPE_CHECKING:
from marimo._server.rtc.doc import LoroDocManager
from marimo._server.session_manager import SessionManager
from marimo._server.workspace import MarimoFileKey
from marimo._server.workspace import FileKey
from marimo._session import Session

LOGGER = _loggers.marimo_logger()
Expand All @@ -40,7 +40,7 @@ def build_kernel_ready(
last_execution_time: dict[CellId_t, float],
kiosk: bool,
rtc_enabled: bool,
file_key: MarimoFileKey,
file_key: FileKey,
mode: SessionMode,
doc_manager: LoroDocManager,
auto_instantiated: bool = False,
Expand Down Expand Up @@ -167,7 +167,7 @@ def _should_init_rtc(rtc_enabled: bool, mode: SessionMode) -> bool:
def _try_init_rtc_doc(
cell_ids: tuple[CellId_t, ...],
codes: tuple[str, ...],
file_key: MarimoFileKey,
file_key: FileKey,
doc_manager: LoroDocManager,
) -> None:
"""Try to initialize RTC document with cell data.
Expand Down
4 changes: 2 additions & 2 deletions marimo/_server/api/endpoints/ws/ws_rtc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from starlette.websockets import WebSocket

from marimo._server.rtc.doc import LoroDocManager
from marimo._server.workspace import MarimoFileKey
from marimo._server.workspace import FileKey

LOGGER = _loggers.marimo_logger()

Expand All @@ -24,7 +24,7 @@ class RTCWebSocketHandler:
def __init__(
self,
websocket: WebSocket,
file_key: MarimoFileKey,
file_key: FileKey,
doc_manager: LoroDocManager,
):
self.websocket = websocket
Expand Down
4 changes: 2 additions & 2 deletions marimo/_server/api/lifespans.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from marimo._server.tokens import AuthToken
from marimo._server.utils import initialize_mimetypes
from marimo._server.uvicorn_utils import close_uvicorn
from marimo._server.workspace import NEW_FILE
from marimo._server.workspace import NewFileKey
from marimo._session.model import SessionMode
from marimo._utils.subprocess import cancel_pending_reaps

Expand Down Expand Up @@ -170,7 +170,7 @@ async def logging(app: Starlette) -> AsyncIterator[None]:
file_name=file.name if file else None,
url=_startup_url(state),
run=manager.mode == SessionMode.RUN,
new=workspace.get_unique_file_key() == NEW_FILE,
new=isinstance(workspace.get_unique_file_key(), NewFileKey),
network=state.host == "0.0.0.0",
startup_tip=state.startup_tip,
)
Expand Down
18 changes: 11 additions & 7 deletions marimo/_server/resume_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
LOGGER = _loggers.marimo_logger()

if TYPE_CHECKING:
from marimo._server.workspace import MarimoFileKey
from marimo._server.workspace import FileKey


class SessionResumeStrategy(Protocol):
Expand All @@ -26,7 +26,7 @@ class SessionResumeStrategy(Protocol):
def try_resume(
self,
new_session_id: SessionId,
file_key: MarimoFileKey,
file_key: FileKey,
) -> Session | None:
"""Try to resume a session.

Expand All @@ -53,18 +53,22 @@ def __init__(self, repository: SessionRepository) -> None:
def try_resume(
self,
new_session_id: SessionId,
file_key: MarimoFileKey,
file_key: FileKey,
) -> Session | None:
"""Try to resume an orphaned session for the same file."""
import os

from marimo._server.workspace import NewFileKey

if isinstance(file_key, NewFileKey):
return None

abs_path = os.path.abspath(file_key.path)
# Find sessions with the same file
sessions_with_file = []
for session in self._repository.get_all():
session_id = self._repository.get_session_id(session)
if session_id and session.app_file_manager.path == os.path.abspath(
file_key
):
if session_id and session.app_file_manager.path == abs_path:
sessions_with_file.append((session_id, session))

if len(sessions_with_file) == 0:
Expand Down Expand Up @@ -108,7 +112,7 @@ def __init__(self, repository: SessionRepository) -> None:
def try_resume(
self,
new_session_id: SessionId,
file_key: MarimoFileKey, # noqa: ARG002
file_key: FileKey, # noqa: ARG002
) -> Session | None:
"""Try to resume a session with matching ID if orphaned."""
session = self._repository.get_sync(new_session_id)
Expand Down
Loading
Loading