Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11766.feature.md
Comment thread
jopemachine marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce `TusUploadSession`, a concurrency-safe upload engine for resumable (TUS) uploads that keeps session metadata and a per-session lock in Valkey (via `ValkeyTusClient`) while storing chunk payloads on the shared filesystem, so multiple storage-proxy replicas can accept chunks without corruption and without relying on filesystem lock semantics.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Valkey client for TUS resumable-upload session metadata."""

from .client import ValkeyTusClient

__all__ = ["ValkeyTusClient"]
116 changes: 116 additions & 0 deletions src/ai/backend/common/clients/valkey_client/valkey_tus/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Valkey client for TUS resumable-upload sessions.

Stores per-session upload metadata (the source of truth) on the Glide-based
valkey connection so that multiple Storage Proxy replicas share one view of an
upload's progress without relying on shared-filesystem semantics. This client is
metadata-only; the per-session distributed lock that serializes the
read-modify-write window is a separate :class:`DistributedLockFactory` resource
(see :mod:`ai.backend.storage.services.upload.tus_session`). The payload bytes
stay on the shared filesystem; only this small metadata lives in Valkey.
"""

from __future__ import annotations

import logging
from typing import Final, Self

from glide import ExpirySet, ExpiryType

from ai.backend.common.clients.valkey_client.client import (
AbstractValkeyClient,
create_valkey_client,
)
from ai.backend.common.exception import BackendAIError
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience import (
BackoffStrategy,
MetricArgs,
MetricPolicy,
Resilience,
RetryArgs,
RetryPolicy,
)
from ai.backend.common.types import TusSessionId, ValkeyTarget
from ai.backend.logging import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

valkey_tus_resilience = Resilience(
policies=[
MetricPolicy(MetricArgs(domain=DomainType.VALKEY, layer=LayerType.VALKEY_TUS)),
RetryPolicy(
RetryArgs(
max_retries=3,
retry_delay=0.1,
backoff_strategy=BackoffStrategy.FIXED,
non_retryable_exceptions=(BackendAIError,),
)
),
]
)

_STATE_KEY_PREFIX: Final = "tus.upload.session" # tus.upload.session:{session_id}

# Stale (never-completed / abandoned) sessions are reclaimed by this TTL, so no
# separate GC sweep is needed.
_DEFAULT_STATE_TTL_SECONDS: Final = 24 * 60 * 60


class ValkeyTusClient:
"""Valkey-backed metadata store for TUS uploads."""

_client: AbstractValkeyClient

def __init__(self, client: AbstractValkeyClient) -> None:
self._client = client

@classmethod
async def create(
cls,
valkey_target: ValkeyTarget,
*,
db_id: int,
human_readable_name: str,
) -> Self:
client = create_valkey_client(
valkey_target=valkey_target,
db_id=db_id,
human_readable_name=human_readable_name,
)
await client.connect()
return cls(client=client)

@valkey_tus_resilience.apply()
async def close(self) -> None:
await self._client.disconnect()

@staticmethod
def _state_key(session_id: TusSessionId) -> str:
return f"{_STATE_KEY_PREFIX}:{session_id}"

@valkey_tus_resilience.apply()
async def get_session_state(self, session_id: TusSessionId) -> bytes | None:
"""Return the raw serialized session state, or ``None`` if absent."""
async with self._client.client() as conn:
return await conn.get(self._state_key(session_id))

@valkey_tus_resilience.apply()
async def set_session_state(
self,
session_id: TusSessionId,
payload: str | bytes,
*,
ttl_seconds: int = _DEFAULT_STATE_TTL_SECONDS,
) -> None:
async with self._client.client() as conn:
await conn.set(
self._state_key(session_id),
payload,
expiry=ExpirySet(ExpiryType.SEC, ttl_seconds),
)

@valkey_tus_resilience.apply()
async def delete_session_state(self, session_id: TusSessionId) -> None:
async with self._client.client() as conn:
await conn.delete([self._state_key(session_id)])
1 change: 1 addition & 0 deletions src/ai/backend/common/defs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
REDIS_STREAM_LOCK: Final = 5
REDIS_CONTAINER_LOG: Final = 6
REDIS_BGTASK_DB: Final = 7
REDIS_TUS_DB: Final = 8


class RedisRole(StrEnum):
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/common/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Mapping
from io import IOBase
from pathlib import Path
from typing import Any, ClassVar
from typing import Any, ClassVar, Protocol, runtime_checkable

import trafaret as t
from etcd_client import Client as EtcdClient
Expand Down Expand Up @@ -52,6 +52,11 @@ async def __aexit__(self, *exc_info: Any) -> bool | None:
raise NotImplementedError


