Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11769.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support the TUS Checksum extension on storage proxy uploads: clients may now send `Upload-Checksum: sha256 <base64>` and the server rejects mismatched chunks with HTTP 460.
57 changes: 50 additions & 7 deletions src/ai/backend/storage/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from __future__ import annotations

import asyncio
import base64
import binascii
import logging
import os
import urllib.parse
Expand Down Expand Up @@ -47,7 +49,9 @@
from ai.backend.storage import __version__
from ai.backend.storage.dto.context import StorageRootCtx
from ai.backend.storage.errors import (
ChunkChecksumMismatchError,
InvalidAPIParameters,
InvalidUploadChecksumHeaderError,
TusSessionNotFoundError,
UploadChunkExceedsTotalSizeError,
UploadOffsetMismatchError,
Expand Down Expand Up @@ -400,6 +404,10 @@ class Params(TypedDict):
f"Upload-Offset {client_offset} is out of range [0, {total_size}]"
)

expected_checksum_hex = _parse_upload_checksum_header(
request.headers.get("Upload-Checksum")
)

async with ctx.get_volume(token_data["volume"]) as volume:
session_dir = (
volume.mangle_vfpath(token_data["vfid"]) / ".upload" / token_data["session"]
Expand All @@ -426,6 +434,10 @@ class Params(TypedDict):
f"Chunk at offset {client_offset} with length {written.length} "
f"exceeds declared size {total_size}"
)
if expected_checksum_hex is not None and expected_checksum_hex != written.sha256:
raise ChunkChecksumMismatchError(
f"Chunk at offset {client_offset} failed SHA-256 verification"
)
commit_result = await session.commit_chunk(
offset=client_offset,
chunk_path=written.path,
Expand Down Expand Up @@ -453,7 +465,9 @@ class Params(TypedDict):
return web.Response(status=HTTPStatus.NO_CONTENT, headers=headers)


_TUS_HEADER_LIST = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type"
_TUS_HEADER_LIST = (
"Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Upload-Checksum, Content-Type"
)


def _prepare_tus_session_headers(*, upload_offset: int, upload_length: int) -> dict[str, str]:
Expand All @@ -469,22 +483,51 @@ def _prepare_tus_session_headers(*, upload_offset: int, upload_length: int) -> d
}


def _parse_upload_checksum_header(raw: str | None) -> str | None:
"""
Parse a TUS Checksum extension header value into a hex SHA-256 digest.

Returns ``None`` if no header was sent. Raises
``InvalidUploadChecksumHeaderError`` for malformed values or unsupported
algorithms (only ``sha256`` is accepted).
"""
if raw is None:
return None
parts = raw.strip().split(None, 1)
if len(parts) != 2:
raise InvalidUploadChecksumHeaderError(
f"Upload-Checksum must be '<algorithm> <base64>', got {raw!r}"
)
algorithm, encoded = parts
if algorithm.lower() != "sha256":
raise InvalidUploadChecksumHeaderError(
f"Unsupported checksum algorithm: {algorithm}. Only 'sha256' is accepted."
)
try:
digest = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as e:
raise InvalidUploadChecksumHeaderError(
f"Invalid base64 in Upload-Checksum: {encoded!r}"
) from e
if len(digest) != 32:
raise InvalidUploadChecksumHeaderError(f"sha256 digest must be 32 bytes, got {len(digest)}")
return digest.hex()


async def tus_options(request: web.Request) -> web.Response:
"""
Let clients discover the supported features of our tus.io server-side implementation.
"""
ctx: RootContext = request.app["ctx"]
headers = {}
headers["Access-Control-Allow-Origin"] = "*"
headers["Access-Control-Allow-Headers"] = (
"Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type"
)
headers["Access-Control-Expose-Headers"] = (
"Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type"
)
headers["Access-Control-Allow-Headers"] = _TUS_HEADER_LIST
headers["Access-Control-Expose-Headers"] = _TUS_HEADER_LIST
headers["Access-Control-Allow-Methods"] = "*"
headers["Tus-Resumable"] = "1.0.0"
headers["Tus-Version"] = "1.0.0"
headers["Tus-Extension"] = "checksum"
headers["Tus-Checksum-Algorithm"] = "sha256"
headers["Tus-Max-Size"] = str(
int(BinarySize.from_str(ctx.local_config.storage_proxy.max_upload_size)),
)
Expand Down
6 changes: 5 additions & 1 deletion src/ai/backend/storage/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@
QuotaTreeNotFoundError,
)
from .upload import (
ChunkChecksumMismatchError,
ChunkConflictError,
InvalidUploadChecksumHeaderError,
TusSessionNotFoundError,
UploadChunkExceedsTotalSizeError,
UploadSessionCorruptedError,
Expand Down Expand Up @@ -97,10 +99,12 @@
"ServiceNotInitializedError",
"UploadOffsetMismatchError",
# upload
"ChunkChecksumMismatchError",
"ChunkConflictError",
"InvalidUploadChecksumHeaderError",
"TusSessionNotFoundError",
"UploadChunkExceedsTotalSizeError",
"UploadSessionCorruptedError",
"TusSessionNotFoundError",
# vfolder
"VFolderNotFoundError",
"InvalidSubpathError",
Expand Down
43 changes: 43 additions & 0 deletions src/ai/backend/storage/errors/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,46 @@ def error_code(self) -> ErrorCode:
operation=ErrorOperation.READ,
error_detail=ErrorDetail.INTERNAL_ERROR,
)


class _ChecksumMismatch(web.HTTPClientError):
"""
TUS Checksum extension defines HTTP 460 for checksum mismatches.
aiohttp does not ship a built-in class for this code, so we declare one.
"""

