diff --git a/alembic/versions/243a597b6a3a_add_curr_run_id_to_agent_session_history.py b/alembic/versions/243a597b6a3a_add_curr_run_id_to_agent_session_history.py new file mode 100644 index 0000000000..c57e464a63 --- /dev/null +++ b/alembic/versions/243a597b6a3a_add_curr_run_id_to_agent_session_history.py @@ -0,0 +1,40 @@ +"""add curr_run_id to agent_session_history + +Revision ID: 243a597b6a3a +Revises: e1a2b3c4d5f6 +Create Date: 2026-05-27 10:43:33.204849 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "243a597b6a3a" +down_revision: str | None = "e1a2b3c4d5f6" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "agent_session_history", + sa.Column("curr_run_id", sa.UUID(), nullable=True), + ) + op.create_index( + op.f("ix_agent_session_history_curr_run_id"), + "agent_session_history", + ["curr_run_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index( + op.f("ix_agent_session_history_curr_run_id"), + table_name="agent_session_history", + ) + op.drop_column("agent_session_history", "curr_run_id") diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index 537b92a3f5..6dfdb45803 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -4037,6 +4037,46 @@ export const $AgentSessionStatus = { description: "Lifecycle state for an agent session turn.", } as const +export const $AgentSessionStatusRead = { + properties: { + turn_status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, + curr_run_id: { + anyOf: [ + { + type: "string", + format: "uuid", + }, + { + type: "null", + }, + ], + title: "Curr Run Id", + }, + prompt: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Prompt", + description: + "Human prompt that started the active run, for observer clients to render (user messages cannot stream over the Vercel protocol).", + }, + }, + type: "object", + title: "AgentSessionStatusRead", + description: `Lightweight session lifecycle status for cheap polling. + +Clients poll this (instead of the full message history) to learn when a turn +starts elsewhere and attach to the live stream.`, +} as const + export const $AgentSessionUpdate = { properties: { title: { diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 4ee4f0a81c..d7aa1e47a6 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -171,6 +171,8 @@ import type { AgentSessionsForkSessionResponse, AgentSessionsGetSessionData, AgentSessionsGetSessionResponse, + AgentSessionsGetSessionStatusData, + AgentSessionsGetSessionStatusResponse, AgentSessionsGetSessionVercelData, AgentSessionsGetSessionVercelResponse, AgentSessionsListSessionsData, @@ -6613,6 +6615,31 @@ export const agentSessionsGetSessionVercel = ( }) } +/** + * Get Session Status + * Lifecycle status for polling without loading message history. + * @param data The data for the request. + * @param data.sessionId + * @param data.workspaceId + * @returns AgentSessionStatusRead Successful Response + * @throws ApiError + */ +export const agentSessionsGetSessionStatus = ( + data: AgentSessionsGetSessionStatusData +): CancelablePromise => { + return __request(OpenAPI, { + method: "GET", + url: "/workspaces/{workspace_id}/agent/sessions/{session_id}/status", + path: { + session_id: data.sessionId, + workspace_id: data.workspaceId, + }, + errors: { + 422: "Validation Error", + }, + }) +} + /** * Cancel Session * Request graceful cancellation for the active agent session turn. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index ce48d61ea2..38e7a0712a 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -954,6 +954,21 @@ export type AgentSessionStatus = | "stopped" | "failed" +/** + * Lightweight session lifecycle status for cheap polling. + * + * Clients poll this (instead of the full message history) to learn when a turn + * starts elsewhere and attach to the live stream. + */ +export type AgentSessionStatusRead = { + turn_status?: AgentSessionStatus + curr_run_id?: string | null + /** + * Human prompt that started the active run, for observer clients to render (user messages cannot stream over the Vercel protocol). + */ + prompt?: string | null +} + /** * Request schema for updating an agent session. */ @@ -10781,6 +10796,13 @@ export type AgentSessionsGetSessionVercelResponse = | AgentSessionReadVercel | ChatReadVercel +export type AgentSessionsGetSessionStatusData = { + sessionId: string + workspaceId: string +} + +export type AgentSessionsGetSessionStatusResponse = AgentSessionStatusRead + export type AgentSessionsCancelSessionData = { requestBody?: AgentSessionCancelRequest | null sessionId: string @@ -15780,6 +15802,21 @@ export type $OpenApiTs = { } } } + "/workspaces/{workspace_id}/agent/sessions/{session_id}/status": { + get: { + req: AgentSessionsGetSessionStatusData + res: { + /** + * Successful Response + */ + 200: AgentSessionStatusRead + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } "/workspaces/{workspace_id}/agent/sessions/{session_id}/cancel": { post: { req: AgentSessionsCancelSessionData diff --git a/frontend/src/components/chat/chat-session-pane.tsx b/frontend/src/components/chat/chat-session-pane.tsx index 6799620fac..be2d1dc937 100644 --- a/frontend/src/components/chat/chat-session-pane.tsx +++ b/frontend/src/components/chat/chat-session-pane.tsx @@ -102,6 +102,7 @@ import { makeContinueMessage, parseChatError, useCancelChatTurn, + useSessionStatus, useUpdateChat, useVercelChat, } from "@/hooks/use-chat" @@ -279,12 +280,26 @@ export function ChatSessionPane({ () => (chat?.messages || []).map(toUIMessage), [chat?.messages] ) + // Cheap status poll drives attaching to a live turn (started by another tab). + // `prompt` carries the active turn's user message so observer tabs can show + // it (the Vercel stream protocol only streams assistant messages). + const { + turnStatus, + currRunId, + prompt: activePrompt, + } = useSessionStatus({ + chatId: isReadonly ? undefined : chat?.id, + workspaceId, + }) const { sendMessage, messages, status, regenerate, lastError, clearError } = useVercelChat({ chatId: chat?.id, workspaceId, messages: uiMessages, modelInfo, + turnStatus, + currRunId, + activePrompt, }) // Track pending message sends to avoid duplicate sends diff --git a/frontend/src/hooks/use-chat.ts b/frontend/src/hooks/use-chat.ts index a28fe99114..197ea0b265 100644 --- a/frontend/src/hooks/use-chat.ts +++ b/frontend/src/hooks/use-chat.ts @@ -6,12 +6,14 @@ import { useQueryClient, } from "@tanstack/react-query" import { DefaultChatTransport, type UIMessage } from "ai" -import { useCallback, useMemo, useState } from "react" +import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { type AgentSessionCancelResponse, type AgentSessionCreate, type AgentSessionEntity, type AgentSessionRead, + type AgentSessionStatus, + type AgentSessionStatusRead, type AgentSessionsGetSessionResponse, type AgentSessionsGetSessionVercelResponse, type AgentSessionsListSessionsResponse, @@ -21,6 +23,7 @@ import { agentSessionsCreateSession, agentSessionsDeleteSession, agentSessionsGetSession, + agentSessionsGetSessionStatus, agentSessionsGetSessionVercel, agentSessionsListSessions, agentSessionsUpdateSession, @@ -34,6 +37,170 @@ import { type ModelInfo, toServerUIMessage } from "@/lib/chat" const DEFAULT_CHAT_ERROR_MESSAGE = "The assistant couldn't complete that request. Please try again." +/** A turn still has active run metadata while running or awaiting approval. */ +function isActiveTurnStatus(status: AgentSessionStatus | undefined): boolean { + return status === "running" || status === "waiting_for_approval" +} + +/** Fresh stream attachment is only safe while the Redis stream is still open. */ +function isStreamAttachableTurnStatus( + status: AgentSessionStatus | undefined +): boolean { + return status === "running" +} + +function getMessageText(message: UIMessage): string { + return message.parts + .filter( + (part): part is Extract => + part.type === "text" + ) + .map((part) => part.text) + .join("") +} + +function isMatchingUserPrompt( + message: UIMessage | undefined, + prompt: string +): boolean { + return message?.role === "user" && getMessageText(message) === prompt +} + +function hasActivePromptMessage( + messages: UIMessage[], + promptMessageId: string, + activeAssistantId: string, + prompt: string +): boolean { + if (messages.some((message) => message.id === promptMessageId)) { + return true + } + + const activeAssistantIndex = messages.findIndex( + (message) => + message.role === "assistant" && message.id === activeAssistantId + ) + if ( + activeAssistantIndex > 0 && + isMatchingUserPrompt(messages[activeAssistantIndex - 1], prompt) + ) { + return true + } + + return isMatchingUserPrompt(messages[messages.length - 1], prompt) +} + +/** + * Insert the active turn's user prompt into SDK chat state for observer tabs. + * + * The sender already gets an optimistic user message from `sendMessage`, while + * observers only receive the assistant over the Vercel UI stream. This keeps + * the prompt as durable chat state instead of deriving it during render. + */ +export function upsertActivePromptMessage( + messages: UIMessage[], + { + chatId, + currRunId, + prompt, + }: { + chatId?: string + currRunId?: string | null + prompt?: string | null + } +): UIMessage[] { + if (!chatId || !currRunId || !prompt?.trim()) { + return messages + } + + const promptMessageId = `active-user:${chatId}:${currRunId}` + const activeAssistantId = `${chatId}:${currRunId}` + + if ( + hasActivePromptMessage(messages, promptMessageId, activeAssistantId, prompt) + ) { + return messages + } + + const promptMessage: UIMessage = { + id: promptMessageId, + role: "user", + parts: [{ type: "text", text: prompt }], + } + const activeAssistantIndex = messages.findIndex( + (message) => + message.role === "assistant" && message.id === activeAssistantId + ) + + if (activeAssistantIndex === -1) { + return [...messages, promptMessage] + } + + return [ + ...messages.slice(0, activeAssistantIndex), + promptMessage, + ...messages.slice(activeAssistantIndex), + ] +} + +/** + * Read an SSE response stream and record the latest `id:` line into a ref. + * + * The AI SDK does not surface SSE event ids, so we capture them ourselves to + * resume from the right place on reconnect (sent back as `Last-Event-ID`). + * Consumes its own tee'd branch of the body; cancels cleanly on stream end. + */ +export async function scanSseIds( + stream: ReadableStream, + lastEventIdRef: { current: string | null } +): Promise { + const reader = stream.getReader() + const decoder = new TextDecoder() + let buffer = "" + let pendingEventId: string | null = null + + function processLine(rawLine: string) { + const line = rawLine.endsWith("\r") ? rawLine.slice(0, -1) : rawLine + if (line === "") { + if (pendingEventId !== null) { + lastEventIdRef.current = pendingEventId + pendingEventId = null + } + return + } + + if (line.startsWith("id:")) { + pendingEventId = line.slice(3).trim() + } + } + + function processCompleteLines() { + let newlineIndex = buffer.indexOf("\n") + while (newlineIndex !== -1) { + processLine(buffer.slice(0, newlineIndex)) + buffer = buffer.slice(newlineIndex + 1) + newlineIndex = buffer.indexOf("\n") + } + } + + try { + while (true) { + const { done, value } = await reader.read() + if (done) { + buffer += decoder.decode() + processCompleteLines() + break + } + buffer += decoder.decode(value, { stream: true }) + processCompleteLines() + } + } catch { + // Best-effort: a cancelled/aborted stream is not an error for id tracking. + } finally { + reader.cancel().catch(() => {}) + } +} + type UpdateableChatRecord = | AgentSessionsGetSessionResponse | AgentSessionsGetSessionVercelResponse @@ -410,20 +577,89 @@ export function useGetChatVercel({ return { chat, chatLoading, chatError } } +/** + * Poll the lifecycle status endpoint so a client learns when a turn + * starts (e.g. from another tab) and can attach to the live stream. Kept + * separate from the heavy message-history fetch. Polls a touch faster while a + * turn is live; always polls so an idle client notices a new turn. + */ +export function useSessionStatus({ + chatId, + workspaceId, +}: { + chatId?: string + workspaceId: string +}) { + const { data } = useQuery({ + queryKey: ["chat", chatId, workspaceId, "status"], + queryFn: async () => { + if (!chatId) { + throw new Error("No chat ID available") + } + return await agentSessionsGetSessionStatus({ + sessionId: chatId, + workspaceId, + }) + }, + enabled: !!chatId, + refetchInterval: (query) => + isActiveTurnStatus(query.state.data?.turn_status) ? 2000 : 3000, + refetchIntervalInBackground: false, + }) + return { + turnStatus: data?.turn_status, + currRunId: data?.curr_run_id, + prompt: data?.prompt, + } +} + // Combined hook for chat functionality with Vercel AI SDK streaming export function useVercelChat({ chatId, workspaceId, messages, modelInfo, + turnStatus, + currRunId, + activePrompt, }: { chatId?: string workspaceId: string messages: UIMessage[] modelInfo: ModelInfo + /** Server-reported lifecycle status; drives attaching to a live turn. */ + turnStatus?: AgentSessionStatus + /** Active server run id; used to key the observer prompt bubble. */ + currRunId?: string | null + /** Active run's user prompt for observer tabs. */ + activePrompt?: string | null }) { const queryClient = useQueryClient() const [lastError, setLastError] = useState(null) + // Last SSE id seen on the live stream; resent as Last-Event-ID on reconnect. + const lastEventIdRef = useRef(null) + const resumeAttemptKeyRef = useRef(null) + const activeResumeRunKeyRef = useRef(null) + const completedResumeRunKeyRef = useRef(null) + const insertedPromptKeyRef = useRef(null) + const previousActiveRunRef = useRef<{ + currRunId?: string | null + turnStatus?: AgentSessionStatus + }>({}) + + // Tee every streamed response: one branch feeds the SDK untouched, the other + // is scanned for SSE `id:` lines (the SDK does not expose them). + const trackingFetch = useCallback(async (input, init) => { + const response = await fetch(input, init) + if (!response.body) return response + const [toSdk, toScan] = response.body.tee() + void scanSseIds(toScan, lastEventIdRef) + return new Response(toSdk, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }) + }, []) // Build the Vercel streaming endpoint URL const apiEndpoint = useMemo(() => { @@ -434,19 +670,27 @@ export function useVercelChat({ }, [chatId, workspaceId]) // Use Vercel's useChat hook for streaming + // We attach to live turns via the status-driven effect below, not the SDK's + // one-shot `resume` (fires once on mount, never re-fires for a turn started + // afterward by another client). const chat = aiSdk.useChat({ id: chatId, - resume: !!chatId, messages, transport: new DefaultChatTransport({ api: apiEndpoint, credentials: "include", + fetch: trackingFetch, prepareReconnectToStreamRequest: ({ id }) => { const url = new URL(`/api/agent/sessions/${id}/stream`, getBaseUrl()) url.searchParams.set("workspace_id", workspaceId) + const headers: Record = {} + if (lastEventIdRef.current) { + headers["Last-Event-ID"] = lastEventIdRef.current + } return { api: url.toString(), credentials: "include", + headers, } }, prepareSendMessagesRequest: ({ messages }) => { @@ -486,7 +730,15 @@ export function useVercelChat({ description: friendlyMessage, }) }, - onFinish: () => { + onFinish: ({ isAbort, isDisconnect, isError }) => { + if (activeResumeRunKeyRef.current) { + if (isAbort || isDisconnect || isError) { + resumeAttemptKeyRef.current = null + } else { + completedResumeRunKeyRef.current = activeResumeRunKeyRef.current + } + activeResumeRunKeyRef.current = null + } setLastError(null) queryClient.invalidateQueries({ queryKey: ["chat", chatId, workspaceId, "vercel"], @@ -495,6 +747,110 @@ export function useVercelChat({ }, }) + const { messages: chatMessages, status, resumeStream, setMessages } = chat + const activePromptText = activePrompt?.trim() ? activePrompt : undefined + const activePromptPresent = + chatId && currRunId && activePromptText + ? hasActivePromptMessage( + chatMessages, + `active-user:${chatId}:${currRunId}`, + `${chatId}:${currRunId}`, + activePromptText + ) + : false + + useEffect(() => { + const previous = previousActiveRunRef.current + if ( + previous.currRunId !== currRunId || + !isActiveTurnStatus(turnStatus) || + (previous.turnStatus === "waiting_for_approval" && + turnStatus === "running") + ) { + resumeAttemptKeyRef.current = null + activeResumeRunKeyRef.current = null + completedResumeRunKeyRef.current = null + insertedPromptKeyRef.current = null + } + + previousActiveRunRef.current = { currRunId, turnStatus } + }, [currRunId, turnStatus]) + + // Observer tabs don't call sendMessage, so they don't get the optimistic user + // bubble. Add it to useChat state from the status poll before attaching. + useEffect(() => { + if ( + !isStreamAttachableTurnStatus(turnStatus) || + status !== "ready" || + !chatId || + !currRunId || + !activePromptText || + activePromptPresent + ) { + return + } + + const promptInsertKey = `${chatId}:${currRunId}` + if (insertedPromptKeyRef.current === promptInsertKey) { + return + } + + insertedPromptKeyRef.current = promptInsertKey + setMessages((current) => + upsertActivePromptMessage(current, { + chatId, + currRunId, + prompt: activePromptText, + }) + ) + }, [ + activePromptPresent, + activePromptText, + chatId, + currRunId, + setMessages, + status, + turnStatus, + ]) + + // Attach to a live turn whenever the server reports one and we're idle. + // The condition self-guards: resumeStream() flips status off "ready", so it + // won't re-fire mid-stream; if the stream drops while the turn is still + // running, the next status poll re-attaches. Handles back-to-back turns + // started by another client without relying on the SDK's one-shot `resume`. + useEffect(() => { + if (!isStreamAttachableTurnStatus(turnStatus) || status !== "ready") { + return + } + + if (activePromptText && !activePromptPresent) { + return + } + + const resumeRunKey = + chatId && currRunId ? `${chatId}:${currRunId}` : undefined + if (!resumeRunKey || completedResumeRunKeyRef.current === resumeRunKey) { + return + } + + const resumeAttemptKey = `${resumeRunKey}:${lastEventIdRef.current ?? "0-0"}` + if (resumeAttemptKeyRef.current === resumeAttemptKey) { + return + } + + resumeAttemptKeyRef.current = resumeAttemptKey + activeResumeRunKeyRef.current = resumeRunKey + void resumeStream() + }, [ + activePromptPresent, + activePromptText, + chatId, + currRunId, + resumeStream, + status, + turnStatus, + ]) + return { ...chat, lastError, diff --git a/frontend/tests/chat-session-pane.test.tsx b/frontend/tests/chat-session-pane.test.tsx index baa30a8f8e..812e1485bf 100644 --- a/frontend/tests/chat-session-pane.test.tsx +++ b/frontend/tests/chat-session-pane.test.tsx @@ -7,17 +7,26 @@ import { MessagePart, } from "@/components/chat/chat-session-pane" import { TooltipProvider } from "@/components/ui/tooltip" -import { useUpdateChat, useVercelChat } from "@/hooks/use-chat" +import { + useSessionStatus, + useUpdateChat, + useVercelChat, +} from "@/hooks/use-chat" import { useBuilderRegistryActions } from "@/lib/hooks" jest.mock("@/hooks/use-chat", () => ({ useVercelChat: jest.fn(), useGetChat: jest.fn(() => ({ chat: null })), - useUpdateChat: jest.fn(() => ({ updateChat: jest.fn(), isUpdating: false })), + useSessionStatus: jest.fn(() => ({ + turnStatus: undefined, + currRunId: undefined, + prompt: undefined, + })), useCancelChatTurn: jest.fn(() => ({ cancelChatTurn: jest.fn(), isCancellingChatTurn: false, })), + useUpdateChat: jest.fn(() => ({ updateChat: jest.fn(), isUpdating: false })), parseChatError: (error: unknown) => error instanceof Error ? error.message : "Chat error", makeContinueMessage: (decisions: unknown) => ({ @@ -80,6 +89,9 @@ const mockUseVercelChat = useVercelChat as jest.MockedFunction< const mockUseUpdateChat = useUpdateChat as jest.MockedFunction< typeof useUpdateChat > +const mockUseSessionStatus = useSessionStatus as jest.MockedFunction< + typeof useSessionStatus +> const mockUseBuilderRegistryActions = useBuilderRegistryActions as jest.MockedFunction< typeof useBuilderRegistryActions @@ -236,6 +248,30 @@ describe("ChatSessionPane", () => { }) }) + it("does not poll session status for read-only legacy chats", () => { + mockUseVercelChatStatus("ready") + const readonlyChat = { ...createChatFixture(), is_readonly: true } + + render( + + + + + + ) + + expect(mockUseSessionStatus).toHaveBeenCalledWith({ + chatId: undefined, + workspaceId: "workspace-1", + }) + }) + it("renders dots indicator while status is submitted", () => { mockUseVercelChat.mockReturnValue({ sendMessage: jest.fn(), diff --git a/frontend/tests/memo-render-count.test.tsx b/frontend/tests/memo-render-count.test.tsx index 69bdef3bba..6a4dad3d95 100644 --- a/frontend/tests/memo-render-count.test.tsx +++ b/frontend/tests/memo-render-count.test.tsx @@ -57,6 +57,11 @@ jest.mock("@/components/editor/codemirror/code-editor", () => ({ jest.mock("@/hooks/use-chat", () => ({ useVercelChat: jest.fn(), useGetChat: jest.fn(() => ({ chat: null })), + useSessionStatus: jest.fn(() => ({ + turnStatus: undefined, + currRunId: undefined, + prompt: undefined, + })), useUpdateChat: jest.fn(() => ({ updateChat: jest.fn(), isUpdating: false })), useCancelChatTurn: jest.fn(() => ({ cancelChatTurn: jest.fn(), diff --git a/frontend/tests/use-chat.test.tsx b/frontend/tests/use-chat.test.tsx index d7a310a8c5..6fc8c6e394 100644 --- a/frontend/tests/use-chat.test.tsx +++ b/frontend/tests/use-chat.test.tsx @@ -1,12 +1,17 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { renderHook, waitFor } from "@testing-library/react" +import type { UIMessage } from "ai" import type { AgentSessionRead, AgentSessionReadVercel, AgentSessionReadWithMessages, } from "@/client" import { agentSessionsUpdateSession } from "@/client" -import { useUpdateChat } from "@/hooks/use-chat" +import { + scanSseIds, + upsertActivePromptMessage, + useUpdateChat, +} from "@/hooks/use-chat" jest.mock("@/client", () => { const actual = jest.requireActual("@/client") @@ -73,6 +78,143 @@ function createSessionReadVercel( } } +function streamFromChunks(chunks: string[]): ReadableStream { + const encoder = new TextEncoder() + return new ReadableStream({ + start(controller) { + for (const chunk of chunks) { + controller.enqueue(encoder.encode(chunk)) + } + controller.close() + }, + }) +} + +describe("scanSseIds", () => { + it("publishes the event id after the blank line terminates the SSE event", async () => { + const lastEventIdRef = { current: null as string | null } + + await scanSseIds( + streamFromChunks(["id: 1000-0:1\n", 'data: {"type":"text"}\n\n']), + lastEventIdRef + ) + + expect(lastEventIdRef.current).toBe("1000-0:1") + }) + + it("does not publish an id from an incomplete SSE event", async () => { + const lastEventIdRef = { current: "999-0:0" as string | null } + + await scanSseIds( + streamFromChunks(["id: 1000-0:1\n", 'data: {"type":"text"}']), + lastEventIdRef + ) + + expect(lastEventIdRef.current).toBe("999-0:0") + }) + + it("uses the last id line in a completed SSE event", async () => { + const lastEventIdRef = { current: null as string | null } + + await scanSseIds( + streamFromChunks(["id: 1000-0:1\ndata: first\n", "id: 1000-0:2\n\n"]), + lastEventIdRef + ) + + expect(lastEventIdRef.current).toBe("1000-0:2") + }) +}) + +describe("upsertActivePromptMessage", () => { + it("appends the active prompt before the stream creates an assistant", () => { + const messages: UIMessage[] = [] + + const result = upsertActivePromptMessage(messages, { + chatId: "chat-1", + currRunId: "run-1", + prompt: "Investigate this alert", + }) + + expect(result).toEqual([ + { + id: "active-user:chat-1:run-1", + role: "user", + parts: [{ type: "text", text: "Investigate this alert" }], + }, + ]) + }) + + it("inserts the active prompt before an existing streaming assistant", () => { + const messages: UIMessage[] = [ + { + id: "chat-1:run-1", + role: "assistant", + parts: [{ type: "text", text: "Working" }], + }, + ] + + const result = upsertActivePromptMessage(messages, { + chatId: "chat-1", + currRunId: "run-1", + prompt: "Investigate this alert", + }) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + id: "active-user:chat-1:run-1", + role: "user", + parts: [{ type: "text", text: "Investigate this alert" }], + }) + expect(result[1]).toBe(messages[0]) + }) + + it("does not duplicate the sender's optimistic user prompt", () => { + const messages: UIMessage[] = [ + { + id: "local-user", + role: "user", + parts: [{ type: "text", text: "Investigate this alert" }], + }, + { + id: "chat-1:run-1", + role: "assistant", + parts: [{ type: "text", text: "Working" }], + }, + ] + + const result = upsertActivePromptMessage(messages, { + chatId: "chat-1", + currRunId: "run-1", + prompt: "Investigate this alert", + }) + + expect(result).toBe(messages) + }) + + it("is idempotent after inserting the active prompt", () => { + const messages: UIMessage[] = [ + { + id: "chat-1:run-1", + role: "assistant", + parts: [{ type: "text", text: "Working" }], + }, + ] + + const first = upsertActivePromptMessage(messages, { + chatId: "chat-1", + currRunId: "run-1", + prompt: "Investigate this alert", + }) + const second = upsertActivePromptMessage(first, { + chatId: "chat-1", + currRunId: "run-1", + prompt: "Investigate this alert", + }) + + expect(second).toBe(first) + }) +}) + describe("useUpdateChat", () => { let queryClient: QueryClient diff --git a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py index ec941b4515..d6f185e32a 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py +++ b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py @@ -4,7 +4,7 @@ import uuid from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from typing import Any, Literal +from typing import Any from pydantic import BaseModel, ConfigDict, Field from temporalio import workflow @@ -74,6 +74,7 @@ AgentCancelReason, AgentSessionEntity, AgentSessionStatus, + AgentWorkflowTurnStatus, ) from tracecat.agent.subagents import has_manual_tool_approvals from tracecat.agent.tokens import ( @@ -354,7 +355,7 @@ def __init__(self, args: AgentWorkflowArgs): ctx_role.set(args.role) AgentContext.set(session_id=args.agent_args.session_id) - self._status: Literal["running", "waiting_for_results", "done"] = "running" + self._status: AgentWorkflowTurnStatus = AgentSessionStatus.RUNNING.value self._turn: int = 0 if args.role.workspace_id is None: raise ApplicationError("Role must have a workspace ID", non_retryable=True) @@ -372,6 +373,11 @@ def __init__(self, args: AgentWorkflowArgs): self._cancel_requested = False self._cancel_reason: AgentCancelReason | None = None + @workflow.query + def get_turn_status(self) -> AgentWorkflowTurnStatus: + """Return the live product turn status owned by this workflow.""" + return self._status + def _upsert_tracecat_search_attributes(self) -> None: """Ensure direct agent runs have core Tracecat search attributes. @@ -701,6 +707,7 @@ async def run(self, args: AgentWorkflowArgs) -> AgentOutput: cfg = await self._build_config(args) return await self._run_with_agent_executor(args, cfg) except ActivityError as e: + self._status = AgentSessionStatus.FAILED.value await self._set_agent_session_status( AgentSessionStatus.FAILED, clear_curr_run_id=True, @@ -709,6 +716,7 @@ async def run(self, args: AgentWorkflowArgs) -> AgentOutput: await self._emit_session_error(_activity_error_message(e)) raise except ApplicationError as e: + self._status = AgentSessionStatus.FAILED.value await self._set_agent_session_status( AgentSessionStatus.FAILED, clear_curr_run_id=True, @@ -722,6 +730,7 @@ async def run(self, args: AgentWorkflowArgs) -> AgentOutput: except TemporalCancelledError: raise except Exception: + self._status = AgentSessionStatus.FAILED.value await self._set_agent_session_failed_for_unhandled_failure() raise @@ -1017,7 +1026,7 @@ async def _run_with_agent_executor( # Run the executor activity while True: logger.info("Executing agent turn", turn=self._turn) - self._status = "running" + self._status = AgentSessionStatus.RUNNING.value await self._set_agent_session_status(AgentSessionStatus.RUNNING) result = await self._run_agent_activity_turn(executor_input) @@ -1028,7 +1037,7 @@ async def _run_with_agent_executor( session_id=self.session_id, reason=result.cancelled_reason, ) - self._status = "done" + self._status = AgentSessionStatus.STOPPED.value await self._set_agent_session_status( AgentSessionStatus.STOPPED, clear_curr_run_id=True, @@ -1058,7 +1067,7 @@ async def _run_with_agent_executor( if result.approval_requested: logger.info("Agent waiting for approval", session_id=self.session_id) - self._status = "waiting_for_results" + self._status = AgentSessionStatus.WAITING_FOR_APPROVAL.value await self._set_agent_session_status( AgentSessionStatus.WAITING_FOR_APPROVAL ) @@ -1101,7 +1110,7 @@ async def _run_with_agent_executor( } ) await self.approvals.handle_decisions() - self._status = "done" + self._status = AgentSessionStatus.STOPPED.value await self._set_agent_session_status( AgentSessionStatus.STOPPED, clear_curr_run_id=True, @@ -1122,7 +1131,7 @@ async def _run_with_agent_executor( session_id=self.session_id, reason=self._cancel_reason, ) - self._status = "done" + self._status = AgentSessionStatus.STOPPED.value await self._set_agent_session_status( AgentSessionStatus.STOPPED, clear_curr_run_id=True, @@ -1135,7 +1144,7 @@ async def _run_with_agent_executor( usage=RunUsage(requests=0, input_tokens=0, output_tokens=0), session_id=self.session_id, ) - self._status = "running" + self._status = AgentSessionStatus.RUNNING.value await self._set_agent_session_status(AgentSessionStatus.RUNNING) if self._cancel_requested: logger.info( @@ -1143,7 +1152,7 @@ async def _run_with_agent_executor( session_id=self.session_id, reason=self._cancel_reason, ) - self._status = "done" + self._status = AgentSessionStatus.STOPPED.value await self._set_agent_session_status( AgentSessionStatus.STOPPED, clear_curr_run_id=True, @@ -1210,7 +1219,7 @@ async def _run_with_agent_executor( output=result.output, ) message_history = await self._load_terminal_message_history(result) - self._status = "done" + self._status = AgentSessionStatus.IDLE.value await self._set_agent_session_status( AgentSessionStatus.IDLE, clear_curr_run_id=True, diff --git a/tests/unit/test_agent_executor_loopback.py b/tests/unit/test_agent_executor_loopback.py index f2740bf40c..242744f868 100644 --- a/tests/unit/test_agent_executor_loopback.py +++ b/tests/unit/test_agent_executor_loopback.py @@ -1,14 +1,19 @@ from __future__ import annotations import asyncio +import contextlib import uuid +from collections.abc import AsyncIterator from pathlib import Path +from typing import cast from unittest.mock import AsyncMock from uuid import UUID import orjson import pytest +from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession from tracecat.agent.common.protocol import RuntimeEventEnvelope from tracecat.agent.common.socket_io import MessageType, build_message @@ -19,6 +24,8 @@ LoopbackHandler, LoopbackInput, ) +from tracecat.auth.types import Role +from tracecat.db.models import AgentSession, AgentSessionHistory class _FakeStream: @@ -417,3 +424,66 @@ async def test_cancelled_loopback_preserves_later_stream_and_history() -> None: assert stream.append.await_count == 1 persist.assert_awaited_once() + + +@pytest.mark.anyio +@pytest.mark.usefixtures("db") +async def test_persist_session_line_stamps_curr_run_id( + monkeypatch: pytest.MonkeyPatch, + session: AsyncSession, + svc_role: Role, +) -> None: + """Persisted history rows carry the session's active run id so mid-turn DB + loads can hide them.""" + run_id = uuid.uuid4() + agent_session = AgentSession( + id=uuid.uuid4(), + workspace_id=svc_role.workspace_id, + title="Chat", + entity_type="case", + entity_id=uuid.uuid4(), + curr_run_id=run_id, + ) + session.add(agent_session) + await session.commit() + + handler = LoopbackHandler( + input=LoopbackInput( + session_id=agent_session.id, + workspace_id=cast(uuid.UUID, svc_role.workspace_id), + ) + ) + + @contextlib.asynccontextmanager + async def _ctx() -> AsyncIterator[AsyncSession]: + yield session + + monkeypatch.setattr( + "tracecat.agent.executor.loopback.get_async_session_bypass_rls_context_manager", + _ctx, + ) + + await handler._persist_session_line( + "sdk-session", + orjson.dumps( + { + "type": "assistant", + "uuid": str(uuid.uuid4()), + "message": {"role": "assistant", "content": []}, + } + ).decode(), + ) + + rows = ( + ( + await session.execute( + select(AgentSessionHistory).where( + AgentSessionHistory.session_id == agent_session.id + ) + ) + ) + .scalars() + .all() + ) + assert len(rows) == 1 + assert rows[0].curr_run_id == run_id diff --git a/tests/unit/test_agent_session_messages.py b/tests/unit/test_agent_session_messages.py index 2fbef003db..42fb7d6c79 100644 --- a/tests/unit/test_agent_session_messages.py +++ b/tests/unit/test_agent_session_messages.py @@ -7,13 +7,14 @@ import orjson import pytest +from sqlalchemy.ext.asyncio import AsyncSession from tracecat.agent.approvals.enums import ApprovalStatus from tracecat.agent.session.service import AgentSessionService from tracecat.agent.session.types import AgentSessionStatus from tracecat.auth.types import Role from tracecat.chat.enums import MessageKind -from tracecat.db.models import AgentSession +from tracecat.db.models import AgentSession, AgentSessionHistory def _mock_scalar_result(items: list[Any]) -> Mock: @@ -407,3 +408,100 @@ async def test_list_messages_skips_misclassified_continuation_artifacts() -> Non assert len(messages) == 1 assert messages[0].message is not None + + +def _chat_message_content(text: str) -> dict[str, Any]: + return { + "type": "assistant", + "uuid": str(uuid.uuid4()), + "message": {"role": "assistant", "content": [{"type": "text", "text": text}]}, + } + + +@pytest.mark.anyio +@pytest.mark.usefixtures("db") +async def test_list_messages_hides_active_run_rows_while_running( + session: AsyncSession, + svc_role: Role, +) -> None: + """While RUNNING, the active run's rows are hidden (Redis owns them); rows + from prior runs or without a run id stay visible.""" + active_run = uuid.uuid4() + prior_run = uuid.uuid4() + agent_session = AgentSession( + id=uuid.uuid4(), + workspace_id=svc_role.workspace_id, + title="Chat", + entity_type="case", + entity_id=uuid.uuid4(), + status=AgentSessionStatus.RUNNING.value, + curr_run_id=active_run, + ) + session.add(agent_session) + for text, run_id in ( + ("prior turn", prior_run), + ("legacy row", None), + ("active partial", active_run), + ): + session.add( + AgentSessionHistory( + session_id=agent_session.id, + workspace_id=svc_role.workspace_id, + content=_chat_message_content(text), + kind=MessageKind.CHAT_MESSAGE.value, + curr_run_id=run_id, + ) + ) + await session.commit() + + service = AgentSessionService(session=session, role=svc_role) + texts = [ + getattr(part, "text", None) + for m in await service.list_messages(agent_session.id) + if m.message is not None + for part in cast(Any, m.message).content + ] + + assert "prior turn" in texts + assert "legacy row" in texts + assert "active partial" not in texts + + +@pytest.mark.anyio +@pytest.mark.usefixtures("db") +async def test_list_messages_keeps_waiting_for_approval_rows( + session: AsyncSession, + svc_role: Role, +) -> None: + """WAITING_FOR_APPROVAL is DB-backed; fresh observers should load history.""" + active_run = uuid.uuid4() + agent_session = AgentSession( + id=uuid.uuid4(), + workspace_id=svc_role.workspace_id, + title="Chat", + entity_type="case", + entity_id=uuid.uuid4(), + status=AgentSessionStatus.WAITING_FOR_APPROVAL.value, + curr_run_id=active_run, + ) + session.add(agent_session) + session.add( + AgentSessionHistory( + session_id=agent_session.id, + workspace_id=svc_role.workspace_id, + content=_chat_message_content("approval is visible"), + kind=MessageKind.CHAT_MESSAGE.value, + curr_run_id=active_run, + ) + ) + await session.commit() + + service = AgentSessionService(session=session, role=svc_role) + texts = [ + getattr(part, "text", None) + for m in await service.list_messages(agent_session.id) + if m.message is not None + for part in cast(Any, m.message).content + ] + + assert "approval is visible" in texts diff --git a/tests/unit/test_agent_session_router.py b/tests/unit/test_agent_session_router.py index 63862ff528..d3669274b2 100644 --- a/tests/unit/test_agent_session_router.py +++ b/tests/unit/test_agent_session_router.py @@ -16,10 +16,12 @@ from tracecat.agent.session.router import ( cancel_session, get_session, + get_session_status, get_session_vercel, send_message, stream_session_events, ) +from tracecat.agent.session.service import AgentTurnState from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.auth.types import Role from tracecat.chat.schemas import ( @@ -197,6 +199,46 @@ async def test_cancel_session_maps_conflict_to_409() -> None: assert exc_info.value.detail == {"status": "idle"} +@pytest.mark.anyio +@pytest.mark.parametrize("status_value", ["running", "waiting_for_approval"]) +async def test_get_session_status_includes_prompt_for_active_turn( + status_value: str, +) -> None: + session_id = uuid.uuid4() + workspace_id = uuid.uuid4() + run_id = uuid.uuid4() + session_stub = SimpleNamespace( + status=AgentSessionStatus.IDLE.value, + curr_run_id=None, + ) + fake_svc = SimpleNamespace( + get_turn_state=AsyncMock( + return_value=AgentTurnState( + session=cast(Any, session_stub), + turn_status=AgentSessionStatus(status_value), + curr_run_id=run_id, + ) + ), + get_active_run_prompt=AsyncMock(return_value="Investigate this alert"), + ) + + with patch( + "tracecat.agent.session.router.AgentSessionService", return_value=fake_svc + ): + raw_get_session_status = cast(Any, get_session_status).__wrapped__ + response = await raw_get_session_status( + session_id=session_id, + role=_read_role(workspace_id), + session=AsyncMock(), + ) + + assert response.turn_status == status_value + assert response.curr_run_id == run_id + assert response.prompt == "Investigate this alert" + fake_svc.get_turn_state.assert_awaited_once_with(session_id) + fake_svc.get_active_run_prompt.assert_awaited_once_with(session_id, run_id) + + @pytest.mark.anyio async def test_send_message_continue_uses_path_session_id_for_stream_key() -> None: session_id = uuid.uuid4() @@ -222,6 +264,7 @@ async def test_send_message_continue_uses_path_session_id_for_stream_key() -> No is_legacy_session=AsyncMock(return_value=False), validate_turn_request=AsyncMock(return_value=None), run_turn=AsyncMock(return_value=None), + get_session=AsyncMock(return_value=SimpleNamespace(curr_run_id=uuid.uuid4())), ) fake_stream = SimpleNamespace( reset_for_new_turn=AsyncMock(return_value=None), @@ -287,13 +330,16 @@ async def test_send_message_new_turn_resets_stream_before_streaming() -> None: model_provider="openai", ) + run_id = uuid.uuid4() fake_svc = SimpleNamespace( is_legacy_session=AsyncMock(return_value=False), validate_turn_request=AsyncMock(return_value=None), run_turn=AsyncMock(return_value=None), + get_session=AsyncMock(return_value=SimpleNamespace(curr_run_id=run_id)), ) fake_stream = SimpleNamespace( reset_for_new_turn=AsyncMock(return_value=None), + append=AsyncMock(return_value=None), abort_new_turn=AsyncMock(return_value=None), sse=Mock(return_value=_empty_event_stream()), ) @@ -466,17 +512,57 @@ def _make_stream_role(workspace_id: uuid.UUID) -> Role: ) -@pytest.mark.anyio -async def test_stream_session_events_returns_204_when_no_turn_started() -> None: - """last_stream_id=None with no Last-Event-ID header returns 204.""" - session_id = uuid.uuid4() - workspace_id = uuid.uuid4() - role = _make_stream_role(workspace_id) +def _fake_stream_session( + *, + status_value: str, + last_stream_id: str | None = None, + curr_run_id: uuid.UUID | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + status=status_value, + last_stream_id=last_stream_id, + curr_run_id=curr_run_id, + ) - fake_session = SimpleNamespace(last_stream_id=None) - fake_svc = SimpleNamespace(get_session=AsyncMock(return_value=fake_session)) - fake_stream = SimpleNamespace(sse=Mock(return_value=_empty_event_stream())) +def _fake_stream(min_id: str | None = None) -> SimpleNamespace: + return SimpleNamespace( + sse=Mock(return_value=_empty_event_stream()), + finished_sse=Mock(return_value=_empty_event_stream()), + min_entry_id=AsyncMock(return_value=min_id), + ) + + +async def _run_stream_endpoint( + *, + session: SimpleNamespace | None, + stream: SimpleNamespace, + headers: dict[str, str], + session_id: uuid.UUID | None = None, + latest_run_id: uuid.UUID | None = None, + turn_status: AgentSessionStatus | None = None, + turn_run_id: uuid.UUID | None = None, +) -> Any: + session_id = session_id or uuid.uuid4() + role = _make_stream_role(uuid.uuid4()) + if turn_status is None: + turn_status = ( + AgentSessionStatus(session.status) + if session is not None + else AgentSessionStatus.IDLE + ) + if turn_run_id is None and session is not None: + turn_run_id = session.curr_run_id + fake_svc = SimpleNamespace( + get_turn_state=AsyncMock( + return_value=AgentTurnState( + session=cast(Any, session), + turn_status=turn_status, + curr_run_id=turn_run_id, + ) + ), + get_latest_history_run_id=AsyncMock(return_value=latest_run_id), + ) with ( patch( "tracecat.agent.session.router.AgentSessionService.with_session", @@ -484,19 +570,32 @@ async def test_stream_session_events_returns_204_when_no_turn_started() -> None: ), patch( "tracecat.agent.session.router.AgentStream.new", - AsyncMock(return_value=fake_stream), + AsyncMock(return_value=stream), ), ): raw = cast(Any, stream_session_events).__wrapped__ - response = await raw( + return await raw( role=role, - request=SimpleNamespace(headers={}), + request=SimpleNamespace( + headers=headers, is_disconnected=AsyncMock(return_value=False) + ), session_id=session_id, ) + +@pytest.mark.anyio +async def test_stream_session_events_returns_204_when_no_turn_started() -> None: + """Idle status with no Last-Event-ID header returns 204.""" + stream = _fake_stream() + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="idle"), + stream=stream, + headers={}, + ) + assert isinstance(response, Response) assert response.status_code == status.HTTP_204_NO_CONTENT - fake_stream.sse.assert_not_called() + stream.sse.assert_not_called() @pytest.mark.anyio @@ -504,65 +603,138 @@ async def test_stream_session_events_attaches_when_turn_in_progress_no_events_ye None ): session_id = uuid.uuid4() - workspace_id = uuid.uuid4() - role = _make_stream_role(workspace_id) + curr_run_id = uuid.uuid4() + stream = _fake_stream() + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="running", curr_run_id=curr_run_id), + stream=stream, + headers={}, + session_id=session_id, + ) - fake_session = SimpleNamespace(last_stream_id="0-0") - fake_svc = SimpleNamespace(get_session=AsyncMock(return_value=fake_session)) - fake_stream = SimpleNamespace(sse=Mock(return_value=_empty_event_stream())) + assert isinstance(response, StreamingResponse) + stream.sse.assert_called_once() + assert stream.sse.call_args.kwargs["last_id"] == "0-0" + assert stream.sse.call_args.kwargs["message_id"] == f"{session_id}:{curr_run_id}" - with ( - patch( - "tracecat.agent.session.router.AgentSessionService.with_session", - return_value=_AsyncContext(fake_svc), - ), - patch( - "tracecat.agent.session.router.AgentStream.new", - AsyncMock(return_value=fake_stream), - ), - ): - raw = cast(Any, stream_session_events).__wrapped__ - response = await raw( - role=role, - request=SimpleNamespace( - headers={}, is_disconnected=AsyncMock(return_value=False) - ), - session_id=session_id, - ) + +@pytest.mark.anyio +async def test_stream_session_events_uses_temporal_turn_state_for_attach() -> None: + session_id = uuid.uuid4() + curr_run_id = uuid.uuid4() + stream = _fake_stream() + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="idle", curr_run_id=curr_run_id), + stream=stream, + headers={}, + session_id=session_id, + turn_status=AgentSessionStatus.RUNNING, + ) assert isinstance(response, StreamingResponse) - fake_stream.sse.assert_called_once() + stream.sse.assert_called_once() + assert stream.sse.call_args.kwargs["message_id"] == f"{session_id}:{curr_run_id}" + + +@pytest.mark.anyio +async def test_stream_session_events_waiting_without_cursor_uses_db_history() -> None: + stream = _fake_stream() + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="waiting_for_approval"), + stream=stream, + headers={}, + ) + + assert isinstance(response, Response) + assert response.status_code == status.HTTP_204_NO_CONTENT + stream.sse.assert_not_called() + stream.finished_sse.assert_not_called() @pytest.mark.anyio -async def test_stream_session_events_attaches_when_last_event_id_present() -> None: +async def test_stream_session_events_resumes_after_last_event_id() -> None: + """A live cursor newer than the buffer min resumes after it (composite id).""" + curr_run_id = uuid.uuid4() + stream = _fake_stream(min_id="1000-0") + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="running", curr_run_id=curr_run_id), + stream=stream, + headers={"Last-Event-ID": "1234-0:2"}, + ) + + assert isinstance(response, StreamingResponse) + stream.sse.assert_called_once() + assert stream.sse.call_args.kwargs["last_id"] == "1234-0" + assert stream.sse.call_args.kwargs["resume_from"] == "1234-0:2" + + +@pytest.mark.anyio +async def test_stream_session_events_terminal_reconnect_uses_latest_run_id() -> None: + """Terminal reconnect against retained Redis keeps the turn's assistant id.""" session_id = uuid.uuid4() - workspace_id = uuid.uuid4() - role = _make_stream_role(workspace_id) + latest_run_id = uuid.uuid4() + stream = _fake_stream(min_id="1000-0") + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="stopped"), + stream=stream, + headers={"Last-Event-ID": "1234-0:2"}, + session_id=session_id, + latest_run_id=latest_run_id, + ) - fake_session = SimpleNamespace(last_stream_id=None) - fake_svc = SimpleNamespace(get_session=AsyncMock(return_value=fake_session)) - fake_stream = SimpleNamespace(sse=Mock(return_value=_empty_event_stream())) + assert isinstance(response, StreamingResponse) + stream.sse.assert_called_once() + assert stream.sse.call_args.kwargs["message_id"] == ( + f"{session_id}:{latest_run_id}" + ) + assert stream.sse.call_args.kwargs["resume_from"] == "1234-0:2" - with ( - patch( - "tracecat.agent.session.router.AgentSessionService.with_session", - return_value=_AsyncContext(fake_svc), - ), - patch( - "tracecat.agent.session.router.AgentStream.new", - AsyncMock(return_value=fake_stream), - ), - ): - raw = cast(Any, stream_session_events).__wrapped__ - response = await raw( - role=role, - request=SimpleNamespace( - headers={"Last-Event-ID": "1234-0"}, - is_disconnected=AsyncMock(return_value=False), - ), - session_id=session_id, - ) + +@pytest.mark.anyio +async def test_stream_session_events_stale_cursor_running_replays_from_start() -> None: + """Cursor older than the buffer min, still running -> replay from 0-0.""" + curr_run_id = uuid.uuid4() + stream = _fake_stream(min_id="2000-0") + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="running", curr_run_id=curr_run_id), + stream=stream, + headers={"Last-Event-ID": "1000-0:0"}, + ) assert isinstance(response, StreamingResponse) - fake_stream.sse.assert_called_once() + stream.sse.assert_called_once() + assert stream.sse.call_args.kwargs["last_id"] == "0-0" + stream.finished_sse.assert_not_called() + + +@pytest.mark.anyio +async def test_stream_session_events_stale_cursor_terminal_finishes() -> None: + """Cursor older than the buffer min and turn terminal -> finishing stream.""" + stream = _fake_stream(min_id=None) + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="stopped"), + stream=stream, + headers={"Last-Event-ID": "1000-0:0"}, + ) + + assert isinstance(response, StreamingResponse) + stream.finished_sse.assert_called_once() + assert stream.finished_sse.call_args.kwargs["message_id"] is None + stream.sse.assert_not_called() + + +@pytest.mark.anyio +async def test_stream_session_events_terminal_without_run_id_finishes() -> None: + """Terminal reconnects without a run id do not synthesize a new bubble id.""" + stream = _fake_stream(min_id="1000-0") + response = await _run_stream_endpoint( + session=_fake_stream_session(status_value="stopped"), + stream=stream, + headers={"Last-Event-ID": "1234-0:2"}, + latest_run_id=None, + ) + + assert isinstance(response, StreamingResponse) + stream.finished_sse.assert_called_once() + assert stream.finished_sse.call_args.kwargs["message_id"] is None + stream.sse.assert_not_called() diff --git a/tests/unit/test_agent_session_turn_validation.py b/tests/unit/test_agent_session_turn_validation.py index 7ca03e3a96..49694a0b0e 100644 --- a/tests/unit/test_agent_session_turn_validation.py +++ b/tests/unit/test_agent_session_turn_validation.py @@ -7,6 +7,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession +from temporalio.client import WorkflowExecutionStatus from tracecat.agent.session.schemas import AgentSessionCancelRequest from tracecat.agent.session.service import AgentSessionService @@ -87,6 +88,128 @@ async def test_validate_turn_request_rejects_active_turn_before_pending_rows( has_pending_approvals.assert_not_awaited() +@pytest.mark.anyio +async def test_validate_turn_request_rejects_idle_projection_with_active_workflow() -> ( + None +): + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.IDLE, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ) + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + ) + has_pending_approvals = AsyncMock(return_value=False) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch.object(service, "has_pending_approvals", has_pending_approvals), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.validate_turn_request( + session_id=session_id, + request=BasicChatRequest(message="hello"), + ) + + has_pending_approvals.assert_not_awaited() + workflow_handle.describe.assert_awaited_once() + + +@pytest.mark.anyio +async def test_get_turn_state_uses_workflow_phase_over_db_projection() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.IDLE, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ), + query=AsyncMock(return_value="waiting_for_approval"), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + state = await service.get_turn_state(session_id) + + assert state.session is agent_session + assert state.curr_run_id == run_id + assert state.turn_status is AgentSessionStatus.WAITING_FOR_APPROVAL + assert state.is_stream_attachable is False + workflow_handle.describe.assert_awaited_once() + workflow_handle.query.assert_awaited_once() + + +@pytest.mark.anyio +async def test_get_turn_state_repairs_terminal_projection() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.TERMINATED) + ) + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + state = await service.get_turn_state(session_id) + + assert state.curr_run_id is None + assert state.turn_status is AgentSessionStatus.STOPPED + assert agent_session.curr_run_id is None + assert agent_session.status == AgentSessionStatus.STOPPED.value + cast(Any, service.session.commit).assert_awaited_once() + + @pytest.mark.anyio @pytest.mark.parametrize( "status", @@ -106,9 +229,15 @@ async def test_request_cancel_accepts_active_turn_statuses( status=status, curr_run_id=run_id, ) - workflow_handle = SimpleNamespace(execute_update=AsyncMock()) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ), + execute_update=AsyncMock(), + ) temporal_client = SimpleNamespace( - get_workflow_handle_for=MagicMock(return_value=workflow_handle) + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), ) with ( @@ -126,4 +255,90 @@ async def test_request_cancel_accepts_active_turn_statuses( assert response.session_id == session_id assert response.run_id == run_id assert response.turn_status is status + workflow_handle.describe.assert_awaited_once() + workflow_handle.execute_update.assert_awaited_once() + + +@pytest.mark.anyio +async def test_request_cancel_accepts_idle_projection_with_active_workflow() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.IDLE, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.RUNNING) + ), + execute_update=AsyncMock(), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + response = await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert response.session_id == session_id + assert response.run_id == run_id + assert response.turn_status is AgentSessionStatus.IDLE + workflow_handle.describe.assert_awaited_once() workflow_handle.execute_update.assert_awaited_once() + + +@pytest.mark.anyio +async def test_request_cancel_handles_finished_current_run() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + run_id = uuid.uuid4() + role = _build_role(workspace_id) + service = _build_service(role) + agent_session = _build_agent_session( + workspace_id=workspace_id, + session_id=session_id, + status=AgentSessionStatus.RUNNING, + curr_run_id=run_id, + ) + workflow_handle = SimpleNamespace( + describe=AsyncMock( + return_value=SimpleNamespace(status=WorkflowExecutionStatus.TERMINATED) + ), + execute_update=AsyncMock(), + ) + temporal_client = SimpleNamespace( + get_workflow_handle=MagicMock(return_value=workflow_handle), + get_workflow_handle_for=MagicMock(return_value=workflow_handle), + ) + + with ( + patch.object(service, "get_session", AsyncMock(return_value=agent_session)), + patch( + "tracecat.agent.session.service.get_temporal_client", + AsyncMock(return_value=temporal_client), + ), + ): + with pytest.raises(TracecatConflictError, match="active turn"): + await service.request_cancel( + session_id, + AgentSessionCancelRequest(reason="user_cancel"), + ) + + assert agent_session.curr_run_id is None + assert agent_session.status == AgentSessionStatus.STOPPED.value + workflow_handle.execute_update.assert_not_awaited() diff --git a/tests/unit/test_agent_stream_connector.py b/tests/unit/test_agent_stream_connector.py index 7f91c402fa..87954c19e3 100644 --- a/tests/unit/test_agent_stream_connector.py +++ b/tests/unit/test_agent_stream_connector.py @@ -74,7 +74,8 @@ async def test_stream_events_clears_buffer_after_terminal_marker() -> None: assert isinstance(event, StreamEnd) - stream._set_last_stream_id.assert_awaited_once_with(None) + # Readers never write the cursor; on terminal we only expire the buffer. + stream._set_last_stream_id.assert_not_awaited() raw_client.expire.assert_awaited_once_with( name=stream._stream_key, time=stream.COMPLETED_STREAM_TTL_SECONDS, @@ -82,7 +83,7 @@ async def test_stream_events_clears_buffer_after_terminal_marker() -> None: @pytest.mark.anyio -async def test_stream_events_preserves_cursor_when_stream_not_completed() -> None: +async def test_stream_events_does_not_write_cursor_when_not_completed() -> None: workspace_id = uuid.uuid4() session_id = uuid.uuid4() raw_client = SimpleNamespace(expire=AsyncMock(return_value=None)) @@ -120,5 +121,91 @@ async def test_stream_events_preserves_cursor_when_stream_not_completed() -> Non assert len(events) == 1 assert isinstance(events[0], StreamDelta) - stream._set_last_stream_id.assert_awaited() + # Readers never write the cursor: the browser owns it via Last-Event-ID. + stream._set_last_stream_id.assert_not_awaited() raw_client.expire.assert_not_awaited() + + +@pytest.mark.anyio +async def test_stream_events_can_replay_cursor_entry_before_xread() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + raw_client = SimpleNamespace(expire=AsyncMock(return_value=None)) + client = SimpleNamespace( + xrange=AsyncMock( + return_value=[ + ( + "1717426372768-0", + { + tokens.DATA_KEY: b'{"type":"text_delta","text":"first"}', + }, + ) + ] + ), + xread=AsyncMock( + return_value=[ + ( + f"agent-stream:{workspace_id}:{session_id}", + [ + ( + "1717426372769-0", + { + tokens.DATA_KEY: b'{"type":"text_delta","text":"next"}', + }, + ) + ], + ) + ] + ), + _get_client=AsyncMock(return_value=raw_client), + ) + stream = AgentStream( + client=cast(RedisClient, client), + workspace_id=workspace_id, + session_id=session_id, + ) + + events = [ + event + async for event in stream._stream_events( + AsyncMock(side_effect=[False, True]), + last_id="1717426372768-0", + include_last_id=True, + ) + ] + + assert [event.id for event in events if isinstance(event, StreamDelta)] == [ + "1717426372768-0", + "1717426372769-0", + ] + client.xrange.assert_awaited_once_with( + stream._stream_key, + min_id="1717426372768-0", + max_id="1717426372768-0", + count=1, + ) + client.xread.assert_awaited_once_with( + streams={stream._stream_key: "1717426372768-0"}, + count=100, + block=1000, + ) + + +@pytest.mark.anyio +async def test_min_entry_id_returns_oldest_or_none() -> None: + workspace_id = uuid.uuid4() + session_id = uuid.uuid4() + client = SimpleNamespace( + xrange=AsyncMock(return_value=[("1717426372768-0", {})]), + ) + stream = AgentStream( + client=cast(RedisClient, client), + workspace_id=workspace_id, + session_id=session_id, + ) + + assert await stream.min_entry_id() == "1717426372768-0" + client.xrange.assert_awaited_once_with(stream._stream_key, count=1) + + client.xrange = AsyncMock(return_value=[]) + assert await stream.min_entry_id() is None diff --git a/tests/unit/test_vercel_stream_context.py b/tests/unit/test_vercel_stream_context.py index 8dea768499..598ce4eea0 100644 --- a/tests/unit/test_vercel_stream_context.py +++ b/tests/unit/test_vercel_stream_context.py @@ -24,6 +24,7 @@ VercelSSEPayload, VercelStreamContext, format_sse, + sse_vercel, ) from tracecat.agent.common.stream_types import ( StreamEventType, @@ -980,3 +981,81 @@ async def test_format_sse_produces_valid_output(): data = json.loads(json_str) assert data["type"] == "text-start" assert data["id"] == "test_id" + + +@pytest.mark.anyio +async def test_format_sse_emits_id_line_when_given(): + """An sse_id prefixes an `id:` line so the browser can replay from it.""" + payload = TextDeltaEventPayload(id="part", delta="hi") + assert format_sse(payload, "1234-0:2").startswith("id: 1234-0:2\ndata: ") + + +@pytest.mark.anyio +async def test_sse_vercel_uses_stable_bubble_id_and_composite_frame_ids(): + """sse_vercel keeps a stable messageId and stamps composite ids per frame.""" + from tracecat.agent.stream.events import StreamDelta + + async def events(): + yield StreamDelta( + id="1000-0", + event=UnifiedStreamEvent( + type=StreamEventType.TEXT_START, part_id=0, text="hello" + ), + ) + + frames: list[str] = [] + async for frame in sse_vercel(events(), message_id="sess:run"): + frames.append(frame) + + # Start frame carries the stable bubble id (no random msg_*). + start = next(f for f in frames if '"type":"start"' in f) + assert '"messageId":"sess:run"' in start + + # Both delta frames from the single Redis entry share the redis id with an + # incrementing frame index. + id_frames = [f for f in frames if f.startswith("id: ")] + assert id_frames[0].startswith("id: 1000-0:0\n") + assert id_frames[1].startswith("id: 1000-0:1\n") + + +@pytest.mark.anyio +async def test_sse_vercel_omits_start_frame_when_message_id_unknown(): + """Terminal reconnect finish streams should not create a synthetic bubble.""" + + async def events(): + return + yield # pragma: no cover - establishes async generator + + frames: list[str] = [] + async for frame in sse_vercel(events(), message_id=None): + frames.append(frame) + + assert not any('"type":"start"' in frame for frame in frames) + assert any('"type":"finish"' in frame for frame in frames) + + +@pytest.mark.anyio +async def test_sse_vercel_skips_frames_at_composite_resume_cursor(): + """Reconnect inside a Redis entry replays that entry but drops seen frames.""" + from tracecat.agent.stream.events import StreamDelta + + async def events(): + yield StreamDelta( + id="1000-0", + event=UnifiedStreamEvent( + type=StreamEventType.TEXT_START, part_id=0, text="hello" + ), + ) + + frames: list[str] = [] + async for frame in sse_vercel( + events(), + message_id="sess:run", + resume_from="1000-0:0", + ): + frames.append(frame) + + id_frames = [f for f in frames if f.startswith("id: ")] + assert len(id_frames) == 1 + assert id_frames[0].startswith("id: 1000-0:1\n") + assert '"delta":"hello"' in id_frames[0] diff --git a/tracecat/agent/adapter/vercel.py b/tracecat/agent/adapter/vercel.py index 879efa33fa..6c6cdd65d7 100644 --- a/tracecat/agent/adapter/vercel.py +++ b/tracecat/agent/adapter/vercel.py @@ -71,6 +71,7 @@ StreamEvent, StreamKeepAlive, StreamMessage, + parse_vercel_frame_cursor, ) from tracecat.agent.types import UnifiedMessage from tracecat.chat.constants import ( @@ -736,9 +737,14 @@ class ErrorEventPayload: ) -def format_sse(data: VercelSSEPayload) -> str: - """Formats a dictionary into a Server-Sent Event string.""" - return f"data: {to_json(data).decode()}\n\n" +def format_sse(data: VercelSSEPayload, sse_id: str | None = None) -> str: + """Formats a payload into a Server-Sent Event string. + + When ``sse_id`` is given, emit an ``id:`` line so the browser records it as + the Last-Event-ID for reconnect. + """ + prefix = f"id: {sse_id}\n" if sse_id else "" + return f"{prefix}data: {to_json(data).decode()}\n\n" @dataclasses.dataclass @@ -760,7 +766,7 @@ class VercelStreamContext: consistent start/delta/end sequences required by the Vercel protocol. """ - message_id: str + message_id: str | None # Active parts keyed by event index -> maintains per-part lifecycle state. part_states: dict[int, _PartState] = dataclasses.field(default_factory=dict) tool_finished: dict[str, bool] = dataclasses.field(default_factory=dict) @@ -1796,24 +1802,58 @@ def convert_chat_messages_to_ui( return UIMessagesTA.validate_python(raw_messages) -async def sse_vercel(events: AsyncIterable[StreamEvent]) -> AsyncIterable[str]: - """Stream Redis events as Vercel AI SDK frames without persisting adapter output.""" +async def sse_vercel( + events: AsyncIterable[StreamEvent], + *, + message_id: str | None, + resume_from: str | None = None, +) -> AsyncIterable[str]: + """Stream Redis events as Vercel AI SDK frames without persisting adapter output. + + ``message_id`` is the stable assistant-bubble id (``session_id:curr_run_id``) + so reconnects within a turn resume the same bubble instead of spawning a new + one. It is omitted for terminal reconnects where no run id can be resolved. + Each Redis entry can fan out to many Vercel frames, so frames carry a + composite ``id: {redis_id}:{frame_index}`` that the browser replays from. + """ - message_id = f"msg_{uuid.uuid4().hex}" context = VercelStreamContext(message_id=message_id) + resume_cursor = parse_vercel_frame_cursor(resume_from) + + # Composite-id state: the Redis id of the entry currently fanning out, and a + # per-entry frame counter. format_sse omits id: when redis_id is None. + redis_id: str | None = None + frame_index = 0 + + def emit(payload: VercelSSEPayload) -> str | None: + nonlocal frame_index + sse_id = f"{redis_id}:{frame_index}" if redis_id else None + current_frame_index = frame_index + frame_index += 1 + if ( + resume_cursor is not None + and redis_id == resume_cursor.redis_id + and current_frame_index <= resume_cursor.frame_index + ): + return None + return format_sse(payload, sse_id) try: # 1. Start of the message stream - yield format_sse(StartEventPayload(messageId=message_id)) + if message_id is not None: + yield format_sse(StartEventPayload(messageId=message_id)) # 2. Process events from Redis stream async for stream_event in events: match stream_event: - case StreamDelta(event=agent_event): + case StreamDelta(id=delta_id, event=agent_event): + redis_id, frame_index = delta_id, 0 # Process agent stream events (PartStartEvent, PartDeltaEvent, etc.) async for msg in context.handle_event(agent_event): - yield format_sse(msg) - case StreamMessage(message=message): + if frame := emit(msg): + yield frame + case StreamMessage(id=message_redis_id, message=message): + redis_id, frame_index = message_redis_id, 0 if approval_payload := _extract_approval_payload_from_message( message ): @@ -1836,7 +1876,8 @@ async def sse_vercel(events: AsyncIterable[StreamEvent]) -> AsyncIterable[str]: ) in context.collect_current_part_end_events( index=index ): - yield format_sse(end_evt) + if frame := emit(end_evt): + yield frame except Exception: # Best-effort only; do not abort streaming on cache/finalize errors pass @@ -1847,7 +1888,8 @@ async def sse_vercel(events: AsyncIterable[StreamEvent]) -> AsyncIterable[str]: ) ) for data_event in context.flush_data_events(): - yield format_sse(data_event) + if frame := emit(data_event): + yield frame continue case StreamKeepAlive(): yield StreamKeepAlive.sse() diff --git a/tracecat/agent/executor/loopback.py b/tracecat/agent/executor/loopback.py index ce6d823ea8..4afb9efe67 100644 --- a/tracecat/agent/executor/loopback.py +++ b/tracecat/agent/executor/loopback.py @@ -248,6 +248,9 @@ def __init__(self, input: LoopbackInput) -> None: self._stream_sink: LoopbackEventSink | None = None self._result = LoopbackResult(success=False) self._sdk_session_id: str | None = None # Track SDK session ID for this run + self._curr_run_id: uuid.UUID | None = ( + None # Active run; stamped on history rows + ) self._stream_done_emitted: bool = False # Dedupe flag for stream.done() self._interrupt_notice_emitted: bool = False # Track which session lines have been persisted to avoid duplicates @@ -872,13 +875,17 @@ async def _persist_session_line( ) result = await session.execute(stmt) agent_session = result.scalar_one_or_none() - if agent_session and agent_session.sdk_session_id is None: - agent_session.sdk_session_id = sdk_session_id - logger.info( - "Updated AgentSession with sdk_session_id", - session_id=self.input.session_id, - sdk_session_id=sdk_session_id, - ) + if agent_session: + # Stamp the active run on every history row so mid-turn DB + # loads can hide this turn's partial rows (Redis owns them). + self._curr_run_id = agent_session.curr_run_id + if agent_session.sdk_session_id is None: + agent_session.sdk_session_id = sdk_session_id + logger.info( + "Updated AgentSession with sdk_session_id", + session_id=self.input.session_id, + sdk_session_id=sdk_session_id, + ) # Use explicit internal flag from runtime, not content-based heuristics kind: SessionLineKind = "internal" if internal else "chat-message" @@ -896,6 +903,7 @@ async def _persist_session_line( workspace_id=self.input.workspace_id, content=_session_line_db_content(line_data), kind=kind, + curr_run_id=self._curr_run_id, ) session.add(history_entry) await session.commit() diff --git a/tracecat/agent/session/router.py b/tracecat/agent/session/router.py index 3623da7b2b..26e21c22f7 100644 --- a/tracecat/agent/session/router.py +++ b/tracecat/agent/session/router.py @@ -18,12 +18,13 @@ AgentSessionRead, AgentSessionReadVercel, AgentSessionReadWithMessages, + AgentSessionStatusRead, AgentSessionUpdate, ) from tracecat.agent.session.service import AgentSessionService from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.stream.connector import AgentStream -from tracecat.agent.stream.events import StreamFormat +from tracecat.agent.stream.events import StreamFormat, parse_vercel_frame_cursor from tracecat.agent.subagents import ResolvedAgentsConfig from tracecat.auth.dependencies import WorkspaceUserRouteRole from tracecat.authz.controls import require_scope @@ -41,6 +42,21 @@ router = APIRouter(prefix="/agent/sessions", tags=["agent-sessions"]) +def _bubble_id(session_id: uuid.UUID, curr_run_id: uuid.UUID | None) -> str | None: + """Stable assistant-bubble id for a turn, if the turn is known.""" + return f"{session_id}:{curr_run_id}" if curr_run_id else None + + +def _redis_id_lt(a: str, b: str) -> bool: + """Order Redis stream ids ("-") as (ms, seq) tuples.""" + + def parts(rid: str) -> tuple[int, int]: + ms, _, seq = rid.partition("-") + return int(ms), int(seq or 0) + + return parts(a) < parts(b) + + @router.post("") @require_scope("agent:execute") async def create_session( @@ -236,6 +252,36 @@ async def get_session_vercel( ) +@router.get("/{session_id}/status") +@require_scope("agent:read") +async def get_session_status( + session_id: uuid.UUID, + role: WorkspaceUserRouteRole, + session: AsyncDBSession, +) -> AgentSessionStatusRead: + """Lifecycle status for polling without loading message history.""" + svc = AgentSessionService(session, role) + turn_state = await svc.get_turn_state(session_id) + agent_session = turn_state.session + if agent_session is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Session not found", + ) + prompt: str | None = None + if ( + turn_state.turn_status + in {AgentSessionStatus.RUNNING, AgentSessionStatus.WAITING_FOR_APPROVAL} + and turn_state.curr_run_id is not None + ): + prompt = await svc.get_active_run_prompt(session_id, turn_state.curr_run_id) + return AgentSessionStatusRead( + turn_status=turn_state.turn_status, + curr_run_id=turn_state.curr_run_id, + prompt=prompt, + ) + + @router.patch("/{session_id}") @require_scope("agent:execute") async def update_session( @@ -385,6 +431,12 @@ async def send_message( ) raise + # run_turn set curr_run_id; build a bubble id stable for this turn. + updated = await svc.get_session(session_id) + message_id = _bubble_id( + session_id, updated.curr_run_id if updated else None + ) + logger.info( "Starting Vercel streaming session", session_id=session_id, @@ -393,7 +445,12 @@ async def send_message( # Create stream and return with Vercel format return StreamingResponse( - stream.sse(http_request.is_disconnected, last_id=start_id, format="vercel"), + stream.sse( + http_request.is_disconnected, + last_id=start_id, + format="vercel", + message_id=message_id, + ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", @@ -456,25 +513,24 @@ async def stream_session_events( detail="Workspace access required", ) - # Try to get last_stream_id from session, but don't fail if session doesn't exist yet. - # This handles the race condition where frontend connects before session is created. - last_stream_id: str | None = None - async with AgentSessionService.with_session(role=role) as svc: - agent_session = await svc.get_session(session_id) - if agent_session is not None: - last_stream_id = agent_session.last_stream_id - + # Don't fail if the session doesn't exist yet: the frontend can connect + # before the session row is created (handled by the 204 below). last_event_id = request.headers.get("Last-Event-ID") - if last_stream_id is None and not last_event_id: + async with AgentSessionService.with_session(role=role) as svc: + turn_state = await svc.get_turn_state(session_id) + curr_run_id = turn_state.curr_run_id + if curr_run_id is None and last_event_id: + curr_run_id = await svc.get_latest_history_run_id(session_id) + + is_stream_attachable = turn_state.is_stream_attachable + # Nothing live to attach to and no client cursor to resume -> let the client + # fall back to the persisted DB history. + if not is_stream_attachable and not last_event_id: return Response(status_code=status.HTTP_204_NO_CONTENT) - start_id = last_event_id or last_stream_id or "0-0" - logger.info( - "Starting session stream", - last_id=start_id, - session_id=session_id, - ) stream = await AgentStream.new(session_id, workspace_id) + message_id = _bubble_id(session_id, curr_run_id) + headers = { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", @@ -484,8 +540,55 @@ async def stream_session_events( } if format == "vercel": headers["x-vercel-ai-ui-message-stream"] = "v1" + + # Browser owns the cursor: no Last-Event-ID -> replay 0-0; has one -> resume + # after it. Readers never read last_stream_id (producer/lifecycle-only). + start_id = "0-0" + resume_from: str | None = None + if last_event_id: + cursor = parse_vercel_frame_cursor(last_event_id) + requested = cursor.redis_id if cursor else last_event_id.split(":", 1)[0] + min_id = await stream.min_entry_id() + if min_id is None or _redis_id_lt(requested, min_id): + # Cursor predates the live buffer (maxlen/TTL eviction). While running, + # replay from the start. Otherwise there is nothing live, so emit a + # finishing stream that ends cleanly -> client refetches DB history. + if not is_stream_attachable: + return StreamingResponse( + stream.finished_sse(format=format, message_id=message_id), + media_type="text/event-stream", + headers=headers, + ) + # else start_id stays "0-0" + else: + start_id = requested + resume_from = last_event_id if cursor else None + + # Terminal reconnects can outlive curr_run_id and, in edge cases, fail to + # resolve a persisted run id. Do not invent a session-only assistant id: + # finish cleanly so the client refetches DB history without rendering an + # empty synthetic bubble. + if not is_stream_attachable and message_id is None: + return StreamingResponse( + stream.finished_sse(format=format, message_id=None), + media_type="text/event-stream", + headers=headers, + ) + + logger.info( + "Starting session stream", + last_id=start_id, + session_id=session_id, + ) + return StreamingResponse( - stream.sse(request.is_disconnected, last_id=start_id, format=format), + stream.sse( + request.is_disconnected, + last_id=start_id, + format=format, + message_id=message_id, + resume_from=resume_from, + ), media_type="text/event-stream", headers=headers, ) diff --git a/tracecat/agent/session/schemas.py b/tracecat/agent/session/schemas.py index 30163c3efc..fb6799f814 100644 --- a/tracecat/agent/session/schemas.py +++ b/tracecat/agent/session/schemas.py @@ -155,6 +155,22 @@ class AgentSessionReadVercel(AgentSessionRead): ) +class AgentSessionStatusRead(BaseModel): + """Lightweight session lifecycle status for cheap polling. + + Clients poll this (instead of the full message history) to learn when a turn + starts elsewhere and attach to the live stream. + """ + + turn_status: AgentSessionStatus = Field(default=AgentSessionStatus.IDLE) + curr_run_id: uuid.UUID | None = None + prompt: str | None = Field( + default=None, + description="Human prompt that started the active run, for observer " + "clients to render (user messages cannot stream over the Vercel protocol).", + ) + + class AgentSessionForkRequest(BaseModel): """Request schema for forking an agent session.""" diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index 3b9ca1bae7..6fa20466bb 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -60,7 +60,11 @@ AgentSessionUpdate, ) from tracecat.agent.session.title_generator import generate_session_title -from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus +from tracecat.agent.session.types import ( + AgentSessionEntity, + AgentSessionStatus, + AgentWorkflowTurnStatus, +) from tracecat.agent.subagents import ( ResolvedAgentsConfig, ) @@ -115,6 +119,18 @@ WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.CONTINUED_AS_NEW, ) +# Lifecycle source of truth: +# - curr_run_id + Temporal describe decides whether a turn is actually active. +# - AgentSession.status is a DB projection for cheap reads and durable product +# terminal state; do not use it alone for correctness gates. +TERMINAL_FAILURE_WORKFLOW_STATUSES = { + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.TIMED_OUT, +} +TERMINAL_STOPPED_WORKFLOW_STATUSES = { + WorkflowExecutionStatus.CANCELED, + WorkflowExecutionStatus.TERMINATED, +} @dataclass @@ -126,6 +142,22 @@ class SessionHistoryData: is_fork: bool = False # If True, SDK should use fork_session=True +@dataclass(frozen=True) +class AgentTurnState: + """Authoritative live turn state, with DB projection as fallback.""" + + session: AgentSession | None + turn_status: AgentSessionStatus + curr_run_id: uuid.UUID | None + + @property + def is_stream_attachable(self) -> bool: + return ( + self.turn_status is AgentSessionStatus.RUNNING + and self.curr_run_id is not None + ) + + class AgentSessionService(BaseWorkspaceService): """Service for managing agent sessions and history.""" @@ -594,38 +626,202 @@ async def set_session_status( await self.session.commit() return bool(cast(CursorResult[Any], result).rowcount) - async def _has_active_agent_turn(self, agent_session: AgentSession) -> bool: - """Return whether DB/Temporal state should block a new turn.""" + async def _describe_curr_run( + self, agent_session: AgentSession + ) -> WorkflowExecutionStatus | None: + """Describe the Temporal run pointed to by the session, if any.""" if agent_session.curr_run_id is None: - return agent_session.status in ACTIVE_AGENT_SESSION_STATUSES + return None + + client = await get_temporal_client() + workflow_id = AgentWorkflowID(agent_session.curr_run_id) + handle = client.get_workflow_handle(str(workflow_id)) + description = await handle.describe() + return description.status + + @staticmethod + def _terminal_status_for_workflow_status( + workflow_status: WorkflowExecutionStatus, + ) -> AgentSessionStatus: + """Map Temporal terminal state to the product status shown in the UI.""" + if workflow_status in TERMINAL_FAILURE_WORKFLOW_STATUSES: + return AgentSessionStatus.FAILED + if workflow_status in TERMINAL_STOPPED_WORKFLOW_STATUSES: + return AgentSessionStatus.STOPPED + return AgentSessionStatus.IDLE + + @staticmethod + def _status_for_workflow_turn_status( + workflow_turn_status: AgentWorkflowTurnStatus, + ) -> AgentSessionStatus: + """Return the workflow-owned status using the DB status vocabulary.""" + return AgentSessionStatus(workflow_turn_status) + + async def _query_curr_run_turn_status( + self, agent_session: AgentSession + ) -> AgentWorkflowTurnStatus: + """Query the live workflow phase for the session's current run.""" + from tracecat_ee.agent.workflows.durable import DurableAgentWorkflow + + if agent_session.curr_run_id is None: + raise ValueError("Agent session has no current run") client = await get_temporal_client() workflow_id = AgentWorkflowID(agent_session.curr_run_id) + handle = client.get_workflow_handle_for( + DurableAgentWorkflow.run, + str(workflow_id), + ) + return await handle.query(DurableAgentWorkflow.get_turn_status) + + async def _repair_terminal_projection( + self, + agent_session: AgentSession, + workflow_status: WorkflowExecutionStatus, + ) -> AgentSessionStatus: + """Cache terminal Temporal state back to the DB projection.""" + turn_status = self._terminal_status_for_workflow_status(workflow_status) + agent_session.status = turn_status.value + agent_session.curr_run_id = None + await self.session.commit() + return turn_status + + async def get_turn_state(self, session_id: uuid.UUID) -> AgentTurnState: + """Return Temporal-owned turn state for live decisions. + + The DB row locates the current workflow run and caches status. Temporal + owns whether that run is live and whether it is waiting for approval. + """ + agent_session = await self.get_session(session_id) + if agent_session is None: + return AgentTurnState( + session=None, + turn_status=AgentSessionStatus.IDLE, + curr_run_id=None, + ) + + projected_status = AgentSessionStatus(agent_session.status) + if agent_session.curr_run_id is None: + return AgentTurnState( + session=agent_session, + turn_status=projected_status, + curr_run_id=None, + ) + + try: + workflow_status = await self._describe_curr_run(agent_session) + except Exception as exc: + logger.warning( + "Failed to describe agent workflow; using DB projection", + session_id=str(session_id), + run_id=str(agent_session.curr_run_id), + error=str(exc), + ) + return AgentTurnState( + session=agent_session, + turn_status=projected_status, + curr_run_id=agent_session.curr_run_id, + ) + + if workflow_status in ACTIVE_WORKFLOW_STATUSES: + try: + workflow_turn_status = await self._query_curr_run_turn_status( + agent_session + ) + turn_status = self._status_for_workflow_turn_status( + workflow_turn_status + ) + except Exception as exc: + logger.warning( + "Failed to query agent workflow phase; using DB projection", + session_id=str(session_id), + run_id=str(agent_session.curr_run_id), + error=str(exc), + ) + turn_status = projected_status + + return AgentTurnState( + session=agent_session, + turn_status=turn_status, + curr_run_id=agent_session.curr_run_id, + ) + + if workflow_status is None: + return AgentTurnState( + session=agent_session, + turn_status=projected_status, + curr_run_id=agent_session.curr_run_id, + ) + + turn_status = await self._repair_terminal_projection( + agent_session, + workflow_status, + ) + return AgentTurnState( + session=agent_session, + turn_status=turn_status, + curr_run_id=None, + ) + + async def _has_active_agent_turn(self, agent_session: AgentSession) -> bool: + """Return whether Temporal/projected state should block a new turn.""" + if agent_session.curr_run_id is None: + return agent_session.status in ACTIVE_AGENT_SESSION_STATUSES + try: - handle = client.get_workflow_handle(str(workflow_id)) - description = await handle.describe() + workflow_status = await self._describe_curr_run(agent_session) except Exception: return True - if description.status in ACTIVE_WORKFLOW_STATUSES: + if workflow_status in ACTIVE_WORKFLOW_STATUSES: return True - agent_session.status = ( - AgentSessionStatus.FAILED.value - if description.status - in {WorkflowExecutionStatus.FAILED, WorkflowExecutionStatus.TIMED_OUT} - else AgentSessionStatus.STOPPED.value - if description.status - in { - WorkflowExecutionStatus.CANCELED, - WorkflowExecutionStatus.TERMINATED, - } - else AgentSessionStatus.IDLE.value - ) - agent_session.curr_run_id = None - await self.session.commit() + if workflow_status is not None: + await self._repair_terminal_projection(agent_session, workflow_status) return False + async def get_active_run_prompt( + self, session_id: uuid.UUID, run_id: uuid.UUID + ) -> str | None: + """Return the human prompt text that started the given run, if any. + + The prompt is the run's first user row whose message content is a plain + string (excludes leading queue-operation rows and tool_result arrays). + Used so observer clients can show the prompt while the assistant streams + (the Vercel stream protocol cannot carry user messages). + """ + content = AgentSessionHistory.content + stmt = ( + select(content["message"]["content"].astext) + .where( + AgentSessionHistory.session_id == session_id, + AgentSessionHistory.curr_run_id == run_id, + content["type"].astext == "user", + func.jsonb_typeof(content["message"]["content"]) == "string", + ) + .order_by(AgentSessionHistory.surrogate_id) + .limit(1) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_latest_history_run_id( + self, session_id: uuid.UUID + ) -> uuid.UUID | None: + """Return the newest stamped history run id for terminal stream reconnects.""" + stmt = ( + select(AgentSessionHistory.curr_run_id) + .where( + AgentSessionHistory.session_id == session_id, + AgentSessionHistory.workspace_id == self.workspace_id, + AgentSessionHistory.curr_run_id.is_not(None), + ) + .order_by(AgentSessionHistory.surrogate_id.desc()) + .limit(1) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + async def request_cancel( self, session_id: uuid.UUID, @@ -642,19 +838,29 @@ async def request_cancel( raise TracecatNotFoundError(f"Session with ID {session_id} not found") current_status = AgentSessionStatus(agent_session.status) - if current_status not in { - AgentSessionStatus.RUNNING, - AgentSessionStatus.WAITING_FOR_APPROVAL, - }: + curr_run_id = agent_session.curr_run_id + if curr_run_id is None: + raise TracecatConflictError("Agent session has no active workflow run") + + try: + workflow_status = await self._describe_curr_run(agent_session) + except Exception as e: raise TracecatConflictError( "Agent session does not have an active turn", detail={"status": current_status.value}, + ) from e + if workflow_status not in ACTIVE_WORKFLOW_STATUSES: + if workflow_status is not None: + agent_session.status = self._terminal_status_for_workflow_status( + workflow_status + ).value + agent_session.curr_run_id = None + await self.session.commit() + raise TracecatConflictError( + "Agent session does not have an active turn", + detail={"status": AgentSessionStatus(agent_session.status).value}, ) - curr_run_id = agent_session.curr_run_id - if curr_run_id is None: - raise TracecatConflictError("Agent session has no active workflow run") - client = await get_temporal_client() workflow_id = AgentWorkflowID(curr_run_id) handle = client.get_workflow_handle_for( @@ -1664,6 +1870,18 @@ async def list_messages( .where(AgentSessionHistory.session_id.in_(session_ids)) .order_by(AgentSessionHistory.surrogate_id) ) + # While a turn is running, the active run's partial rows live in Redis; + # hide them here so the live assistant has exactly one source (no dedupe). + if ( + getattr(agent_session, "status", None) == AgentSessionStatus.RUNNING.value + and getattr(agent_session, "curr_run_id", None) is not None + ): + all_history_stmt = all_history_stmt.where( + or_( + AgentSessionHistory.curr_run_id.is_(None), + AgentSessionHistory.curr_run_id != agent_session.curr_run_id, + ) + ) all_history_result = await self.session.execute(all_history_stmt) all_entries = list(all_history_result.scalars().all()) diff --git a/tracecat/agent/session/types.py b/tracecat/agent/session/types.py index 2bae61fe58..5468ebfb15 100644 --- a/tracecat/agent/session/types.py +++ b/tracecat/agent/session/types.py @@ -37,3 +37,10 @@ class AgentSessionStatus(StrEnum): type AgentCancelReason = Literal["user_cancel", "worker_drain"] +type AgentWorkflowTurnStatus = Literal[ + "idle", + "running", + "waiting_for_approval", + "stopped", + "failed", +] diff --git a/tracecat/agent/stream/connector.py b/tracecat/agent/stream/connector.py index 5f05cf89e2..61c6dd74ee 100644 --- a/tracecat/agent/stream/connector.py +++ b/tracecat/agent/stream/connector.py @@ -2,7 +2,14 @@ import asyncio import uuid -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Mapping, + Sequence, +) from time import monotonic from typing import Any @@ -21,6 +28,7 @@ StreamKeepAlive, StreamMessage, UnifiedStreamEventTA, + parse_vercel_frame_cursor, ) from tracecat.agent.types import ModelMessageTA, StreamKey from tracecat.chat import tokens @@ -93,6 +101,15 @@ async def _expire_completed_stream(self) -> None: error=str(exc), ) + async def min_entry_id(self) -> str | None: + """Oldest id still in the live buffer, or None if empty/evicted. + + Used for reconnect gap detection: a client cursor older than this was + trimmed (maxlen) or TTL-evicted, so it cannot be resumed. + """ + entries = await self.client.xrange(self._stream_key, count=1) + return entries[0][0] if entries else None + async def _set_last_stream_id(self, last_stream_id: str | None) -> None: """Update last stream ID for reconnection support.""" @@ -101,7 +118,11 @@ async def _set_last_stream_id(self, last_stream_id: str | None) -> None: await session_svc.update_last_stream_id(agent_session, last_stream_id) async def _stream_events( - self, stop_condition: Callable[[], Awaitable[bool]], last_id: str + self, + stop_condition: Callable[[], Awaitable[bool]], + last_id: str, + *, + include_last_id: bool = False, ) -> AsyncIterator[StreamEvent]: """Stream events from Redis until a stop condition is met. @@ -120,7 +141,7 @@ async def _stream_events( StreamError (error events), or StreamEnd (end-of-stream marker). Note: - - Periodically updates the chat's last_stream_id for reconnection support + - Read-only: never writes last_stream_id (browser owns the cursor) - Implements exponential backoff on errors (1s sleep) - Blocks for up to 1 second waiting for new messages - Processes up to 100 messages per read operation @@ -128,7 +149,56 @@ async def _stream_events( current_id = last_id last_keepalive = monotonic() stream_completed = False + + async def parse_stream_messages( + messages: Sequence[tuple[str, Mapping[str, str]]], + ) -> AsyncIterator[StreamEvent]: + nonlocal current_id, stream_completed + for msg_id, fields in messages: + data = orjson.loads(fields[tokens.DATA_KEY]) + current_id = msg_id + match data: + case {tokens.END_TOKEN: tokens.END_TOKEN_VALUE}: + stream_completed = True + yield StreamEnd(id=msg_id) + case {"event_kind": _}: + legacy_event = AgentStreamEventTA.validate_python(data) + yield StreamDelta(id=msg_id, event=legacy_event) + case {"type": _}: + unified_event = UnifiedStreamEventTA.validate_python(data) + yield StreamDelta(id=msg_id, event=unified_event) + case {"kind": "error", "error": error_message}: + logger.warning( + "Stream error received", + error=error_message, + message_id=msg_id, + ) + yield StreamError(error=error_message) + case {"kind": _}: + message = ModelMessageTA.validate_python(data) + yield StreamMessage(id=msg_id, message=message) + case _: + logger.warning( + "Invalid stream message", + error="Unexpected payload", + message_id=msg_id, + ) + try: + if include_last_id and current_id != "0-0": + try: + entries = await self.client.xrange( + self._stream_key, + min_id=current_id, + max_id=current_id, + count=1, + ) + async for event in parse_stream_messages(entries): + yield event + except Exception as e: + logger.error("Error reading Redis cursor entry", error=str(e)) + yield StreamError(error="Stream read error") + while not await stop_condition(): try: if result := await self.client.xread( @@ -138,44 +208,8 @@ async def _stream_events( ): last_keepalive = monotonic() for _stream_name, messages in result: - for msg_id, fields in messages: - data = orjson.loads(fields[tokens.DATA_KEY]) - current_id = msg_id - match data: - case {tokens.END_TOKEN: tokens.END_TOKEN_VALUE}: - stream_completed = True - yield StreamEnd(id=msg_id) - case {"event_kind": _}: - legacy_event = ( - AgentStreamEventTA.validate_python(data) - ) - yield StreamDelta(id=msg_id, event=legacy_event) - case {"type": _}: - unified_event = ( - UnifiedStreamEventTA.validate_python(data) - ) - yield StreamDelta( - id=msg_id, event=unified_event - ) - case {"kind": "error", "error": error_message}: - logger.warning( - "Stream error received", - error=error_message, - message_id=msg_id, - ) - yield StreamError(error=error_message) - case {"kind": _}: - message = ModelMessageTA.validate_python(data) - yield StreamMessage(id=msg_id, message=message) - case _: - logger.warning( - "Invalid stream message", - error="Unexpected payload", - message_id=msg_id, - ) - - if not stream_completed: - await self._set_last_stream_id(current_id) + async for event in parse_stream_messages(messages): + yield event now = monotonic() if now - last_keepalive >= self.KEEPALIVE_INTERVAL_SECONDS: @@ -194,17 +228,24 @@ async def _stream_events( yield StreamError(error="Fatal stream error") finally: logger.info("Chat stream ended", stream_key=self._stream_key) + # Readers never write last_stream_id; the browser owns the reconnect + # cursor (Last-Event-ID). We only expire the buffer after terminal. if stream_completed: - await self._set_last_stream_id(None) await self._expire_completed_stream() - else: - await self._set_last_stream_id(current_id) async def stream_events( - self, stop_condition: Callable[[], Awaitable[bool]], last_id: str + self, + stop_condition: Callable[[], Awaitable[bool]], + last_id: str, + *, + include_last_id: bool = False, ) -> AsyncIterator[StreamEvent]: """Public stream-events iterator for external stream consumers.""" - async for event in self._stream_events(stop_condition, last_id): + async for event in self._stream_events( + stop_condition, + last_id, + include_last_id=include_last_id, + ): yield event def sse( @@ -212,17 +253,56 @@ def sse( stop_condition: Callable[[], Awaitable[bool]], last_id: str, format: StreamFormat, + *, + message_id: str | None, + resume_from: str | None = None, ) -> AsyncIterable[str]: + cursor = parse_vercel_frame_cursor(resume_from) match format: case "vercel": from tracecat.agent.adapter.vercel import sse_vercel - return sse_vercel(self.stream_events(stop_condition, last_id)) + return sse_vercel( + self.stream_events( + stop_condition, + last_id, + include_last_id=cursor is not None, + ), + message_id=message_id, + resume_from=resume_from, + ) case "basic": return self.simple_sse(stop_condition, last_id) case _: raise ValueError(f"Invalid format: {format}") + def finished_sse( + self, format: StreamFormat, *, message_id: str | None + ) -> AsyncIterable[str]: + """Emit an immediately-finishing stream (no live content). + + Used on reconnect when the cursor is stale and the turn is already + terminal: the client gets a clean finish and refetches DB history. + """ + + async def _empty() -> AsyncIterator[StreamEvent]: + return + yield # pragma: no cover - establishes async generator + + match format: + case "vercel": + from tracecat.agent.adapter.vercel import sse_vercel + + return sse_vercel(_empty(), message_id=message_id) + case "basic": + + async def _end() -> AsyncIterable[str]: + yield StreamEnd.sse() + + return _end() + case _: + raise ValueError(f"Invalid format: {format}") + async def simple_sse( self, stop_condition: Callable[[], Awaitable[bool]], last_id: str ) -> AsyncIterable[str]: diff --git a/tracecat/agent/stream/events.py b/tracecat/agent/stream/events.py index 1b106e73ff..02f16ea05f 100644 --- a/tracecat/agent/stream/events.py +++ b/tracecat/agent/stream/events.py @@ -13,6 +13,24 @@ AgentStreamEventTA: TypeAdapter[AgentStreamEvent] = TypeAdapter(AgentStreamEvent) +@dataclass(frozen=True, slots=True) +class VercelFrameCursor: + """Browser SSE cursor for a Vercel frame fanned out from one Redis entry.""" + + redis_id: str + frame_index: int + + +def parse_vercel_frame_cursor(event_id: str | None) -> VercelFrameCursor | None: + """Parse ``:`` cursors emitted by the Vercel adapter.""" + if not event_id: + return None + redis_id, separator, frame_index = event_id.rpartition(":") + if not separator or not redis_id or not frame_index.isdecimal(): + return None + return VercelFrameCursor(redis_id=redis_id, frame_index=int(frame_index)) + + @dataclass(slots=True, kw_only=True) class StreamDelta: """Container for Redis stream payloads and adapter errors.""" diff --git a/tracecat/db/models.py b/tracecat/db/models.py index 02b09d0972..78e4caeea7 100644 --- a/tracecat/db/models.py +++ b/tracecat/db/models.py @@ -2915,6 +2915,12 @@ class AgentSessionHistory(WorkspaceModel): index=True, doc="Message kind for filtering (chat-message, internal). Default to internal - only user/assistant messages explicitly marked visible.", ) + curr_run_id: Mapped[uuid.UUID | None] = mapped_column( + UUID, + nullable=True, + index=True, + doc="Workflow run that produced this row; used to hide active-turn rows mid-stream", + ) session: Mapped[AgentSession] = relationship( "AgentSession",