From a454d4887bebc3b2b18ecdd742b0a088d29b046b Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Tue, 26 May 2026 18:18:25 -0400 Subject: [PATCH 1/5] feat(agent): agent interruptions --- .../e1a2b3c4d5f6_add_agent_session_status.py | 42 ++++ deployments/k8s | 2 +- frontend/src/client/schemas.gen.ts | 81 +++++++- frontend/src/client/services.gen.ts | 30 +++ frontend/src/client/types.gen.ts | 66 +++++++ .../src/components/chat/chat-session-pane.tsx | 109 ++++++++++- frontend/src/hooks/use-chat.ts | 37 ++++ frontend/src/lib/chat.ts | 11 ++ frontend/tests/chat-session-pane.test.tsx | 4 + frontend/tests/memo-render-count.test.tsx | 4 + .../tracecat_ee/agent/workflows/durable.py | 185 ++++++++++++++++-- tests/integration/test_agent_worker.py | 18 ++ tests/integration/test_dsl_agent_wiring.py | 14 ++ tests/temporal/test_durable_agent_workflow.py | 99 +++++++++- tests/unit/test_agent_activities.py | 3 +- tests/unit/test_agent_executor_loopback.py | 40 ++++ tests/unit/test_agent_runtime.py | 113 +++++++++++ tests/unit/test_agent_runtime_broker.py | 39 ++++ tests/unit/test_agent_sandbox_litellm.py | 5 + tests/unit/test_agent_session_messages.py | 85 ++++++++ tests/unit/test_agent_session_router.py | 64 +++++- tests/unit/test_vercel_adapter.py | 62 +++++- tests/unit/test_vercel_stream_context.py | 17 ++ tracecat/agent/adapter/vercel.py | 35 ++++ tracecat/agent/common/stream_types.py | 9 + tracecat/agent/executor/activity.py | 55 +++++- tracecat/agent/executor/loopback.py | 23 ++- tracecat/agent/runtime/claude_code/broker.py | 85 +++++--- tracecat/agent/runtime/claude_code/runtime.py | 40 ++++ tracecat/agent/session/activities.py | 35 +++- tracecat/agent/session/router.py | 40 +++- tracecat/agent/session/schemas.py | 29 ++- tracecat/agent/session/service.py | 163 +++++++++++++-- tracecat/agent/session/types.py | 14 ++ tracecat/chat/constants.py | 3 + tracecat/chat/enums.py | 1 + tracecat/chat/schemas.py | 5 + tracecat/db/models.py | 8 + 38 files changed, 1607 insertions(+), 68 deletions(-) create mode 100644 alembic/versions/e1a2b3c4d5f6_add_agent_session_status.py diff --git a/alembic/versions/e1a2b3c4d5f6_add_agent_session_status.py b/alembic/versions/e1a2b3c4d5f6_add_agent_session_status.py new file mode 100644 index 0000000000..f35165dbed --- /dev/null +++ b/alembic/versions/e1a2b3c4d5f6_add_agent_session_status.py @@ -0,0 +1,42 @@ +"""add_agent_session_status + +Revision ID: e1a2b3c4d5f6 +Revises: a3d7c9e8b4f2 +Create Date: 2026-05-26 00:00:00.000000 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e1a2b3c4d5f6" +down_revision: str | None = "a3d7c9e8b4f2" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "agent_session", + sa.Column( + "status", + sa.String(length=32), + server_default="idle", + nullable=False, + ), + ) + op.create_index( + op.f("ix_agent_session_status"), + "agent_session", + ["status"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_agent_session_status"), table_name="agent_session") + op.drop_column("agent_session", "status") diff --git a/deployments/k8s b/deployments/k8s index d19d6666ad..d94b52b411 160000 --- a/deployments/k8s +++ b/deployments/k8s @@ -1 +1 @@ -Subproject commit d19d6666adf43f51cc123862fe950ca1c1d9d03b +Subproject commit d94b52b411a61111d475662a3fc012bb96c19cf2 diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index f8550ed32a..537b92a3f5 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -1212,6 +1212,11 @@ export const $AdminUserRead = { description: "Admin view of a user.", } as const +export const $AgentCancelReason = { + type: "string", + enum: ["user_cancel", "worker_drain"], +} as const + export const $AgentCatalogListResponse = { properties: { items: { @@ -3329,6 +3334,45 @@ export const $AgentPresetVersionReadMinimal = { description: "Metadata returned when listing immutable preset versions.", } as const +export const $AgentSessionCancelRequest = { + properties: { + reason: { + $ref: "#/components/schemas/AgentCancelReason", + description: "Reason for requesting cancellation.", + default: "user_cancel", + }, + }, + type: "object", + title: "AgentSessionCancelRequest", + description: "Request schema for cancelling an active agent session turn.", +} as const + +export const $AgentSessionCancelResponse = { + properties: { + session_id: { + type: "string", + format: "uuid", + title: "Session Id", + }, + run_id: { + type: "string", + format: "uuid", + title: "Run Id", + }, + reason: { + $ref: "#/components/schemas/AgentCancelReason", + }, + turn_status: { + $ref: "#/components/schemas/AgentSessionStatus", + }, + }, + type: "object", + required: ["session_id", "run_id", "reason", "turn_status"], + title: "AgentSessionCancelResponse", + description: + "Response schema for an accepted agent session cancellation request.", +} as const + export const $AgentSessionCreate = { properties: { id: { @@ -3593,6 +3637,10 @@ export const $AgentSessionRead = { ], title: "Last Stream Id", }, + turn_status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, parent_session_id: { anyOf: [ { @@ -3755,6 +3803,10 @@ export const $AgentSessionReadVercel = { ], title: "Last Stream Id", }, + turn_status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, parent_session_id: { anyOf: [ { @@ -3925,6 +3977,10 @@ export const $AgentSessionReadWithMessages = { ], title: "Last Stream Id", }, + turn_status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, parent_session_id: { anyOf: [ { @@ -3974,6 +4030,13 @@ export const $AgentSessionReadWithMessages = { description: "Response schema for agent session with message history.", } as const +export const $AgentSessionStatus = { + type: "string", + enum: ["idle", "running", "waiting_for_approval", "stopped", "failed"], + title: "AgentSessionStatus", + description: "Lifecycle state for an agent session turn.", +} as const + export const $AgentSessionUpdate = { properties: { title: { @@ -8911,6 +8974,20 @@ export const $ChatMessage = { description: "Compaction status data for badge rendering (for kind=COMPACTION)", }, + interrupt: { + anyOf: [ + { + additionalProperties: true, + type: "object", + }, + { + type: "null", + }, + ], + title: "Interrupt", + description: + "Interruption status data for badge rendering (for kind=INTERRUPT)", + }, }, type: "object", required: ["id"], @@ -8920,7 +8997,8 @@ export const $ChatMessage = { This model supports multiple message kinds: - kind=CHAT_MESSAGE: Contains message field with user/assistant content - kind=APPROVAL_REQUEST/APPROVAL_DECISION: Contains approval field with approval data -- kind=COMPACTION: Contains compaction field with compaction status data`, +- kind=COMPACTION: Contains compaction field with compaction status data +- kind=INTERRUPT: Contains interrupt field with interruption status data`, } as const export const $ChatRead = { @@ -15249,6 +15327,7 @@ export const $MessageKind = { "approval-decision", "internal", "compaction", + "interrupt", ], title: "MessageKind", description: "The type/kind of message stored in the chat.", diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 3e57837079..4ee4f0a81c 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -161,6 +161,8 @@ import type { AgentPresetsRestoreAgentPresetVersionResponse, AgentPresetsUpdateAgentPresetData, AgentPresetsUpdateAgentPresetResponse, + AgentSessionsCancelSessionData, + AgentSessionsCancelSessionResponse, AgentSessionsCreateSessionData, AgentSessionsCreateSessionResponse, AgentSessionsDeleteSessionData, @@ -6611,6 +6613,34 @@ export const agentSessionsGetSessionVercel = ( }) } +/** + * Cancel Session + * Request graceful cancellation for the active agent session turn. + * @param data The data for the request. + * @param data.sessionId + * @param data.workspaceId + * @param data.requestBody + * @returns AgentSessionCancelResponse Successful Response + * @throws ApiError + */ +export const agentSessionsCancelSession = ( + data: AgentSessionsCancelSessionData +): CancelablePromise => { + return __request(OpenAPI, { + method: "POST", + url: "/workspaces/{workspace_id}/agent/sessions/{session_id}/cancel", + path: { + session_id: data.sessionId, + workspace_id: data.workspaceId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 422: "Validation Error", + }, + }) +} + /** * Send Message * Send a message to the agent session with streaming response. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 9a3d5cfb01..ce48d61ea2 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -319,6 +319,8 @@ export type AdminUserRead = { last_login_at?: string | null } +export type AgentCancelReason = "user_cancel" | "worker_drain" + /** * List catalog entries with pagination. */ @@ -766,6 +768,26 @@ export type AgentPresetVersionReadMinimal = { updated_at: string } +/** + * Request schema for cancelling an active agent session turn. + */ +export type AgentSessionCancelRequest = { + /** + * Reason for requesting cancellation. + */ + reason?: AgentCancelReason +} + +/** + * Response schema for an accepted agent session cancellation request. + */ +export type AgentSessionCancelResponse = { + session_id: string + run_id: string + reason: AgentCancelReason + turn_status: AgentSessionStatus +} + /** * Request schema for creating an agent session. */ @@ -858,6 +880,7 @@ export type AgentSessionRead = { agents_binding?: ResolvedAgentsConfig | null harness_type: string | null last_stream_id?: string | null + turn_status?: AgentSessionStatus parent_session_id?: string | null created_at: string updated_at: string @@ -882,6 +905,7 @@ export type AgentSessionReadVercel = { agents_binding?: ResolvedAgentsConfig | null harness_type: string | null last_stream_id?: string | null + turn_status?: AgentSessionStatus parent_session_id?: string | null created_at: string updated_at: string @@ -910,6 +934,7 @@ export type AgentSessionReadWithMessages = { agents_binding?: ResolvedAgentsConfig | null harness_type: string | null last_stream_id?: string | null + turn_status?: AgentSessionStatus parent_session_id?: string | null created_at: string updated_at: string @@ -919,6 +944,16 @@ export type AgentSessionReadWithMessages = { messages?: Array } +/** + * Lifecycle state for an agent session turn. + */ +export type AgentSessionStatus = + | "idle" + | "running" + | "waiting_for_approval" + | "stopped" + | "failed" + /** * Request schema for updating an agent session. */ @@ -2343,6 +2378,7 @@ export type ChannelType = "slack" * - kind=CHAT_MESSAGE: Contains message field with user/assistant content * - kind=APPROVAL_REQUEST/APPROVAL_DECISION: Contains approval field with approval data * - kind=COMPACTION: Contains compaction field with compaction status data + * - kind=INTERRUPT: Contains interrupt field with interruption status data */ export type ChatMessage = { /** @@ -2375,6 +2411,12 @@ export type ChatMessage = { compaction?: { [key: string]: unknown } | null + /** + * Interruption status data for badge rendering (for kind=INTERRUPT) + */ + interrupt?: { + [key: string]: unknown + } | null } /** @@ -4693,6 +4735,7 @@ export type MessageKind = | "approval-decision" | "internal" | "compaction" + | "interrupt" export type ModelConfig = { /** @@ -10738,6 +10781,14 @@ export type AgentSessionsGetSessionVercelResponse = | AgentSessionReadVercel | ChatReadVercel +export type AgentSessionsCancelSessionData = { + requestBody?: AgentSessionCancelRequest | null + sessionId: string + workspaceId: string +} + +export type AgentSessionsCancelSessionResponse = AgentSessionCancelResponse + export type AgentSessionsSendMessageData = { requestBody: VercelChatRequest | ContinueRunRequest sessionId: string @@ -15729,6 +15780,21 @@ export type $OpenApiTs = { } } } + "/workspaces/{workspace_id}/agent/sessions/{session_id}/cancel": { + post: { + req: AgentSessionsCancelSessionData + res: { + /** + * Successful Response + */ + 200: AgentSessionCancelResponse + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } "/workspaces/{workspace_id}/agent/sessions/{session_id}/messages": { post: { req: AgentSessionsSendMessageData diff --git a/frontend/src/components/chat/chat-session-pane.tsx b/frontend/src/components/chat/chat-session-pane.tsx index e1c64193a1..6799620fac 100644 --- a/frontend/src/components/chat/chat-session-pane.tsx +++ b/frontend/src/components/chat/chat-session-pane.tsx @@ -11,7 +11,9 @@ import { type UITools, } from "ai" import { + AlertTriangleIcon, CheckIcon, + InfoIcon, Loader2, MousePointer2OffIcon, MousePointerClickIcon, @@ -99,6 +101,7 @@ import { type ApprovalCard, makeContinueMessage, parseChatError, + useCancelChatTurn, useUpdateChat, useVercelChat, } from "@/hooks/use-chat" @@ -264,6 +267,8 @@ export function ChatSessionPane({ const [toolMention, setToolMention] = useState() const [selectedTools, setSelectedTools] = useState([]) const { updateChat, isUpdating: isUpdatingTools } = useUpdateChat(workspaceId) + const { cancelChatTurn, isCancellingChatTurn } = + useCancelChatTurn(workspaceId) const { registryActions, registryActionsIsLoading } = useBuilderRegistryActions() @@ -322,7 +327,12 @@ export function ChatSessionPane({ (lastPart.type === "reasoning" && lastPart.text.length > 0) || (isToolUIPart(lastPart) && lastPart.state === "input-streaming") - if (lastPart.type === "data-approval-request") return false + if ( + lastPart.type === "data-approval-request" || + lastPart.type === "data-interrupt" + ) { + return false + } return !isStreamingVisual } @@ -333,6 +343,31 @@ export function ChatSessionPane({ const isInputDisabledRef = useRef(isInputDisabled) isInputDisabledRef.current = isInputDisabled const wasInputDisabledRef = useRef(isInputDisabled) + const isGenerating = status === "submitted" || status === "streaming" + const [cancelRequested, setCancelRequested] = useState(false) + + useEffect(() => { + if (!isGenerating) { + setCancelRequested(false) + } + }, [isGenerating]) + + const handleStop = useCallback(async () => { + if (!chat?.id || cancelRequested) { + return + } + + setCancelRequested(true) + try { + await cancelChatTurn({ chatId: chat.id }) + } catch (error) { + setCancelRequested(false) + toast({ + title: "Failed to stop run", + description: parseChatError(error), + }) + } + }, [cancelChatTurn, cancelRequested, chat?.id]) useEffect(() => { const wasInputDisabled = wasInputDisabledRef.current @@ -1083,7 +1118,12 @@ export function ChatSessionPane({ ) : null} void handleStop()} status={status} className="text-muted-foreground/80" /> @@ -1291,6 +1331,50 @@ function PromptModelIndicator({ modelInfo }: { modelInfo: ModelInfo }) { ) } +type TimelineNoticeTone = "neutral" | "warning" + +function TimelineNoticePart({ + label, + tone, +}: { + label: string + tone: TimelineNoticeTone +}) { + const Icon = tone === "warning" ? AlertTriangleIcon : InfoIcon + + return ( +
+
+ + {label} +
+
+ ) +} + +function getCompactionNoticeLabel(data: unknown): string { + const phase = + data && typeof data === "object" && "phase" in data + ? (data as { phase?: unknown }).phase + : undefined + + switch (phase) { + case "started": + return "Compacting conversation" + case "failed": + return "Compaction failed" + default: + return "Conversation compacted" + } +} + export function MessagePart({ part, partIdx, @@ -1322,6 +1406,27 @@ export function MessagePart({ ) } + if (part.type === "data-compaction") { + const payload = (part as { data?: unknown }).data + return ( + + ) + } + + if (part.type === "data-interrupt") { + return ( + + ) + } + if (part.type === "text") { return ( diff --git a/frontend/src/hooks/use-chat.ts b/frontend/src/hooks/use-chat.ts index 0775bc1547..a28fe99114 100644 --- a/frontend/src/hooks/use-chat.ts +++ b/frontend/src/hooks/use-chat.ts @@ -8,6 +8,7 @@ import { import { DefaultChatTransport, type UIMessage } from "ai" import { useCallback, useMemo, useState } from "react" import { + type AgentSessionCancelResponse, type AgentSessionCreate, type AgentSessionEntity, type AgentSessionRead, @@ -16,6 +17,7 @@ import { type AgentSessionsListSessionsResponse, type AgentSessionUpdate, type ApiError, + agentSessionsCancelSession, agentSessionsCreateSession, agentSessionsDeleteSession, agentSessionsGetSession, @@ -346,6 +348,41 @@ export function useDeleteChat(workspaceId: string) { } } +/** Request cancellation for the currently running agent turn. */ +export function useCancelChatTurn(workspaceId: string) { + const queryClient = useQueryClient() + + const mutation = useMutation< + AgentSessionCancelResponse, + ApiError, + { chatId: string } + >({ + mutationFn: ({ chatId }) => + agentSessionsCancelSession({ + sessionId: chatId, + workspaceId, + requestBody: { reason: "user_cancel" }, + }), + onSettled: (_, __, variables) => { + queryClient.invalidateQueries({ + queryKey: ["chat", variables.chatId, workspaceId], + }) + queryClient.invalidateQueries({ + queryKey: ["chat", variables.chatId, workspaceId, "vercel"], + }) + queryClient.invalidateQueries({ + queryKey: ["chats", workspaceId], + }) + }, + }) + + return { + cancelChatTurn: mutation.mutateAsync, + isCancellingChatTurn: mutation.isPending, + cancelChatTurnError: mutation.error, + } +} + export function useGetChatVercel({ chatId, workspaceId, diff --git a/frontend/src/lib/chat.ts b/frontend/src/lib/chat.ts index a64012d189..a5db10b13c 100644 --- a/frontend/src/lib/chat.ts +++ b/frontend/src/lib/chat.ts @@ -302,12 +302,14 @@ export function transformMessages(messages: ai.UIMessage[]): ai.UIMessage[] { // Array positions to ignore (using "msgIndex-partIndex" string format) const ignorePos = new Set() let pendingCompactionStartPos: string | null = null + let pendingInterruptPos: string | null = null for (const [i, message] of messages.entries()) { for (const [j, part] of message.parts.entries()) { const posKey = `${i}-${j}` if (ai.isToolUIPart(part)) { + pendingInterruptPos = null const { state, toolCallId } = part if (state === "input-available") { // OPEN STATE @@ -333,6 +335,7 @@ export function transformMessages(messages: ai.UIMessage[]): ai.UIMessage[] { states.set(toolCallId, { ...currState, close: posKey }) } } else if (part.type === "data-approval-request") { + pendingInterruptPos = null // Handle approval request parts // 1. If approval request we mark positions, only ignore if we hit a close state // 2. If we see approval requests after a close state, we should ignore the approval requests @@ -356,6 +359,7 @@ export function transformMessages(messages: ai.UIMessage[]): ai.UIMessage[] { } } } else if (part.type === "data-compaction") { + pendingInterruptPos = null const phase = part.data && typeof part.data === "object" && @@ -373,6 +377,13 @@ export function transformMessages(messages: ai.UIMessage[]): ai.UIMessage[] { pendingCompactionStartPos = null } } + } else if (part.type === "data-interrupt") { + if (pendingInterruptPos) { + ignorePos.add(pendingInterruptPos) + } + pendingInterruptPos = posKey + } else { + pendingInterruptPos = null } } } diff --git a/frontend/tests/chat-session-pane.test.tsx b/frontend/tests/chat-session-pane.test.tsx index 8119405526..baa30a8f8e 100644 --- a/frontend/tests/chat-session-pane.test.tsx +++ b/frontend/tests/chat-session-pane.test.tsx @@ -14,6 +14,10 @@ jest.mock("@/hooks/use-chat", () => ({ useVercelChat: jest.fn(), useGetChat: jest.fn(() => ({ chat: null })), useUpdateChat: jest.fn(() => ({ updateChat: jest.fn(), isUpdating: false })), + useCancelChatTurn: jest.fn(() => ({ + cancelChatTurn: jest.fn(), + isCancellingChatTurn: false, + })), parseChatError: (error: unknown) => error instanceof Error ? error.message : "Chat error", makeContinueMessage: (decisions: unknown) => ({ diff --git a/frontend/tests/memo-render-count.test.tsx b/frontend/tests/memo-render-count.test.tsx index b8e6fe6413..69bdef3bba 100644 --- a/frontend/tests/memo-render-count.test.tsx +++ b/frontend/tests/memo-render-count.test.tsx @@ -58,6 +58,10 @@ jest.mock("@/hooks/use-chat", () => ({ useVercelChat: jest.fn(), useGetChat: jest.fn(() => ({ chat: null })), useUpdateChat: jest.fn(() => ({ updateChat: jest.fn(), isUpdating: false })), + useCancelChatTurn: jest.fn(() => ({ + cancelChatTurn: jest.fn(), + isCancellingChatTurn: false, + })), parseChatError: (error: unknown) => error instanceof Error ? error.message : "Chat error", makeContinueMessage: jest.fn(), diff --git a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py index 5d92b11d52..6be784ea56 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py +++ b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import uuid from dataclasses import dataclass from datetime import UTC, datetime, timedelta @@ -8,7 +9,13 @@ from pydantic import BaseModel, ConfigDict, Field from temporalio import workflow from temporalio.common import TypedSearchAttributes -from temporalio.exceptions import ActivityError, ApplicationError +from temporalio.exceptions import ( + ActivityError, + ApplicationError, +) +from temporalio.exceptions import ( + CancelledError as TemporalCancelledError, +) with workflow.unsafe.imports_passed_through(): from pydantic_ai.messages import ToolCallPart @@ -56,12 +63,18 @@ LoadSessionMessagesInput, PendingToolResult, ReconcileToolResultsInput, + UpdateSessionStatusInput, create_session_activity, load_session_activity, load_session_messages_activity, reconcile_tool_results_activity, + update_session_status_activity, + ) + from tracecat.agent.session.types import ( + AgentCancelReason, + AgentSessionEntity, + AgentSessionStatus, ) - from tracecat.agent.session.types import AgentSessionEntity from tracecat.agent.subagents import has_manual_tool_approvals from tracecat.agent.tokens import ( InternalToolContext, @@ -124,6 +137,10 @@ def _activity_error_message(error: ActivityError) -> str: return str(error) +def _activity_was_cancelled(error: ActivityError) -> bool: + return isinstance(error.cause, TemporalCancelledError) + + def _build_approved_tool_run_input( *, tool_call: ApprovedToolCall, @@ -300,6 +317,10 @@ class WorkflowApprovalSubmission(BaseModel): decision_metadata: dict[str, dict[str, Any]] | None = None +class WorkflowCancelRequest(BaseModel): + reason: AgentCancelReason = "user_cancel" + + def _resolve_agent_output( *, output: Any, @@ -319,6 +340,9 @@ def _resolve_agent_output( # marker have aged out, then use workflow.deprecate_patch(...) before removing # the marker entirely in a later cleanup. LOAD_TERMINAL_MESSAGE_HISTORY_PATCH = "durable-agent-load-terminal-message-history-v1" +AGENT_ACTIVITY_GRACEFUL_CANCEL_PATCH = "durable-agent-activity-graceful-cancel-v1" +AGENT_SESSION_STATUS_PATCH = "durable-agent-session-status-v1" +AGENT_ACTIVITY_GRACEFUL_CANCEL_HEARTBEAT_TIMEOUT_SECONDS = 10 @workflow.defn @@ -344,6 +368,8 @@ def __init__(self, args: AgentWorkflowArgs): self.approvals = ApprovalManager(role=self.role) self.max_requests = args.agent_args.max_requests self.max_tool_calls = args.agent_args.max_tool_calls + self._cancel_requested = False + self._cancel_reason: AgentCancelReason | None = None def _upsert_tracecat_search_attributes(self) -> None: """Ensure direct agent runs have core Tracecat search attributes. @@ -674,10 +700,18 @@ async def run(self, args: AgentWorkflowArgs) -> AgentOutput: cfg = await self._build_config(args) return await self._run_with_agent_executor(args, cfg) except ActivityError as e: + await self._set_agent_session_status( + AgentSessionStatus.FAILED, + clear_curr_run_id=True, + ) if workflow.patched(EMIT_PRE_STREAM_SESSION_ERRORS_PATCH): await self._emit_session_error(_activity_error_message(e)) raise except ApplicationError as e: + await self._set_agent_session_status( + AgentSessionStatus.FAILED, + clear_curr_run_id=True, + ) if e.type == AGENT_TOOL_DEFINITION_ERROR or ( e.type != AGENT_RUNTIME_EXECUTION_ERROR and workflow.patched(EMIT_PRE_STREAM_SESSION_ERRORS_PATCH) @@ -704,6 +738,22 @@ async def _emit_session_error(self, message: str) -> None: error=str(emit_error), ) + @workflow.update + def request_cancel(self, request: WorkflowCancelRequest) -> None: + request = WorkflowCancelRequest.model_validate(request) + logger.info( + "Agent cancellation requested", + session_id=self.session_id, + reason=request.reason, + ) + if self._cancel_reason is None: + self._cancel_reason = request.reason + self._cancel_requested = True + + @request_cancel.validator + def validate_request_cancel(self, request: WorkflowCancelRequest) -> None: + WorkflowCancelRequest.model_validate(request) + @workflow.update def set_approvals(self, submission: WorkflowApprovalSubmission) -> None: submission = WorkflowApprovalSubmission.model_validate(submission) @@ -738,6 +788,93 @@ def validate_set_approvals(self, submission: WorkflowApprovalSubmission) -> None + ", ".join(sorted(unexpected_metadata_ids)) ) + async def _set_agent_session_status( + self, + status: AgentSessionStatus, + *, + clear_curr_run_id: bool = False, + ) -> None: + """Best-effort session status update for UI/API lifecycle gates.""" + if not workflow.patched(AGENT_SESSION_STATUS_PATCH): + return + + try: + await workflow.execute_activity( + update_session_status_activity, + UpdateSessionStatusInput( + role=self.role, + session_id=self.session_id, + status=status, + clear_curr_run_id=clear_curr_run_id, + ), + start_to_close_timeout=timedelta(seconds=30), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) + except ActivityError as e: + logger.warning( + "Failed to update agent session status", + session_id=self.session_id, + status=status, + error=str(e), + ) + + async def _run_agent_activity_turn( + self, + executor_input: AgentExecutorInput, + ) -> AgentExecutorResult: + """Run one executor activity turn with update-driven cancellation.""" + if not workflow.patched(AGENT_ACTIVITY_GRACEFUL_CANCEL_PATCH): + return await workflow.execute_activity( + run_agent_activity, + executor_input, + task_queue=config.TRACECAT__AGENT_EXECUTOR_QUEUE, + start_to_close_timeout=timedelta( + seconds=config.TRACECAT__AGENT_SANDBOX_TIMEOUT + ), + heartbeat_timeout=timedelta(seconds=60), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) + + activity_handle = workflow.start_activity( + run_agent_activity, + executor_input, + cancellation_type=workflow.ActivityCancellationType.WAIT_CANCELLATION_COMPLETED, + task_queue=config.TRACECAT__AGENT_EXECUTOR_QUEUE, + start_to_close_timeout=timedelta( + seconds=config.TRACECAT__AGENT_SANDBOX_TIMEOUT + ), + heartbeat_timeout=timedelta( + seconds=AGENT_ACTIVITY_GRACEFUL_CANCEL_HEARTBEAT_TIMEOUT_SECONDS + ), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) + cancel_wait_task = asyncio.create_task( + workflow.wait_condition(lambda: self._cancel_requested) + ) + try: + done, _pending = await asyncio.wait( + [activity_handle, cancel_wait_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if activity_handle in done: + return await activity_handle + + activity_handle.cancel() + try: + return await activity_handle + except ActivityError as e: + if self._cancel_requested and _activity_was_cancelled(e): + return AgentExecutorResult( + success=True, + cancelled=True, + cancelled_reason=self._cancel_reason or "user_cancel", + ) + raise + finally: + if not cancel_wait_task.done(): + cancel_wait_task.cancel() + async def _run_with_agent_executor( self, args: AgentWorkflowArgs, cfg: AgentConfig ) -> AgentOutput: @@ -864,17 +1001,30 @@ async def _run_with_agent_executor( # Run the executor activity while True: logger.info("Executing agent turn", turn=self._turn) + self._status = "running" + await self._set_agent_session_status(AgentSessionStatus.RUNNING) - result = await workflow.execute_activity( - run_agent_activity, - executor_input, - task_queue=config.TRACECAT__AGENT_EXECUTOR_QUEUE, - start_to_close_timeout=timedelta( - seconds=config.TRACECAT__AGENT_SANDBOX_TIMEOUT - ), - heartbeat_timeout=timedelta(seconds=60), - retry_policy=RETRY_POLICIES["activity:fail_fast"], - ) + result = await self._run_agent_activity_turn(executor_input) + + if result.cancelled: + logger.info( + "Agent turn cancelled", + session_id=self.session_id, + reason=result.cancelled_reason, + ) + self._status = "done" + await self._set_agent_session_status( + AgentSessionStatus.STOPPED, + clear_curr_run_id=True, + ) + message_history = await self._load_terminal_message_history(result) + return AgentOutput( + output=None, + message_history=message_history, + duration=(datetime.now(UTC) - info.start_time).total_seconds(), + usage=RunUsage(requests=0, input_tokens=0, output_tokens=0), + session_id=self.session_id, + ) if not result.success: # Missing means a legacy activity result from before the flag @@ -892,6 +1042,10 @@ async def _run_with_agent_executor( if result.approval_requested: logger.info("Agent waiting for approval", session_id=self.session_id) + self._status = "waiting_for_results" + await self._set_agent_session_status( + AgentSessionStatus.WAITING_FOR_APPROVAL + ) # Convert ToolCallContent to ToolCallPart for ApprovalManager if result.approval_items: tool_call_parts = [ @@ -916,6 +1070,8 @@ async def _run_with_agent_executor( await self.approvals.wait() # Persist approval decisions to DB (atomic with chat messages) await self.approvals.handle_decisions() + self._status = "running" + await self._set_agent_session_status(AgentSessionStatus.RUNNING) # Execute approved tools and reconcile the SDK transcript. approved_tools, denied_tools = self._build_tool_lists_from_approvals( @@ -970,6 +1126,11 @@ async def _run_with_agent_executor( output=result.output, ) message_history = await self._load_terminal_message_history(result) + self._status = "done" + await self._set_agent_session_status( + AgentSessionStatus.IDLE, + clear_curr_run_id=True, + ) return AgentOutput( output=output, message_history=message_history, diff --git a/tests/integration/test_agent_worker.py b/tests/integration/test_agent_worker.py index 2848ed60bb..83c084e499 100644 --- a/tests/integration/test_agent_worker.py +++ b/tests/integration/test_agent_worker.py @@ -52,6 +52,8 @@ LoadSessionResult, ReconcileToolResultsInput, ReconcileToolResultsResult, + UpdateSessionStatusInput, + UpdateSessionStatusResult, get_session_activities, ) from tracecat.agent.session.types import AgentSessionEntity @@ -115,6 +117,19 @@ async def mock_load_session_messages_activity( return mock_load_session_messages_activity +def create_mock_update_session_status_activity() -> Callable[..., Any]: + """Create a mock update_session_status_activity.""" + + @activity.defn(name="update_session_status_activity") + async def mock_update_session_status_activity( + input: UpdateSessionStatusInput, + ) -> UpdateSessionStatusResult: + del input + return UpdateSessionStatusResult(found=True) + + return mock_update_session_status_activity + + def create_mock_run_agent_activity( response_callback: Callable[[AgentExecutorInput], AgentExecutorResult], ) -> Callable[..., Any]: @@ -155,6 +170,7 @@ def create_activities_with_mock_executor( # Add mocked session activities (to avoid DB FK constraints) activities.append(create_mock_create_session_activity()) activities.append(create_mock_load_session_activity()) + activities.append(create_mock_update_session_status_activity()) activities.append(create_mock_load_session_messages_activity()) # Add mocked runtime activity @@ -245,6 +261,7 @@ async def test_session_activities_registered(self) -> None: assert set(activity_names) == { "create_session_activity", "load_session_activity", + "update_session_status_activity", "load_session_messages_activity", "reconcile_tool_results_activity", } @@ -513,6 +530,7 @@ def mock_executor(input: AgentExecutorInput) -> AgentExecutorResult: *agent_activities.get_activities(), mock_create_session_activity, mock_load_session_activity, + create_mock_update_session_status_activity(), create_mock_load_session_messages_activity(), mock_record_approval_requests, mock_apply_approval_decisions, diff --git a/tests/integration/test_dsl_agent_wiring.py b/tests/integration/test_dsl_agent_wiring.py index 7b7329567c..ce67e948ce 100644 --- a/tests/integration/test_dsl_agent_wiring.py +++ b/tests/integration/test_dsl_agent_wiring.py @@ -29,6 +29,8 @@ LoadSessionMessagesInput, LoadSessionMessagesResult, LoadSessionResult, + UpdateSessionStatusInput, + UpdateSessionStatusResult, ) from tracecat.agent.worker import ( get_activities as get_agent_worker_activities, @@ -125,6 +127,16 @@ async def mock_load_session_messages_activity( return mock_load_session_messages_activity +def create_mock_update_session_status_activity() -> Callable[..., Any]: + @activity.defn(name="update_session_status_activity") + async def mock_update_session_status_activity( + _: UpdateSessionStatusInput, + ) -> UpdateSessionStatusResult: + return UpdateSessionStatusResult(found=True) + + return mock_update_session_status_activity + + def create_mock_build_tool_definitions_activity() -> Callable[..., Any]: @activity.defn(name="build_agent_tool_definitions") async def mock_build_tool_definitions( @@ -197,6 +209,7 @@ async def test_dsl_workflow_executes_ai_agent_on_agent_worker( for replacement in ( create_mock_create_session_activity(), create_mock_load_session_activity(), + create_mock_update_session_status_activity(), create_mock_load_session_messages_activity(), create_mock_build_tool_definitions_activity(), ): @@ -265,6 +278,7 @@ async def test_dsl_workflow_marks_existing_agent_session_as_required( for replacement in ( create_mock_create_session_activity(captured_inputs), create_mock_load_session_activity(), + create_mock_update_session_status_activity(), create_mock_load_session_messages_activity(), create_mock_build_tool_definitions_activity(), ): diff --git a/tests/temporal/test_durable_agent_workflow.py b/tests/temporal/test_durable_agent_workflow.py index 21b5445aed..3abaacd257 100644 --- a/tests/temporal/test_durable_agent_workflow.py +++ b/tests/temporal/test_durable_agent_workflow.py @@ -48,6 +48,7 @@ AgentWorkflowArgs, DurableAgentWorkflow, WorkflowApprovalSubmission, + WorkflowCancelRequest, ) from tracecat import config @@ -69,13 +70,16 @@ LoadSessionResult, ReconcileToolResultsInput, ReconcileToolResultsResult, + UpdateSessionStatusInput, + UpdateSessionStatusResult, create_session_activity, load_session_activity, load_session_messages_activity, reconcile_tool_results_activity, + update_session_status_activity, ) from tracecat.agent.session.service import AgentSessionService -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.types import AgentConfig from tracecat.auth.types import Role from tracecat.authz.scopes import SERVICE_PRINCIPAL_SCOPES @@ -154,6 +158,22 @@ async def mock_load_session_messages_activity( return mock_load_session_messages_activity +def create_mock_update_session_status_activity( + captured_inputs: list[UpdateSessionStatusInput] | None = None, +) -> Callable[..., Any]: + """Create a mock update_session_status_activity.""" + + @activity.defn(name="update_session_status_activity") + async def mock_update_session_status_activity( + input: UpdateSessionStatusInput, + ) -> UpdateSessionStatusResult: + if captured_inputs is not None: + captured_inputs.append(input) + return UpdateSessionStatusResult(found=True) + + return mock_update_session_status_activity + + def create_mock_build_tool_definitions_activity( tool_definitions: dict[str, MCPToolDefinition] | None = None, ) -> Callable[..., Any]: @@ -270,6 +290,7 @@ def create_activities_with_mock_executor( tool_definitions: dict[str, MCPToolDefinition] | None = None, message_load_inputs: list[LoadSessionMessagesInput] | None = None, session_messages: list[ChatMessage] | None = None, + status_update_inputs: list[UpdateSessionStatusInput] | None = None, ) -> Sequence[Callable[..., Any]]: """Create a full activity list with mocked agent executor. @@ -284,6 +305,7 @@ def create_activities_with_mock_executor( activities: list[Callable[..., Any]] = [ create_mock_create_session_activity(), create_mock_load_session_activity(), + create_mock_update_session_status_activity(status_update_inputs), create_mock_load_session_messages_activity( captured_inputs=message_load_inputs, messages=session_messages, @@ -500,6 +522,7 @@ async def mock_emit_session_error(args: EmitSessionErrorInputs) -> None: activities = [ create_mock_create_session_activity(), + create_mock_update_session_status_activity(), mock_build_tool_definitions, mock_emit_session_error, ] @@ -672,6 +695,76 @@ def mock_executor( assert [input.session_id for input in message_load_inputs] == [mock_session_id] +@pytest.mark.anyio +@pytest.mark.integration +async def test_request_cancel_gracefully_stops_running_agent_turn( + temporal_client: Client, + agent_worker_factory, + agent_workflow_args: AgentWorkflowArgs, + mock_session_id: uuid.UUID, +) -> None: + """Workflow cancel update requests activity cancellation and exits cleanly.""" + queue = f"test-agent-queue-{mock_session_id}" + activity_started = asyncio.Event() + cancellation_seen = asyncio.Event() + status_update_inputs: list[UpdateSessionStatusInput] = [] + + @activity.defn(name="run_agent_activity") + async def mock_run_agent_activity( + input: AgentExecutorInput, + ) -> AgentExecutorResult: + del input + activity_started.set() + try: + while True: + activity.heartbeat("Mock agent waiting for cancellation") + await asyncio.sleep(0.1) + except asyncio.CancelledError: + cancellation_seen.set() + return AgentExecutorResult( + success=True, + cancelled=True, + cancelled_reason="user_cancel", + ) + + activities: list[Callable[..., Any]] = [ + create_mock_create_session_activity(), + create_mock_load_session_activity(), + create_mock_update_session_status_activity(status_update_inputs), + create_mock_load_session_messages_activity(), + create_mock_build_tool_definitions_activity(), + mock_run_agent_activity, + create_mock_execute_action_activity(), + create_mock_reconcile_tool_results_activity(), + *ApprovalManager.get_activities(), + ] + + async with agent_worker_factory( + temporal_client, task_queue=queue, custom_activities=activities + ): + handle = await temporal_client.start_workflow( + DurableAgentWorkflow.run, + agent_workflow_args, + id=AgentWorkflowID(mock_session_id), + task_queue=queue, + retry_policy=RETRY_POLICIES["workflow:fail_fast"], + execution_timeout=timedelta(seconds=60), + ) + await asyncio.wait_for(activity_started.wait(), timeout=10) + + await handle.execute_update( + DurableAgentWorkflow.request_cancel, + WorkflowCancelRequest(reason="user_cancel"), + ) + result = await handle.result() + + assert cancellation_seen.is_set() + assert result.session_id == mock_session_id + assert result.output is None + assert status_update_inputs[-1].status is AgentSessionStatus.STOPPED + assert status_update_inputs[-1].clear_curr_run_id is True + + @pytest.mark.anyio @pytest.mark.integration async def test_agent_workflow_preserves_legacy_activity_message_history( @@ -1002,6 +1095,7 @@ async def mock_execute_action_activity( activities=[ create_session_activity, load_session_activity, + update_session_status_activity, load_session_messages_activity, reconcile_tool_results_activity, mock_build_tool_definitions, @@ -1301,6 +1395,7 @@ async def mock_reconcile_tool_results_activity( activities=[ mock_create_session_activity, mock_load_session_activity, + create_mock_update_session_status_activity(), create_mock_load_session_messages_activity(), mock_build_tool_definitions, mock_record_approval_requests, @@ -1491,6 +1586,7 @@ async def mock_apply_approval_decisions( activities = [ create_mock_create_session_activity(), mock_load_session_activity, + create_mock_update_session_status_activity(), create_mock_load_session_messages_activity(), create_mock_build_tool_definitions_activity(), mock_run_agent_activity, @@ -1748,6 +1844,7 @@ async def mock_execute_action_activity( activities=[ create_session_activity, load_session_activity, + update_session_status_activity, load_session_messages_activity, reconcile_tool_results_activity, mock_build_tool_definitions, diff --git a/tests/unit/test_agent_activities.py b/tests/unit/test_agent_activities.py index 7046f89384..b2be54827d 100644 --- a/tests/unit/test_agent_activities.py +++ b/tests/unit/test_agent_activities.py @@ -95,7 +95,7 @@ def test_get_session_activities_returns_list(self): """Test that get_session_activities returns a list of activity functions.""" activities = get_session_activities() assert isinstance(activities, list) - assert len(activities) == 4 + assert len(activities) == 5 # All returned items should have the temporal activity definition for activity in activities: @@ -109,6 +109,7 @@ def test_session_activities_names(self): ] assert "create_session_activity" in activity_names assert "load_session_activity" in activity_names + assert "update_session_status_activity" in activity_names assert "load_session_messages_activity" in activity_names assert "reconcile_tool_results_activity" in activity_names diff --git a/tests/unit/test_agent_executor_loopback.py b/tests/unit/test_agent_executor_loopback.py index c033de4c64..f2740bf40c 100644 --- a/tests/unit/test_agent_executor_loopback.py +++ b/tests/unit/test_agent_executor_loopback.py @@ -377,3 +377,43 @@ async def test_send_done_preserves_existing_error_state() -> None: assert handler._result.error == "runtime failed" stream.error.assert_not_awaited() stream.done.assert_awaited_once() + + +@pytest.mark.anyio +async def test_mark_cancelled_skips_empty_runtime_completion_validation() -> None: + handler = _make_handler() + stream = _FakeStream() + handler._stream_sink = stream + + handler.mark_cancelled("user_cancel") + await handler.send_done() + + assert handler._result.success is True + assert handler._result.cancelled is True + assert handler._result.cancelled_reason == "user_cancel" + stream.append.assert_awaited_once() + append_args = stream.append.await_args + assert append_args is not None + event = append_args.args[0] + assert event.type == StreamEventType.INTERRUPT + assert event.metadata == {"reason": "user_cancel"} + stream.error.assert_not_awaited() + stream.done.assert_awaited_once() + + +@pytest.mark.anyio +async def test_cancelled_loopback_preserves_later_stream_and_history() -> None: + handler = _make_handler() + stream = _FakeStream() + handler._stream_sink = stream + persist = AsyncMock() + handler._persist_session_line = persist + + handler.mark_cancelled("user_cancel") + await handler.send_stream_event( + UnifiedStreamEvent(type=StreamEventType.TEXT_DELTA, text="late") + ) + await handler.send_session_line("sdk-session", '{"type":"assistant"}') + + assert stream.append.await_count == 1 + persist.assert_awaited_once() diff --git a/tests/unit/test_agent_runtime.py b/tests/unit/test_agent_runtime.py index 3976153dbc..2d79177c94 100644 --- a/tests/unit/test_agent_runtime.py +++ b/tests/unit/test_agent_runtime.py @@ -295,6 +295,80 @@ async def test_sends_done_on_completion( mock_socket_writer.send_done.assert_awaited_once() + @pytest.mark.anyio + async def test_interrupt_before_query_stops_sdk_after_query_ready( + self, + mock_socket_writer: MagicMock, + mock_claude_sdk_client: MagicMock, + sample_init_payload: RuntimeInitPayload, + ) -> None: + """A pre-ready interrupt is retained until the SDK query can be stopped.""" + with ( + patch( + "tracecat.agent.runtime.claude_code.runtime.ClaudeSDKClient", + return_value=mock_claude_sdk_client, + ), + ): + runtime = ClaudeAgentRuntime( + mock_socket_writer, transport_factory=lambda _: MagicMock() + ) + interrupt_task = asyncio.create_task( + runtime.interrupt(reason="user_cancel") + ) + await asyncio.sleep(0) + await runtime.run(sample_init_payload) + await interrupt_task + + mock_claude_sdk_client.interrupt.assert_awaited_once() + mock_socket_writer.send_done.assert_awaited_once() + + @pytest.mark.anyio + async def test_interrupt_during_receive_lets_sdk_finish_turn( + self, + mock_socket_writer: MagicMock, + mock_claude_sdk_client: MagicMock, + sample_init_payload: RuntimeInitPayload, + ) -> None: + """Cancellation interrupts the SDK but preserves its remaining output.""" + query_sent = asyncio.Event() + + async def query(_input: object) -> None: + query_sent.set() + + async def receive_response() -> Any: + await query_sent.wait() + yield StreamEvent( + uuid="stream-event-uuid", + session_id="sdk-session", + event={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "interrupted"}, + }, + ) + + mock_claude_sdk_client.query.side_effect = query + mock_claude_sdk_client.receive_response = receive_response + + with ( + patch( + "tracecat.agent.runtime.claude_code.runtime.ClaudeSDKClient", + return_value=mock_claude_sdk_client, + ), + ): + runtime = ClaudeAgentRuntime( + mock_socket_writer, transport_factory=lambda _: MagicMock() + ) + run_task = asyncio.create_task(runtime.run(sample_init_payload)) + await asyncio.wait_for(query_sent.wait(), timeout=1) + + await runtime.interrupt(reason="user_cancel") + await asyncio.wait_for(run_task, timeout=1) + + mock_claude_sdk_client.interrupt.assert_awaited_once() + mock_socket_writer.send_stream_event.assert_awaited() + mock_socket_writer.send_done.assert_awaited_once() + @pytest.mark.anyio async def test_handles_stream_events( self, @@ -2352,6 +2426,45 @@ def test_hides_structured_compaction_artifacts( assert runtime._is_internal_session_line(line_data) is True + def test_interrupt_rows_are_internal( + self, + mock_socket_writer: MagicMock, + ) -> None: + """SDK interruption artifacts should not render as chat messages.""" + runtime = ClaudeAgentRuntime( + mock_socket_writer, transport_factory=lambda _: MagicMock() + ) + + text_marker = { + "type": "user", + "message": { + "content": [ + { + "type": "text", + "text": "[Request interrupted by user]", + } + ] + }, + } + tool_result = { + "type": "user", + "message": { + "content": [ + { + "type": "tool_result", + "tool_use_id": "call_123", + "is_error": True, + "content": ( + "The user doesn't want to take this action right now." + ), + } + ] + }, + } + + assert runtime._is_internal_session_line(text_marker) is True + assert runtime._is_internal_session_line(tool_result) is True + @pytest.mark.anyio async def test_approval_continuation_hides_sdk_meta_prompt_only( self, diff --git a/tests/unit/test_agent_runtime_broker.py b/tests/unit/test_agent_runtime_broker.py index e28789a024..f9255bcde7 100644 --- a/tests/unit/test_agent_runtime_broker.py +++ b/tests/unit/test_agent_runtime_broker.py @@ -135,6 +135,45 @@ async def run(self, _payload: RuntimeInitPayload) -> None: await first_task +@pytest.mark.anyio +async def test_broker_interrupts_active_runtime( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + release = asyncio.Event() + runtimes: list[Any] = [] + + class FakeRuntime: + def __init__(self, *_args, **_kwargs) -> None: + self.interrupt = AsyncMock() + runtimes.append(self) + + async def run(self, _payload: RuntimeInitPayload) -> None: + await release.wait() + + monkeypatch.setattr(broker_module, "ClaudeAgentRuntime", FakeRuntime) + + broker = ClaudeRuntimeBroker() + await broker.start() + + request = _make_request(tmp_path) + handler = cast( + Any, + SimpleNamespace( + prepare=AsyncMock(), + process_envelope=AsyncMock(), + ), + ) + + task = asyncio.create_task(broker.run_turn(request, handler)) + await asyncio.sleep(0) + + await broker.interrupt_turn(str(request.init_payload.session_id), "user_cancel") + + runtimes[0].interrupt.assert_awaited_once_with(reason="user_cancel") + release.set() + await task + + @pytest.mark.anyio async def test_broker_rechecks_closed_state_after_waiting_for_lock( tmp_path: Path, diff --git a/tests/unit/test_agent_sandbox_litellm.py b/tests/unit/test_agent_sandbox_litellm.py index 7370fd2cea..cd15efff4b 100644 --- a/tests/unit/test_agent_sandbox_litellm.py +++ b/tests/unit/test_agent_sandbox_litellm.py @@ -292,6 +292,7 @@ class _FakeBroker: def __init__(self) -> None: self.requests: list[ClaudeTurnRequest] = [] self.cancelled_session_ids: list[str] = [] + self.interrupted_session_ids: list[str] = [] async def run_turn( self, @@ -304,6 +305,10 @@ async def run_turn( async def cancel_turn(self, session_id: str) -> None: self.cancelled_session_ids.append(session_id) + async def interrupt_turn(self, session_id: str, reason: str) -> None: + del reason + self.interrupted_session_ids.append(session_id) + class _FakeProxy: def __init__(self) -> None: diff --git a/tests/unit/test_agent_session_messages.py b/tests/unit/test_agent_session_messages.py index c8e7eada39..64af71d97a 100644 --- a/tests/unit/test_agent_session_messages.py +++ b/tests/unit/test_agent_session_messages.py @@ -9,6 +9,7 @@ import pytest from tracecat.agent.session.service import AgentSessionService +from tracecat.agent.session.types import AgentSessionStatus from tracecat.auth.types import Role from tracecat.chat.enums import MessageKind from tracecat.db.models import AgentSession @@ -80,6 +81,90 @@ async def test_list_messages_preserves_compaction_metadata() -> None: } +@pytest.mark.anyio +async def test_list_messages_adds_interrupt_notice_for_stopped_session() -> None: + service, agent_session = _build_service() + agent_session.status = AgentSessionStatus.STOPPED.value + interrupt_entry = SimpleNamespace( + id=uuid.uuid4(), + session_id=agent_session.id, + kind=MessageKind.INTERNAL.value, + content={ + "type": "user", + "uuid": "interrupt-uuid", + "message": { + "role": "user", + "content": [{"type": "text", "text": "[Request interrupted by user]"}], + }, + }, + ) + duplicate_interrupt_entry = SimpleNamespace( + id=uuid.uuid4(), + session_id=agent_session.id, + kind=MessageKind.INTERNAL.value, + content={ + "type": "user", + "uuid": "interrupt-tool-result-uuid", + "message": { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "call_123", + "is_error": True, + "content": "The user doesn't want to take this action right now.", + } + ], + }, + }, + ) + + service.get_session = AsyncMock(return_value=agent_session) + service.session.execute = AsyncMock( + side_effect=[ + _mock_scalar_result([interrupt_entry, duplicate_interrupt_entry]), + _mock_scalar_result([]), + ] + ) + + messages = await service.list_messages(agent_session.id) + + assert len(messages) == 1 + assert messages[0].kind == MessageKind.INTERRUPT + assert messages[0].interrupt == {"reason": "user_cancel"} + + +@pytest.mark.anyio +async def test_list_messages_keeps_approval_interrupt_artifacts_hidden() -> None: + service, agent_session = _build_service() + agent_session.status = AgentSessionStatus.WAITING_FOR_APPROVAL.value + interrupt_entry = SimpleNamespace( + id=uuid.uuid4(), + session_id=agent_session.id, + kind=MessageKind.INTERNAL.value, + content={ + "type": "user", + "uuid": "interrupt-uuid", + "message": { + "role": "user", + "content": [{"type": "text", "text": "[Request interrupted by user]"}], + }, + }, + ) + + service.get_session = AsyncMock(return_value=agent_session) + service.session.execute = AsyncMock( + side_effect=[ + _mock_scalar_result([interrupt_entry]), + _mock_scalar_result([]), + ] + ) + + messages = await service.list_messages(agent_session.id) + + assert messages == [] + + @pytest.mark.anyio async def test_load_session_history_omits_internal_rows_and_repairs_parent_chain() -> ( None diff --git a/tests/unit/test_agent_session_router.py b/tests/unit/test_agent_session_router.py index af670ae27e..429277ece6 100644 --- a/tests/unit/test_agent_session_router.py +++ b/tests/unit/test_agent_session_router.py @@ -14,6 +14,7 @@ from tracecat.agent.adapter.vercel import UIMessage from tracecat.agent.common.stream_types import HarnessType from tracecat.agent.session.router import ( + cancel_session, get_session, get_session_vercel, send_message, @@ -26,7 +27,7 @@ ContinueRunRequest, VercelChatRequest, ) -from tracecat.exceptions import TracecatNotFoundError +from tracecat.exceptions import TracecatConflictError, TracecatNotFoundError async def _empty_event_stream() -> AsyncIterator[None]: @@ -57,6 +58,7 @@ def _agent_session_stub(**overrides: Any) -> SimpleNamespace: "created_at": now, "updated_at": now, "last_stream_id": None, + "status": "idle", } values.update(overrides) return SimpleNamespace(**values) @@ -131,6 +133,66 @@ async def test_get_session_vercel_includes_agents_binding() -> None: } +@pytest.mark.anyio +async def test_cancel_session_delegates_to_service() -> None: + session_id = uuid.uuid4() + workspace_id = uuid.uuid4() + role = _read_role(workspace_id) + fake_response = SimpleNamespace( + session_id=session_id, + run_id=uuid.uuid4(), + reason="user_cancel", + status="running", + ) + fake_svc = SimpleNamespace(request_cancel=AsyncMock(return_value=fake_response)) + + with patch( + "tracecat.agent.session.router.AgentSessionService", return_value=fake_svc + ): + raw_cancel_session = cast(Any, cancel_session).__wrapped__ + response = await raw_cancel_session( + session_id=session_id, + role=role, + session=AsyncMock(), + request=None, + ) + + assert response is fake_response + fake_svc.request_cancel.assert_awaited_once() + assert fake_svc.request_cancel.await_args.args[0] == session_id + assert fake_svc.request_cancel.await_args.args[1].reason == "user_cancel" + + +@pytest.mark.anyio +async def test_cancel_session_maps_conflict_to_409() -> None: + session_id = uuid.uuid4() + workspace_id = uuid.uuid4() + role = _read_role(workspace_id) + fake_svc = SimpleNamespace( + request_cancel=AsyncMock( + side_effect=TracecatConflictError( + "not running", + detail={"status": "idle"}, + ) + ) + ) + + with patch( + "tracecat.agent.session.router.AgentSessionService", return_value=fake_svc + ): + raw_cancel_session = cast(Any, cancel_session).__wrapped__ + with pytest.raises(HTTPException) as exc_info: + await raw_cancel_session( + session_id=session_id, + role=role, + session=AsyncMock(), + request=None, + ) + + assert exc_info.value.status_code == status.HTTP_409_CONFLICT + assert exc_info.value.detail == {"status": "idle"} + + @pytest.mark.anyio async def test_send_message_continue_uses_path_session_id_for_stream_key() -> None: session_id = uuid.uuid4() diff --git a/tests/unit/test_vercel_adapter.py b/tests/unit/test_vercel_adapter.py index 10466bb280..5181c96a67 100644 --- a/tests/unit/test_vercel_adapter.py +++ b/tests/unit/test_vercel_adapter.py @@ -1,8 +1,12 @@ from datetime import UTC, datetime +from typing import cast from uuid import uuid4 -from tracecat.agent.adapter.vercel import convert_chat_messages_to_ui +from claude_agent_sdk.types import TextBlock, ToolResultBlock, UserMessage + +from tracecat.agent.adapter.vercel import DataUIPart, convert_chat_messages_to_ui from tracecat.agent.approvals.enums import ApprovalStatus +from tracecat.chat.constants import INTERRUPT_DATA_PART_TYPE from tracecat.chat.enums import MessageKind from tracecat.chat.schemas import ApprovalRead, ChatMessage @@ -34,3 +38,59 @@ def test_convert_chat_messages_to_ui_skips_resolved_approval_request() -> None: messages = convert_chat_messages_to_ui([_approval_message(ApprovalStatus.APPROVED)]) assert messages == [] + + +def test_convert_chat_messages_to_ui_filters_interrupt_text() -> None: + messages = convert_chat_messages_to_ui( + [ + ChatMessage( + id=str(uuid4()), + message=UserMessage( + content=[ + TextBlock(text="[Request interrupted by user]"), + ] + ), + ) + ] + ) + + assert messages == [] + + +def test_convert_chat_messages_to_ui_filters_interrupt_tool_result() -> None: + messages = convert_chat_messages_to_ui( + [ + ChatMessage( + id=str(uuid4()), + message=UserMessage( + content=[ + ToolResultBlock( + tool_use_id="call_123", + content="The user doesn't want to take this action right now.", + is_error=True, + ) + ] + ), + ) + ] + ) + + assert messages == [] + + +def test_convert_chat_messages_to_ui_emits_interrupt_notice() -> None: + messages = convert_chat_messages_to_ui( + [ + ChatMessage( + id=str(uuid4()), + kind=MessageKind.INTERRUPT, + interrupt={"reason": "user_cancel"}, + ) + ] + ) + + assert len(messages) == 1 + assert messages[0].role == "system" + part = cast(DataUIPart, messages[0].parts[0]) + assert part["type"] == INTERRUPT_DATA_PART_TYPE + assert part["data"] == {"reason": "user_cancel"} diff --git a/tests/unit/test_vercel_stream_context.py b/tests/unit/test_vercel_stream_context.py index 5b8775cc1a..8dea768499 100644 --- a/tests/unit/test_vercel_stream_context.py +++ b/tests/unit/test_vercel_stream_context.py @@ -30,6 +30,7 @@ ToolCallContent, UnifiedStreamEvent, ) +from tracecat.chat.constants import INTERRUPT_DATA_PART_TYPE async def collect_frames( @@ -944,6 +945,22 @@ async def test_tool_input_emitted_state_tracking(): assert ctx.tool_input_emitted[tool_call_id] is True +@pytest.mark.anyio +async def test_interrupt_event_emits_data_payload(): + """Test interrupted turns emit a system data part.""" + ctx = VercelStreamContext(message_id="msg_test") + + frames = await collect_frames( + ctx, + [UnifiedStreamEvent.interrupt_event(reason="user_cancel")], + ) + + assert len(frames) == 1 + assert isinstance(frames[0], DataEventPayload) + assert frames[0].type == INTERRUPT_DATA_PART_TYPE + assert frames[0].data.reason == "user_cancel" + + # ============================================================================== # SSE Format Tests # ============================================================================== diff --git a/tracecat/agent/adapter/vercel.py b/tracecat/agent/adapter/vercel.py index f43f93e6cd..879efa33fa 100644 --- a/tracecat/agent/adapter/vercel.py +++ b/tracecat/agent/adapter/vercel.py @@ -27,6 +27,7 @@ from claude_agent_sdk.types import ( AssistantMessage, TextBlock, + ToolResultBlock, ToolUseBlock, UserMessage, ) @@ -76,6 +77,7 @@ APPROVAL_DATA_PART_TYPE, APPROVAL_REQUEST_HEADER, COMPACTION_DATA_PART_TYPE, + INTERRUPT_DATA_PART_TYPE, ) from tracecat.chat.enums import MessageKind from tracecat.logger import logger @@ -694,6 +696,11 @@ class CompactionDataPayload: pre_tokens: int | None = None +@dataclasses.dataclass(slots=True, kw_only=True) +class InterruptDataPayload: + reason: str | None = None + + @dataclasses.dataclass(slots=True, kw_only=True) class DataEventPayload: type: str @@ -1030,6 +1037,17 @@ async def _handle_unified_event( data=payload, ) + case StreamEventType.INTERRUPT: + metadata = event.metadata or {} + reason = metadata.get("reason") + payload = InterruptDataPayload( + reason=reason if isinstance(reason, str) else None + ) + yield DataEventPayload( + type=INTERRUPT_DATA_PART_TYPE, + data=payload, + ) + case StreamEventType.ERROR: yield ErrorEventPayload(errorText=event.error or "Unknown error") @@ -1508,6 +1526,11 @@ def _is_internal_interrupt_message(chat_message: ChatMessage) -> bool: return True if part.text == "No response requested.": return True + elif isinstance(part, ToolResultBlock): + if part.is_error: + text = str(part.content or "") + if "doesn't want to take this action" in text: + return True return False @@ -1542,6 +1565,18 @@ def convert_chat_messages_to_ui( mutable_messages.append(mutable_message) continue + # Handle interrupt status badges from DB (kind=INTERRUPT) + # These show when a chat turn was stopped by the user. + if chat_message.kind == MessageKind.INTERRUPT and chat_message.interrupt: + interrupt_data = chat_message.interrupt + mutable_message = MutableMessage( + id=chat_message.id, + role="system", + parts=[DataUIPart(type=INTERRUPT_DATA_PART_TYPE, data=interrupt_data)], + ) + mutable_messages.append(mutable_message) + continue + # Handle approval request bubbles from DB (kind=APPROVAL_REQUEST) # These are inserted by list_messages() when loading session history if chat_message.kind == MessageKind.APPROVAL_REQUEST and chat_message.approval: diff --git a/tracecat/agent/common/stream_types.py b/tracecat/agent/common/stream_types.py index a15e347bfb..ffe7945681 100644 --- a/tracecat/agent/common/stream_types.py +++ b/tracecat/agent/common/stream_types.py @@ -44,6 +44,7 @@ class StreamEventType(StrEnum): # System/status events COMPACTION = "compaction" + INTERRUPT = "interrupt" # Control events ERROR = "error" @@ -208,6 +209,14 @@ def compaction_event( metadata=event_metadata, ) + @classmethod + def interrupt_event(cls, *, reason: str | None = None) -> UnifiedStreamEvent: + """Factory method for creating interrupted turn status events.""" + metadata: dict[str, Any] = {} + if reason is not None: + metadata["reason"] = reason + return cls(type=StreamEventType.INTERRUPT, metadata=metadata) + @classmethod def tool_result_event( cls, diff --git a/tracecat/agent/executor/activity.py b/tracecat/agent/executor/activity.py index ef51eaedec..850255ccdc 100644 --- a/tracecat/agent/executor/activity.py +++ b/tracecat/agent/executor/activity.py @@ -51,6 +51,7 @@ LLMSocketProxy, ) from tracecat.agent.session.service import AgentSessionService +from tracecat.agent.session.types import AgentCancelReason from tracecat.agent.skill.service import SkillService from tracecat.agent.types import AgentConfig from tracecat.auth.types import Role @@ -69,6 +70,9 @@ ToolExecutionResult, ) +GRACEFUL_CANCEL_TIMEOUT_SECONDS = 30 +AGENT_ACTIVITY_HEARTBEAT_INTERVAL_SECONDS = 2 + class AgentExecutorInput(BaseModel): """Input for the agent executor activity.""" @@ -126,6 +130,8 @@ class AgentExecutorResult(BaseModel): ) result_usage: dict[str, Any] | None = None result_num_turns: int | None = None + cancelled: bool = False + cancelled_reason: AgentCancelReason | None = None class ExecuteApprovedToolsInput(BaseModel): @@ -400,6 +406,16 @@ def _apply_loopback_result( result.output = loopback_result.output result.result_usage = loopback_result.result_usage result.result_num_turns = loopback_result.result_num_turns + result.cancelled = loopback_result.cancelled + result.cancelled_reason = loopback_result.cancelled_reason + + @staticmethod + def _activity_cancel_reason() -> AgentCancelReason: + """Map Temporal activity cancellation metadata to agent cancellation reason.""" + details = activity.cancellation_details() + if details is not None and details.worker_shutdown: + return "worker_drain" + return "user_cancel" async def _run_with_broker( self, @@ -441,7 +457,7 @@ async def wait_fatal_error() -> str: return self._fatal_error or "Unknown LLM error" fatal_error_task = asyncio.create_task(wait_fatal_error()) - heartbeat_interval = 30 + heartbeat_interval = AGENT_ACTIVITY_HEARTBEAT_INTERVAL_SECONDS elapsed = 0 try: @@ -492,8 +508,37 @@ async def wait_fatal_error() -> str: if not isinstance(e, ConcurrentSessionTurnError): raise except asyncio.CancelledError: - await broker.cancel_turn(str(self.input.session_id)) - raise + reason = self._activity_cancel_reason() + logger.info( + "Agent activity cancellation requested", + session_id=self.input.session_id, + reason=reason, + ) + handler.mark_cancelled(reason) + try: + async with asyncio.timeout(GRACEFUL_CANCEL_TIMEOUT_SECONDS): + await broker.interrupt_turn(str(self.input.session_id), reason) + await broker_task + except TimeoutError: + logger.warning( + "Timed out gracefully cancelling agent activity", + session_id=self.input.session_id, + timeout_seconds=GRACEFUL_CANCEL_TIMEOUT_SECONDS, + ) + await broker.cancel_turn(str(self.input.session_id)) + raise + except asyncio.CancelledError: + await broker.cancel_turn(str(self.input.session_id)) + raise + + self._apply_loopback_result(result, handler.build_result()) + result.success = True + result.cancelled = True + result.cancelled_reason = reason + self._log_benchmark_phase( + "broker_activity_cancelled", + reason=reason, + ) finally: for task in (fatal_error_task, broker_task): if not task.done(): @@ -727,7 +772,9 @@ async def run_agent_activity(input: AgentExecutorInput) -> AgentExecutorResult: executor = SandboxedAgentExecutor(input=input) result = await executor.run() - if result.success: + if result.cancelled: + activity.heartbeat(f"Agent execution cancelled: {input.session_id}") + elif result.success: activity.heartbeat(f"Agent execution completed: {input.session_id}") else: activity.heartbeat(f"Agent execution failed: {result.error}") diff --git a/tracecat/agent/executor/loopback.py b/tracecat/agent/executor/loopback.py index 7b427cb54c..ce6d823ea8 100644 --- a/tracecat/agent/executor/loopback.py +++ b/tracecat/agent/executor/loopback.py @@ -36,7 +36,7 @@ ToolCallContent, UnifiedStreamEvent, ) -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentCancelReason, AgentSessionEntity from tracecat.agent.stream.connector import AgentStream from tracecat.db.engine import ( get_async_session_bypass_rls_context_manager, @@ -96,6 +96,8 @@ class LoopbackResult: output: RuntimeOutput | None = None result_usage: ResultUsage | None = None result_num_turns: int | None = None + cancelled: bool = False + cancelled_reason: AgentCancelReason | None = None @dataclass(frozen=True, kw_only=True, slots=True) @@ -247,6 +249,7 @@ def __init__(self, input: LoopbackInput) -> None: self._result = LoopbackResult(success=False) self._sdk_session_id: str | None = None # Track SDK session ID for this run self._stream_done_emitted: bool = False # Dedupe flag for stream.done() + self._interrupt_notice_emitted: bool = False # Track which session lines have been persisted to avoid duplicates self._persisted_line_uuids: set[str] = set() # Track pending approval tool IDs to suppress synthetic interruption results. @@ -624,6 +627,21 @@ async def send_error(self, error: str) -> None: """Handle a terminal runtime error.""" await self._handle_error(error) + def mark_cancelled(self, reason: AgentCancelReason) -> None: + """Record that the active runtime turn is expected to stop early.""" + self._result.cancelled = True + self._result.cancelled_reason = reason + + async def _emit_interrupt_notice_if_cancelled( + self, stream_sink: LoopbackEventSink + ) -> None: + if not self._result.cancelled or self._interrupt_notice_emitted: + return + self._interrupt_notice_emitted = True + await stream_sink.append( + UnifiedStreamEvent.interrupt_event(reason=self._result.cancelled_reason) + ) + async def _handle_done(self) -> bool: """Handle runtime completion.""" stream_sink = await self.prepare() @@ -641,6 +659,7 @@ async def _handle_done(self) -> bool: self._result.error = validation_error return True self._result.success = True + await self._emit_interrupt_notice_if_cancelled(stream_sink) await self._emit_stream_done() return True @@ -889,6 +908,8 @@ def _validate_runtime_completion(self) -> str | None: """Return a terminal error when runtime completion is missing a usable result.""" if self._result.approval_requested or self._result.error is not None: return None + if self._result.cancelled: + return None if not self._received_result: return "Runtime completed without final result" if ( diff --git a/tracecat/agent/runtime/claude_code/broker.py b/tracecat/agent/runtime/claude_code/broker.py index 81aeebec96..0b200e0246 100644 --- a/tracecat/agent/runtime/claude_code/broker.py +++ b/tracecat/agent/runtime/claude_code/broker.py @@ -17,6 +17,7 @@ from tracecat.agent.runtime.claude_code.transport import ( SandboxedCLITransport, ) +from tracecat.agent.session.types import AgentCancelReason class ConcurrentSessionTurnError(RuntimeError): @@ -35,12 +36,21 @@ class ClaudeTurnRequest: skills_dir: Path | None = None +@dataclass(frozen=True, slots=True) +class ActiveTurn: + """Runtime state for one broker-managed session turn.""" + + task: asyncio.Task[None] + runtime: ClaudeAgentRuntime + session_id: str + + class ClaudeRuntimeBroker: """Warm in-process broker that owns host-side Claude orchestration.""" def __init__(self) -> None: self._closed = False - self._active_turns: dict[str, asyncio.Task[None]] = {} + self._active_turns: dict[str, ActiveTurn] = {} self._lock = asyncio.Lock() async def start(self) -> None: @@ -51,7 +61,7 @@ async def stop(self) -> None: """Cancel any remaining active turns and reject future work.""" async with self._lock: self._closed = True - active_tasks = list(self._active_turns.values()) + active_tasks = [turn.task for turn in self._active_turns.values()] for task in active_tasks: task.cancel() for task in active_tasks: @@ -71,6 +81,27 @@ async def run_turn( if current_task is None: raise RuntimeError("Broker turn must run inside an asyncio task") + path_mapping = self._build_path_mapping( + session_id=str(request.init_payload.session_id) + ) + runtime = ClaudeAgentRuntime( + handler, + transport_factory=lambda options: SandboxedCLITransport( + options=options, + socket_dir=request.socket_dir, + llm_socket_path=request.llm_socket_path, + job_dir=request.job_dir, + path_mapping=path_mapping, + enable_internet_access=request.enable_internet_access, + use_jailed_paths=not TRACECAT__DISABLE_NSJAIL, + session_id=str(request.init_payload.session_id), + skills_dir=request.skills_dir, + ), + session_home_dir=path_mapping.host_home_dir, + cwd=path_mapping.runtime_cwd, + cwd_setup_path=path_mapping.host_project_dir, + ) + async with self._lock: if self._closed: raise RuntimeError("Claude runtime broker is not running") @@ -78,30 +109,14 @@ async def run_turn( raise ConcurrentSessionTurnError( f"Session {session_key} already has an active turn" ) - self._active_turns[session_key] = current_task + self._active_turns[session_key] = ActiveTurn( + task=current_task, + runtime=runtime, + session_id=session_key, + ) try: await handler.prepare() - path_mapping = self._build_path_mapping( - session_id=str(request.init_payload.session_id) - ) - runtime = ClaudeAgentRuntime( - handler, - transport_factory=lambda options: SandboxedCLITransport( - options=options, - socket_dir=request.socket_dir, - llm_socket_path=request.llm_socket_path, - job_dir=request.job_dir, - path_mapping=path_mapping, - enable_internet_access=request.enable_internet_access, - use_jailed_paths=not TRACECAT__DISABLE_NSJAIL, - session_id=str(request.init_payload.session_id), - skills_dir=request.skills_dir, - ), - session_home_dir=path_mapping.host_home_dir, - cwd=path_mapping.runtime_cwd, - cwd_setup_path=path_mapping.host_project_dir, - ) await runtime.run(request.init_payload) finally: async with self._lock: @@ -110,9 +125,27 @@ async def run_turn( async def cancel_turn(self, session_id: str) -> None: """Cancel an active turn for the provided session, if one exists.""" async with self._lock: - task = self._active_turns.get(session_id) - if task is not None: - task.cancel() + active = self._active_turns.get(session_id) + if active is not None: + active.task.cancel() + + async def interrupt_turn( + self, + session_id: str, + reason: AgentCancelReason, + timeout: float | None = None, + ) -> None: + """Interrupt an active turn through the runtime, if one exists.""" + async with self._lock: + active = self._active_turns.get(session_id) + if active is None: + return + + interrupt = active.runtime.interrupt(reason=reason) + if timeout is None: + await interrupt + else: + await asyncio.wait_for(interrupt, timeout=timeout) @staticmethod def _build_path_mapping(*, session_id: str) -> ClaudeSandboxPathMapping: diff --git a/tracecat/agent/runtime/claude_code/runtime.py b/tracecat/agent/runtime/claude_code/runtime.py index 4c27c4aa60..0c4312be87 100644 --- a/tracecat/agent/runtime/claude_code/runtime.py +++ b/tracecat/agent/runtime/claude_code/runtime.py @@ -80,6 +80,7 @@ is_meta_session_line, is_synthetic_session_line, ) +from tracecat.agent.session.types import AgentCancelReason from tracecat.logger import logger CLAUDE_PROJECT_DIR_MAX_LENGTH = 200 @@ -266,6 +267,11 @@ def __init__( self._pending_approval_tool_ids: set[str] = set() self.client: ClaudeSDKClient | None = None self._was_interrupted: bool = False + self._interrupt_requested: AgentCancelReason | None = None + self._interrupt_sent: bool = False + self._client_connected_event = asyncio.Event() + self._query_sent_event = asyncio.Event() + self._interrupt_lock = asyncio.Lock() # For incremental JSONL line tracking self._sdk_session_id: str | None = None self._last_seen_byte_offset: int = 0 @@ -863,6 +869,34 @@ async def _handle_approval_request( self._was_interrupted = True await self.client.interrupt() + async def interrupt(self, *, reason: AgentCancelReason) -> None: + """Request a graceful interrupt for this turn. + + The broker can call this before the SDK client is ready. In that case we + retain the requested reason and self-interrupt once the query is sent. + """ + self._interrupt_requested = reason + await self._client_connected_event.wait() + await self._query_sent_event.wait() + await self._send_pending_interrupt() + + async def _send_pending_interrupt(self) -> None: + """Deliver the pending user/worker interrupt to Claude once.""" + async with self._interrupt_lock: + reason = self._interrupt_requested + if reason is None or self._interrupt_sent or self.client is None: + return + + logger.info( + "Interrupting Claude runtime", + session_id=self._session_id, + reason=reason, + ) + self._was_interrupted = True + self._interrupt_sent = True + await self.client.interrupt() + await self._emit_new_session_lines() + async def _pre_tool_use_hook( self, input_data: HookInput, @@ -1074,6 +1108,9 @@ def _configure_runtime_state(self, payload: RuntimeInitPayload) -> None: self._sdk_session_id = None self._last_seen_byte_offset = 0 self._session_flush_event.clear() + self._client_connected_event.clear() + self._query_sent_event.clear() + self._interrupt_sent = False self.registry_tools = payload.allowed_actions self.tool_approvals = payload.config.tool_approvals self._agents_enabled = payload.config.agents.enabled @@ -1366,6 +1403,7 @@ async def drain_stderr() -> None: logger.debug("Client created, entering context") async with client: self.client = client + self._client_connected_event.set() log_benchmark_phase("runtime_client_connected") stderr_task = asyncio.create_task(drain_stderr()) session_flush_task = asyncio.create_task( @@ -1387,6 +1425,8 @@ async def drain_stderr() -> None: self._build_compaction_status_event(phase="started") ) await client.query(query_input) + self._query_sent_event.set() + await self._send_pending_interrupt() log_benchmark_phase("runtime_query_sent") await self._event_writer.send_log( diff --git a/tracecat/agent/session/activities.py b/tracecat/agent/session/activities.py index 914cd87560..afe5e66fea 100644 --- a/tracecat/agent/session/activities.py +++ b/tracecat/agent/session/activities.py @@ -18,7 +18,7 @@ from tracecat.agent.executor.schemas import ToolExecutionResult from tracecat.agent.session.schemas import AgentSessionCreate from tracecat.agent.session.service import AgentSessionService -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.stream.connector import AgentStream from tracecat.agent.subagents import ResolvedAgentsConfig from tracecat.auth.types import Role @@ -117,6 +117,21 @@ class LoadSessionMessagesResult(BaseModel): error: str | None = None +class UpdateSessionStatusInput(BaseModel): + """Input for update_session_status_activity.""" + + role: Role + session_id: uuid.UUID + status: AgentSessionStatus + clear_curr_run_id: bool = False + + +class UpdateSessionStatusResult(BaseModel): + """Result from update_session_status_activity.""" + + found: bool + + @activity.defn async def create_session_activity(input: CreateSessionInput) -> CreateSessionResult: """Create or get an existing agent session in the database. @@ -203,6 +218,7 @@ async def create_session_activity(input: CreateSessionInput) -> CreateSessionRes # Set curr_run_id if provided (for workflow-initiated sessions) if input.curr_run_id is not None: agent_session.curr_run_id = input.curr_run_id + agent_session.status = AgentSessionStatus.RUNNING.value service.session.add(agent_session) await service.session.commit() @@ -303,6 +319,22 @@ async def load_session_activity(input: LoadSessionInput) -> LoadSessionResult: return LoadSessionResult(found=False, error=str(e)) +@activity.defn +async def update_session_status_activity( + input: UpdateSessionStatusInput, +) -> UpdateSessionStatusResult: + """Update lightweight agent session lifecycle state.""" + ctx_role.set(input.role) + + async with AgentSessionService.with_session(role=input.role) as service: + found = await service.set_session_status( + input.session_id, + input.status, + clear_curr_run_id=input.clear_curr_run_id, + ) + return UpdateSessionStatusResult(found=found) + + @activity.defn async def load_session_messages_activity( input: LoadSessionMessagesInput, @@ -384,6 +416,7 @@ def get_session_activities() -> list: return [ create_session_activity, load_session_activity, + update_session_status_activity, load_session_messages_activity, reconcile_tool_results_activity, ] diff --git a/tracecat/agent/session/router.py b/tracecat/agent/session/router.py index 5da32d4ff9..3623da7b2b 100644 --- a/tracecat/agent/session/router.py +++ b/tracecat/agent/session/router.py @@ -11,6 +11,8 @@ from tracecat import config from tracecat.agent.adapter import vercel from tracecat.agent.session.schemas import ( + AgentSessionCancelRequest, + AgentSessionCancelResponse, AgentSessionCreate, AgentSessionForkRequest, AgentSessionRead, @@ -19,7 +21,7 @@ AgentSessionUpdate, ) from tracecat.agent.session.service import AgentSessionService -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.stream.connector import AgentStream from tracecat.agent.stream.events import StreamFormat from tracecat.agent.subagents import ResolvedAgentsConfig @@ -33,7 +35,7 @@ ContinueRunRequest, ) from tracecat.db.dependencies import AsyncDBSession -from tracecat.exceptions import TracecatNotFoundError +from tracecat.exceptions import TracecatConflictError, TracecatNotFoundError from tracecat.logger import logger router = APIRouter(prefix="/agent/sessions", tags=["agent-sessions"]) @@ -134,6 +136,7 @@ async def get_session( created_at=agent_session.created_at, updated_at=agent_session.updated_at, last_stream_id=agent_session.last_stream_id, + turn_status=AgentSessionStatus(agent_session.status), messages=messages, ) @@ -203,6 +206,7 @@ async def get_session_vercel( created_at=agent_session.created_at, updated_at=agent_session.updated_at, last_stream_id=agent_session.last_stream_id, + turn_status=AgentSessionStatus(agent_session.status), messages=ui_messages, ) @@ -288,6 +292,33 @@ async def delete_session( await svc.delete_session(agent_session) +@router.post("/{session_id}/cancel") +@require_scope("agent:execute") +async def cancel_session( + session_id: uuid.UUID, + role: WorkspaceUserRouteRole, + session: AsyncDBSession, + request: AgentSessionCancelRequest | None = None, +) -> AgentSessionCancelResponse: + """Request graceful cancellation for the active agent session turn.""" + svc = AgentSessionService(session, role) + try: + return await svc.request_cancel( + session_id, + request or AgentSessionCancelRequest(), + ) + except TracecatNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except TracecatConflictError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=e.detail or str(e), + ) from e + + @router.post("/{session_id}/messages") @require_scope("agent:execute") async def send_message( @@ -380,6 +411,11 @@ async def send_message( status_code=status.HTTP_404_NOT_FOUND, detail=str(e), ) from e + except TracecatConflictError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=e.detail or str(e), + ) from e except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/tracecat/agent/session/schemas.py b/tracecat/agent/session/schemas.py index 506f5acd93..a575579850 100644 --- a/tracecat/agent/session/schemas.py +++ b/tracecat/agent/session/schemas.py @@ -10,7 +10,11 @@ from tracecat.agent.adapter.vercel import UIMessage from tracecat.agent.common.stream_types import HarnessType -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import ( + AgentCancelReason, + AgentSessionEntity, + AgentSessionStatus, +) from tracecat.agent.subagents import ResolvedAgentsConfig @@ -121,6 +125,11 @@ class AgentSessionRead(BaseModel): harness_type: str | None # Stream tracking last_stream_id: str | None = None + # Lifecycle + turn_status: AgentSessionStatus = Field( + default=AgentSessionStatus.IDLE, + validation_alias="status", + ) # Fork tracking parent_session_id: uuid.UUID | None = None # Timestamps @@ -154,3 +163,21 @@ class AgentSessionForkRequest(BaseModel): description="Override entity type for the forked session. " "Use 'approval' for inbox forks to hide from main chat list.", ) + + +class AgentSessionCancelRequest(BaseModel): + """Request schema for cancelling an active agent session turn.""" + + reason: AgentCancelReason = Field( + default="user_cancel", + description="Reason for requesting cancellation.", + ) + + +class AgentSessionCancelResponse(BaseModel): + """Response schema for an accepted agent session cancellation request.""" + + session_id: uuid.UUID + run_id: uuid.UUID + reason: AgentCancelReason + turn_status: AgentSessionStatus diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index 9e6c680a4b..e2a059adb5 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator, Sequence from dataclasses import dataclass, replace from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import orjson from pydantic_ai.messages import ( @@ -28,6 +28,7 @@ update, ) from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from temporalio.common import TypedSearchAttributes from tracecat_registry._internal.exceptions import SecretNotFoundError @@ -42,6 +43,7 @@ from tracecat.agent.runtime.claude_code.session_lines import ( APPROVAL_INTERRUPT_CONTENT_EXACT, APPROVAL_INTERRUPT_CONTENT_MARKERS, + is_approval_interrupt_content, is_approval_interrupt_tool_result, is_continuation_control_artifact, session_line_uuid, @@ -49,12 +51,14 @@ from tracecat.agent.schemas import RunAgentArgs from tracecat.agent.service import AgentManagementService from tracecat.agent.session.schemas import ( + AgentSessionCancelRequest, + AgentSessionCancelResponse, AgentSessionCreate, AgentSessionRead, AgentSessionUpdate, ) from tracecat.agent.session.title_generator import generate_session_title -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.subagents import ( ResolvedAgentsConfig, ) @@ -80,7 +84,11 @@ from tracecat.db.models import AgentSession, AgentSessionHistory, Approval, Chat from tracecat.dsl.client import get_temporal_client from tracecat.dsl.common import RETRY_POLICIES -from tracecat.exceptions import TracecatNotFoundError, TracecatValidationError +from tracecat.exceptions import ( + TracecatConflictError, + TracecatNotFoundError, + TracecatValidationError, +) from tracecat.identifiers import UserID from tracecat.logger import logger from tracecat.redis.client import get_redis_client @@ -555,6 +563,76 @@ async def update_last_stream_id( await self.session.refresh(agent_session) return agent_session + async def set_session_status( + self, + session_id: uuid.UUID, + status: AgentSessionStatus, + *, + clear_curr_run_id: bool = False, + ) -> bool: + """Update an agent session lifecycle status by ID.""" + values: dict[str, Any] = {"status": status.value} + if clear_curr_run_id: + values["curr_run_id"] = None + + stmt = ( + update(AgentSession) + .where( + AgentSession.id == session_id, + AgentSession.workspace_id == self.workspace_id, + ) + .values(**values) + ) + result = await self.session.execute(stmt) + await self.session.commit() + return bool(cast(CursorResult[Any], result).rowcount) + + async def request_cancel( + self, + session_id: uuid.UUID, + request: AgentSessionCancelRequest, + ) -> AgentSessionCancelResponse: + """Request graceful cancellation for the active agent workflow turn.""" + from tracecat_ee.agent.types import AgentWorkflowID + from tracecat_ee.agent.workflows.durable import ( + DurableAgentWorkflow, + WorkflowCancelRequest, + ) + + agent_session = await self.get_session(session_id) + if agent_session is None: + raise TracecatNotFoundError(f"Session with ID {session_id} not found") + + current_status = AgentSessionStatus(agent_session.status) + if current_status is not AgentSessionStatus.RUNNING: + raise TracecatConflictError( + "Agent session is not currently running", + detail={"status": current_status.value}, + ) + + curr_run_id = agent_session.curr_run_id + if curr_run_id is None: + raise TracecatConflictError("Agent session has no active workflow run") + + client = await get_temporal_client() + workflow_id = AgentWorkflowID(curr_run_id) + handle = client.get_workflow_handle_for( + DurableAgentWorkflow.run, + str(workflow_id), + ) + + await handle.execute_update( + DurableAgentWorkflow.request_cancel, + WorkflowCancelRequest(reason=request.reason), + ) + + return AgentSessionCancelResponse( + session_id=session_id, + run_id=curr_run_id, + reason=request.reason, + turn_status=current_status, + ) + # ========================================================================= # Session History Management (for Claude SDK session persistence) # ========================================================================= @@ -1074,6 +1152,7 @@ async def run_turn( # Update session with current run_id for approval lookups agent_session.curr_run_id = run_id + agent_session.status = AgentSessionStatus.RUNNING.value self.session.add(agent_session) await self.session.commit() @@ -1089,16 +1168,22 @@ async def run_turn( task_queue=config.TRACECAT__AGENT_QUEUE, ) - await client.start_workflow( - DurableAgentWorkflow.run, - workflow_args, - id=str(workflow_id), - task_queue=config.TRACECAT__AGENT_QUEUE, - retry_policy=RETRY_POLICIES["workflow:fail_fast"], - search_attributes=self._build_direct_agent_search_attributes( - session_id - ), - ) + try: + await client.start_workflow( + DurableAgentWorkflow.run, + workflow_args, + id=str(workflow_id), + task_queue=config.TRACECAT__AGENT_QUEUE, + retry_policy=RETRY_POLICIES["workflow:fail_fast"], + search_attributes=self._build_direct_agent_search_attributes( + session_id + ), + ) + except Exception: + agent_session.status = AgentSessionStatus.FAILED.value + self.session.add(agent_session) + await self.session.commit() + raise # Return ChatResponse with session_id for streaming stream_url = f"/api/agent/sessions/{session_id}/stream" @@ -1122,6 +1207,10 @@ async def validate_turn_request( ) return agent_session case VercelChatRequest() | BasicChatRequest(): + if agent_session.status == AgentSessionStatus.RUNNING.value: + raise TracecatConflictError( + "This session already has a running turn." + ) if await self.has_pending_approvals(session_id): raise ValueError( "This session is waiting for approval decisions. " @@ -1552,6 +1641,9 @@ async def list_messages( # Internal entries contain tool results that the adapter will extract messages: list[ChatMessage] = [] internal_uuids: set[str] = set() + show_interrupt_notice = ( + getattr(agent_session, "status", None) == AgentSessionStatus.STOPPED.value + ) for entry in all_entries: content = entry.content if not content: @@ -1559,6 +1651,21 @@ async def list_messages( # Skip internal entries (e.g., continuation prompts) if entry.kind == MessageKind.INTERNAL.value: + if ( + show_interrupt_notice + and getattr(entry, "session_id", session_id) == session_id + and self._is_interrupt_session_line(content) + and not (messages and messages[-1].kind == MessageKind.INTERRUPT) + ): + kind = MessageKind.INTERRUPT + if not kinds or kind in kinds: + messages.append( + ChatMessage( + id=str(entry.id), + kind=kind, + interrupt={"reason": "user_cancel"}, + ) + ) if line_uuid := session_line_uuid(content): internal_uuids.add(line_uuid) continue @@ -1649,6 +1756,36 @@ async def list_messages( return messages + @staticmethod + def _is_interrupt_session_line(content: dict[str, Any]) -> bool: + """Return True for internal SDK rows that represent a user interruption.""" + message = content.get("message") + if not isinstance(message, dict): + return False + + msg_content = message.get("content") + if isinstance(msg_content, str): + return is_approval_interrupt_content(msg_content) + if not isinstance(msg_content, list): + return False + + for block in msg_content: + if isinstance(block, dict): + if block.get("type") == "text" and is_approval_interrupt_content( + block.get("text", "") + ): + return True + if ( + block.get("type") == "tool_result" + and block.get("is_error") is True + and is_approval_interrupt_content(block.get("content", "")) + ): + return True + elif is_approval_interrupt_content(block): + return True + + return False + @staticmethod def _extract_tool_uses_from_message( message: dict[str, Any], diff --git a/tracecat/agent/session/types.py b/tracecat/agent/session/types.py index 037b3bd939..2bae61fe58 100644 --- a/tracecat/agent/session/types.py +++ b/tracecat/agent/session/types.py @@ -1,6 +1,7 @@ """Domain types for agent session management.""" from enum import StrEnum +from typing import Literal class AgentSessionEntity(StrEnum): @@ -23,3 +24,16 @@ class AgentSessionEntity(StrEnum): WORKFLOW = "workflow" APPROVAL = "approval" EXTERNAL_CHANNEL = "external_channel" + + +class AgentSessionStatus(StrEnum): + """Lifecycle state for an agent session turn.""" + + IDLE = "idle" + RUNNING = "running" + WAITING_FOR_APPROVAL = "waiting_for_approval" + STOPPED = "stopped" + FAILED = "failed" + + +type AgentCancelReason = Literal["user_cancel", "worker_drain"] diff --git a/tracecat/chat/constants.py b/tracecat/chat/constants.py index f265b0b952..6484012533 100644 --- a/tracecat/chat/constants.py +++ b/tracecat/chat/constants.py @@ -8,3 +8,6 @@ COMPACTION_DATA_PART_TYPE = "data-compaction" """UI data part identifier for transient compaction status payloads.""" + +INTERRUPT_DATA_PART_TYPE = "data-interrupt" +"""UI data part identifier for interrupted chat turn notifications.""" diff --git a/tracecat/chat/enums.py b/tracecat/chat/enums.py index 915bf8809f..0bfebea325 100644 --- a/tracecat/chat/enums.py +++ b/tracecat/chat/enums.py @@ -24,3 +24,4 @@ class MessageKind(StrEnum): COMPACTION = ( "compaction" # Compaction status badge shown when conversation is compacted ) + INTERRUPT = "interrupt" # Status badge shown when a chat turn is interrupted diff --git a/tracecat/chat/schemas.py b/tracecat/chat/schemas.py index 1633b71467..b699a7b0bc 100644 --- a/tracecat/chat/schemas.py +++ b/tracecat/chat/schemas.py @@ -199,6 +199,7 @@ class ChatMessage(BaseModel): - kind=CHAT_MESSAGE: Contains message field with user/assistant content - kind=APPROVAL_REQUEST/APPROVAL_DECISION: Contains approval field with approval data - kind=COMPACTION: Contains compaction field with compaction status data + - kind=INTERRUPT: Contains interrupt field with interruption status data """ id: str = Field(..., description="Unique message identifier") @@ -218,6 +219,10 @@ class ChatMessage(BaseModel): default=None, description="Compaction status data for badge rendering (for kind=COMPACTION)", ) + interrupt: dict[str, Any] | None = Field( + default=None, + description="Interruption status data for badge rendering (for kind=INTERRUPT)", + ) @classmethod def from_db(cls, db_msg: models.ChatMessage) -> ChatMessage: diff --git a/tracecat/db/models.py b/tracecat/db/models.py index d448df27f8..02b09d0972 100644 --- a/tracecat/db/models.py +++ b/tracecat/db/models.py @@ -2850,6 +2850,14 @@ class AgentSession(WorkspaceModel): nullable=True, doc="Last processed Redis stream ID - used to resume streaming from correct position", ) + status: Mapped[str] = mapped_column( + String(32), + default="idle", + server_default="idle", + nullable=False, + index=True, + doc="Agent session lifecycle status for active turn coordination", + ) # Parent session for forked sessions (approval continuations) parent_session_id: Mapped[uuid.UUID | None] = mapped_column( UUID, From e589dd7826dc1ebd68f591b5fa1070ce58da614c Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Wed, 27 May 2026 11:45:47 -0400 Subject: [PATCH 2/5] fix(agent): clear running state on workflow failures --- .../tracecat_ee/agent/workflows/durable.py | 16 +++++++ tests/temporal/test_durable_agent_workflow.py | 43 +++++++++++++++++++ tests/unit/test_agent_activities.py | 37 +++++++++++++++- tests/unit/test_agent_session_router.py | 10 +++-- tracecat/agent/session/activities.py | 3 +- tracecat/agent/session/schemas.py | 6 +-- 6 files changed, 107 insertions(+), 8 deletions(-) diff --git a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py index 6be784ea56..f8cb1ac540 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py +++ b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py @@ -342,6 +342,7 @@ def _resolve_agent_output( LOAD_TERMINAL_MESSAGE_HISTORY_PATCH = "durable-agent-load-terminal-message-history-v1" AGENT_ACTIVITY_GRACEFUL_CANCEL_PATCH = "durable-agent-activity-graceful-cancel-v1" AGENT_SESSION_STATUS_PATCH = "durable-agent-session-status-v1" +AGENT_SESSION_FAILURE_STATUS_PATCH = "durable-agent-session-failure-status-v1" AGENT_ACTIVITY_GRACEFUL_CANCEL_HEARTBEAT_TIMEOUT_SECONDS = 10 @@ -718,6 +719,11 @@ async def run(self, args: AgentWorkflowArgs) -> AgentOutput: ): await self._emit_session_error(e.message) raise + except TemporalCancelledError: + raise + except Exception: + await self._set_agent_session_failed_for_unhandled_failure() + raise async def _emit_session_error(self, message: str) -> None: try: @@ -818,6 +824,16 @@ async def _set_agent_session_status( error=str(e), ) + async def _set_agent_session_failed_for_unhandled_failure(self) -> None: + """Mark failed for post-deploy unhandled workflow failures.""" + if not workflow.patched(AGENT_SESSION_FAILURE_STATUS_PATCH): + return + + await self._set_agent_session_status( + AgentSessionStatus.FAILED, + clear_curr_run_id=True, + ) + async def _run_agent_activity_turn( self, executor_input: AgentExecutorInput, diff --git a/tests/temporal/test_durable_agent_workflow.py b/tests/temporal/test_durable_agent_workflow.py index 3abaacd257..4487cc470e 100644 --- a/tests/temporal/test_durable_agent_workflow.py +++ b/tests/temporal/test_durable_agent_workflow.py @@ -765,6 +765,49 @@ async def mock_run_agent_activity( assert status_update_inputs[-1].clear_curr_run_id is True +@pytest.mark.anyio +@pytest.mark.integration +async def test_agent_workflow_marks_failed_when_executor_activity_errors( + temporal_client: Client, + agent_worker_factory, + agent_workflow_args: AgentWorkflowArgs, + mock_session_id: uuid.UUID, +) -> None: + """Executor ActivityError failures should release the running session gate.""" + queue = f"test-agent-queue-{mock_session_id}" + status_update_inputs: list[UpdateSessionStatusInput] = [] + + def mock_executor( + _call_count: int, + _input: AgentExecutorInput, + ) -> AgentExecutorResult: + raise ApplicationError("sandbox worker timed out") + + activities = create_activities_with_mock_executor( + mock_executor, + status_update_inputs=status_update_inputs, + ) + + async with agent_worker_factory( + temporal_client, task_queue=queue, custom_activities=activities + ): + with pytest.raises(WorkflowFailureError): + await temporal_client.execute_workflow( + DurableAgentWorkflow.run, + agent_workflow_args, + id=AgentWorkflowID(mock_session_id), + task_queue=queue, + retry_policy=RETRY_POLICIES["workflow:fail_fast"], + execution_timeout=timedelta(seconds=30), + ) + + assert [input.status for input in status_update_inputs] == [ + AgentSessionStatus.RUNNING, + AgentSessionStatus.FAILED, + ] + assert status_update_inputs[-1].clear_curr_run_id is True + + @pytest.mark.anyio @pytest.mark.integration async def test_agent_workflow_preserves_legacy_activity_message_history( diff --git a/tests/unit/test_agent_activities.py b/tests/unit/test_agent_activities.py index b2be54827d..b0052abc04 100644 --- a/tests/unit/test_agent_activities.py +++ b/tests/unit/test_agent_activities.py @@ -47,7 +47,7 @@ load_session_activity, load_session_messages_activity, ) -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.skill.types import ResolvedSkillRef from tracecat.agent.subagents import ResolvedAgentsConfig from tracecat.agent.tools import BuildToolsResult @@ -396,6 +396,41 @@ async def test_idempotent_for_existing_session( assert result.success is True assert result.session_id == mock_session_id + @pytest.mark.anyio + @patch("tracecat.agent.session.activities.AgentSessionService.with_session") + async def test_curr_run_id_does_not_mark_session_running( + self, mock_with_session, mock_role: Role, mock_session_id: uuid.UUID + ): + """Create setup records the run token without taking the running lock.""" + curr_run_id = uuid.uuid4() + input = CreateSessionInput( + role=mock_role, + session_id=mock_session_id, + entity_type=AgentSessionEntity.WORKFLOW, + entity_id=uuid.uuid4(), + curr_run_id=curr_run_id, + ) + + mock_agent_session = MagicMock() + mock_agent_session.agents_binding = None + mock_agent_session.status = AgentSessionStatus.IDLE.value + mock_service = AsyncMock() + mock_service.get_or_create_session.return_value = (mock_agent_session, False) + mock_service.session = MagicMock() + mock_service.session.commit = AsyncMock() + + mock_ctx = AsyncMock() + mock_ctx.__aenter__.return_value = mock_service + mock_with_session.return_value = mock_ctx + + result = await create_session_activity(input) + + assert result.success is True + assert mock_agent_session.curr_run_id == curr_run_id + assert mock_agent_session.status == AgentSessionStatus.IDLE.value + mock_service.session.add.assert_called_once_with(mock_agent_session) + mock_service.session.commit.assert_awaited_once() + @pytest.mark.anyio @patch("tracecat.agent.session.activities.AgentSessionService.with_session") async def test_backfills_disabled_agents_binding_for_legacy_existing_session( diff --git a/tests/unit/test_agent_session_router.py b/tests/unit/test_agent_session_router.py index 429277ece6..63862ff528 100644 --- a/tests/unit/test_agent_session_router.py +++ b/tests/unit/test_agent_session_router.py @@ -20,7 +20,7 @@ send_message, stream_session_events, ) -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.auth.types import Role from tracecat.chat.schemas import ( ApprovalDecision, @@ -87,7 +87,7 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: @pytest.mark.anyio async def test_get_session_includes_agents_binding() -> None: - session_stub = _agent_session_stub() + session_stub = _agent_session_stub(status=AgentSessionStatus.RUNNING.value) fake_svc = SimpleNamespace( get_session=AsyncMock(return_value=session_stub), list_messages=AsyncMock(return_value=[]), @@ -107,11 +107,14 @@ async def test_get_session_includes_agents_binding() -> None: "enabled": False, "subagents": [], } + assert response.turn_status is AgentSessionStatus.RUNNING @pytest.mark.anyio async def test_get_session_vercel_includes_agents_binding() -> None: - session_stub = _agent_session_stub() + session_stub = _agent_session_stub( + status=AgentSessionStatus.WAITING_FOR_APPROVAL.value + ) fake_svc = SimpleNamespace( get_session=AsyncMock(return_value=session_stub), list_messages=AsyncMock(return_value=[]), @@ -131,6 +134,7 @@ async def test_get_session_vercel_includes_agents_binding() -> None: "enabled": False, "subagents": [], } + assert response.turn_status is AgentSessionStatus.WAITING_FOR_APPROVAL @pytest.mark.anyio diff --git a/tracecat/agent/session/activities.py b/tracecat/agent/session/activities.py index afe5e66fea..b187638569 100644 --- a/tracecat/agent/session/activities.py +++ b/tracecat/agent/session/activities.py @@ -217,8 +217,9 @@ async def create_session_activity(input: CreateSessionInput) -> CreateSessionRes # Set curr_run_id if provided (for workflow-initiated sessions) if input.curr_run_id is not None: + # The workflow marks RUNNING only after setup completes, just + # before handing control to the executor activity. agent_session.curr_run_id = input.curr_run_id - agent_session.status = AgentSessionStatus.RUNNING.value service.session.add(agent_session) await service.session.commit() diff --git a/tracecat/agent/session/schemas.py b/tracecat/agent/session/schemas.py index a575579850..30163c3efc 100644 --- a/tracecat/agent/session/schemas.py +++ b/tracecat/agent/session/schemas.py @@ -6,7 +6,7 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from tracecat.agent.adapter.vercel import UIMessage from tracecat.agent.common.stream_types import HarnessType @@ -103,7 +103,7 @@ class AgentSessionHistoryRead(BaseModel): created_at: datetime updated_at: datetime - model_config = {"from_attributes": True} + model_config = ConfigDict(from_attributes=True) class AgentSessionRead(BaseModel): @@ -136,7 +136,7 @@ class AgentSessionRead(BaseModel): created_at: datetime updated_at: datetime - model_config = {"from_attributes": True} + model_config = ConfigDict(from_attributes=True, validate_by_name=True) class AgentSessionReadWithMessages(AgentSessionRead): From 2fe7b79fa71daced43b6c8e867c27c89d39c33f6 Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Wed, 27 May 2026 15:47:47 -0400 Subject: [PATCH 3/5] fix(agent): block new turns while awaiting approval --- tests/unit/test_agent_activities.py | 97 ++++++++++++++++++- .../test_agent_session_turn_validation.py | 83 ++++++++++++++++ ...urable_agent_workflow_search_attributes.py | 4 +- tracecat/agent/executor/activity.py | 47 ++++++++- tracecat/agent/session/service.py | 7 +- 5 files changed, 231 insertions(+), 7 deletions(-) create mode 100644 tests/unit/test_agent_session_turn_validation.py diff --git a/tests/unit/test_agent_activities.py b/tests/unit/test_agent_activities.py index b0052abc04..4acb59abed 100644 --- a/tests/unit/test_agent_activities.py +++ b/tests/unit/test_agent_activities.py @@ -11,7 +11,7 @@ import shutil import uuid from pathlib import Path -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -33,7 +33,11 @@ _hydrate_sdk_session_history, run_agent_activity, ) -from tracecat.agent.executor.loopback import LoopbackResult +from tracecat.agent.executor.loopback import ( + LoopbackHandler, + LoopbackInput, + LoopbackResult, +) from tracecat.agent.schemas import ToolFilters from tracecat.agent.session.activities import ( CreateSessionInput, @@ -1128,6 +1132,95 @@ def test_apply_loopback_result_copies_terminal_stream_error_flag( assert result.error == "runtime failed" assert result.terminal_stream_error_emitted is True + @pytest.mark.anyio + async def test_run_with_broker_heartbeats_while_gracefully_cancelling( + self, + executor_input: AgentExecutorInput, + tmp_path: Path, + ) -> None: + class SlowInterruptBroker: + def __init__(self) -> None: + self.turn_started = asyncio.Event() + self.interrupt_finished = asyncio.Event() + self.interrupted_session_ids: list[str] = [] + self.cancelled_session_ids: list[str] = [] + + async def run_turn(self, request: Any, handler: LoopbackHandler) -> None: + del request, handler + self.turn_started.set() + await self.interrupt_finished.wait() + + async def interrupt_turn(self, session_id: str, reason: str) -> None: + del reason + self.interrupted_session_ids.append(session_id) + await asyncio.sleep(0.03) + self.interrupt_finished.set() + + async def cancel_turn(self, session_id: str) -> None: + self.cancelled_session_ids.append(session_id) + + class StartedProxy: + async def start(self) -> None: + pass + + broker = SlowInterruptBroker() + executor = SandboxedAgentExecutor(input=executor_input) + executor._job_dir = tmp_path + executor._llm_proxy = cast(Any, StartedProxy()) + socket_dir = tmp_path / "sockets" + socket_dir.mkdir() + result = AgentExecutorResult( + success=False, + terminal_stream_error_emitted=False, + ) + handler = LoopbackHandler( + input=LoopbackInput( + session_id=executor_input.session_id, + workspace_id=executor_input.workspace_id, + ) + ) + + with ( + patch( + "tracecat.agent.executor.activity.get_claude_runtime_broker", + return_value=broker, + ), + patch("tracecat.agent.executor.activity.activity") as mock_activity, + patch( + "tracecat.agent.executor.activity." + "AGENT_ACTIVITY_HEARTBEAT_INTERVAL_SECONDS", + 0.01, + ), + ): + mock_activity.heartbeat = MagicMock() + mock_activity.cancellation_details.return_value = None + run_task = asyncio.create_task( + executor._run_with_broker( + result=result, + handler=handler, + init_payload=executor._build_runtime_init_payload(), + socket_dir=socket_dir, + llm_socket_path=socket_dir / "llm.sock", + ) + ) + await asyncio.wait_for(broker.turn_started.wait(), timeout=1) + + run_task.cancel() + await asyncio.wait_for(run_task, timeout=1) + + heartbeat_messages = [ + call.args[0] for call in mock_activity.heartbeat.call_args_list + ] + assert result.success is True + assert result.cancelled is True + assert result.cancelled_reason == "user_cancel" + assert broker.interrupted_session_ids == [str(executor_input.session_id)] + assert broker.cancelled_session_ids == [] + assert any( + "Agent cancellation interrupting runtime" in message + for message in heartbeat_messages + ) + @pytest.mark.anyio @patch("tracecat.agent.executor.activity.AgentSessionService.with_session") async def test_hydrates_session_history_for_runtime( diff --git a/tests/unit/test_agent_session_turn_validation.py b/tests/unit/test_agent_session_turn_validation.py new file mode 100644 index 0000000000..42e2edc8a9 --- /dev/null +++ b/tests/unit/test_agent_session_turn_validation.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import uuid +from typing import Any, cast +from unittest.mock import AsyncMock, create_autospec, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from tracecat.agent.session.service import AgentSessionService +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus +from tracecat.auth.types import Role +from tracecat.chat.schemas import BasicChatRequest +from tracecat.db.models import AgentSession +from tracecat.exceptions import TracecatConflictError + + +def _build_role(workspace_id: uuid.UUID) -> Role: + return Role( + type="user", + service_id="tracecat-api", + workspace_id=workspace_id, + organization_id=uuid.uuid4(), + user_id=uuid.uuid4(), + scopes=frozenset({"agent:execute"}), + ) + + +def _build_service(role: Role) -> AgentSessionService: + db_session = cast( + AsyncSession, + create_autospec(AsyncSession, instance=True, spec_set=True), + ) + return AgentSessionService(cast(Any, db_session), role) + + +def _build_agent_session( + *, + workspace_id: uuid.UUID, + session_id: uuid.UUID, + status: AgentSessionStatus, +) -> AgentSession: + agent_session = AgentSession( + workspace_id=workspace_id, + title="Test session", + entity_type=AgentSessionEntity.COPILOT.value, + entity_id=uuid.uuid4(), + status=status.value, + ) + agent_session.id = session_id + return agent_session + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "status", + [AgentSessionStatus.RUNNING, AgentSessionStatus.WAITING_FOR_APPROVAL], +) +async def test_validate_turn_request_rejects_active_turn_before_pending_rows( + status: AgentSessionStatus, +) -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=status, + ) + has_pending_approvals = AsyncMock(return_value=False) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch.object(service, "has_pending_approvals", has_pending_approvals), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.validate_turn_request( + session_id=session_id, + request=BasicChatRequest(message="hello"), + ) + + has_pending_approvals.assert_not_awaited() diff --git a/tests/unit/test_durable_agent_workflow_search_attributes.py b/tests/unit/test_durable_agent_workflow_search_attributes.py index 30a932a858..5108b44313 100644 --- a/tests/unit/test_durable_agent_workflow_search_attributes.py +++ b/tests/unit/test_durable_agent_workflow_search_attributes.py @@ -13,6 +13,7 @@ from temporalio.exceptions import ActivityError from tracecat_ee.agent.activities import BuildToolDefsArgs, BuildToolDefsResult from tracecat_ee.agent.workflows.durable import ( + AGENT_SESSION_STATUS_PATCH, BUILD_AGENT_TOOL_DEFINITIONS_PATCH, EMIT_PRE_STREAM_SESSION_ERRORS_PATCH, LOAD_TERMINAL_MESSAGE_HISTORY_PATCH, @@ -233,7 +234,7 @@ async def test_run_skips_activity_error_emission_without_patch_marker() -> None: with ( patch( "tracecat_ee.agent.workflows.durable.workflow.patched", - side_effect=[False, False], + side_effect=[False, False, False], ) as patched_mock, patch( "tracecat_ee.agent.workflows.durable.workflow.unsafe.is_replaying", @@ -256,6 +257,7 @@ async def test_run_skips_activity_error_emission_without_patch_marker() -> None: assert patched_mock.call_args_list == [ ((UPSERT_TRACECAT_SEARCH_ATTRIBUTES_PATCH,),), + ((AGENT_SESSION_STATUS_PATCH,),), ((EMIT_PRE_STREAM_SESSION_ERRORS_PATCH,),), ] emit_error_mock.assert_not_awaited() diff --git a/tracecat/agent/executor/activity.py b/tracecat/agent/executor/activity.py index 850255ccdc..ef3b644406 100644 --- a/tracecat/agent/executor/activity.py +++ b/tracecat/agent/executor/activity.py @@ -8,6 +8,7 @@ import tempfile import uuid from collections import Counter +from collections.abc import Coroutine from dataclasses import dataclass, field from pathlib import Path from time import perf_counter @@ -417,6 +418,42 @@ def _activity_cancel_reason() -> AgentCancelReason: return "worker_drain" return "user_cancel" + async def _await_cancel_step_with_heartbeats[T]( + self, + awaitable: asyncio.Task[T] | Coroutine[Any, Any, T], + *, + phase: str, + ) -> T: + """Await a graceful-cancel task while keeping the activity heartbeat alive.""" + if isinstance(awaitable, asyncio.Task): + task = awaitable + owns_task = False + else: + task = asyncio.create_task(awaitable) + owns_task = True + heartbeat_interval = AGENT_ACTIVITY_HEARTBEAT_INTERVAL_SECONDS + elapsed = 0.0 + try: + while True: + try: + return await asyncio.wait_for( + asyncio.shield(task), + timeout=heartbeat_interval, + ) + except TimeoutError: + if task.done(): + return await task + elapsed += heartbeat_interval + activity.heartbeat( + f"Agent cancellation {phase}: {self.input.session_id} " + f"({elapsed:g}s elapsed)" + ) + finally: + if owns_task and not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + async def _run_with_broker( self, *, @@ -517,8 +554,14 @@ async def wait_fatal_error() -> str: handler.mark_cancelled(reason) try: async with asyncio.timeout(GRACEFUL_CANCEL_TIMEOUT_SECONDS): - await broker.interrupt_turn(str(self.input.session_id), reason) - await broker_task + await self._await_cancel_step_with_heartbeats( + broker.interrupt_turn(str(self.input.session_id), reason), + phase="interrupting runtime", + ) + await self._await_cancel_step_with_heartbeats( + broker_task, + phase="waiting for runtime", + ) except TimeoutError: logger.warning( "Timed out gracefully cancelling agent activity", diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index e2a059adb5..f2c75d9fb8 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -1207,9 +1207,12 @@ async def validate_turn_request( ) return agent_session case VercelChatRequest() | BasicChatRequest(): - if agent_session.status == AgentSessionStatus.RUNNING.value: + if agent_session.status in { + AgentSessionStatus.RUNNING.value, + AgentSessionStatus.WAITING_FOR_APPROVAL.value, + }: raise TracecatConflictError( - "This session already has a running turn." + "This session already has an active turn." ) if await self.has_pending_approvals(session_id): raise ValueError( From d7054ea0034813397fabf2e3cbf6349958fb757f Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Wed, 27 May 2026 16:40:47 -0400 Subject: [PATCH 4/5] fix(agent): interruptions during approvals --- .../tracecat_ee/agent/workflows/durable.py | 72 ++++++- tests/temporal/test_durable_agent_workflow.py | 200 ++++++++++++++++++ tests/unit/test_agent_activities.py | 6 +- tests/unit/test_agent_session_messages.py | 67 +++++- .../test_agent_session_turn_validation.py | 48 ++++- tracecat/agent/session/activities.py | 7 +- tracecat/agent/session/service.py | 149 ++++++++++--- 7 files changed, 510 insertions(+), 39 deletions(-) diff --git a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py index f8cb1ac540..ec941b4515 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py +++ b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py @@ -1082,12 +1082,80 @@ async def _run_with_agent_executor( tool_call_parts, request_metadata=request_metadata, ) - # Wait for approval signal - await self.approvals.wait() + # Wait for either approval decisions or a user cancellation. + await workflow.wait_condition( + lambda: self.approvals.is_ready() or self._cancel_requested + ) + if self._cancel_requested: + logger.info( + "Agent turn cancelled while waiting for approval", + session_id=self.session_id, + reason=self._cancel_reason, + ) + self.approvals.set( + { + item.id: ToolDenied( + message="Cancelled while waiting for approval" + ) + for item in result.approval_items or [] + } + ) + await self.approvals.handle_decisions() + self._status = "done" + await self._set_agent_session_status( + AgentSessionStatus.STOPPED, + clear_curr_run_id=True, + ) + message_history = await self._load_terminal_message_history(result) + return AgentOutput( + output=None, + message_history=message_history, + duration=(datetime.now(UTC) - info.start_time).total_seconds(), + usage=RunUsage(requests=0, input_tokens=0, output_tokens=0), + session_id=self.session_id, + ) # Persist approval decisions to DB (atomic with chat messages) await self.approvals.handle_decisions() + if self._cancel_requested: + logger.info( + "Agent turn cancelled after approval decisions", + session_id=self.session_id, + reason=self._cancel_reason, + ) + self._status = "done" + await self._set_agent_session_status( + AgentSessionStatus.STOPPED, + clear_curr_run_id=True, + ) + message_history = await self._load_terminal_message_history(result) + return AgentOutput( + output=None, + message_history=message_history, + duration=(datetime.now(UTC) - info.start_time).total_seconds(), + usage=RunUsage(requests=0, input_tokens=0, output_tokens=0), + session_id=self.session_id, + ) self._status = "running" await self._set_agent_session_status(AgentSessionStatus.RUNNING) + if self._cancel_requested: + logger.info( + "Agent turn cancelled before approved tool execution", + session_id=self.session_id, + reason=self._cancel_reason, + ) + self._status = "done" + await self._set_agent_session_status( + AgentSessionStatus.STOPPED, + clear_curr_run_id=True, + ) + message_history = await self._load_terminal_message_history(result) + return AgentOutput( + output=None, + message_history=message_history, + duration=(datetime.now(UTC) - info.start_time).total_seconds(), + usage=RunUsage(requests=0, input_tokens=0, output_tokens=0), + session_id=self.session_id, + ) # Execute approved tools and reconcile the SDK transcript. approved_tools, denied_tools = self._build_tool_lists_from_approvals( diff --git a/tests/temporal/test_durable_agent_workflow.py b/tests/temporal/test_durable_agent_workflow.py index 4487cc470e..1d56d3ae86 100644 --- a/tests/temporal/test_durable_agent_workflow.py +++ b/tests/temporal/test_durable_agent_workflow.py @@ -765,6 +765,206 @@ async def mock_run_agent_activity( assert status_update_inputs[-1].clear_curr_run_id is True +@pytest.mark.anyio +@pytest.mark.integration +async def test_request_cancel_stops_agent_turn_waiting_for_approval( + temporal_client: Client, + agent_worker_factory, + agent_workflow_args: AgentWorkflowArgs, + mock_session_id: uuid.UUID, +) -> None: + """Workflow cancel update stops a turn parked on manual approval.""" + queue = f"test-agent-queue-{mock_session_id}" + approval_request_recorded = asyncio.Event() + status_update_inputs: list[UpdateSessionStatusInput] = [] + approval_decision_inputs: list[ApplyApprovalResultsActivityInputs] = [] + + def mock_executor( + call_count: int, + _input: AgentExecutorInput, + ) -> AgentExecutorResult: + assert call_count == 0 + return AgentExecutorResult( + success=True, + approval_requested=True, + approval_items=[ + ToolCallContent( + id="call_123", + name="core__http_request", + input={"url": "https://example.com", "method": "GET"}, + ) + ], + ) + + @activity.defn(name="record_approval_requests") + async def mock_record_approval_requests( + input: PersistApprovalsActivityInputs, + ) -> None: + del input + approval_request_recorded.set() + + @activity.defn(name="apply_approval_decisions") + async def mock_apply_approval_decisions( + input: ApplyApprovalResultsActivityInputs, + ) -> None: + approval_decision_inputs.append(input) + + activities: list[Callable[..., Any]] = [ + create_mock_create_session_activity(), + create_mock_load_session_activity(), + create_mock_update_session_status_activity(status_update_inputs), + create_mock_load_session_messages_activity(), + create_mock_build_tool_definitions_activity(), + create_mock_run_agent_activity(mock_executor), + create_mock_execute_action_activity(), + create_mock_reconcile_tool_results_activity(), + mock_record_approval_requests, + mock_apply_approval_decisions, + ] + + async with agent_worker_factory( + temporal_client, task_queue=queue, custom_activities=activities + ): + handle = await temporal_client.start_workflow( + DurableAgentWorkflow.run, + agent_workflow_args, + id=AgentWorkflowID(mock_session_id), + task_queue=queue, + retry_policy=RETRY_POLICIES["workflow:fail_fast"], + execution_timeout=timedelta(seconds=60), + ) + await asyncio.wait_for(approval_request_recorded.wait(), timeout=10) + + await handle.execute_update( + DurableAgentWorkflow.request_cancel, + WorkflowCancelRequest(reason="user_cancel"), + ) + result = await handle.result() + + assert result.session_id == mock_session_id + assert result.output is None + assert [input.status for input in status_update_inputs] == [ + AgentSessionStatus.RUNNING, + AgentSessionStatus.WAITING_FOR_APPROVAL, + AgentSessionStatus.STOPPED, + ] + assert status_update_inputs[-1].clear_curr_run_id is True + assert len(approval_decision_inputs) == 1 + assert len(approval_decision_inputs[0].decisions) == 1 + decision = approval_decision_inputs[0].decisions[0] + assert decision.tool_call_id == "call_123" + assert decision.approved is False + assert decision.reason == "Cancelled while waiting for approval" + + +@pytest.mark.anyio +@pytest.mark.integration +async def test_request_cancel_after_approval_decisions_skips_tool_execution( + temporal_client: Client, + agent_worker_factory, + agent_workflow_args: AgentWorkflowArgs, + mock_session_id: uuid.UUID, +) -> None: + """Cancel received while approval decisions persist stops before tools run.""" + queue = f"test-agent-queue-{mock_session_id}" + approval_request_recorded = asyncio.Event() + decision_apply_started = asyncio.Event() + release_decision_apply = asyncio.Event() + execute_action_called = asyncio.Event() + status_update_inputs: list[UpdateSessionStatusInput] = [] + + def mock_executor( + call_count: int, + _input: AgentExecutorInput, + ) -> AgentExecutorResult: + assert call_count == 0 + return AgentExecutorResult( + success=True, + approval_requested=True, + approval_items=[ + ToolCallContent( + id="call_123", + name="core__http_request", + input={"url": "https://example.com", "method": "GET"}, + ) + ], + ) + + @activity.defn(name="record_approval_requests") + async def mock_record_approval_requests( + input: PersistApprovalsActivityInputs, + ) -> None: + del input + approval_request_recorded.set() + + @activity.defn(name="apply_approval_decisions") + async def mock_apply_approval_decisions( + input: ApplyApprovalResultsActivityInputs, + ) -> None: + del input + decision_apply_started.set() + await release_decision_apply.wait() + + @activity.defn(name="execute_action_activity") + async def mock_execute_action_activity( + input: RunActionInput, + role: Role, + ) -> InlineObject[dict[str, str]]: + del input, role + execute_action_called.set() + raise AssertionError("approved tool should not execute after cancellation") + + activities: list[Callable[..., Any]] = [ + create_mock_create_session_activity(), + create_mock_load_session_activity(), + create_mock_update_session_status_activity(status_update_inputs), + create_mock_load_session_messages_activity(), + create_mock_build_tool_definitions_activity(), + create_mock_run_agent_activity(mock_executor), + mock_execute_action_activity, + create_mock_reconcile_tool_results_activity(), + mock_record_approval_requests, + mock_apply_approval_decisions, + ] + + async with agent_worker_factory( + temporal_client, task_queue=queue, custom_activities=activities + ): + handle = await temporal_client.start_workflow( + DurableAgentWorkflow.run, + agent_workflow_args, + id=AgentWorkflowID(mock_session_id), + task_queue=queue, + retry_policy=RETRY_POLICIES["workflow:fail_fast"], + execution_timeout=timedelta(seconds=60), + ) + await asyncio.wait_for(approval_request_recorded.wait(), timeout=10) + await handle.execute_update( + DurableAgentWorkflow.set_approvals, + WorkflowApprovalSubmission( + approvals={"call_123": True}, + approved_by=agent_workflow_args.role.user_id, + ), + ) + await asyncio.wait_for(decision_apply_started.wait(), timeout=10) + + cancel_task = asyncio.create_task( + handle.execute_update( + DurableAgentWorkflow.request_cancel, + WorkflowCancelRequest(reason="user_cancel"), + ) + ) + await asyncio.wait_for(cancel_task, timeout=10) + release_decision_apply.set() + result = await handle.result() + + assert result.session_id == mock_session_id + assert result.output is None + assert not execute_action_called.is_set() + assert status_update_inputs[-1].status is AgentSessionStatus.STOPPED + assert status_update_inputs[-1].clear_curr_run_id is True + + @pytest.mark.anyio @pytest.mark.integration async def test_agent_workflow_marks_failed_when_executor_activity_errors( diff --git a/tests/unit/test_agent_activities.py b/tests/unit/test_agent_activities.py index 4acb59abed..6f436e6750 100644 --- a/tests/unit/test_agent_activities.py +++ b/tests/unit/test_agent_activities.py @@ -402,10 +402,10 @@ async def test_idempotent_for_existing_session( @pytest.mark.anyio @patch("tracecat.agent.session.activities.AgentSessionService.with_session") - async def test_curr_run_id_does_not_mark_session_running( + async def test_curr_run_id_marks_session_running( self, mock_with_session, mock_role: Role, mock_session_id: uuid.UUID ): - """Create setup records the run token without taking the running lock.""" + """Create setup records the run token and takes the running lock.""" curr_run_id = uuid.uuid4() input = CreateSessionInput( role=mock_role, @@ -431,7 +431,7 @@ async def test_curr_run_id_does_not_mark_session_running( assert result.success is True assert mock_agent_session.curr_run_id == curr_run_id - assert mock_agent_session.status == AgentSessionStatus.IDLE.value + assert mock_agent_session.status == AgentSessionStatus.RUNNING.value mock_service.session.add.assert_called_once_with(mock_agent_session) mock_service.session.commit.assert_awaited_once() diff --git a/tests/unit/test_agent_session_messages.py b/tests/unit/test_agent_session_messages.py index 64af71d97a..2fbef003db 100644 --- a/tests/unit/test_agent_session_messages.py +++ b/tests/unit/test_agent_session_messages.py @@ -8,6 +8,7 @@ import orjson import pytest +from tracecat.agent.approvals.enums import ApprovalStatus from tracecat.agent.session.service import AgentSessionService from tracecat.agent.session.types import AgentSessionStatus from tracecat.auth.types import Role @@ -135,9 +136,9 @@ async def test_list_messages_adds_interrupt_notice_for_stopped_session() -> None @pytest.mark.anyio -async def test_list_messages_keeps_approval_interrupt_artifacts_hidden() -> None: +async def test_list_messages_adds_interrupt_notice_after_session_resumes() -> None: service, agent_session = _build_service() - agent_session.status = AgentSessionStatus.WAITING_FOR_APPROVAL.value + agent_session.status = AgentSessionStatus.IDLE.value interrupt_entry = SimpleNamespace( id=uuid.uuid4(), session_id=agent_session.id, @@ -162,6 +163,68 @@ async def test_list_messages_keeps_approval_interrupt_artifacts_hidden() -> None messages = await service.list_messages(agent_session.id) + assert len(messages) == 1 + assert messages[0].kind == MessageKind.INTERRUPT + assert messages[0].interrupt == {"reason": "user_cancel"} + + +@pytest.mark.anyio +async def test_list_messages_keeps_pending_approval_interrupt_artifacts_hidden() -> ( + None +): + service, agent_session = _build_service() + agent_session.status = AgentSessionStatus.WAITING_FOR_APPROVAL.value + assistant_entry = SimpleNamespace( + id=uuid.uuid4(), + session_id=agent_session.id, + kind=MessageKind.CHAT_MESSAGE.value, + content={ + "type": "assistant", + "uuid": "assistant-uuid", + "message": { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "call_123", + "name": "core__http_request", + "input": {"url": "https://example.com"}, + } + ], + }, + }, + ) + interrupt_entry = SimpleNamespace( + id=uuid.uuid4(), + session_id=agent_session.id, + kind=MessageKind.INTERNAL.value, + content={ + "type": "user", + "uuid": "interrupt-uuid", + "message": { + "role": "user", + "content": [{"type": "text", "text": "[Request interrupted by user]"}], + }, + }, + ) + pending_approval = SimpleNamespace( + tool_call_id="call_123", + status=ApprovalStatus.PENDING, + ) + + service.get_session = AsyncMock(return_value=agent_session) + service.session.execute = AsyncMock( + side_effect=[ + _mock_scalar_result([assistant_entry, interrupt_entry]), + _mock_scalar_result([pending_approval]), + ] + ) + + messages = await service.list_messages( + agent_session.id, + kinds=[MessageKind.INTERRUPT], + ) + assert messages == [] diff --git a/tests/unit/test_agent_session_turn_validation.py b/tests/unit/test_agent_session_turn_validation.py index 42e2edc8a9..7ca03e3a96 100644 --- a/tests/unit/test_agent_session_turn_validation.py +++ b/tests/unit/test_agent_session_turn_validation.py @@ -1,12 +1,14 @@ from __future__ import annotations import uuid +from types import SimpleNamespace from typing import Any, cast -from unittest.mock import AsyncMock, create_autospec, patch +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession +from tracecat.agent.session.schemas import AgentSessionCancelRequest from tracecat.agent.session.service import AgentSessionService from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.auth.types import Role @@ -39,6 +41,7 @@ def _build_agent_session( workspace_id: uuid.UUID, session_id: uuid.UUID, status: AgentSessionStatus, + curr_run_id: uuid.UUID | None = None, ) -> AgentSession: agent_session = AgentSession( workspace_id=workspace_id, @@ -46,6 +49,7 @@ def _build_agent_session( entity_type=AgentSessionEntity.COPILOT.value, entity_id=uuid.uuid4(), status=status.value, + curr_run_id=curr_run_id, ) agent_session.id = session_id return agent_session @@ -81,3 +85,45 @@ async def test_validate_turn_request_rejects_active_turn_before_pending_rows( ) has_pending_approvals.assert_not_awaited() + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "status", + [AgentSessionStatus.RUNNING, AgentSessionStatus.WAITING_FOR_APPROVAL], +) +async def test_request_cancel_accepts_active_turn_statuses( + status: AgentSessionStatus, +) -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=status, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace(execute_update=AsyncMock()) + temporal_client = SimpleNamespace( + get_workflow_handle_for=MagicMock(return_value=workflow_handle) + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + response = await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert response.session_id == session_id + assert response.run_id == run_id + assert response.turn_status is status + workflow_handle.execute_update.assert_awaited_once() diff --git a/tracecat/agent/session/activities.py b/tracecat/agent/session/activities.py index b187638569..0995abc8dd 100644 --- a/tracecat/agent/session/activities.py +++ b/tracecat/agent/session/activities.py @@ -215,11 +215,12 @@ async def create_session_activity(input: CreateSessionInput) -> CreateSessionRes service.session.add(agent_session) await service.session.commit() - # Set curr_run_id if provided (for workflow-initiated sessions) + # Set curr_run_id and status together for workflow-initiated + # sessions so API lifecycle gates see this workflow as active + # throughout setup, not only once the executor activity starts. if input.curr_run_id is not None: - # The workflow marks RUNNING only after setup completes, just - # before handing control to the executor activity. agent_session.curr_run_id = input.curr_run_id + agent_session.status = AgentSessionStatus.RUNNING.value service.session.add(agent_session) await service.session.commit() diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index f2c75d9fb8..3b9ca1bae7 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -30,7 +30,9 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError +from temporalio.client import WorkflowExecutionStatus from temporalio.common import TypedSearchAttributes +from tracecat_ee.agent.types import AgentWorkflowID from tracecat_registry._internal.exceptions import SecretNotFoundError import tracecat.agent.adapter.vercel @@ -108,6 +110,11 @@ AUTO_TITLE_SERVICE_ID = "tracecat-api" APPROVAL_CONTINUATION_DEDUP_TTL_SECONDS = 5 * 60 +ACTIVE_AGENT_SESSION_STATUSES = ("running", "waiting_for_approval") +ACTIVE_WORKFLOW_STATUSES = ( + WorkflowExecutionStatus.RUNNING, + WorkflowExecutionStatus.CONTINUED_AS_NEW, +) @dataclass @@ -587,13 +594,44 @@ async def set_session_status( await self.session.commit() return bool(cast(CursorResult[Any], result).rowcount) + async def _has_active_agent_turn(self, agent_session: AgentSession) -> bool: + """Return whether DB/Temporal state should block a new turn.""" + if agent_session.curr_run_id is None: + return agent_session.status in ACTIVE_AGENT_SESSION_STATUSES + + client = await get_temporal_client() + workflow_id = AgentWorkflowID(agent_session.curr_run_id) + try: + handle = client.get_workflow_handle(str(workflow_id)) + description = await handle.describe() + except Exception: + return True + + if description.status in ACTIVE_WORKFLOW_STATUSES: + return True + + agent_session.status = ( + AgentSessionStatus.FAILED.value + if description.status + in {WorkflowExecutionStatus.FAILED, WorkflowExecutionStatus.TIMED_OUT} + else AgentSessionStatus.STOPPED.value + if description.status + in { + WorkflowExecutionStatus.CANCELED, + WorkflowExecutionStatus.TERMINATED, + } + else AgentSessionStatus.IDLE.value + ) + agent_session.curr_run_id = None + await self.session.commit() + return False + async def request_cancel( self, session_id: uuid.UUID, request: AgentSessionCancelRequest, ) -> AgentSessionCancelResponse: """Request graceful cancellation for the active agent workflow turn.""" - from tracecat_ee.agent.types import AgentWorkflowID from tracecat_ee.agent.workflows.durable import ( DurableAgentWorkflow, WorkflowCancelRequest, @@ -604,9 +642,12 @@ async def request_cancel( raise TracecatNotFoundError(f"Session with ID {session_id} not found") current_status = AgentSessionStatus(agent_session.status) - if current_status is not AgentSessionStatus.RUNNING: + if current_status not in { + AgentSessionStatus.RUNNING, + AgentSessionStatus.WAITING_FOR_APPROVAL, + }: raise TracecatConflictError( - "Agent session is not currently running", + "Agent session does not have an active turn", detail={"status": current_status.value}, ) @@ -1061,7 +1102,6 @@ async def run_turn( TracecatNotFoundError: If the session is not found. ValueError: If the request/entity type is unsupported. """ - from tracecat_ee.agent.types import AgentWorkflowID from tracecat_ee.agent.workflows.durable import ( AgentWorkflowArgs, DurableAgentWorkflow, @@ -1152,7 +1192,6 @@ async def run_turn( # Update session with current run_id for approval lookups agent_session.curr_run_id = run_id - agent_session.status = AgentSessionStatus.RUNNING.value self.session.add(agent_session) await self.session.commit() @@ -1180,6 +1219,7 @@ async def run_turn( ), ) except Exception: + agent_session.curr_run_id = None agent_session.status = AgentSessionStatus.FAILED.value self.session.add(agent_session) await self.session.commit() @@ -1207,10 +1247,7 @@ async def validate_turn_request( ) return agent_session case VercelChatRequest() | BasicChatRequest(): - if agent_session.status in { - AgentSessionStatus.RUNNING.value, - AgentSessionStatus.WAITING_FOR_APPROVAL.value, - }: + if await self._has_active_agent_turn(agent_session): raise TracecatConflictError( "This session already has an active turn." ) @@ -1241,7 +1278,6 @@ async def _continue_with_approvals( TracecatNotFoundError: If no active session exists. """ from tracecat_ee.agent.approvals.service import ApprovalMap - from tracecat_ee.agent.types import AgentWorkflowID from tracecat_ee.agent.workflows.durable import ( DurableAgentWorkflow, WorkflowApprovalSubmission, @@ -1638,15 +1674,16 @@ async def list_messages( approval_by_tool_id: dict[str, Approval] = { a.tool_call_id: a for a in approvals } + pending_approval_tool_ids = { + a.tool_call_id for a in approvals if a.status == ApprovalStatus.PENDING + } # Build timeline with interleaved approvals # Process both chat-message and internal entries in order # Internal entries contain tool results that the adapter will extract messages: list[ChatMessage] = [] internal_uuids: set[str] = set() - show_interrupt_notice = ( - getattr(agent_session, "status", None) == AgentSessionStatus.STOPPED.value - ) + pending_approval_interrupt_tool_ids: set[str] = set() for entry in all_entries: content = entry.content if not content: @@ -1655,9 +1692,12 @@ async def list_messages( # Skip internal entries (e.g., continuation prompts) if entry.kind == MessageKind.INTERNAL.value: if ( - show_interrupt_notice - and getattr(entry, "session_id", session_id) == session_id + getattr(entry, "session_id", session_id) == session_id and self._is_interrupt_session_line(content) + and not self._is_pending_approval_interrupt_session_line( + content, + pending_approval_interrupt_tool_ids, + ) and not (messages and messages[-1].kind == MessageKind.INTERRUPT) ): kind = MessageKind.INTERRUPT @@ -1714,10 +1754,6 @@ async def list_messages( # Standard chat messages kind = MessageKind.CHAT_MESSAGE - # Filter by kinds if specified - if kinds and kind not in kinds: - continue - # Extract the inner message from JSONL envelope inner_message = content.get("message") if not inner_message: @@ -1728,27 +1764,47 @@ async def list_messages( # Deserialize the content using Claude SDK TypeAdapter message = ClaudeSDKMessageTA.validate_python(sanitized_message) - messages.append(ChatMessage(id=str(entry.id), message=message)) + if not kinds or kind in kinds: + messages.append(ChatMessage(id=str(entry.id), message=message)) # For assistant messages, check for tool calls needing approval bubbles if msg_type == "assistant": tool_uses = self._extract_tool_uses_from_message(sanitized_message) for tool_use in tool_uses: tool_use_id = tool_use.get("id") + if tool_use_id in pending_approval_tool_ids: + pending_approval_interrupt_tool_ids.add(tool_use_id) if tool_use_id and ( approval := approval_by_tool_id.get(tool_use_id) ): + should_append_approval_request = ( + not kinds or MessageKind.APPROVAL_REQUEST in kinds + ) + should_append_approval_decision = ( + approval.status != ApprovalStatus.PENDING + and (not kinds or MessageKind.APPROVAL_DECISION in kinds) + ) + if ( + should_append_approval_request + or should_append_approval_decision + ): + approval_read = ApprovalRead.model_validate(approval) + else: + approval_read = None + # Insert approval-request bubble - approval_read = ApprovalRead.model_validate(approval) - messages.append( - ChatMessage( - id=str(approval.id), - kind=MessageKind.APPROVAL_REQUEST, - approval=approval_read, + if should_append_approval_request: + assert approval_read is not None + messages.append( + ChatMessage( + id=str(approval.id), + kind=MessageKind.APPROVAL_REQUEST, + approval=approval_read, + ) ) - ) # If decided, also insert decision bubble - if approval.status != ApprovalStatus.PENDING: + if should_append_approval_decision: + assert approval_read is not None messages.append( ChatMessage( id=f"{approval.id}-decision", @@ -1759,6 +1815,43 @@ async def list_messages( return messages + @staticmethod + def _is_pending_approval_interrupt_session_line( + content: dict[str, Any], + pending_tool_call_ids: set[str], + ) -> bool: + """Return True for SDK interrupt rows caused by an active approval wait.""" + if not pending_tool_call_ids: + return False + + message = content.get("message") + if not isinstance(message, dict): + return False + + msg_content = message.get("content") + if isinstance(msg_content, str): + return is_approval_interrupt_content(msg_content) + if not isinstance(msg_content, list): + return False + + for block in msg_content: + if isinstance(block, dict): + if block.get("type") == "text" and is_approval_interrupt_content( + block.get("text", "") + ): + return True + if ( + block.get("type") == "tool_result" + and block.get("is_error") is True + and block.get("tool_use_id") in pending_tool_call_ids + and is_approval_interrupt_content(block.get("content", "")) + ): + return True + elif is_approval_interrupt_content(block): + return True + + return False + @staticmethod def _is_interrupt_session_line(content: dict[str, Any]) -> bool: """Return True for internal SDK rows that represent a user interruption.""" From dba556b8c43f0a04c30a60fdc20d04a9b173bf49 Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Thu, 28 May 2026 16:55:18 -0400 Subject: [PATCH 5/5] fix(agent): use Temporal as source of truth for active turns --- .../test_agent_session_turn_validation.py | 275 +++++++++++++++++- tracecat/agent/session/service.py | 126 ++++++-- 2 files changed, 368 insertions(+), 33 deletions(-) diff --git a/tests/unit/test_agent_session_turn_validation.py b/tests/unit/test_agent_session_turn_validation.py index 7ca03e3a96..c311514fd9 100644 --- a/tests/unit/test_agent_session_turn_validation.py +++ b/tests/unit/test_agent_session_turn_validation.py @@ -7,6 +7,8 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession +from temporalio.client import WorkflowExecutionStatus +from temporalio.service import RPCError, RPCStatusCode from tracecat.agent.session.schemas import AgentSessionCancelRequest from tracecat.agent.session.service import AgentSessionService @@ -87,6 +89,96 @@ async def test_validate_turn_request_rejects_active_turn_before_pending_rows( has_pending_approvals.assert_not_awaited() +@pytest.mark.anyio +async def test_validate_turn_request_rejects_idle_projection_with_active_workflow() -> ( + None +): + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.IDLE, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ) + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + ) + has_pending_approvals = AsyncMock(return_value=False) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch.object(service, "has_pending_approvals", has_pending_approvals), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.validate_turn_request( + session_id=session_id, + request=BasicChatRequest(message="hello"), + ) + + has_pending_approvals.assert_not_awaited() + workflow_handle.describe.assert_awaited_once() + + +@pytest.mark.anyio +async def test_validate_turn_request_clears_missing_current_run() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + side_effect=RPCError( + "workflow not found", + RPCStatusCode.NOT_FOUND, + b"", + ) + ) + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + ) + has_pending_approvals = AsyncMock(return_value=False) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch.object(service, "has_pending_approvals", has_pending_approvals), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + result = await service.validate_turn_request( + session_id=session_id, + request=BasicChatRequest(message="hello"), + ) + + assert result is agent_session + assert agent_session.curr_run_id is None + assert agent_session.status == AgentSessionStatus.IDLE.value + workflow_handle.describe.assert_awaited_once() + has_pending_approvals.assert_awaited_once_with(session_id) + + @pytest.mark.anyio @pytest.mark.parametrize( "status", @@ -106,9 +198,15 @@ async def test_request_cancel_accepts_active_turn_statuses( status=status, curr_run_id=run_id, ) - workflow_handle = SimpleNamespace(execute_update=AsyncMock()) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ), + execute_update=AsyncMock(), + ) temporal_client = SimpleNamespace( - get_workflow_handle_for=MagicMock(return_value=workflow_handle) + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), ) with ( @@ -126,4 +224,177 @@ async def test_request_cancel_accepts_active_turn_statuses( assert response.session_id == session_id assert response.run_id == run_id assert response.turn_status is status + workflow_handle.describe.assert_awaited_once() + workflow_handle.execute_update.assert_awaited_once() + + +@pytest.mark.anyio +async def test_request_cancel_accepts_idle_projection_with_active_workflow() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.IDLE, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ), + execute_update=AsyncMock(), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + response = await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert response.session_id == session_id + assert response.run_id == run_id + assert response.turn_status is AgentSessionStatus.IDLE + workflow_handle.describe.assert_awaited_once() workflow_handle.execute_update.assert_awaited_once() + + +@pytest.mark.anyio +async def test_request_cancel_handles_finished_current_run() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.TERMINATED) + ), + execute_update=AsyncMock(), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert agent_session.curr_run_id is None + assert agent_session.status == AgentSessionStatus.STOPPED.value + workflow_handle.execute_update.assert_not_awaited() + + +@pytest.mark.anyio +async def test_request_cancel_clears_missing_current_run() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + side_effect=RPCError( + "workflow not found", + RPCStatusCode.NOT_FOUND, + b"", + ) + ), + execute_update=AsyncMock(), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert agent_session.curr_run_id is None + assert agent_session.status == AgentSessionStatus.IDLE.value + workflow_handle.execute_update.assert_not_awaited() + + +@pytest.mark.anyio +async def test_request_cancel_propagates_unexpected_temporal_describe_error() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + error = RPCError( + "temporal unavailable", + RPCStatusCode.UNAVAILABLE, + b"", + ) + workflow_handle = SimpleNamespace(describe=AsyncMock(side_effect=error)) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(RPCError) as exc_info: + await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert exc_info.value is error + assert agent_session.curr_run_id == run_id + assert agent_session.status == AgentSessionStatus.RUNNING.value diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index 3b9ca1bae7..1b8443ad5d 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -32,6 +32,7 @@ from sqlalchemy.exc import SQLAlchemyError from temporalio.client import WorkflowExecutionStatus from temporalio.common import TypedSearchAttributes +from temporalio.service import RPCError, RPCStatusCode from tracecat_ee.agent.types import AgentWorkflowID from tracecat_registry._internal.exceptions import SecretNotFoundError @@ -111,10 +112,27 @@ AUTO_TITLE_SERVICE_ID = "tracecat-api" APPROVAL_CONTINUATION_DEDUP_TTL_SECONDS = 5 * 60 ACTIVE_AGENT_SESSION_STATUSES = ("running", "waiting_for_approval") +TERMINAL_AGENT_SESSION_STATUSES = { + "idle", + "stopped", + "failed", +} ACTIVE_WORKFLOW_STATUSES = ( WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.CONTINUED_AS_NEW, ) +# Lifecycle source of truth: +# - curr_run_id + Temporal describe decides whether a turn is actually active. +# - AgentSession.status is a DB projection for cheap reads and durable product +# terminal state; do not use it alone for correctness gates. +TERMINAL_FAILURE_WORKFLOW_STATUSES = { + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.TIMED_OUT, +} +TERMINAL_STOPPED_WORKFLOW_STATUSES = { + WorkflowExecutionStatus.CANCELED, + WorkflowExecutionStatus.TERMINATED, +} @dataclass @@ -594,36 +612,65 @@ async def set_session_status( await self.session.commit() return bool(cast(CursorResult[Any], result).rowcount) - async def _has_active_agent_turn(self, agent_session: AgentSession) -> bool: - """Return whether DB/Temporal state should block a new turn.""" + async def _describe_curr_run( + self, agent_session: AgentSession + ) -> WorkflowExecutionStatus | None: + """Describe the Temporal run pointed to by the session, if any.""" if agent_session.curr_run_id is None: - return agent_session.status in ACTIVE_AGENT_SESSION_STATUSES + return None client = await get_temporal_client() workflow_id = AgentWorkflowID(agent_session.curr_run_id) + handle = client.get_workflow_handle(str(workflow_id)) + description = await handle.describe() + return description.status + + @staticmethod + def _terminal_status_for_workflow_status( + workflow_status: WorkflowExecutionStatus, + ) -> AgentSessionStatus: + """Map Temporal terminal state to the product status shown in the UI.""" + if workflow_status in TERMINAL_FAILURE_WORKFLOW_STATUSES: + return AgentSessionStatus.FAILED + if workflow_status in TERMINAL_STOPPED_WORKFLOW_STATUSES: + return AgentSessionStatus.STOPPED + return AgentSessionStatus.IDLE + + @staticmethod + def _status_after_missing_current_run( + current_status: AgentSessionStatus, + ) -> AgentSessionStatus: + """Keep terminal UI state if present; otherwise unblock as idle.""" + if current_status.value in TERMINAL_AGENT_SESSION_STATUSES: + return current_status + return AgentSessionStatus.IDLE + + async def _has_active_agent_turn(self, agent_session: AgentSession) -> bool: + """Return whether Temporal/projected state should block a new turn.""" + if agent_session.curr_run_id is None: + return agent_session.status in ACTIVE_AGENT_SESSION_STATUSES + try: - handle = client.get_workflow_handle(str(workflow_id)) - description = await handle.describe() - except Exception: - return True + workflow_status = await self._describe_curr_run(agent_session) + except RPCError as e: + if e.status != RPCStatusCode.NOT_FOUND: + raise + agent_session.status = self._status_after_missing_current_run( + AgentSessionStatus(agent_session.status) + ).value + agent_session.curr_run_id = None + await self.session.commit() + return False - if description.status in ACTIVE_WORKFLOW_STATUSES: + if workflow_status in ACTIVE_WORKFLOW_STATUSES: return True - agent_session.status = ( - AgentSessionStatus.FAILED.value - if description.status - in {WorkflowExecutionStatus.FAILED, WorkflowExecutionStatus.TIMED_OUT} - else AgentSessionStatus.STOPPED.value - if description.status - in { - WorkflowExecutionStatus.CANCELED, - WorkflowExecutionStatus.TERMINATED, - } - else AgentSessionStatus.IDLE.value - ) - agent_session.curr_run_id = None - await self.session.commit() + if workflow_status is not None: + agent_session.status = self._terminal_status_for_workflow_status( + workflow_status + ).value + agent_session.curr_run_id = None + await self.session.commit() return False async def request_cancel( @@ -642,19 +689,36 @@ async def request_cancel( raise TracecatNotFoundError(f"Session with ID {session_id} not found") current_status = AgentSessionStatus(agent_session.status) - if current_status not in { - AgentSessionStatus.RUNNING, - AgentSessionStatus.WAITING_FOR_APPROVAL, - }: - raise TracecatConflictError( - "Agent session does not have an active turn", - detail={"status": current_status.value}, - ) - curr_run_id = agent_session.curr_run_id if curr_run_id is None: raise TracecatConflictError("Agent session has no active workflow run") + try: + workflow_status = await self._describe_curr_run(agent_session) + except RPCError as e: + if e.status != RPCStatusCode.NOT_FOUND: + raise + agent_session.status = self._status_after_missing_current_run( + current_status + ).value + agent_session.curr_run_id = None + await self.session.commit() + raise TracecatConflictError( + "Agent session does not have an active turn", + detail={"status": AgentSessionStatus(agent_session.status).value}, + ) from e + if workflow_status not in ACTIVE_WORKFLOW_STATUSES: + if workflow_status is not None: + agent_session.status = self._terminal_status_for_workflow_status( + workflow_status + ).value + agent_session.curr_run_id = None + await self.session.commit() + raise TracecatConflictError( + "Agent session does not have an active turn", + detail={"status": AgentSessionStatus(agent_session.status).value}, + ) + client = await get_temporal_client() workflow_id = AgentWorkflowID(curr_run_id) handle = client.get_workflow_handle_for(