diff --git a/marimo/_server/api/endpoints/assets.py b/marimo/_server/api/endpoints/assets.py index 793da1ca3f8..5a8f016f41a 100644 --- a/marimo/_server/api/endpoints/assets.py +++ b/marimo/_server/api/endpoints/assets.py @@ -37,6 +37,11 @@ inject_script, notebook_page_template, ) +from marimo._server.workspace import ( + FileKey, + parse_file_key, + serialize_file_key, +) from marimo._session.model import SessionMode from marimo._utils.async_path import AsyncPath from marimo._utils.paths import ( @@ -192,11 +197,14 @@ def og_thumbnail(*, request: Request) -> Response: from marimo._utils.paths import normalize_path app_state = AppState(request) - file_key = ( - app_state.query_params(FILE_QUERY_PARAM_KEY) - or app_state.session_manager.workspace.get_unique_file_key() + raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY) + # Empty ``?file=`` falls back to the workspace key — same as missing. + file_key: FileKey | None = ( + parse_file_key(raw_file_key) + if raw_file_key + else app_state.session_manager.workspace.get_unique_file_key() ) - if not file_key: + if file_key is None: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="File not found" ) @@ -216,7 +224,7 @@ def og_thumbnail(*, request: Request) -> Response: notebook_path, context=OpenGraphContext( filepath=notebook_path, - file_key=file_key, + file_key=serialize_file_key(file_key), base_url=app_state.base_url, mode=app_state.mode.value, ), @@ -309,9 +317,12 @@ async def index(request: Request) -> Response: index_html = root / "index.html" file_key_from_query = app_state.query_params(FILE_QUERY_PARAM_KEY) - file_key = ( - file_key_from_query - or app_state.session_manager.workspace.get_unique_file_key() + # Empty ``?file=`` falls back to the workspace key — same as missing — + # which preserves the homepage rendering when no file is selected. + file_key: FileKey | None = ( + parse_file_key(file_key_from_query) + if file_key_from_query + else app_state.session_manager.workspace.get_unique_file_key() ) # Try local index.html first, fallback to asset_url if local file doesn't exist @@ -329,7 +340,7 @@ async def index(request: Request) -> Response: detail=_missing_index_html_detail(), ) - if not file_key: + if file_key is None: # We don't know which file to use, so we need to render a homepage LOGGER.debug("No file key provided, serving homepage") html = home_page_template( @@ -342,10 +353,11 @@ async def index(request: Request) -> Response: asset_url=app_state.asset_url, ) else: - config_manager = app_state.config_manager_at_file(file_key) + serialized_file_key = serialize_file_key(file_key) + config_manager = app_state.config_manager_at_file(serialized_file_key) # We have a file key, so we can render the app with the file - LOGGER.debug(f"File key provided: {file_key}") + LOGGER.debug(f"File key provided: {serialized_file_key}") app_manager = app_state.session_manager.app_manager(file_key) app_config = app_manager.app.config absolute_filepath = app_manager.filename @@ -402,7 +414,7 @@ async def index(request: Request) -> Response: ) # Inject service worker registration with the notebook ID - html = _inject_service_worker(html, file_key) + html = _inject_service_worker(html, serialized_file_key) return HTMLResponse(html, headers=_HTML_SECURITY_HEADERS) @@ -614,7 +626,9 @@ async def serve_public_file(request: Request) -> Response: if notebook_id: # Decode notebook ID notebook_id = uri_decode_component(notebook_id) - app_manager = app_state.session_manager.app_manager(notebook_id) + app_manager = app_state.session_manager.app_manager( + parse_file_key(notebook_id) + ) if app_manager.filename: notebook_dir = Path(app_manager.filename).parent else: diff --git a/marimo/_server/api/endpoints/execution.py b/marimo/_server/api/endpoints/execution.py index 1b94f81fa32..c77b2f32628 100644 --- a/marimo/_server/api/endpoints/execution.py +++ b/marimo/_server/api/endpoints/execution.py @@ -31,7 +31,11 @@ ) from marimo._server.router import APIRouter from marimo._server.uvicorn_utils import close_uvicorn -from marimo._server.workspace import MarimoFileKey +from marimo._server.workspace import ( + FileKey, + PathFileKey, + parse_file_key, +) from marimo._types.ids import ConsumerId if TYPE_CHECKING: @@ -423,12 +427,16 @@ async def restart_session( session = app_state.require_current_session() session_manager.close_session(session_id) - # Close RTC doc if it exists - file_key: MarimoFileKey | None = ( - app_state.query_params(FILE_QUERY_PARAM_KEY) - or session_manager.workspace.get_unique_file_key() - or session.app_file_manager.path - ) + # Close RTC doc if it exists. Empty ``?file=`` falls back to the workspace + # key — same as a missing query param — to preserve the prior or-chain. + raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY) + file_key: FileKey | None + if raw_file_key: + file_key = parse_file_key(raw_file_key) + else: + file_key = session_manager.workspace.get_unique_file_key() + if file_key is None and session.app_file_manager.path is not None: + file_key = PathFileKey(session.app_file_manager.path) if file_key is not None: await DOC_MANAGER.remove_doc(file_key) else: @@ -514,9 +522,12 @@ async def takeover_endpoint( """ app_state = AppState(request) - file_key: MarimoFileKey | None = ( - app_state.query_params(FILE_QUERY_PARAM_KEY) - or app_state.session_manager.workspace.get_unique_file_key() + raw_file_key = app_state.query_params(FILE_QUERY_PARAM_KEY) + # Empty ``?file=`` falls back to the workspace key — same as missing. + file_key: FileKey | None = ( + parse_file_key(raw_file_key) + if raw_file_key + else app_state.session_manager.workspace.get_unique_file_key() ) if file_key is None: LOGGER.error("No file key provided") diff --git a/marimo/_server/api/endpoints/health.py b/marimo/_server/api/endpoints/health.py index 702a4b19cd5..0bc59e63d52 100644 --- a/marimo/_server/api/endpoints/health.py +++ b/marimo/_server/api/endpoints/health.py @@ -11,6 +11,7 @@ from marimo._dependencies.dependencies import DependencyManager from marimo._server.api.deps import AppState from marimo._server.router import APIRouter +from marimo._server.workspace import NEW_FILE_WIRE from marimo._utils.health import ( get_cgroup_cpu_percent, get_cgroup_mem_stats, @@ -74,7 +75,7 @@ async def status(request: Request) -> JSONResponse: """ app_state = AppState(request) files = [ - session.app_file_manager.filename or "__new__" + session.app_file_manager.filename or NEW_FILE_WIRE for session in app_state.session_manager.sessions.values() ] return JSONResponse( diff --git a/marimo/_server/api/endpoints/home.py b/marimo/_server/api/endpoints/home.py index a16b6603135..091936b39e5 100644 --- a/marimo/_server/api/endpoints/home.py +++ b/marimo/_server/api/endpoints/home.py @@ -25,6 +25,7 @@ ) from marimo._server.router import APIRouter from marimo._server.workspace import ( + PathFileKey, count_files, flatten_files, ) @@ -116,7 +117,7 @@ def get_files_with_metadata() -> list[FileInfo]: for file in marimo_files: try: resolved_path = session_manager.workspace.resolve( - file.path + PathFileKey(file.path) ) except HTTPException as e: if e.status_code == HTTPStatus.NOT_FOUND: diff --git a/marimo/_server/api/endpoints/ws/ws_connection_validator.py b/marimo/_server/api/endpoints/ws/ws_connection_validator.py index 21f65a3d038..b36bc1b4da0 100644 --- a/marimo/_server/api/endpoints/ws/ws_connection_validator.py +++ b/marimo/_server/api/endpoints/ws/ws_connection_validator.py @@ -11,7 +11,11 @@ from marimo._server.api.auth import validate_auth from marimo._server.api.deps import AppState from marimo._server.codes import WebSocketCodes -from marimo._server.workspace import MarimoFileKey +from marimo._server.workspace import ( + FileKey, + parse_file_key, + serialize_file_key, +) from marimo._types.ids import SessionId LOGGER = _loggers.marimo_logger() @@ -26,7 +30,7 @@ class ConnectionParams: """Parameters extracted from WebSocket connection request.""" session_id: SessionId - file_key: MarimoFileKey + file_key: FileKey kiosk: bool auto_instantiate: bool rtc_enabled: bool @@ -71,10 +75,7 @@ async def extract_connection_params( session_id = SessionId(raw_session_id) # Extract file_key - file_key: MarimoFileKey | None = ( - self.app_state.query_params(FILE_QUERY_PARAM_KEY) - or self.app_state.session_manager.workspace.get_unique_file_key() - ) + file_key = self._extract_file_key() if file_key is None: await self.websocket.close( @@ -86,7 +87,9 @@ async def extract_connection_params( kiosk = self.app_state.query_params(KIOSK_QUERY_PARAM_KEY) == "true" # Extract config-based parameters - config = self.app_state.config_manager_at_file(file_key).get_config() + config = self.app_state.config_manager_at_file( + serialize_file_key(file_key) + ).get_config() rtc_enabled = config.get("experimental", {}).get("rtc_v2", False) auto_instantiate = config["runtime"]["auto_instantiate"] @@ -98,16 +101,13 @@ async def extract_connection_params( rtc_enabled=rtc_enabled, ) - async def extract_file_key_only(self) -> MarimoFileKey | None: + async def extract_file_key_only(self) -> FileKey | None: """Extract only the file_key parameter (for RTC endpoint). Returns: - MarimoFileKey if valid, None otherwise. + FileKey if valid, None otherwise. """ - file_key: MarimoFileKey | None = ( - self.app_state.query_params(FILE_QUERY_PARAM_KEY) - or self.app_state.session_manager.workspace.get_unique_file_key() - ) + file_key = self._extract_file_key() if file_key is None: LOGGER.warning("RTC: Closing websocket - no file key") @@ -117,3 +117,14 @@ async def extract_file_key_only(self) -> MarimoFileKey | None: return None return file_key + + def _extract_file_key(self) -> FileKey | None: + """Extract a FileKey from query params or fall back to the workspace. + + An empty ``?file=`` value falls back to the workspace key — same as a + missing query param — to preserve the prior ``or``-chain semantics. + """ + raw = self.app_state.query_params(FILE_QUERY_PARAM_KEY) + if raw: + return parse_file_key(raw) + return self.app_state.session_manager.workspace.get_unique_file_key() diff --git a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py index f3bdea78cfe..2f2cbe23845 100644 --- a/marimo/_server/api/endpoints/ws/ws_kernel_ready.py +++ b/marimo/_server/api/endpoints/ws/ws_kernel_ready.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from marimo._server.rtc.doc import LoroDocManager from marimo._server.session_manager import SessionManager - from marimo._server.workspace import MarimoFileKey + from marimo._server.workspace import FileKey from marimo._session import Session LOGGER = _loggers.marimo_logger() @@ -40,7 +40,7 @@ def build_kernel_ready( last_execution_time: dict[CellId_t, float], kiosk: bool, rtc_enabled: bool, - file_key: MarimoFileKey, + file_key: FileKey, mode: SessionMode, doc_manager: LoroDocManager, auto_instantiated: bool = False, @@ -167,7 +167,7 @@ def _should_init_rtc(rtc_enabled: bool, mode: SessionMode) -> bool: def _try_init_rtc_doc( cell_ids: tuple[CellId_t, ...], codes: tuple[str, ...], - file_key: MarimoFileKey, + file_key: FileKey, doc_manager: LoroDocManager, ) -> None: """Try to initialize RTC document with cell data. diff --git a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py index 6e0263ba56c..17ed3c7077d 100644 --- a/marimo/_server/api/endpoints/ws/ws_rtc_handler.py +++ b/marimo/_server/api/endpoints/ws/ws_rtc_handler.py @@ -13,7 +13,7 @@ from starlette.websockets import WebSocket from marimo._server.rtc.doc import LoroDocManager - from marimo._server.workspace import MarimoFileKey + from marimo._server.workspace import FileKey LOGGER = _loggers.marimo_logger() @@ -24,7 +24,7 @@ class RTCWebSocketHandler: def __init__( self, websocket: WebSocket, - file_key: MarimoFileKey, + file_key: FileKey, doc_manager: LoroDocManager, ): self.websocket = websocket diff --git a/marimo/_server/api/lifespans.py b/marimo/_server/api/lifespans.py index 10903d0b088..e5648f41506 100644 --- a/marimo/_server/api/lifespans.py +++ b/marimo/_server/api/lifespans.py @@ -25,7 +25,7 @@ from marimo._server.tokens import AuthToken from marimo._server.utils import initialize_mimetypes from marimo._server.uvicorn_utils import close_uvicorn -from marimo._server.workspace import NEW_FILE +from marimo._server.workspace import NewFileKey from marimo._session.model import SessionMode from marimo._utils.subprocess import cancel_pending_reaps @@ -170,7 +170,7 @@ async def logging(app: Starlette) -> AsyncIterator[None]: file_name=file.name if file else None, url=_startup_url(state), run=manager.mode == SessionMode.RUN, - new=workspace.get_unique_file_key() == NEW_FILE, + new=isinstance(workspace.get_unique_file_key(), NewFileKey), network=state.host == "0.0.0.0", startup_tip=state.startup_tip, ) diff --git a/marimo/_server/resume_strategies.py b/marimo/_server/resume_strategies.py index 1996339cf62..250ed753968 100644 --- a/marimo/_server/resume_strategies.py +++ b/marimo/_server/resume_strategies.py @@ -17,7 +17,7 @@ LOGGER = _loggers.marimo_logger() if TYPE_CHECKING: - from marimo._server.workspace import MarimoFileKey + from marimo._server.workspace import FileKey class SessionResumeStrategy(Protocol): @@ -26,7 +26,7 @@ class SessionResumeStrategy(Protocol): def try_resume( self, new_session_id: SessionId, - file_key: MarimoFileKey, + file_key: FileKey, ) -> Session | None: """Try to resume a session. @@ -53,18 +53,22 @@ def __init__(self, repository: SessionRepository) -> None: def try_resume( self, new_session_id: SessionId, - file_key: MarimoFileKey, + file_key: FileKey, ) -> Session | None: """Try to resume an orphaned session for the same file.""" import os + from marimo._server.workspace import NewFileKey + + if isinstance(file_key, NewFileKey): + return None + + abs_path = os.path.abspath(file_key.path) # Find sessions with the same file sessions_with_file = [] for session in self._repository.get_all(): session_id = self._repository.get_session_id(session) - if session_id and session.app_file_manager.path == os.path.abspath( - file_key - ): + if session_id and session.app_file_manager.path == abs_path: sessions_with_file.append((session_id, session)) if len(sessions_with_file) == 0: @@ -108,7 +112,7 @@ def __init__(self, repository: SessionRepository) -> None: def try_resume( self, new_session_id: SessionId, - file_key: MarimoFileKey, # noqa: ARG002 + file_key: FileKey, # noqa: ARG002 ) -> Session | None: """Try to resume a session with matching ID if orphaned.""" session = self._repository.get_sync(new_session_id) diff --git a/marimo/_server/rtc/doc.py b/marimo/_server/rtc/doc.py index 7165f6bad65..b63ed3ca3e2 100644 --- a/marimo/_server/rtc/doc.py +++ b/marimo/_server/rtc/doc.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from marimo import _loggers -from marimo._server.workspace import MarimoFileKey +from marimo._server.workspace import FileKey from marimo._types.ids import CellId_t if TYPE_CHECKING: @@ -16,17 +16,13 @@ class LoroDocManager: def __init__(self) -> None: - self.loro_docs: dict[MarimoFileKey, LoroDoc] = {} + self.loro_docs: dict[FileKey, LoroDoc] = {} self.loro_docs_lock = asyncio.Lock() - self.loro_docs_clients: dict[ - MarimoFileKey, set[asyncio.Queue[bytes]] - ] = {} - self.loro_docs_cleaners: dict[ - MarimoFileKey, asyncio.Task[None] | None - ] = {} + self.loro_docs_clients: dict[FileKey, set[asyncio.Queue[bytes]]] = {} + self.loro_docs_cleaners: dict[FileKey, asyncio.Task[None] | None] = {} async def _clean_loro_doc( - self, file_key: MarimoFileKey, timeout: float = 60 + self, file_key: FileKey, timeout: float = 60 ) -> None: """Clean up a loro doc if no clients are connected.""" try: @@ -49,7 +45,7 @@ async def _clean_loro_doc( async def create_doc( self, - file_key: MarimoFileKey, + file_key: FileKey, cell_ids: tuple[CellId_t, ...], codes: tuple[str, ...], ) -> LoroDoc: @@ -80,7 +76,7 @@ async def create_doc( # when the client connects for the first time. return doc - async def get_or_create_doc(self, file_key: MarimoFileKey) -> LoroDoc: + async def get_or_create_doc(self, file_key: FileKey) -> LoroDoc: """Get or create a loro doc for a file key.""" from loro import LoroDoc @@ -102,7 +98,7 @@ async def get_or_create_doc(self, file_key: MarimoFileKey) -> LoroDoc: return doc def add_client_to_doc( - self, file_key: MarimoFileKey, update_queue: asyncio.Queue[bytes] + self, file_key: FileKey, update_queue: asyncio.Queue[bytes] ) -> None: """Add a client queue to the loro doc clients.""" if file_key not in self.loro_docs_clients: @@ -112,7 +108,7 @@ def add_client_to_doc( async def broadcast_update( self, - file_key: MarimoFileKey, + file_key: FileKey, message: bytes, exclude_queue: asyncio.Queue[bytes] | None = None, ) -> None: @@ -125,7 +121,7 @@ async def broadcast_update( async def remove_client( self, - file_key: MarimoFileKey, + file_key: FileKey, update_queue: asyncio.Queue[bytes], ) -> None: """Clean up a loro client and potentially the doc if no clients remain.""" @@ -151,7 +147,7 @@ async def remove_client( self._clean_loro_doc(file_key, 60.0) ) - async def _do_remove_doc(self, file_key: MarimoFileKey) -> None: + async def _do_remove_doc(self, file_key: FileKey) -> None: """Actual implementation of removing a doc, separate from remove_doc to avoid deadlocks.""" if file_key in self.loro_docs: del self.loro_docs[file_key] @@ -160,7 +156,7 @@ async def _do_remove_doc(self, file_key: MarimoFileKey) -> None: if file_key in self.loro_docs_cleaners: del self.loro_docs_cleaners[file_key] - async def remove_doc(self, file_key: MarimoFileKey) -> None: + async def remove_doc(self, file_key: FileKey) -> None: """Remove a loro doc and all associated clients""" async with self.loro_docs_lock: await self._do_remove_doc(file_key) diff --git a/marimo/_server/session_manager.py b/marimo/_server/session_manager.py index 3c115d659da..7c516073a93 100644 --- a/marimo/_server/session_manager.py +++ b/marimo/_server/session_manager.py @@ -27,10 +27,12 @@ from marimo._server.token_manager import TokenManager from marimo._server.tokens import AuthToken, SkewProtectionToken from marimo._server.workspace import ( - NEW_FILE, - MarimoFileKey, + FileKey, + NewFileKey, NotebookWorkspace, + PathFileKey, flatten_files, + serialize_file_key, ) from marimo._session.app_host import AppHostContext, AppHostPool from marimo._session.consumer import SessionConsumer @@ -159,11 +161,11 @@ def sessions(self) -> Mapping[SessionId, Session]: """Get all sessions as a dict.""" return self._repository.sessions - def app_manager(self, key: MarimoFileKey) -> AppFileManager: + def app_manager(self, key: FileKey) -> AppFileManager: """Get the app manager for the given key.""" defaults = AppDefaults.from_config_manager(self._config_manager) - if self.mode is SessionMode.EDIT and not key.startswith(NEW_FILE): - self.workspace.register_allowed_path(key) + if self.mode is SessionMode.EDIT and isinstance(key, PathFileKey): + self.workspace.register_allowed_path(key.path) return self.workspace.load(key, defaults) def create_session( @@ -171,7 +173,7 @@ def create_session( session_id: SessionId, session_consumer: SessionConsumer, query_params: SerializedQueryParams, - file_key: MarimoFileKey, + file_key: FileKey, auto_instantiate: bool, ) -> Session: """Create a new session.""" @@ -184,8 +186,8 @@ def create_session( # Get app file manager defaults = AppDefaults.from_config_manager(self._config_manager) - if self.mode is SessionMode.EDIT and not file_key.startswith(NEW_FILE): - self.workspace.register_allowed_path(file_key) + if self.mode is SessionMode.EDIT and isinstance(file_key, PathFileKey): + self.workspace.register_allowed_path(file_key.path) app_file_manager = self.workspace.load(file_key, defaults) # Create the session @@ -202,7 +204,7 @@ def create_session( ) session = SessionImpl.create( - initialization_id=file_key, + initialization_id=serialize_file_key(file_key), session_consumer=session_consumer, mode=self.mode, app_metadata=AppMetadata( @@ -309,14 +311,12 @@ def get_session(self, session_id: SessionId) -> Session | None: # Search for kiosk sessions by consumer ID return self._repository.get_by_consumer_id(ConsumerId(session_id)) - def get_session_by_file_key( - self, file_key: MarimoFileKey - ) -> Session | None: + def get_session_by_file_key(self, file_key: FileKey) -> Session | None: """Get a session by file key.""" return self._repository.get_by_file_key(file_key) def maybe_resume_session( - self, new_session_id: SessionId, file_key: MarimoFileKey + self, new_session_id: SessionId, file_key: FileKey ) -> Session | None: """Try to resume a session if one is resumable. @@ -349,12 +349,12 @@ def _cleanup_dead_sessions(self) -> None: if session.kernel_state() is KernelState.STOPPED: self.close_session(session_id) - def any_clients_connected(self, key: MarimoFileKey) -> bool: + def any_clients_connected(self, key: FileKey) -> bool: """Returns True if at least one client has an open socket.""" - if key.startswith(NEW_FILE): + if isinstance(key, NewFileKey): return False - sessions_for_file = self._repository.get_by_file_path(key) + sessions_for_file = self._repository.get_by_file_path(key.path) return any( session.connection_state() == ConnectionState.OPEN for session in sessions_for_file diff --git a/marimo/_server/workspace/__init__.py b/marimo/_server/workspace/__init__.py index 68dc8c90ce6..daff5395889 100644 --- a/marimo/_server/workspace/__init__.py +++ b/marimo/_server/workspace/__init__.py @@ -14,8 +14,6 @@ from marimo import _loggers from marimo._server.workspace._base import ( - NEW_FILE, - MarimoFileKey, NotebookWorkspace, count_files, flatten_files, @@ -23,6 +21,14 @@ from marimo._server.workspace._directory import DirectoryWorkspace from marimo._server.workspace._empty import EmptyWorkspace from marimo._server.workspace._fixed import FixedFilesWorkspace +from marimo._server.workspace._keys import ( + NEW_FILE_WIRE, + FileKey, + NewFileKey, + PathFileKey, + parse_file_key, + serialize_file_key, +) from marimo._server.workspace._single import SingleFileWorkspace from marimo._utils.http import HTTPException, HTTPStatus from marimo._utils.marimo_path import MarimoPath @@ -45,14 +51,18 @@ def infer_workspace(path: str) -> NotebookWorkspace: __all__ = [ - "NEW_FILE", + "NEW_FILE_WIRE", "DirectoryWorkspace", "EmptyWorkspace", + "FileKey", "FixedFilesWorkspace", - "MarimoFileKey", + "NewFileKey", "NotebookWorkspace", + "PathFileKey", "SingleFileWorkspace", "count_files", "flatten_files", "infer_workspace", + "parse_file_key", + "serialize_file_key", ] diff --git a/marimo/_server/workspace/_base.py b/marimo/_server/workspace/_base.py index 1afb8cd3540..15b1f3c3910 100644 --- a/marimo/_server/workspace/_base.py +++ b/marimo/_server/workspace/_base.py @@ -8,6 +8,12 @@ from typing import TYPE_CHECKING from marimo._server.app_defaults import AppDefaults +from marimo._server.workspace._keys import ( + FileKey, + NewFileKey, + PathFileKey, + serialize_file_key, +) from marimo._session.notebook import AppFileManager from marimo._utils.http import HTTPException, HTTPStatus from marimo._utils.marimo_path import MarimoPath @@ -19,14 +25,6 @@ from marimo._server.models.files import FileInfo from marimo._server.models.home import MarimoFile -# Wire-format key for an untitled notebook. The string boundary is preserved -# for HTTP query params and session initialization IDs. -NEW_FILE: str = "__new__" - -# Some unique identifier for a file. Phase 3 will replace this with a tagged -# union (NewFileKey | PathFileKey). -MarimoFileKey = str - class NotebookWorkspace(abc.ABC): """A server-side abstraction for the set of notebooks a server is hosting. @@ -50,11 +48,11 @@ def single_file(self) -> MarimoFile | None: """If this workspace represents a single notebook, return it.""" @abc.abstractmethod - def get_unique_file_key(self) -> MarimoFileKey | None: + def get_unique_file_key(self) -> FileKey | None: """The unique file key for this workspace, if any.""" @abc.abstractmethod - def resolve(self, key: MarimoFileKey) -> str | None: + def resolve(self, key: FileKey) -> str | None: """Resolve a key to an absolute path; ``None`` for new files. Useful for endpoints that need file-backed resources (e.g. thumbnails) @@ -63,7 +61,7 @@ def resolve(self, key: MarimoFileKey) -> str | None: def load( self, - key: MarimoFileKey, + key: FileKey, defaults: AppDefaults | None = None, ) -> AppFileManager: """Load the notebook for the given key into an ``AppFileManager``. @@ -127,11 +125,11 @@ def set_include_markdown(self, include_markdown: bool) -> None: del include_markdown -def file_not_found(key: MarimoFileKey) -> HTTPException: +def file_not_found(key: FileKey) -> HTTPException: """Build the standard 404 response for an unresolvable file key.""" return HTTPException( status_code=HTTPStatus.NOT_FOUND, - detail=f"File {key} not found", + detail=f"File {serialize_file_key(key)} not found", ) @@ -165,3 +163,16 @@ def flatten_files(files: list[FileInfo]) -> Iterator[FileInfo]: stack.extend(file.children) else: yield file + + +__all__ = [ + "FileKey", + "NewFileKey", + "NotebookWorkspace", + "PathFileKey", + "count_files", + "file_not_found", + "flatten_files", + "normalize_allowlist_entry", + "serialize_file_key", +] diff --git a/marimo/_server/workspace/_directory.py b/marimo/_server/workspace/_directory.py index b77b5ada508..b5fb9339f8b 100644 --- a/marimo/_server/workspace/_directory.py +++ b/marimo/_server/workspace/_directory.py @@ -10,11 +10,13 @@ from marimo._server.files.directory_scanner import DirectoryScanner from marimo._server.files.path_validator import PathValidator from marimo._server.workspace._base import ( - NEW_FILE, - MarimoFileKey, NotebookWorkspace, file_not_found, ) +from marimo._server.workspace._keys import ( + FileKey, + NewFileKey, +) from marimo._utils.http import HTTPException, HTTPStatus if TYPE_CHECKING: @@ -69,15 +71,15 @@ def is_in_allowed_temp_dir(self, path: str) -> bool: def single_file(self) -> MarimoFile | None: return None - def get_unique_file_key(self) -> MarimoFileKey | None: + def get_unique_file_key(self) -> FileKey | None: return None - def resolve(self, key: MarimoFileKey) -> str | None: - if key.startswith(NEW_FILE): + def resolve(self, key: FileKey) -> str | None: + if isinstance(key, NewFileKey): return None directory = Path(self._directory) - filepath = Path(key) + filepath = Path(key.path) # Resolve relative paths against the workspace directory. if not filepath.is_absolute(): diff --git a/marimo/_server/workspace/_empty.py b/marimo/_server/workspace/_empty.py index 2391fef6ebb..b9c93ec3485 100644 --- a/marimo/_server/workspace/_empty.py +++ b/marimo/_server/workspace/_empty.py @@ -8,11 +8,13 @@ from typing import TYPE_CHECKING from marimo._server.workspace._base import ( - NEW_FILE, - MarimoFileKey, NotebookWorkspace, file_not_found, ) +from marimo._server.workspace._keys import ( + FileKey, + NewFileKey, +) from marimo._utils.paths import normalize_path if TYPE_CHECKING: @@ -29,8 +31,8 @@ class EmptyWorkspace(NotebookWorkspace): continue to work. """ - def get_unique_file_key(self) -> MarimoFileKey | None: - return NEW_FILE + def get_unique_file_key(self) -> FileKey | None: + return NewFileKey() def single_file(self) -> MarimoFile | None: return None @@ -39,12 +41,12 @@ def single_file(self) -> MarimoFile | None: def files(self) -> list[FileInfo]: return [] - def resolve(self, key: MarimoFileKey) -> str | None: - if key.startswith(NEW_FILE): + def resolve(self, key: FileKey) -> str | None: + if isinstance(key, NewFileKey): return None - if os.path.exists(key): + if os.path.exists(key.path): # Match sibling workspaces: return an absolute normalized path so # downstream comparisons (e.g. session lookups) don't trip on # relative-vs-absolute mismatches. - return str(normalize_path(Path(key))) + return str(normalize_path(Path(key.path))) raise file_not_found(key) diff --git a/marimo/_server/workspace/_fixed.py b/marimo/_server/workspace/_fixed.py index aab60798a4c..8e0bf8d2db0 100644 --- a/marimo/_server/workspace/_fixed.py +++ b/marimo/_server/workspace/_fixed.py @@ -8,12 +8,14 @@ from marimo._server.models.files import FileInfo from marimo._server.workspace._base import ( - NEW_FILE, - MarimoFileKey, NotebookWorkspace, file_not_found, normalize_allowlist_entry, ) +from marimo._server.workspace._keys import ( + FileKey, + NewFileKey, +) from marimo._utils.paths import normalize_path if TYPE_CHECKING: @@ -59,14 +61,14 @@ def files(self) -> list[FileInfo]: def single_file(self) -> MarimoFile | None: return None - def get_unique_file_key(self) -> MarimoFileKey | None: + def get_unique_file_key(self) -> FileKey | None: return None - def resolve(self, key: MarimoFileKey) -> str | None: - if key.startswith(NEW_FILE): + def resolve(self, key: FileKey) -> str | None: + if isinstance(key, NewFileKey): raise file_not_found(key) - filepath = Path(key) + filepath = Path(key.path) if not filepath.is_absolute() and self._directory: filepath = Path(self._directory) / filepath normalized_path = normalize_path(filepath) diff --git a/marimo/_server/workspace/_keys.py b/marimo/_server/workspace/_keys.py new file mode 100644 index 00000000000..f617580691f --- /dev/null +++ b/marimo/_server/workspace/_keys.py @@ -0,0 +1,56 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Tagged-union ADT for notebook file keys. + +The wire format is a string (HTTP query params, session ``initialization_id``) +but inside the server we model the two cases explicitly: + +- :class:`NewFileKey` — the untitled (``__new__``) notebook +- :class:`PathFileKey` — a notebook identified by a filesystem path +""" + +from __future__ import annotations + +from dataclasses import dataclass + +# Wire-format sentinel for an untitled notebook. Preserved across HTTP query +# params and session initialization IDs. +NEW_FILE_WIRE: str = "__new__" + + +@dataclass(frozen=True) +class NewFileKey: + """Sentinel key for an untitled notebook.""" + + +@dataclass(frozen=True) +class PathFileKey: + """Key for a notebook identified by a filesystem path. + + The ``path`` is the raw value supplied at the boundary; workspaces are + responsible for normalizing and validating it. + """ + + path: str + + +FileKey = NewFileKey | PathFileKey + + +def parse_file_key(raw: str) -> FileKey: + """Parse a wire-format string into a :class:`FileKey`. + + The literal ``__new__`` becomes a :class:`NewFileKey`; any other string is + treated as a path. We match the sentinel exactly so non-sentinel keys + flow through normal path validation rather than being short-circuited + to the blank notebook. + """ + if raw == NEW_FILE_WIRE: + return NewFileKey() + return PathFileKey(raw) + + +def serialize_file_key(key: FileKey) -> str: + """Serialize a :class:`FileKey` back to its wire-format string.""" + if isinstance(key, NewFileKey): + return NEW_FILE_WIRE + return key.path diff --git a/marimo/_server/workspace/_single.py b/marimo/_server/workspace/_single.py index ff9325dcb28..e719827acf3 100644 --- a/marimo/_server/workspace/_single.py +++ b/marimo/_server/workspace/_single.py @@ -8,12 +8,15 @@ from marimo._server.models.home import MarimoFile from marimo._server.workspace._base import ( - NEW_FILE, - MarimoFileKey, NotebookWorkspace, file_not_found, normalize_allowlist_entry, ) +from marimo._server.workspace._keys import ( + FileKey, + NewFileKey, + PathFileKey, +) from marimo._utils.marimo_path import MarimoPath from marimo._utils.paths import normalize_path @@ -62,14 +65,14 @@ def files(self) -> list[FileInfo]: def single_file(self) -> MarimoFile | None: return self._file - def get_unique_file_key(self) -> MarimoFileKey | None: - return self._file.path + def get_unique_file_key(self) -> FileKey | None: + return PathFileKey(self._file.path) - def resolve(self, key: MarimoFileKey) -> str | None: - if key.startswith(NEW_FILE): + def resolve(self, key: FileKey) -> str | None: + if isinstance(key, NewFileKey): return None - filepath = Path(key) + filepath = Path(key.path) normalized_path = normalize_path(filepath) absolute_path = str(normalized_path) if absolute_path not in self._allowed_paths: diff --git a/marimo/_session/session_repository.py b/marimo/_session/session_repository.py index 65a3565a948..cc3f4720663 100644 --- a/marimo/_session/session_repository.py +++ b/marimo/_session/session_repository.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from collections.abc import Mapping - from marimo._server.workspace import MarimoFileKey + from marimo._server.workspace import FileKey class SessionRepository: @@ -52,14 +52,27 @@ def get_by_consumer_id(self, consumer_id: ConsumerId) -> Session | None: return session return None - def get_by_file_key(self, file_key: MarimoFileKey) -> Session | None: + def get_by_file_key(self, file_key: FileKey) -> Session | None: """Get a session by file key.""" import os + from marimo._server.workspace import ( + NewFileKey, + serialize_file_key, + ) + + serialized = serialize_file_key(file_key) + abs_path = ( + None + if isinstance(file_key, NewFileKey) + else os.path.abspath(file_key.path) + ) for session in self._sessions.values(): + if session.initialization_id == serialized: + return session if ( - session.initialization_id == file_key - or session.app_file_manager.path == os.path.abspath(file_key) + abs_path is not None + and session.app_file_manager.path == abs_path ): return session return None diff --git a/tests/_server/api/endpoints/test_assets.py b/tests/_server/api/endpoints/test_assets.py index 065263fa12e..87da771b462 100644 --- a/tests/_server/api/endpoints/test_assets.py +++ b/tests/_server/api/endpoints/test_assets.py @@ -19,6 +19,7 @@ EmptyWorkspace, FixedFilesWorkspace, SingleFileWorkspace, + serialize_file_key, ) from marimo._session.model import SessionMode from marimo._utils.marimo_path import MarimoPath @@ -44,10 +45,11 @@ def test_index(client: TestClient) -> None: response = client.get("/", headers=token_header()) assert response.status_code == 200, response.text content = response.text - filename = session_manager.workspace.get_unique_file_key() + file_key = session_manager.workspace.get_unique_file_key() + assert file_key is not None + filename = serialize_file_key(file_key) title = parse_title(filename) assert f"" in content - assert filename is not None assert filename in content assert '"mode": "edit"' in content assert f"{title}" in content @@ -405,10 +407,11 @@ def test_public_file_serving(client: TestClient) -> None: app_state = AppState.from_app(cast(Any, client.app)) file_key = app_state.session_manager.workspace.get_unique_file_key() assert file_key is not None - assert file_key.endswith(".py") + filepath = serialize_file_key(file_key) + assert filepath.endswith(".py") # Create a test file in a public directory - notebook_dir = Path(file_key).parent + notebook_dir = Path(filepath).parent public_dir = notebook_dir / "public" public_dir.mkdir(parents=True, exist_ok=True) test_file = public_dir / "test.txt" @@ -419,7 +422,7 @@ def test_public_file_serving(client: TestClient) -> None: assert response.status_code == 404 # Test with notebook ID header - headers = {**token_header(), "X-Notebook-Id": file_key} + headers = {**token_header(), "X-Notebook-Id": filepath} response = client.get("/public/test.txt", headers=headers) assert response.status_code == 200 assert response.text == "test content" @@ -445,10 +448,11 @@ def test_public_file_security(client: TestClient) -> None: app_state = AppState.from_app(cast(Any, client.app)) file_key = app_state.session_manager.workspace.get_unique_file_key() assert file_key is not None - assert file_key.endswith(".py") + filepath = serialize_file_key(file_key) + assert filepath.endswith(".py") # Setup notebook and directories - notebook_dir = Path(file_key).parent + notebook_dir = Path(filepath).parent public_dir = notebook_dir / "public" secret_dir = notebook_dir / "secret" public_dir.mkdir(parents=True, exist_ok=True) @@ -467,7 +471,7 @@ def test_public_file_security(client: TestClient) -> None: app_manager = app_state.session_manager.app_manager(file_key) app_manager.filename = str(notebook_dir / "notebook.py") - headers = {**token_header(), "X-Notebook-Id": file_key} + headers = {**token_header(), "X-Notebook-Id": filepath} # Test normal file access response = client.get("/public/safe.txt", headers=headers) @@ -640,7 +644,7 @@ def test_index_lsp_workspace_with_root_directory( client, DirectoryWorkspace(str(temp_project_dir), include_markdown=False), ): - response = client.get("/?file=__new__file.py", headers=token_header()) + response = client.get("/?file=__new__", headers=token_header()) root_path = temp_project_dir root_uri = json.dumps(root_path.as_uri()) document_path = root_path.joinpath(DEFAULT_NOTEBOOK_NAME) @@ -660,7 +664,7 @@ def test_index_lsp_workspace_with_sub_directory( with workspace_scope( client, DirectoryWorkspace(str(subdir), include_markdown=False) ): - response = client.get("/?file=__new__file.py", headers=token_header()) + response = client.get("/?file=__new__", headers=token_header()) root_path = temp_project_dir root_uri = json.dumps(root_path.as_uri()) document_path = subdir.joinpath(DEFAULT_NOTEBOOK_NAME) diff --git a/tests/_server/api/endpoints/test_file_explorer.py b/tests/_server/api/endpoints/test_file_explorer.py index 8e954b0ddcc..1e8118eb1da 100644 --- a/tests/_server/api/endpoints/test_file_explorer.py +++ b/tests/_server/api/endpoints/test_file_explorer.py @@ -8,6 +8,7 @@ import pytest +from marimo._server.workspace import serialize_file_key from marimo._utils.platform import is_windows from tests._server.mocks import get_session_manager, token_header @@ -122,9 +123,9 @@ def test_update_file_with_session(client: TestClient) -> None: # Enable watch mode (file watcher is set up automatically) sm.watch = True - file_path = sm.workspace.get_unique_file_key() - assert file_path - file_path = Path(file_path) + file_key = sm.workspace.get_unique_file_key() + assert file_key + file_path = Path(serialize_file_key(file_key)) assert file_path.exists() # Create a session by connecting via websocket diff --git a/tests/_server/api/endpoints/test_files.py b/tests/_server/api/endpoints/test_files.py index ade1c754dac..0cef47297a2 100644 --- a/tests/_server/api/endpoints/test_files.py +++ b/tests/_server/api/endpoints/test_files.py @@ -12,6 +12,7 @@ import msgspec import pytest +from marimo._server.workspace import serialize_file_key from marimo._utils.platform import is_windows from tests._server.mocks import ( get_session_manager, @@ -35,12 +36,12 @@ @with_session(SESSION_ID) def test_rename(client: TestClient) -> None: - current_filename = get_session_manager( + current_file_key = get_session_manager( client ).workspace.get_unique_file_key() - assert current_filename - current_path = Path(current_filename) + assert current_file_key + current_path = Path(serialize_file_key(current_file_key)) assert current_path.exists() directory = current_path.parent @@ -103,8 +104,9 @@ def test_read_code_in_run_mode_without_include_code( @pytest.mark.flaky(reruns=5) @with_session(SESSION_ID) def test_save_file(client: TestClient) -> None: - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) response = client.post( @@ -157,8 +159,9 @@ def _assert_contents(): ) @with_session(SESSION_ID) def test_save_with_header(client: TestClient) -> None: - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) assert path.exists() @@ -208,8 +211,9 @@ def _assert_contents(): @pytest.mark.flaky(reruns=5) @with_session(SESSION_ID) def test_save_with_invalid_file(client: TestClient) -> None: - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) assert path.exists() @@ -286,8 +290,9 @@ def test_save_file_cannot_rename(client: TestClient) -> None: @pytest.mark.flaky(reruns=5) @with_session(SESSION_ID) def test_save_app_config(client: TestClient) -> None: - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) def _wait_for_file_reset(): @@ -315,8 +320,9 @@ def _assert_contents(): @with_session(SESSION_ID) def test_copy_file(client: TestClient) -> None: - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) assert path.exists() file_contents = path.read_text() @@ -347,8 +353,9 @@ def _assert_contents(): @with_session(SESSION_ID) def test_copy_file_with_relative_paths(client: TestClient) -> None: """Test that copy works with relative paths when workspace has a directory.""" - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) assert path.exists() file_contents = path.read_text() @@ -393,8 +400,9 @@ def _assert_contents(): @with_session(SESSION_ID) def test_path_traversal_save_blocked(client: TestClient) -> None: """Save endpoint must not write outside the workspace's directory.""" - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) directory = str(path.parent) workspace = get_session_manager(client).workspace @@ -423,8 +431,9 @@ def test_path_traversal_save_blocked(client: TestClient) -> None: @with_session(SESSION_ID) def test_path_traversal_rename_blocked(client: TestClient) -> None: """Rename endpoint must not move files outside the workspace's directory.""" - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) directory = str(path.parent) workspace = get_session_manager(client).workspace @@ -446,8 +455,9 @@ def test_path_traversal_rename_blocked(client: TestClient) -> None: @with_session(SESSION_ID) def test_path_traversal_copy_blocked(client: TestClient) -> None: """Copy endpoint must not write outside the workspace's directory.""" - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) directory = str(path.parent) workspace = get_session_manager(client).workspace @@ -473,11 +483,12 @@ def test_path_traversal_copy_blocked(client: TestClient) -> None: def test_rename_propagates( client: TestClient, websocket: WebSocketTestSession ) -> None: - current_filename = get_session_manager( + current_file_key = get_session_manager( client ).workspace.get_unique_file_key() - assert current_filename + assert current_file_key + current_filename = serialize_file_key(current_file_key) assert os.path.exists(current_filename) initial_response = client.post( @@ -558,8 +569,9 @@ def test_read_code_without_saved_file(client: TestClient) -> None: @with_session(SESSION_ID) def test_save_with_unicode_content(client: TestClient) -> None: """Test save endpoint with unicode and special characters.""" - filename = get_session_manager(client).workspace.get_unique_file_key() - assert filename + file_key = get_session_manager(client).workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) path = Path(filename) unicode_code = """# Unicode test: 你好世界 🌍 ñáéíóú diff --git a/tests/_server/api/endpoints/test_home.py b/tests/_server/api/endpoints/test_home.py index 9533ebb99f8..12064dff2c6 100644 --- a/tests/_server/api/endpoints/test_home.py +++ b/tests/_server/api/endpoints/test_home.py @@ -15,6 +15,8 @@ from marimo._server.workspace import ( DirectoryWorkspace, FixedFilesWorkspace, + PathFileKey, + serialize_file_key, ) from marimo._session.model import SessionMode from tests._server.mocks import get_session_manager, token_header, with_session @@ -31,10 +33,10 @@ @with_session(SESSION_ID) def test_workspace_files(client: TestClient) -> None: - current_filename = get_session_manager( + current_file_key = get_session_manager( client ).workspace.get_unique_file_key() - assert current_filename + assert current_file_key response = client.post( "/api/home/workspace_files", @@ -44,7 +46,7 @@ def test_workspace_files(client: TestClient) -> None: body = response.json() files = body["files"] assert len(files) == 1 - assert files[0]["path"] == current_filename + assert files[0]["path"] == serialize_file_key(current_file_key) # Check that new fields are present assert "hasMore" in body assert "fileCount" in body @@ -65,10 +67,10 @@ def test_workspace_files_no_files(client: TestClient) -> None: @with_session(SESSION_ID) def test_running_notebooks(client: TestClient) -> None: - current_filename = get_session_manager( + current_file_key = get_session_manager( client ).workspace.get_unique_file_key() - assert current_filename + assert current_file_key response = client.post( "/api/home/running_notebooks", @@ -77,7 +79,7 @@ def test_running_notebooks(client: TestClient) -> None: body = response.json() files = body["files"] assert len(files) == 1 - assert files[0]["path"] == current_filename + assert files[0]["path"] == serialize_file_key(current_file_key) # TODO: Debug on Windows @@ -379,7 +381,7 @@ def test_tutorial_file_accessible_after_open(client: TestClient) -> None: # Try to get a file manager for the tutorial file # This should not raise an HTTPException about being outside the directory - file_manager = session_manager.app_manager(tutorial_path) + file_manager = session_manager.app_manager(PathFileKey(tutorial_path)) assert file_manager is not None assert file_manager.path == tutorial_path diff --git a/tests/_server/api/endpoints/test_resume_session.py b/tests/_server/api/endpoints/test_resume_session.py index 92005530006..82c40addc06 100644 --- a/tests/_server/api/endpoints/test_resume_session.py +++ b/tests/_server/api/endpoints/test_resume_session.py @@ -11,6 +11,7 @@ CellNotification, KernelReadyNotification, ) +from marimo._server.workspace import serialize_file_key from marimo._session import Session from marimo._types.ids import SessionId from marimo._utils.parse_dataclass import parse_raw @@ -263,8 +264,9 @@ def test_resume_session_after_file_change(client: TestClient) -> None: # Write to the notebook file to add a new cell # we write it as the second to last cell - filename = session_manager.workspace.get_unique_file_key() - assert filename + file_key = session_manager.workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) with open(filename) as f: content = f.read() last_cell_pos = content.rindex("@app.cell") diff --git a/tests/_server/api/endpoints/test_ws.py b/tests/_server/api/endpoints/test_ws.py index 245ce376a99..e450dc1f12f 100644 --- a/tests/_server/api/endpoints/test_ws.py +++ b/tests/_server/api/endpoints/test_ws.py @@ -18,6 +18,7 @@ ) from marimo._server.codes import WebSocketCodes from marimo._server.session_manager import SessionManager +from marimo._server.workspace import serialize_file_key from marimo._session.model import ConnectionState, SessionMode from marimo._utils.parse_dataclass import parse_raw from tests._server.api.endpoints.ws_helpers import ( @@ -163,8 +164,9 @@ async def test_file_watcher_calls_reload(client: TestClient) -> None: with client.websocket_connect(WS_URL) as websocket: data = websocket.receive_json() assert_kernel_ready_response(data) - filename = session_manager.workspace.get_unique_file_key() - assert filename + file_key = session_manager.workspace.get_unique_file_key() + assert file_key + filename = serialize_file_key(file_key) with open(filename, "a") as f: # noqa: ASYNC230 f.write("\n# test") f.close() diff --git a/tests/_server/api/endpoints/test_ws_rtc.py b/tests/_server/api/endpoints/test_ws_rtc.py index 4d4cc720255..7a290b5a52d 100644 --- a/tests/_server/api/endpoints/test_ws_rtc.py +++ b/tests/_server/api/endpoints/test_ws_rtc.py @@ -9,6 +9,7 @@ from marimo._config.manager import UserConfigManager from marimo._server.api.endpoints.ws_endpoint import DOC_MANAGER +from marimo._server.workspace import serialize_file_key from tests._server.api.endpoints.ws_helpers import ( assert_kernel_ready_response, assert_parse_ready_response, @@ -240,7 +241,9 @@ async def test_ws_sync_without_existing_session(client: TestClient) -> None: file_key = get_session_manager(client).workspace.get_unique_file_key() assert file_key is not None - ws_sync_url = f"/ws_sync?file={file_key}&access_token=fake-token" + ws_sync_url = ( + f"/ws_sync?file={serialize_file_key(file_key)}&access_token=fake-token" + ) # Try to connect to ws_sync without creating a main session first with pytest.raises(WebSocketDisconnect) as exc_info: @@ -265,7 +268,9 @@ async def test_ws_sync_cleanup_on_main_disconnect(client: TestClient) -> None: DOC_MANAGER.loro_docs_clients.clear() ws_1 = "/ws?session_id=123&access_token=fake-token" - ws_sync_url = f"/ws_sync?file={file_key}&access_token=fake-token" + ws_sync_url = ( + f"/ws_sync?file={serialize_file_key(file_key)}&access_token=fake-token" + ) with rtc_enabled(get_user_config_manager(client)): with client.websocket_connect(ws_1) as main_websocket: diff --git a/tests/_server/rtc/test_rtc_doc.py b/tests/_server/rtc/test_rtc_doc.py index f0920a881b1..b3cc736dbc6 100644 --- a/tests/_server/rtc/test_rtc_doc.py +++ b/tests/_server/rtc/test_rtc_doc.py @@ -7,7 +7,7 @@ import pytest from marimo._server.rtc.doc import LoroDocManager -from marimo._server.workspace import MarimoFileKey +from marimo._server.workspace import FileKey, PathFileKey from marimo._types.ids import CellId_t if sys.version_info >= (3, 11) and sys.version_info < (3, 14): @@ -41,7 +41,7 @@ async def test_quick_reconnection(setup_doc_manager: None) -> None: """Test that quick reconnection properly handles cleanup task cancellation""" del setup_doc_manager # Setup - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") # Create initial loro_doc doc = LoroDoc() @@ -78,7 +78,7 @@ async def test_quick_reconnection(setup_doc_manager: None) -> None: async def test_two_users_sync(setup_doc_manager: None) -> None: """Test that two users can connect and sync text properly without duplicates""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") cell_id = str(CellId_t("test_cell")) # Convert CellId to string for loro # First user connects @@ -126,7 +126,7 @@ async def test_two_users_sync(setup_doc_manager: None) -> None: async def test_concurrent_doc_creation(setup_doc_manager: None) -> None: """Test concurrent doc creation doesn't cause issues""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") cell_ids = (CellId_t("cell1"), CellId_t("cell2")) codes = ("print('hello')", "print('world')") @@ -149,7 +149,7 @@ async def test_concurrent_client_operations( ) -> None: """Test concurrent client operations don't cause deadlocks""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") doc = LoroDoc() doc_manager.loro_docs[file_key] = doc @@ -176,7 +176,7 @@ async def client_operation(queue: asyncio.Queue[bytes]) -> None: async def test_cleanup_task_management(setup_doc_manager: None) -> None: """Test cleanup task management and cancellation""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") doc = LoroDoc() doc_manager.loro_docs[file_key] = doc @@ -210,7 +210,7 @@ async def test_cleanup_task_management(setup_doc_manager: None) -> None: async def test_broadcast_update(setup_doc_manager: None) -> None: """Test broadcast update functionality""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") doc = LoroDoc() doc_manager.loro_docs[file_key] = doc @@ -238,7 +238,7 @@ async def test_broadcast_update(setup_doc_manager: None) -> None: async def test_remove_nonexistent_doc(setup_doc_manager: None) -> None: """Test removing a doc that doesn't exist""" del setup_doc_manager - file_key = MarimoFileKey("nonexistent") + file_key = PathFileKey("nonexistent") await doc_manager.remove_doc(file_key) assert file_key not in doc_manager.loro_docs assert file_key not in doc_manager.loro_docs_clients @@ -251,7 +251,7 @@ async def test_remove_nonexistent_doc(setup_doc_manager: None) -> None: async def test_remove_nonexistent_client(setup_doc_manager: None) -> None: """Test removing a client that doesn't exist""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") queue = asyncio.Queue[bytes]() await doc_manager.remove_client(file_key, queue) assert file_key not in doc_manager.loro_docs_clients @@ -263,7 +263,7 @@ async def test_remove_nonexistent_client(setup_doc_manager: None) -> None: async def test_concurrent_doc_removal(setup_doc_manager: None) -> None: """Test concurrent doc removal doesn't cause issues""" del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") doc = LoroDoc() doc_manager.loro_docs[file_key] = doc @@ -291,7 +291,7 @@ async def test_prevent_lock_deadlock(setup_doc_manager: None) -> None: The fixed implementation should handle this without deadlocking. """ del setup_doc_manager - file_key = MarimoFileKey("test_file") + file_key = PathFileKey("test_file") # Create a doc and add a client doc = LoroDoc() @@ -332,7 +332,7 @@ async def long_lock_operation() -> None: original_clean_loro_doc = doc_manager._clean_loro_doc async def test_clean_loro_doc( - file_key: MarimoFileKey, timeout: float = original_timeout + file_key: FileKey, timeout: float = original_timeout ) -> None: del timeout # Override timeout with our test value diff --git a/tests/_server/test_file_manager_absolute_path.py b/tests/_server/test_file_manager_absolute_path.py index 141bc7d7acb..a1ca51b6833 100644 --- a/tests/_server/test_file_manager_absolute_path.py +++ b/tests/_server/test_file_manager_absolute_path.py @@ -9,7 +9,7 @@ import pytest -from marimo._server.workspace import DirectoryWorkspace +from marimo._server.workspace import DirectoryWorkspace, PathFileKey from marimo._utils.http import HTTPException, HTTPStatus is_windows = sys.platform == "win32" @@ -58,7 +58,7 @@ def __(): assert file_info.path == "notebook.py" # Try to get a file manager using the relative path from files list - file_manager = router.load(file_info.path) + file_manager = router.load(PathFileKey(file_info.path)) assert file_manager is not None # File manager resolves to absolute path assert file_manager.filename == str(test_file) @@ -108,7 +108,7 @@ def __(): assert file_info.path == "notebook.py" # Try to get a file manager using the relative file path - file_manager = router.load(file_info.path) + file_manager = router.load(PathFileKey(file_info.path)) assert file_manager is not None assert file_manager.is_notebook_named # File manager resolves to absolute path @@ -151,7 +151,9 @@ def __(): # Get file manager with absolute path absolute_file_path = absolute_files[0].path - absolute_file_manager = absolute_router.load(absolute_file_path) + absolute_file_manager = absolute_router.load( + PathFileKey(absolute_file_path) + ) # Change to the parent directory original_cwd = os.getcwd() @@ -168,7 +170,9 @@ def __(): # Get file manager with relative path relative_file_path = relative_files[0].path - relative_file_manager = relative_router.load(relative_file_path) + relative_file_manager = relative_router.load( + PathFileKey(relative_file_path) + ) # Both should reference the same file assert ( @@ -229,7 +233,7 @@ def __(): assert os.path.exists(full_path) # Try to get a file manager - file_manager = router.load(full_path) + file_manager = router.load(PathFileKey(full_path)) assert file_manager is not None assert file_manager.is_notebook_named assert file_manager.read_file() is not None @@ -285,7 +289,7 @@ def __(): # Try to get a file manager file_path = files_after[0].path - file_manager = router.load(file_path) + file_manager = router.load(PathFileKey(file_path)) assert file_manager is not None assert file_manager.is_notebook_named content = file_manager.read_file() @@ -336,7 +340,7 @@ def __(): # the router's directory relative_filename = "notebook.py" - file_manager = router.load(relative_filename) + file_manager = router.load(PathFileKey(relative_filename)) assert file_manager is not None assert file_manager.is_notebook_named # Verify it opened the correct file @@ -400,7 +404,7 @@ def __(): file_info = files[0] # This should work - using the full absolute path - file_manager = router.load(file_info.path) + file_manager = router.load(PathFileKey(file_info.path)) assert file_manager is not None assert file_manager.is_notebook_named @@ -411,7 +415,7 @@ def __(): # This currently fails but should succeed # The router should resolve the basename relative to its directory try: - file_manager_from_basename = router.load(basename) + file_manager_from_basename = router.load(PathFileKey(basename)) assert file_manager_from_basename is not None assert file_manager_from_basename.is_notebook_named except HTTPException: @@ -470,7 +474,7 @@ def __(): relative_path = os.path.join("subdir", "notebook.py") try: - file_manager = router.load(relative_path) + file_manager = router.load(PathFileKey(relative_path)) assert file_manager is not None assert file_manager.is_notebook_named except HTTPException: @@ -512,7 +516,7 @@ def test_files_list_shows_correct_paths_for_absolute_dir( # Verify each file can be opened using its relative path for file_info in files: - file_manager = router.load(file_info.path) + file_manager = router.load(PathFileKey(file_info.path)) assert file_manager is not None assert file_manager.is_notebook_named @@ -539,12 +543,12 @@ def test_absolute_path_outside_directory_denied( router = DirectoryWorkspace(absolute_dir, include_markdown=False) # Should be able to open file within the directory - file_manager = router.load(str(test_file.absolute())) + file_manager = router.load(PathFileKey(str(test_file.absolute()))) assert file_manager is not None # Should NOT be able to open file outside the directory with pytest.raises(HTTPException) as exc_info: - router.load(str(other_file.absolute())) + router.load(PathFileKey(str(other_file.absolute()))) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN assert "Access denied" in exc_info.value.detail @@ -570,7 +574,7 @@ def test_absolute_path_with_symlink_attack_denied( # Try to access the secret file using absolute path with pytest.raises(HTTPException) as exc_info: - router.load(str(secret_file.absolute())) + router.load(PathFileKey(str(secret_file.absolute()))) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN @@ -609,7 +613,7 @@ def test_nested_directory_returns_relative_paths( # Both should be openable using their relative paths for path in _collect_file_paths(files): - file_manager = router.load(path) + file_manager = router.load(PathFileKey(path)) assert file_manager is not None assert file_manager.is_notebook_named @@ -642,7 +646,7 @@ def test_deep_nesting_returns_correct_paths(self, tmp_path: Path) -> None: # Should be openable using the actual path from the files list actual_paths = _collect_file_paths(files) - file_manager = router.load(actual_paths[0]) + file_manager = router.load(PathFileKey(actual_paths[0])) assert file_manager is not None assert file_manager.filename == str(deep_file.absolute()) @@ -667,12 +671,12 @@ def test_both_relative_and_absolute_paths_work( assert relative_path == "notebook.py" # Open with relative path - manager_from_relative = router.load(relative_path) + manager_from_relative = router.load(PathFileKey(relative_path)) assert manager_from_relative is not None # Open with absolute path absolute_path = str(test_file.absolute()) - manager_from_absolute = router.load(absolute_path) + manager_from_absolute = router.load(PathFileKey(absolute_path)) assert manager_from_absolute is not None # Both should point to the same file @@ -695,7 +699,7 @@ def test_path_traversal_in_relative_path_blocked( # Try to access file outside using path traversal with pytest.raises(HTTPException) as exc_info: - router.load("../outside.py") + router.load(PathFileKey("../outside.py")) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN @@ -720,7 +724,7 @@ def test_path_traversal_multiple_levels_blocked( for attempt in traversal_attempts: with pytest.raises(HTTPException) as exc_info: - router.load(str(attempt)) + router.load(PathFileKey(str(attempt))) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN, ( f"Path traversal '{attempt}' should be blocked" ) @@ -743,7 +747,7 @@ def test_file_manager_with_dot_path(self, tmp_path: Path) -> None: ) # Try to open with ./ prefix - file_manager = router.load("./notebook.py") + file_manager = router.load(PathFileKey("./notebook.py")) assert file_manager is not None assert file_manager.filename == str(test_file.absolute()) @@ -763,7 +767,7 @@ def test_file_manager_with_redundant_slashes(self, tmp_path: Path) -> None: ) # Path normalization should handle redundant slashes - file_manager = router.load("subdir//notebook.py") + file_manager = router.load(PathFileKey("subdir//notebook.py")) assert file_manager is not None assert file_manager.filename == str(test_file.absolute()) @@ -777,7 +781,7 @@ def test_nonexistent_file_returns_not_found(self, tmp_path: Path) -> None: ) with pytest.raises(HTTPException) as exc_info: - router.load("nonexistent.py") + router.load(PathFileKey("nonexistent.py")) assert exc_info.value.status_code == HTTPStatus.NOT_FOUND @@ -827,7 +831,7 @@ def test_files_accessible_after_cwd_change(self, tmp_path: Path) -> None: os.chdir(other_dir) # Should still be able to open file using relative path - file_manager = router.load(relative_path) + file_manager = router.load(PathFileKey(relative_path)) assert file_manager is not None assert file_manager.filename == str(test_file.absolute()) diff --git a/tests/_server/test_session_manager.py b/tests/_server/test_session_manager.py index ce7f34b0ffd..c78bf633857 100644 --- a/tests/_server/test_session_manager.py +++ b/tests/_server/test_session_manager.py @@ -12,7 +12,12 @@ from marimo._server.session.listeners import RecentsTrackerListener from marimo._server.session_manager import SessionManager from marimo._server.tokens import AuthToken, SkewProtectionToken -from marimo._server.workspace import NEW_FILE, EmptyWorkspace, infer_workspace +from marimo._server.workspace import ( + EmptyWorkspace, + NewFileKey, + PathFileKey, + infer_workspace, +) from marimo._session import ( KernelManager, Session, @@ -100,7 +105,7 @@ async def test_create_session_new( session_id, mock_session_consumer, query_params={}, - file_key=NEW_FILE, + file_key=NewFileKey(), auto_instantiate=False, ) assert session_id in session_manager.sessions @@ -118,7 +123,7 @@ async def test_create_session_absolute_url( session_id, mock_session_consumer, query_params={}, - file_key=temp_marimo_file, + file_key=PathFileKey(temp_marimo_file), auto_instantiate=False, ) assert session_id in session_manager.sessions @@ -137,20 +142,20 @@ def test_maybe_resume_session_for_new_file( # Resume the same session_id with a new file -> doesn't match resumed_session = session_manager.maybe_resume_session( - session_id, NEW_FILE + session_id, NewFileKey() ) assert resumed_session is None # Resume the same session_id with a different file -> doesn't match # This is technically a bad state and should be unreachable resumed_session = session_manager.maybe_resume_session( - session_id, "different_file.py" + session_id, PathFileKey("different_file.py") ) assert resumed_session is None # Resume with a different session_id -> doesn't match resumed_session = session_manager.maybe_resume_session( - "different_session_id", NEW_FILE + "different_session_id", NewFileKey() ) assert resumed_session is None @@ -166,20 +171,20 @@ def test_maybe_resume_session_for_existing_file( # Resume the same session_id with the same file -> matches resumed_session = session_manager.maybe_resume_session( - session_id, temp_marimo_file + session_id, PathFileKey(temp_marimo_file) ) assert resumed_session is mock_session # Resume the same session_id with a different file -> doesn't match # This is technically a bad state and should be unreachable resumed_session = session_manager.maybe_resume_session( - session_id, "different_file.py" + session_id, PathFileKey("different_file.py") ) assert resumed_session is None # Resume with a different session_id -> matches resumed_session = session_manager.maybe_resume_session( - "different_session_id", temp_marimo_file + "different_session_id", PathFileKey(temp_marimo_file) ) assert resumed_session is mock_session @@ -199,8 +204,11 @@ def test_any_clients_connected_new_file( ) -> None: add_session(session_manager, session_id, mock_session) mock_session.app_file_manager = AppFileManager(filename=None) - assert session_manager.any_clients_connected(NEW_FILE) is False - assert session_manager.any_clients_connected("different_file.py") is False + assert session_manager.any_clients_connected(NewFileKey()) is False + assert ( + session_manager.any_clients_connected(PathFileKey("different_file.py")) + is False + ) def test_any_clients_connected_existing_file( @@ -210,9 +218,15 @@ def test_any_clients_connected_existing_file( ) -> None: add_session(session_manager, session_id, mock_session) mock_session.app_file_manager = AppFileManager(filename=temp_marimo_file) - assert session_manager.any_clients_connected(NEW_FILE) is False - assert session_manager.any_clients_connected(temp_marimo_file) is True - assert session_manager.any_clients_connected("different_file.py") is False + assert session_manager.any_clients_connected(NewFileKey()) is False + assert ( + session_manager.any_clients_connected(PathFileKey(temp_marimo_file)) + is True + ) + assert ( + session_manager.any_clients_connected(PathFileKey("different_file.py")) + is False + ) def test_close_all_sessions( @@ -258,7 +272,7 @@ async def test_create_session_with_script_config_overrides( session_id, mock_session_consumer, query_params={}, - file_key=str(tmp_path / "test.py"), + file_key=PathFileKey(str(tmp_path / "test.py")), auto_instantiate=False, ) assert session_id in session_manager.sessions @@ -465,7 +479,7 @@ def mock_touch(filename: str) -> None: SessionId("recents_test_session"), mock_session_consumer, query_params={}, - file_key=str(tmp_file), + file_key=PathFileKey(str(tmp_file)), auto_instantiate=False, ) diff --git a/tests/_server/test_sessions.py b/tests/_server/test_sessions.py index 8d0694ecdaa..08d689f9450 100644 --- a/tests/_server/test_sessions.py +++ b/tests/_server/test_sessions.py @@ -37,7 +37,7 @@ ) from marimo._server.session_manager import SessionManager from marimo._server.utils import initialize_asyncio -from marimo._server.workspace import SingleFileWorkspace +from marimo._server.workspace import PathFileKey, SingleFileWorkspace from marimo._session import Session from marimo._session.consumer import SessionConsumer from marimo._session.events import SessionEventBus @@ -453,7 +453,7 @@ def __(): 1 """ ) - file_key = str(tmp_file) + file_key = PathFileKey(str(tmp_file)) try: # Create a session manager with file watching enabled @@ -702,7 +702,7 @@ def __(): session_id=session_id, session_consumer=session_consumer, query_params={}, - file_key=str(tmp_file), + file_key=PathFileKey(str(tmp_file)), auto_instantiate=False, ) mock_session_view = MagicMock(spec=SessionView) @@ -811,7 +811,7 @@ def __(): session_id=session_id, session_consumer=session_consumer, query_params={}, - file_key=str(tmp_file), + file_key=PathFileKey(str(tmp_file)), auto_instantiate=False, ) @@ -907,7 +907,7 @@ def __(): session_id=session_id, session_consumer=session_consumer, query_params={}, - file_key=str(tmp_path1), + file_key=PathFileKey(str(tmp_path1)), auto_instantiate=False, ) diff --git a/tests/_server/test_workspace.py b/tests/_server/test_workspace.py index c4b10629ea9..5d4be31b356 100644 --- a/tests/_server/test_workspace.py +++ b/tests/_server/test_workspace.py @@ -13,14 +13,18 @@ from marimo._server.models.files import FileInfo from marimo._server.models.home import MarimoFile from marimo._server.workspace import ( - NEW_FILE, + NEW_FILE_WIRE, DirectoryWorkspace, EmptyWorkspace, FixedFilesWorkspace, + NewFileKey, + PathFileKey, SingleFileWorkspace, count_files, flatten_files, infer_workspace, + parse_file_key, + serialize_file_key, ) from marimo._utils.http import HTTPException, HTTPStatus from marimo._utils.marimo_path import MarimoPath @@ -96,7 +100,7 @@ def test_fixed_files_workspace(self): def test_empty_workspace(self): workspace = EmptyWorkspace() assert workspace.single_file() is None - assert workspace.get_unique_file_key() == NEW_FILE + assert workspace.get_unique_file_key() == NewFileKey() def test_directory_workspace_lists_files(self): workspace = DirectoryWorkspace(self.test_dir, include_markdown=False) @@ -124,26 +128,26 @@ def test_directory_workspace_set_include_markdown_mutates_in_place(self): def test_fixed_files_restricts_access(self): workspace = FixedFilesWorkspace([_marimo_file(self.test_file1.name)]) with pytest.raises(HTTPException) as exc: - workspace.load(self.test_file2.name) + workspace.load(PathFileKey(self.test_file2.name)) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_fixed_files_register_allowed_path_is_noop(self): workspace = FixedFilesWorkspace([_marimo_file(self.test_file1.name)]) workspace.register_allowed_path(self.test_file2.name) with pytest.raises(HTTPException) as exc: - workspace.load(self.test_file2.name) + workspace.load(PathFileKey(self.test_file2.name)) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_single_file_register_allowed_path_grows_allowlist(self): workspace = SingleFileWorkspace(_marimo_file(self.test_file1.name)) workspace.register_allowed_path(self.test_file2.name) - manager = workspace.load(self.test_file2.name) + manager = workspace.load(PathFileKey(self.test_file2.name)) assert manager is not None def test_fixed_files_disallows_new_file_key(self): workspace = FixedFilesWorkspace([_marimo_file(self.test_file1.name)]) with pytest.raises(HTTPException) as exc: - workspace.load(NEW_FILE) + workspace.load(NewFileKey()) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_fixed_files_resolve_supports_relative_key(self): @@ -154,20 +158,20 @@ def test_fixed_files_resolve_supports_relative_key(self): relative_key = str( Path(self.test_file1.name).relative_to(self.test_dir) ) - resolved = workspace.resolve(relative_key) + resolved = workspace.resolve(PathFileKey(relative_key)) assert resolved == self.test_file1.name def test_directory_workspace_load(self): workspace = DirectoryWorkspace(self.test_dir, include_markdown=False) filename = self.test_file1.name assert os.path.exists(filename), f"File {filename} does not exist" - file_manager = workspace.load(key=filename) + file_manager = workspace.load(key=PathFileKey(filename)) assert file_manager.filename == filename def test_directory_workspace_load_nested(self): workspace = DirectoryWorkspace(self.test_dir, include_markdown=False) nested_filename = self.nested_file.name - file_manager = workspace.load(key=nested_filename) + file_manager = workspace.load(key=PathFileKey(nested_filename)) assert file_manager.filename == self.nested_file.name assert file_manager.filename is not None assert os.path.exists(file_manager.filename) @@ -211,7 +215,7 @@ def test_single_file_blocks_unrelated_absolute_path(self): MarimoPath(self.test_file1.name) ) with pytest.raises(HTTPException) as exc: - workspace.resolve(self.test_file2.name) + workspace.resolve(PathFileKey(self.test_file2.name)) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_single_file_blocks_path_traversal(self): @@ -224,7 +228,7 @@ def test_single_file_blocks_path_traversal(self): self.nested_dir, "..", "..", "..", "etc", "passwd" ) with pytest.raises(HTTPException) as exc: - workspace.resolve(traversal) + workspace.resolve(PathFileKey(traversal)) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_single_file_accepts_dotdot_that_normalizes_to_allowed(self): @@ -236,7 +240,7 @@ def test_single_file_accepts_dotdot_that_normalizes_to_allowed(self): equivalent = os.path.join( self.nested_dir, "..", os.path.basename(self.test_file1.name) ) - resolved = workspace.resolve(equivalent) + resolved = workspace.resolve(PathFileKey(equivalent)) assert resolved == self.test_file1.name # ----- security: FixedFilesWorkspace ----- @@ -248,7 +252,7 @@ def test_fixed_files_blocks_path_traversal(self): directory=self.test_dir, ) with pytest.raises(HTTPException) as exc: - workspace.resolve("../../../etc/passwd") + workspace.resolve(PathFileKey("../../../etc/passwd")) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_fixed_files_blocks_absolute_path_outside_allowlist(self): @@ -260,7 +264,7 @@ def test_fixed_files_blocks_absolute_path_outside_allowlist(self): # test_file2 exists on disk inside the same directory but is not on the # allowlist — must still be denied. with pytest.raises(HTTPException) as exc: - workspace.resolve(self.test_file2.name) + workspace.resolve(PathFileKey(self.test_file2.name)) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_fixed_files_accepts_dotdot_that_normalizes_to_allowed(self): @@ -269,7 +273,7 @@ def test_fixed_files_accepts_dotdot_that_normalizes_to_allowed(self): self.nested_dir, "..", os.path.basename(self.test_file1.name) ) workspace = FixedFilesWorkspace([_marimo_file(dotdot_path)]) - resolved = workspace.resolve(dotdot_path) + resolved = workspace.resolve(PathFileKey(dotdot_path)) assert resolved is not None assert os.path.exists(resolved) @@ -277,25 +281,34 @@ def test_fixed_files_new_file_key_is_rejected(self): """``__new__`` keys are not valid in run mode.""" workspace = FixedFilesWorkspace([]) with pytest.raises(HTTPException) as exc: - workspace.resolve(NEW_FILE) + workspace.resolve(NewFileKey()) assert exc.value.status_code == HTTPStatus.NOT_FOUND # ----- security: DirectoryWorkspace ----- - def test_directory_workspace_new_file_prefix_does_not_leak(self): - """``__new__`` prefix returns None — does not bypass containment. + def test_directory_workspace_new_file_key_returns_none(self): + """``NewFileKey`` resolves to ``None`` (a blank notebook).""" + workspace = DirectoryWorkspace(self.test_dir, include_markdown=False) + assert workspace.resolve(NewFileKey()) is None + + def test_directory_workspace_only_exact_sentinel_is_new(self): + """Only the exact ``__new__`` sentinel resolves to the blank notebook. - ``startswith(NEW_FILE)`` could otherwise be coaxed into accepting - crafted keys like ``__new__../etc/passwd``; verify those still resolve - to ``None`` (a blank notebook), never an arbitrary file. + After the FileKey ADT migration, anything other than the literal + ``__new__`` parses to a :class:`PathFileKey` and goes through normal + containment validation. """ workspace = DirectoryWorkspace(self.test_dir, include_markdown=False) - for key in ( - NEW_FILE, - f"{NEW_FILE}../etc/passwd", - f"{NEW_FILE}/etc/passwd", + for raw in ( + "__new__suffix.py", + "__new__/nested.py", ): - assert workspace.resolve(key) is None + with pytest.raises(HTTPException) as exc: + workspace.resolve(PathFileKey(raw)) + assert exc.value.status_code in ( + HTTPStatus.FORBIDDEN, + HTTPStatus.NOT_FOUND, + ) def test_directory_workspace_temp_dir_does_not_enable_traversal(self): """A temp-dir bypass must not let attackers escape via ``..``.""" @@ -309,7 +322,7 @@ def test_directory_workspace_temp_dir_does_not_enable_traversal(self): # Path that traverses through the temp dir to escape entirely. traversal = os.path.join(temp_dir, "..", "..", "etc", "passwd") with pytest.raises(HTTPException) as exc: - workspace.resolve(traversal) + workspace.resolve(PathFileKey(traversal)) # The path normalizes to /etc/passwd (or equivalent), neither in # the temp dir nor in the workspace directory — must be denied. assert exc.value.status_code in ( @@ -322,14 +335,14 @@ def test_directory_workspace_temp_dir_does_not_enable_traversal(self): def test_empty_workspace_load_with_new_file_key_returns_blank(self): """Bare ``__new__`` key always yields an unbacked manager.""" workspace = EmptyWorkspace() - manager = workspace.load(NEW_FILE) + manager = workspace.load(NewFileKey()) assert manager.filename is None def test_empty_workspace_load_with_nonexistent_path_404s(self): """Concrete keys that don't exist on disk must 404.""" workspace = EmptyWorkspace() with pytest.raises(HTTPException) as exc: - workspace.load("/this/path/does/not/exist.py") + workspace.load(PathFileKey("/this/path/does/not/exist.py")) assert exc.value.status_code == HTTPStatus.NOT_FOUND def test_empty_workspace_resolve_returns_absolute_normalized(self): @@ -342,7 +355,7 @@ def test_empty_workspace_resolve_returns_absolute_normalized(self): os.chdir(self.test_dir) relative_key = os.path.basename(self.test_file1.name) workspace = EmptyWorkspace() - resolved = workspace.resolve(relative_key) + resolved = workspace.resolve(PathFileKey(relative_key)) assert resolved is not None assert os.path.isabs(resolved) assert os.path.samefile(resolved, self.test_file1.name) @@ -359,10 +372,33 @@ def test_empty_workspace_load_with_existing_path_falls_back(self): process owner. """ workspace = EmptyWorkspace() - manager = workspace.load(self.test_file1.name) + manager = workspace.load(PathFileKey(self.test_file1.name)) assert manager.filename == self.test_file1.name +def test_file_key_roundtrip() -> None: + """parse(serialize(k)) == k for both variants.""" + new_key = NewFileKey() + assert parse_file_key(serialize_file_key(new_key)) == new_key + + path_key = PathFileKey("/tmp/notebook.py") + assert parse_file_key(serialize_file_key(path_key)) == path_key + + +def test_file_key_wire_format() -> None: + """The serialized form is the wire string preserved across HTTP/WS.""" + assert serialize_file_key(NewFileKey()) == NEW_FILE_WIRE + assert serialize_file_key(PathFileKey("/foo.py")) == "/foo.py" + + +def test_parse_file_key_only_exact_sentinel_is_new() -> None: + """Strings that look like ``__new__`` prefixes are paths, not new files.""" + assert parse_file_key(NEW_FILE_WIRE) == NewFileKey() + assert parse_file_key("__new__suffix.py") == PathFileKey( + "__new__suffix.py" + ) + + def test_flatten_files() -> None: root = FileInfo( id="root", @@ -818,7 +854,7 @@ def test_lazy_router_allows_temp_dir_files(tmp_path: Path): # Without registering temp dir, accessing temp file should fail with pytest.raises(HTTPException) as exc_info: - router.load(str(temp_file)) + router.load(PathFileKey(str(temp_file))) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN assert "outside the allowed directory" in exc_info.value.detail @@ -826,7 +862,7 @@ def test_lazy_router_allows_temp_dir_files(tmp_path: Path): router.register_temp_dir(str(temp_dir)) # Now accessing the temp file should succeed - manager = router.load(str(temp_file)) + manager = router.load(PathFileKey(str(temp_file))) assert manager is not None assert manager.path == str(temp_file) @@ -858,13 +894,13 @@ def test_lazy_router_temp_dir_doesnt_affect_normal_files( router.register_temp_dir(str(other_temp_dir)) # Base file should still be accessible - manager = router.load(str(base_file)) + manager = router.load(PathFileKey(str(base_file))) assert manager is not None assert manager.path == str(base_file) # Outside file should still be blocked (not in registered temp dir) with pytest.raises(HTTPException) as exc_info: - router.load(str(outside_file)) + router.load(PathFileKey(str(outside_file))) assert exc_info.value.status_code == HTTPStatus.FORBIDDEN @@ -896,5 +932,5 @@ def test_lazy_router_symlink_directory_outside_allowed(tmp_path: Path): # Symlinks are preserved (not resolved), so the path # base_dir/shared/outside.py is inside base_dir router = DirectoryWorkspace(str(base_dir), include_markdown=False) - manager = router.load(str(file_through_symlink)) + manager = router.load(PathFileKey(str(file_through_symlink))) assert manager is not None diff --git a/tests/_session/test_resume_strategies.py b/tests/_session/test_resume_strategies.py index 92202596185..5e2de2b1346 100644 --- a/tests/_session/test_resume_strategies.py +++ b/tests/_session/test_resume_strategies.py @@ -13,6 +13,7 @@ EditModeResumeStrategy, RunModeResumeStrategy, ) +from marimo._server.workspace import PathFileKey from marimo._session.model import ConnectionState from marimo._session.session_repository import SessionRepository from marimo._types.ids import SessionId @@ -32,7 +33,9 @@ def test_edit_mode_resume_no_existing_sessions() -> None: repository = SessionRepository() strategy = EditModeResumeStrategy(repository) - result = strategy.try_resume(SessionId("new-session"), "test.py") + result = strategy.try_resume( + SessionId("new-session"), PathFileKey("test.py") + ) assert result is None @@ -48,7 +51,7 @@ def test_edit_mode_resume_orphaned_session() -> None: # Try to resume new_session_id = SessionId("new-session") - result = strategy.try_resume(new_session_id, "test.py") + result = strategy.try_resume(new_session_id, PathFileKey("test.py")) # Should return the orphaned session assert result is orphaned_session @@ -70,7 +73,7 @@ def test_edit_mode_resume_open_session_not_resumed() -> None: # Try to resume new_session_id = SessionId("new-session") - result = strategy.try_resume(new_session_id, "test.py") + result = strategy.try_resume(new_session_id, PathFileKey("test.py")) # Should return None because session is not orphaned assert result is None @@ -93,7 +96,7 @@ def test_edit_mode_resume_different_file() -> None: # Try to resume for different file new_session_id = SessionId("new-session") - result = strategy.try_resume(new_session_id, "test.py") + result = strategy.try_resume(new_session_id, PathFileKey("test.py")) # Should return None because file doesn't match assert result is None @@ -115,7 +118,7 @@ def test_edit_mode_resume_multiple_sessions_raises_error() -> None: InvalidSessionException, match="Only one session should exist while editing", ): - strategy.try_resume(SessionId("new-session"), "test.py") + strategy.try_resume(SessionId("new-session"), PathFileKey("test.py")) def test_run_mode_resume_no_existing_session() -> None: @@ -123,7 +126,9 @@ def test_run_mode_resume_no_existing_session() -> None: repository = SessionRepository() strategy = RunModeResumeStrategy(repository) - result = strategy.try_resume(SessionId("non-existent"), "test.py") + result = strategy.try_resume( + SessionId("non-existent"), PathFileKey("test.py") + ) assert result is None @@ -138,7 +143,7 @@ def test_run_mode_resume_orphaned_session() -> None: repository.add_sync(session_id, orphaned_session) # Try to resume with same ID - result = strategy.try_resume(session_id, "test.py") + result = strategy.try_resume(session_id, PathFileKey("test.py")) # Should return the orphaned session assert result is orphaned_session @@ -155,7 +160,7 @@ def test_run_mode_resume_open_session_not_resumed() -> None: repository.add_sync(session_id, open_session) # Try to resume - result = strategy.try_resume(session_id, "test.py") + result = strategy.try_resume(session_id, PathFileKey("test.py")) # Should return None because session is not orphaned assert result is None @@ -175,8 +180,8 @@ def test_run_mode_allows_multiple_sessions_same_file() -> None: repository.add_sync(session2_id, session2) # Should be able to resume each by their ID - result1 = strategy.try_resume(session1_id, "test.py") - result2 = strategy.try_resume(session2_id, "test.py") + result1 = strategy.try_resume(session1_id, PathFileKey("test.py")) + result2 = strategy.try_resume(session2_id, PathFileKey("test.py")) assert result1 is session1 assert result2 is session2