Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion tests/unit/test_agent_session_messages.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from __future__ import annotations

import contextlib
import uuid
from collections.abc import AsyncIterator
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, Mock

import orjson
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from tracecat.agent.executor.loopback import (
LoopbackHandler,
LoopbackInput,
_session_line_db_content,
_session_line_from_json,
_session_line_jsonb_safe_content,
)
from tracecat.agent.session.service import AgentSessionService
from tracecat.auth.types import Role
from tracecat.chat.enums import MessageKind
from tracecat.db.models import AgentSession
from tracecat.db.models import AgentSession, AgentSessionHistory


def _mock_scalar_result(items: list[Any]) -> Mock:
Expand Down Expand Up @@ -46,6 +57,94 @@ def _build_service() -> tuple[AgentSessionService, AgentSession]:
return service, agent_session


def test_session_line_db_content_sanitizes_nul_only_for_retry() -> None:
session_line = (
r'{"type":"user","uuid":"line-uuid","bad\u0000key":"v",'
r'"message":{"role":"user",'
r'"content":"hello\u0000world"},"toolUseResult":{"stdout":"a\u0000b"}}'
)

line = _session_line_from_json(session_line)

content = _session_line_db_content(line)
assert content["bad\x00key"] == "v"
assert content["message"]["content"] == "hello\x00world"
assert content["toolUseResult"]["stdout"] == "a\x00b"

safe_content = _session_line_jsonb_safe_content(line)
assert safe_content[r"bad\u0000key"] == "v"
assert safe_content["message"]["content"] == r"hello\u0000world"
assert safe_content["toolUseResult"]["stdout"] == r"a\u0000b"
assert "\x00" not in orjson.dumps(safe_content).decode("utf-8")


