diff --git a/changes/11769.feature.md b/changes/11769.feature.md new file mode 100644 index 00000000000..f421b889b9b --- /dev/null +++ b/changes/11769.feature.md @@ -0,0 +1 @@ +Support the TUS Checksum extension on storage proxy uploads: clients may now send `Upload-Checksum: sha256 ` and the server rejects mismatched chunks with HTTP 460. diff --git a/src/ai/backend/storage/api/client.py b/src/ai/backend/storage/api/client.py index d648cbbb82a..b17b7d9bac9 100644 --- a/src/ai/backend/storage/api/client.py +++ b/src/ai/backend/storage/api/client.py @@ -5,6 +5,8 @@ from __future__ import annotations import asyncio +import base64 +import binascii import logging import os import urllib.parse @@ -47,7 +49,9 @@ from ai.backend.storage import __version__ from ai.backend.storage.dto.context import StorageRootCtx from ai.backend.storage.errors import ( + ChunkChecksumMismatchError, InvalidAPIParameters, + InvalidUploadChecksumHeaderError, TusSessionNotFoundError, UploadChunkExceedsTotalSizeError, UploadOffsetMismatchError, @@ -400,6 +404,10 @@ class Params(TypedDict): f"Upload-Offset {client_offset} is out of range [0, {total_size}]" ) + expected_checksum_hex = _parse_upload_checksum_header( + request.headers.get("Upload-Checksum") + ) + async with ctx.get_volume(token_data["volume"]) as volume: session_dir = ( volume.mangle_vfpath(token_data["vfid"]) / ".upload" / token_data["session"] @@ -426,6 +434,10 @@ class Params(TypedDict): f"Chunk at offset {client_offset} with length {written.length} " f"exceeds declared size {total_size}" ) + if expected_checksum_hex is not None and expected_checksum_hex != written.sha256: + raise ChunkChecksumMismatchError( + f"Chunk at offset {client_offset} failed SHA-256 verification" + ) commit_result = await session.commit_chunk( offset=client_offset, chunk_path=written.path, @@ -453,7 +465,9 @@ class Params(TypedDict): return web.Response(status=HTTPStatus.NO_CONTENT, headers=headers) -_TUS_HEADER_LIST = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" +_TUS_HEADER_LIST = ( + "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Upload-Checksum, Content-Type" +) def _prepare_tus_session_headers(*, upload_offset: int, upload_length: int) -> dict[str, str]: @@ -469,6 +483,37 @@ def _prepare_tus_session_headers(*, upload_offset: int, upload_length: int) -> d } +def _parse_upload_checksum_header(raw: str | None) -> str | None: + """ + Parse a TUS Checksum extension header value into a hex SHA-256 digest. + + Returns ``None`` if no header was sent. Raises + ``InvalidUploadChecksumHeaderError`` for malformed values or unsupported + algorithms (only ``sha256`` is accepted). + """ + if raw is None: + return None + parts = raw.strip().split(None, 1) + if len(parts) != 2: + raise InvalidUploadChecksumHeaderError( + f"Upload-Checksum must be ' ', got {raw!r}" + ) + algorithm, encoded = parts + if algorithm.lower() != "sha256": + raise InvalidUploadChecksumHeaderError( + f"Unsupported checksum algorithm: {algorithm}. Only 'sha256' is accepted." + ) + try: + digest = base64.b64decode(encoded, validate=True) + except (binascii.Error, ValueError) as e: + raise InvalidUploadChecksumHeaderError( + f"Invalid base64 in Upload-Checksum: {encoded!r}" + ) from e + if len(digest) != 32: + raise InvalidUploadChecksumHeaderError(f"sha256 digest must be 32 bytes, got {len(digest)}") + return digest.hex() + + async def tus_options(request: web.Request) -> web.Response: """ Let clients discover the supported features of our tus.io server-side implementation. @@ -476,15 +521,13 @@ async def tus_options(request: web.Request) -> web.Response: ctx: RootContext = request.app["ctx"] headers = {} headers["Access-Control-Allow-Origin"] = "*" - headers["Access-Control-Allow-Headers"] = ( - "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" - ) - headers["Access-Control-Expose-Headers"] = ( - "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" - ) + headers["Access-Control-Allow-Headers"] = _TUS_HEADER_LIST + headers["Access-Control-Expose-Headers"] = _TUS_HEADER_LIST headers["Access-Control-Allow-Methods"] = "*" headers["Tus-Resumable"] = "1.0.0" headers["Tus-Version"] = "1.0.0" + headers["Tus-Extension"] = "checksum" + headers["Tus-Checksum-Algorithm"] = "sha256" headers["Tus-Max-Size"] = str( int(BinarySize.from_str(ctx.local_config.storage_proxy.max_upload_size)), ) diff --git a/src/ai/backend/storage/errors/__init__.py b/src/ai/backend/storage/errors/__init__.py index 47a76c8e901..35cac077496 100644 --- a/src/ai/backend/storage/errors/__init__.py +++ b/src/ai/backend/storage/errors/__init__.py @@ -63,7 +63,9 @@ QuotaTreeNotFoundError, ) from .upload import ( + ChunkChecksumMismatchError, ChunkConflictError, + InvalidUploadChecksumHeaderError, TusSessionNotFoundError, UploadChunkExceedsTotalSizeError, UploadSessionCorruptedError, @@ -97,10 +99,12 @@ "ServiceNotInitializedError", "UploadOffsetMismatchError", # upload + "ChunkChecksumMismatchError", "ChunkConflictError", + "InvalidUploadChecksumHeaderError", + "TusSessionNotFoundError", "UploadChunkExceedsTotalSizeError", "UploadSessionCorruptedError", - "TusSessionNotFoundError", # vfolder "VFolderNotFoundError", "InvalidSubpathError", diff --git a/src/ai/backend/storage/errors/upload.py b/src/ai/backend/storage/errors/upload.py index 48c6baf1d95..48ce2bd7d7b 100644 --- a/src/ai/backend/storage/errors/upload.py +++ b/src/ai/backend/storage/errors/upload.py @@ -82,3 +82,46 @@ def error_code(self) -> ErrorCode: operation=ErrorOperation.READ, error_detail=ErrorDetail.INTERNAL_ERROR, ) + + +class _ChecksumMismatch(web.HTTPClientError): + """ + TUS Checksum extension defines HTTP 460 for checksum mismatches. + aiohttp does not ship a built-in class for this code, so we declare one. + """ + + status_code = 460 + + +class ChunkChecksumMismatchError(BackendAIError, _ChecksumMismatch): + """ + Raised when ``Upload-Checksum`` header does not match the SHA-256 digest + of the received chunk body (HTTP 460 per TUS Checksum extension). + """ + + error_type = "https://api.backend.ai/probs/storage/chunk-checksum-mismatch" + error_title = "Upload Chunk Checksum Mismatch" + + def error_code(self) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.STORAGE_PROXY, + operation=ErrorOperation.UPDATE, + error_detail=ErrorDetail.MISMATCH, + ) + + +class InvalidUploadChecksumHeaderError(BackendAIError, web.HTTPBadRequest): + """ + Raised when ``Upload-Checksum`` header is malformed or specifies an + unsupported algorithm (only ``sha256`` is accepted). + """ + + error_type = "https://api.backend.ai/probs/storage/invalid-upload-checksum" + error_title = "Invalid Upload-Checksum header" + + def error_code(self) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.STORAGE_PROXY, + operation=ErrorOperation.REQUEST, + error_detail=ErrorDetail.INVALID_PARAMETERS, + ) diff --git a/tests/unit/storage/api/test_tus_upload.py b/tests/unit/storage/api/test_tus_upload.py index 9c722b924d7..3676073c7dd 100644 --- a/tests/unit/storage/api/test_tus_upload.py +++ b/tests/unit/storage/api/test_tus_upload.py @@ -10,7 +10,9 @@ from __future__ import annotations +import base64 import dataclasses +import hashlib import secrets from collections.abc import AsyncIterator from pathlib import Path @@ -29,7 +31,9 @@ from ai.backend.common.types import RedisConnectionInfo, TusSessionId, ValkeyTarget from ai.backend.storage.api.client import tus_upload_part from ai.backend.storage.errors import ( + ChunkChecksumMismatchError, InvalidAPIParameters, + InvalidUploadChecksumHeaderError, UploadChunkExceedsTotalSizeError, UploadOffsetMismatchError, ) @@ -88,6 +92,7 @@ def _build_request( offset_header: str | None, valkey_client: ValkeyTusClient, lock_factory: DistributedLockFactory, + checksum_header: str | None = None, ) -> MagicMock: volume = MagicMock() volume.mangle_vfpath.return_value = vfpath @@ -104,6 +109,8 @@ def _build_request( request.headers = {} if offset_header is not None: request.headers["Upload-Offset"] = offset_header + if checksum_header is not None: + request.headers["Upload-Checksum"] = checksum_header request.query = {"token": "test-token"} if body is None: @@ -441,3 +448,107 @@ async def test_chunk_exceeding_declared_size_raises( chunks_dir = patch_env.session_dir / "chunks" if chunks_dir.exists(): assert list(chunks_dir.glob("*.tmp")) == [] + + +def _sha256_b64(data: bytes) -> str: + return base64.b64encode(hashlib.sha256(data).digest()).decode("ascii") + + +class TestUploadChecksum: + async def test_matching_checksum_accepted( + self, + patch_env: _PatchEnv, + valkey_tus_client: ValkeyTusClient, + tus_lock_factory: DistributedLockFactory, + ) -> None: + payload = b"X" * 1024 + token_data = _token_data( + session_id=patch_env.session_id, total_size=1024, relpath="result.bin" + ) + cp = _patch_handler_params(token_data) + try: + request = _build_request( + vfpath=patch_env.vfpath, + session_id=patch_env.session_id, + total_size=1024, + body=payload, + offset_header="0", + valkey_client=valkey_tus_client, + lock_factory=tus_lock_factory, + checksum_header=f"sha256 {_sha256_b64(payload)}", + ) + response = await tus_upload_part(request) + finally: + cp.stop() + assert response.headers["Upload-Offset"] == "1024" + assert (patch_env.vfpath / "result.bin").read_bytes() == payload + + async def test_mismatched_checksum_rejected( + self, + patch_env: _PatchEnv, + valkey_tus_client: ValkeyTusClient, + tus_lock_factory: DistributedLockFactory, + ) -> None: + payload = b"X" * 1024 + wrong = _sha256_b64(b"different-payload") + token_data = _token_data( + session_id=patch_env.session_id, total_size=1024, relpath="result.bin" + ) + cp = _patch_handler_params(token_data) + try: + request = _build_request( + vfpath=patch_env.vfpath, + session_id=patch_env.session_id, + total_size=1024, + body=payload, + offset_header="0", + valkey_client=valkey_tus_client, + lock_factory=tus_lock_factory, + checksum_header=f"sha256 {wrong}", + ) + with pytest.raises(ChunkChecksumMismatchError): + await tus_upload_part(request) + finally: + cp.stop() + + # The mismatched chunk must not be committed. + assert not (patch_env.vfpath / "result.bin").exists() + chunks_dir = patch_env.session_dir / "chunks" + if chunks_dir.exists(): + assert list(chunks_dir.glob("*")) == [] + + @pytest.mark.parametrize( + "header", + [ + "sha256", # missing digest + "sha1 abc", # unsupported algorithm + "sha256 not_base64!!", # invalid base64 + "sha256 " + base64.b64encode(b"too-short").decode("ascii"), # wrong length + ], + ) + async def test_malformed_checksum_header_rejected( + self, + patch_env: _PatchEnv, + valkey_tus_client: ValkeyTusClient, + tus_lock_factory: DistributedLockFactory, + header: str, + ) -> None: + token_data = _token_data( + session_id=patch_env.session_id, total_size=1024, relpath="result.bin" + ) + cp = _patch_handler_params(token_data) + try: + request = _build_request( + vfpath=patch_env.vfpath, + session_id=patch_env.session_id, + total_size=1024, + body=b"X" * 1024, + offset_header="0", + valkey_client=valkey_tus_client, + lock_factory=tus_lock_factory, + checksum_header=header, + ) + with pytest.raises(InvalidUploadChecksumHeaderError): + await tus_upload_part(request) + finally: + cp.stop()