status_code = 460


class ChunkChecksumMismatchError(BackendAIError, _ChecksumMismatch):
"""
Raised when ``Upload-Checksum`` header does not match the SHA-256 digest
of the received chunk body (HTTP 460 per TUS Checksum extension).
"""

error_type = "https://api.backend.ai/probs/storage/chunk-checksum-mismatch"
error_title = "Upload Chunk Checksum Mismatch"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.UPDATE,
error_detail=ErrorDetail.MISMATCH,
)


class InvalidUploadChecksumHeaderError(BackendAIError, web.HTTPBadRequest):
"""
Raised when ``Upload-Checksum`` header is malformed or specifies an
unsupported algorithm (only ``sha256`` is accepted).
"""

error_type = "https://api.backend.ai/probs/storage/invalid-upload-checksum"
error_title = "Invalid Upload-Checksum header"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.REQUEST,
error_detail=ErrorDetail.INVALID_PARAMETERS,
)
111 changes: 111 additions & 0 deletions tests/unit/storage/api/test_tus_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from __future__ import annotations

import base64
import dataclasses
import hashlib
import secrets
from collections.abc import AsyncIterator
from pathlib import Path
Expand All @@ -29,7 +31,9 @@
from ai.backend.common.types import RedisConnectionInfo, TusSessionId, ValkeyTarget
from ai.backend.storage.api.client import tus_upload_part
from ai.backend.storage.errors import (
ChunkChecksumMismatchError,
InvalidAPIParameters,
InvalidUploadChecksumHeaderError,
UploadChunkExceedsTotalSizeError,
UploadOffsetMismatchError,
)
Expand Down Expand Up @@ -88,6 +92,7 @@ def _build_request(
offset_header: str | None,
valkey_client: ValkeyTusClient,
lock_factory: DistributedLockFactory,
checksum_header: str | None = None,
) -> MagicMock:
volume = MagicMock()
volume.mangle_vfpath.return_value = vfpath
Expand All @@ -104,6 +109,8 @@ def _build_request(
request.headers = {}
if offset_header is not None:
request.headers["Upload-Offset"] = offset_header
if checksum_header is not None:
request.headers["Upload-Checksum"] = checksum_header
request.query = {"token": "test-token"}

if body is None:
Expand Down Expand Up @@ -441,3 +448,107 @@ async def test_chunk_exceeding_declared_size_raises(
chunks_dir = patch_env.session_dir / "chunks"
if chunks_dir.exists():
assert list(chunks_dir.glob("*.tmp")) == []


def _sha256_b64(data: bytes) -> str:
return base64.b64encode(hashlib.sha256(data).digest()).decode("ascii")


class TestUploadChecksum:
async def test_matching_checksum_accepted(
self,
patch_env: _PatchEnv,
valkey_tus_client: ValkeyTusClient,
tus_lock_factory: DistributedLockFactory,
) -> None:
payload = b"X" * 1024
token_data = _token_data(
session_id=patch_env.session_id, total_size=1024, relpath="result.bin"
)
cp = _patch_handler_params(token_data)
try:
request = _build_request(
vfpath=patch_env.vfpath,
session_id=patch_env.session_id,
total_size=1024,
body=payload,
offset_header="0",
valkey_client=valkey_tus_client,
lock_factory=tus_lock_factory,
checksum_header=f"sha256 {_sha256_b64(payload)}",
)
response = await tus_upload_part(request)
finally:
cp.stop()
assert response.headers["Upload-Offset"] == "1024"
assert (patch_env.vfpath / "result.bin").read_bytes() == payload

async def test_mismatched_checksum_rejected(
self,
patch_env: _PatchEnv,
valkey_tus_client: ValkeyTusClient,
tus_lock_factory: DistributedLockFactory,
) -> None:
payload = b"X" * 1024
wrong = _sha256_b64(b"different-payload")
token_data = _token_data(
session_id=patch_env.session_id, total_size=1024, relpath="result.bin"
)
cp = _patch_handler_params(token_data)
try:
request = _build_request(
vfpath=patch_env.vfpath,
session_id=patch_env.session_id,
total_size=1024,
body=payload,
offset_header="0",
valkey_client=valkey_tus_client,
lock_factory=tus_lock_factory,
checksum_header=f"sha256 {wrong}",
)
with pytest.raises(ChunkChecksumMismatchError):
await tus_upload_part(request)
finally:
cp.stop()

# The mismatched chunk must not be committed.
assert not (patch_env.vfpath / "result.bin").exists()
chunks_dir = patch_env.session_dir / "chunks"
if chunks_dir.exists():
assert list(chunks_dir.glob("*")) == []

@pytest.mark.parametrize(
"header",
[
"sha256", # missing digest
"sha1 abc", # unsupported algorithm
"sha256 not_base64!!", # invalid base64
"sha256 " + base64.b64encode(b"too-short").decode("ascii"), # wrong length
],
)
async def test_malformed_checksum_header_rejected(
self,
patch_env: _PatchEnv,
valkey_tus_client: ValkeyTusClient,
tus_lock_factory: DistributedLockFactory,
header: str,
) -> None:
token_data = _token_data(
session_id=patch_env.session_id, total_size=1024, relpath="result.bin"
)
cp = _patch_handler_params(token_data)
try:
request = _build_request(
vfpath=patch_env.vfpath,
session_id=patch_env.session_id,
total_size=1024,
body=b"X" * 1024,
offset_header="0",
valkey_client=valkey_tus_client,
lock_factory=tus_lock_factory,
checksum_header=header,
)
with pytest.raises(InvalidUploadChecksumHeaderError):
await tus_upload_part(request)
finally:
cp.stop()
Loading