@pytest.mark.anyio
async def test_persist_session_line_lazily_sanitizes_jsonb_nul_failure(
session: AsyncSession,
svc_role: Role,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workspace_id = svc_role.workspace_id
assert workspace_id is not None

agent_session = AgentSession(
id=uuid.uuid4(),
workspace_id=workspace_id,
title="Chat",
created_by=None,
entity_type="case",
entity_id=uuid.uuid4(),
)
session.add(agent_session)
await session.commit()

@contextlib.asynccontextmanager
async def patched_bypass_session() -> AsyncIterator[AsyncSession]:
yield session

monkeypatch.setattr(
"tracecat.agent.executor.loopback.get_async_session_bypass_rls_context_manager",
lambda: patched_bypass_session(),
)

raw_line = (
r'{"type":"user","uuid":"line-uuid","message":{"role":"user",'
r'"content":[{"type":"text","text":"hello\u0000world"},'
r'{"type":"text","text":"literal\\u0000text"}]},'
r'"toolUseResult":{"stdout":"a\u0000b"}}'
)

handler = LoopbackHandler(
input=LoopbackInput(
session_id=agent_session.id,
workspace_id=workspace_id,
)
)
await handler._persist_session_line("sdk-session-123", raw_line)

result = await session.execute(
select(AgentSessionHistory).where(
AgentSessionHistory.session_id == agent_session.id
)
)
persisted = result.scalar_one()
message_content = persisted.content["message"]["content"]

assert isinstance(message_content, list)
assert message_content[0]["text"] == r"hello\u0000world"
assert message_content[1]["text"] == r"literal\u0000text"
assert persisted.content["toolUseResult"]["stdout"] == r"a\u0000b"
assert persisted.kind == MessageKind.CHAT_MESSAGE.value
assert handler._sdk_session_id == "sdk-session-123"
assert "line-uuid" in handler._persisted_line_uuids

session_result = await session.execute(
select(AgentSession).where(AgentSession.id == agent_session.id)
)
persisted_session = session_result.scalar_one()
assert persisted_session.sdk_session_id == "sdk-session-123"


@pytest.mark.anyio
async def test_list_messages_preserves_compaction_metadata() -> None:
service, agent_session = _build_service()
Expand Down
97 changes: 85 additions & 12 deletions tracecat/agent/executor/loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import orjson
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from tracecat.agent.channels.schemas import ChannelType
from tracecat.agent.channels.service import PENDING_SLACK_BOT_TOKEN, AgentChannelService
Expand Down Expand Up @@ -222,11 +223,53 @@ def _session_line_from_json(session_line: str) -> ClaudeSessionLine:
return cast(ClaudeSessionLine, decoded)


def _jsonb_safe_value(value: object) -> object:
match value:
case str() as text:
return text.replace("\x00", "\\u0000")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve NUL session bytes for resume

When a Claude JSONL line contains \u0000, _session_line_from_json decodes it to an actual NUL, and this fallback changes it to the six-character string \\u0000 before storing it. AgentSessionService.load_session_history later reconstructs the resume JSONL from entry.content with orjson.dumps(content), so affected sessions resume with a literal backslash-u sequence instead of the original NUL value. For tool outputs or message text containing NULs, persistence now succeeds but the stored session is no longer an exact Claude session line, which undermines the raw-resume fix this change is meant to provide.

Useful? React with 👍 / 👎.

case list() as items:
return [_jsonb_safe_value(item) for item in items]
case dict() as mapping:
return {
key.replace("\x00", "\\u0000") if isinstance(key, str) else key: (
_jsonb_safe_value(item)
)
for key, item in mapping.items()
}
case _:
return value


def _session_line_db_content(line: ClaudeSessionLine) -> dict[str, Any]:
"""Return a SQLAlchemy JSONB payload for an already validated session line."""
return cast(dict[str, Any], line)


def _session_line_jsonb_safe_content(line: ClaudeSessionLine) -> dict[str, Any]:
"""Return a JSONB-safe fallback payload for retrying rare NUL failures."""
return cast(dict[str, Any], _jsonb_safe_value(line))


def _is_jsonb_nul_error(exc: SQLAlchemyError) -> bool:
"""Return True when PostgreSQL rejected a JSONB value containing NUL text."""
error_text = str(exc).lower()
if not any(marker in error_text for marker in ("\\u0000", "\x00", "nul")):
return False
if "unsupported unicode escape sequence" in error_text:
return True
if "cannot be converted to text" in error_text:
return True

candidates: list[object | None] = [exc, getattr(exc, "orig", None)]
candidates.extend([exc.__cause__, exc.__context__])
candidates.extend(
getattr(candidate, "__cause__", None) for candidate in tuple(candidates)
)
return any(
getattr(candidate, "sqlstate", None) == "22P05" for candidate in candidates
)


class LoopbackHandler:
"""Handles socket communication with the NSJail runtime.

Expand Down Expand Up @@ -814,7 +857,7 @@ async def _persist_session_line(
session_line: Raw JSONL line from the SDK session file.
internal: If True, this is internal state not shown in UI timeline.
"""
# Parse and sanitize to prevent XSS from untrusted content (e.g., tool results)
# Parse for validation. Rare JSONB NUL failures are sanitized on retry.
line_data = _session_line_from_json(session_line)
if not internal and line_data.get("type") == "assistant":
self._received_assistant_content = True
Expand All @@ -837,10 +880,14 @@ async def _persist_session_line(
uuid=line_uuid,
)

async with get_async_session_bypass_rls_context_manager() as session:
# On first session line, update AgentSession with sdk_session_id
if self._sdk_session_id is None:
self._sdk_session_id = sdk_session_id
should_set_sdk_session_id = self._sdk_session_id is None

async def stage_history_entry(
session: AsyncSession, content: dict[str, Any]
) -> bool:
did_set_sdk_session_id = False
# On first session line, update AgentSession with sdk_session_id.
if should_set_sdk_session_id:
stmt = select(AgentSession).where(
AgentSession.id == self.input.session_id,
AgentSession.workspace_id == self.input.workspace_id,
Expand All @@ -849,11 +896,7 @@ async def _persist_session_line(
agent_session = result.scalar_one_or_none()
if agent_session and agent_session.sdk_session_id is None:
agent_session.sdk_session_id = sdk_session_id
logger.info(
"Updated AgentSession with sdk_session_id",
session_id=self.input.session_id,
sdk_session_id=sdk_session_id,
)
did_set_sdk_session_id = True

# Use explicit internal flag from runtime, not content-based heuristics
kind: SessionLineKind = "internal" if internal else "chat-message"
Expand All @@ -869,11 +912,41 @@ async def _persist_session_line(
history_entry = AgentSessionHistory(
session_id=self.input.session_id,
workspace_id=self.input.workspace_id,
content=_session_line_db_content(line_data),
content=content,
kind=kind,
)
session.add(history_entry)
await session.commit()
return did_set_sdk_session_id

async with get_async_session_bypass_rls_context_manager() as session:
did_set_sdk_session_id = await stage_history_entry(
session, _session_line_db_content(line_data)
)
try:
await session.commit()
except SQLAlchemyError as exc:
await session.rollback()
if not _is_jsonb_nul_error(exc):
raise

logger.warning(
"Retrying session line persistence with JSONB-safe content",
session_id=self.input.session_id,
sdk_session_id=sdk_session_id,
uuid=line_uuid,
)
did_set_sdk_session_id = await stage_history_entry(
session, _session_line_jsonb_safe_content(line_data)
)
await session.commit()
if should_set_sdk_session_id:
self._sdk_session_id = sdk_session_id
if did_set_sdk_session_id:
logger.info(
"Updated AgentSession with sdk_session_id",
session_id=self.input.session_id,
sdk_session_id=sdk_session_id,
)

# Track as persisted after successful commit
if isinstance(line_uuid, str):
Expand Down
Loading