@runtime_checkable
class DistributedLockFactory(Protocol):
def __call__(self, lock_id: str, lifetime_hint: float) -> AbstractDistributedLock: ...


class FileLock(AbstractDistributedLock):
default_timeout: float = 3 # not allow infinite timeout for safety

Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ class LayerType(enum.StrEnum):
VALKEY_STREAM = "valkey_stream"
VALKEY_BGTASK = "valkey_bgtask"
VALKEY_VOLUME_STATS = "valkey_volume_stats"
VALKEY_TUS = "valkey_tus"

# Client layers
AGENT_CLIENT = "agent_client"
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def check_typed_tuple(value: tuple[Any, ...], types: tuple[type, ...]) -> tuple[
ContainerPID = NewType("ContainerPID", PID)

ContainerId = NewType("ContainerId", str)
TusSessionId = NewType("TusSessionId", str)
RuleId = NewType("RuleId", UUID)
SessionId = NewType("SessionId", UUID)
KernelId = NewType("KernelId", UUID)
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/storage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from ai.backend.common.clients.valkey_client.valkey_artifact.client import (
ValkeyArtifactDownloadTrackingClient,
)
from ai.backend.common.clients.valkey_client.valkey_tus import ValkeyTusClient
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.events.dispatcher import (
EventDispatcher,
EventProducer,
)
from ai.backend.common.health_checker.probe import HealthProbe
from ai.backend.common.lock import DistributedLockFactory
from ai.backend.common.metrics.metric import CommonMetricRegistry
from ai.backend.logging import BraceStyleAdapter

Expand Down Expand Up @@ -106,6 +108,8 @@ class RootContext:
cors_options: Mapping[str, aiohttp_cors.ResourceOptions]
manager_client_pool: ManagerHTTPClientPool
valkey_artifact_client: ValkeyArtifactDownloadTrackingClient
valkey_tus_client: ValkeyTusClient
tus_lock_factory: DistributedLockFactory
health_probe: HealthProbe
volume_stats_observer: VolumeStatsObserver
volume_stats_state: VolumeState
Expand Down
7 changes: 7 additions & 0 deletions src/ai/backend/storage/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
QuotaScopeNotFoundError,
QuotaTreeNotFoundError,
)
from .upload import (
ChunkConflictError,
UploadSessionCorruptedError,
)
from .vfolder import (
InvalidSubpathError,
VFolderNotFoundError,
Expand Down Expand Up @@ -90,6 +94,9 @@
"InvalidDataLengthError",
"ServiceNotInitializedError",
"UploadOffsetMismatchError",
# upload
"ChunkConflictError",
"UploadSessionCorruptedError",
# vfolder
"VFolderNotFoundError",
"InvalidSubpathError",
Expand Down
49 changes: 49 additions & 0 deletions src/ai/backend/storage/errors/upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Upload session related exceptions.
"""

from __future__ import annotations

from aiohttp import web

from ai.backend.common.exception import (
BackendAIError,
ErrorCode,
ErrorDetail,
ErrorDomain,
ErrorOperation,
)


class ChunkConflictError(BackendAIError, web.HTTPConflict):
"""
Raised when an incoming chunk targets an offset that already holds a
different chunk in the upload session (409 Conflict).
"""

error_type = "https://api.backend.ai/probs/storage/chunk-conflict"
error_title = "Upload Chunk Conflict"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.UPDATE,
error_detail=ErrorDetail.CONFLICT,
)


class UploadSessionCorruptedError(BackendAIError, web.HTTPInternalServerError):
"""
Raised when the on-disk upload session metadata cannot be parsed or is
structurally invalid.
"""

error_type = "https://api.backend.ai/probs/storage/upload-session-corrupted"
error_title = "Upload Session Corrupted"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.READ,
error_detail=ErrorDetail.INTERNAL_ERROR,
)
2 changes: 2 additions & 0 deletions src/ai/backend/storage/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ async def check_and_upgrade(
cors_options={},
manager_client_pool=manager_client_pool,
valkey_artifact_client=None, # type: ignore[arg-type]
valkey_tus_client=None, # type: ignore[arg-type]
tus_lock_factory=None, # type: ignore[arg-type]
backends={**DEFAULT_BACKENDS},
volumes={},
health_probe=health_probe,
Expand Down
25 changes: 25 additions & 0 deletions src/ai/backend/storage/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from aiohttp.typedefs import Middleware
from setproctitle import setproctitle

from ai.backend.common import redis_helper
from ai.backend.common.bgtask.bgtask import BackgroundTaskManager
from ai.backend.common.clients.valkey_client.valkey_artifact.client import (
ValkeyArtifactDownloadTrackingClient,
)
from ai.backend.common.clients.valkey_client.valkey_bgtask.client import ValkeyBgtaskClient
from ai.backend.common.clients.valkey_client.valkey_tus import ValkeyTusClient
from ai.backend.common.clients.valkey_client.valkey_volume_stats import ValkeyVolumeStatsClient
from ai.backend.common.config import (
ConfigurationError,
Expand All @@ -42,6 +44,8 @@
REDIS_BGTASK_DB,
REDIS_STATISTICS_DB,
REDIS_STREAM_DB,
REDIS_STREAM_LOCK,
REDIS_TUS_DB,
RedisRole,
)
from ai.backend.common.etcd import AsyncEtcd
Expand Down Expand Up @@ -111,6 +115,7 @@
StorageManagerWebappPluginContext,
StoragePluginContext,
)
from .services.upload.lock import create_tus_lock_factory
from .storages.storage_pool import StoragePool
from .volumes.noop import init_noop_volume
from .volumes.pool import VolumePool
Expand Down Expand Up @@ -644,6 +649,24 @@ async def server_main(
)
storage_init_stack.push_async_callback(valkey_artifact_client.close)

# TUS upload session metadata store (Valkey) and the per-session lock
# factory, kept as separate resources. The lock backend is a dedicated
# redis-for-lock connection (redis.asyncio, required by RedisLock),
# provisioned and owned here; the factory closes over it.
valkey_tus_client = await ValkeyTusClient.create(
valkey_target=valkey_target,
db_id=REDIS_TUS_DB,
human_readable_name=f"storage-proxy-tus-uploader-{pidx}",
)
storage_init_stack.push_async_callback(valkey_tus_client.close)
tus_lock_redis = redis_helper.get_redis_object_for_lock(
redis_config.to_redis_profile_target().profile_target(RedisRole.STREAM_LOCK),
name=f"storage-proxy-tus-lock-{pidx}",
db=REDIS_STREAM_LOCK,
)
storage_init_stack.push_async_callback(tus_lock_redis.close)
tus_lock_factory = create_tus_lock_factory(tus_lock_redis)

# Initialize health probe
health_probe = HealthProbe(options=HealthProbeOptions(check_interval=60))
# Liveness-registered: also surfaced in readiness β€” connection-stuck
Expand Down Expand Up @@ -724,6 +747,8 @@ async def server_main(
},
manager_client_pool=manager_client_pool,
valkey_artifact_client=valkey_artifact_client,
valkey_tus_client=valkey_tus_client,
tus_lock_factory=tus_lock_factory,
health_probe=health_probe,
volume_stats_observer=volume_stats_observer,
volume_stats_state=volume_stats_state,
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/storage/services/upload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Chunk-based TUS upload session services.
"""
37 changes: 37 additions & 0 deletions src/ai/backend/storage/services/upload/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""TUS upload session locking β€” distributed lock factory and constants."""

from __future__ import annotations

from ai.backend.common.lock import DistributedLockFactory, RedisLock
from ai.backend.common.types import RedisConnectionInfo

# Public β€” used by the engine to build session lock keys and pass lifetime to
# the factory.
LOCK_KEY_PREFIX = "tus.upload.lock" # tus.upload.lock:{session_id}
LOCK_LIFETIME_SECONDS = 30.0 # lock auto-expires after this (crash safety)

_ACQUIRE_TIMEOUT_SECONDS = 10.0 # max wait to acquire the per-session lock
# Poll interval while waiting for the lock; the RedisLock default (1s) is far
# too coarse for the short, highly-contended per-chunk critical section.
_RETRY_INTERVAL_SECONDS = 0.05


def create_tus_lock_factory(redis: RedisConnectionInfo) -> DistributedLockFactory:
"""
Build the per-session lock factory backed by :class:`RedisLock` over ``redis``.

Mirrors the manager's ``create_lock_factory``: the caller owns ``redis``
(lifecycle/close) and the returned factory closes over it, producing a fresh
lock per ``lock_id`` so the factory can live as a standalone resource.
"""

def _factory(lock_id: str, lifetime_hint: float) -> RedisLock:
return RedisLock(
lock_id,
redis,
timeout=_ACQUIRE_TIMEOUT_SECONDS,
lifetime=lifetime_hint,
lock_retry_interval=_RETRY_INTERVAL_SECONDS,
)

return _factory
Loading
Loading