Skip to content

Commit b57c833

Browse files
committed
process_inserted_message() now does one branch fetch
1 parent 4acbfff commit b57c833

File tree

4 files changed

+331
-100
lines changed

4 files changed

+331
-100
lines changed

src/server/core/acontext_core/service/controller/message.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..data import message as MD
22
from ..data import learning_space as LS
33
from ...infra.db import DB_CLIENT
4+
from ...schema.orm import Message
45
from ...schema.session.task import TaskStatus
56
from ...schema.session.message import MessageBlob
67
from ...schema.utils import asUUID
@@ -19,11 +20,11 @@
1920
}
2021

2122

22-
def _claimable_branch_message_ids(message_rows: list[dict]) -> list[asUUID]:
23+
def _claimable_branch_message_ids(messages: list[Message]) -> list[asUUID]:
2324
return [
24-
row["id"]
25-
for row in message_rows
26-
if row["session_task_process_status"] in _CLAIMABLE_BRANCH_STATUSES
25+
message.id
26+
for message in messages
27+
if message.session_task_process_status in _CLAIMABLE_BRANCH_STATUSES
2728
]
2829

2930

@@ -50,15 +51,18 @@ async def process_inserted_message(
5051
wide = get_wide_event()
5152
disabled = await get_metrics(project_id, ExcessMetricTags.new_task_created)
5253
target_message_ids: list[asUUID] = []
54+
branch_messages: list[Message] = []
5355

5456
try:
5557
async with DB_CLIENT.get_session_context() as session:
56-
r = await MD.fetch_message_branch_path_rows(session, message_id, session_id)
57-
message_rows, eil = r.unpack()
58+
r = await MD.fetch_message_branch_path_messages(
59+
session, message_id, session_id
60+
)
61+
branch_messages, eil = r.unpack()
5862
if eil:
5963
return r
6064

61-
target_message_ids = _claimable_branch_message_ids(message_rows)
65+
target_message_ids = _claimable_branch_message_ids(branch_messages)
6266
wide["claimed_branch_message_count"] = len(target_message_ids)
6367

6468
if not target_message_ids:
@@ -78,25 +82,22 @@ async def process_inserted_message(
7882
session, target_message_ids, TaskStatus.RUNNING
7983
)
8084

81-
async with DB_CLIENT.get_session_context() as session:
82-
r = await MD.fetch_message_branch_path_data(
83-
session, message_id, session_id, user_kek=user_kek
84-
)
85-
messages, eil = r.unpack()
86-
if eil:
87-
await _try_rollback_to_failed(target_message_ids)
88-
return r
85+
r = await MD.hydrate_message_parts(branch_messages, user_kek=user_kek)
86+
messages, eil = r.unpack()
87+
if eil:
88+
await _try_rollback_to_failed(target_message_ids)
89+
return r
8990

90-
messages_data = [
91-
MessageBlob(
92-
message_id=m.id,
93-
parent_id=m.parent_id,
94-
role=m.role,
95-
parts=m.parts,
96-
task_id=m.task_id,
97-
)
98-
for m in messages
99-
]
91+
messages_data = [
92+
MessageBlob(
93+
message_id=m.id,
94+
parent_id=m.parent_id,
95+
role=m.role,
96+
parts=m.parts,
97+
task_id=m.task_id,
98+
)
99+
for m in messages
100+
]
100101

101102
async with DB_CLIENT.get_session_context() as session:
102103
r = await LS.get_learning_space_for_session(session, session_id)
@@ -139,6 +140,8 @@ async def process_inserted_message(
139140
wide["task_agent_outcome"] = "exception"
140141
await _try_rollback_to_failed(target_message_ids)
141142
raise
143+
144+
142145
async def process_session_pending_message(
143146
project_config: ProjectConfig,
144147
project_id: asUUID,

src/server/core/acontext_core/service/data/message.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,41 @@ async def fetch_messages_data_by_ids(
109109
f"Some messages({message_ids}) not found in database: {e}"
110110
)
111111

112-
if not ordered_messages:
112+
return await hydrate_message_parts(ordered_messages, user_kek=user_kek)
113+
114+
except Exception as e:
115+
return Result.reject(f"Error fetching messages by IDs {message_ids}: {e}")
116+
117+
118+
async def hydrate_message_parts(
119+
messages: List[Message],
120+
user_kek: bytes | None = None,
121+
) -> Result[List[Message]]:
122+
"""
123+
Load parts from S3 for already-fetched message rows.
124+
125+
Mutates each message in place by setting `message.parts`.
126+
"""
127+
try:
128+
if not messages:
113129
return Result.resolve([])
114130

115-
# Fetch parts concurrently for all messages
116131
parts_tasks = [
117132
_fetch_message_parts(message.parts_asset_meta, user_kek=user_kek)
118-
for message in ordered_messages
133+
for message in messages
119134
]
120135
parts_results = await asyncio.gather(*parts_tasks)
121136

122-
# Assign parts to messages
123-
for message, parts_result in zip(ordered_messages, parts_results):
137+
for message, parts_result in zip(messages, parts_results):
124138
d, eil = parts_result.unpack()
125139
if eil:
126140
message.parts = None
127141
continue
128142
message.parts = d
129143

130-
return Result.resolve(ordered_messages)
131-
144+
return Result.resolve(messages)
132145
except Exception as e:
133-
return Result.reject(f"Error fetching messages by IDs {message_ids}: {e}")
146+
return Result.reject(f"Error hydrating message parts: {e}")
134147

135148

136149
async def fetch_message_branch_path_ids(
@@ -226,6 +239,69 @@ async def fetch_message_branch_path_rows(
226239
)
227240

228241

242+
async def fetch_message_branch_path_messages(
243+
db_session: AsyncSession,
244+
message_id: asUUID,
245+
session_id: asUUID | None = None,
246+
) -> Result[List[Message]]:
247+
"""
248+
Fetch one message's branch path as ordered Message rows.
249+
250+
Uses a recursive CTE to walk parent_id upward in one query.
251+
"""
252+
try:
253+
query = text(
254+
"""
255+
WITH RECURSIVE message_path AS (
256+
SELECT id, parent_id, session_id, 0 AS depth
257+
FROM messages
258+
WHERE id = :message_id
259+
260+
UNION ALL
261+
262+
SELECT
263+
parent.id,
264+
parent.parent_id,
265+
parent.session_id,
266+
child.depth + 1 AS depth
267+
FROM messages AS parent
268+
JOIN message_path AS child
269+
ON parent.id = child.parent_id
270+
)
271+
SELECT m.*
272+
FROM message_path AS mp
273+
JOIN messages AS m ON m.id = mp.id
274+
ORDER BY mp.depth DESC, m.id ASC
275+
"""
276+
)
277+
result = await db_session.execute(
278+
select(Message).from_statement(query),
279+
{"message_id": message_id},
280+
)
281+
messages = list(result.scalars().all())
282+
283+
if not messages:
284+
return Result.reject(f"Message {message_id} doesn't exist")
285+
286+
path_session_ids = {message.session_id for message in messages}
287+
288+
if session_id is not None and path_session_ids != {session_id}:
289+
return Result.reject(
290+
f"Message {message_id} does not belong to session {session_id}"
291+
)
292+
293+
if len(path_session_ids) != 1:
294+
return Result.reject(
295+
f"Message {message_id} has an invalid cross-session parent chain"
296+
)
297+
298+
return Result.resolve(messages)
299+
except Exception as e:
300+
return Result.reject(
301+
f"Error fetching branch path messages for message {message_id}: {e}"
302+
)
303+
304+
229305
async def branch_pending_message_length(
230306
db_session: AsyncSession,
231307
message_id: asUUID,
@@ -267,13 +343,11 @@ async def fetch_message_branch_path_data(
267343
"""
268344
Fetch one message's branch path with parts loaded from S3.
269345
"""
270-
r = await fetch_message_branch_path_ids(db_session, message_id, session_id)
271-
message_ids, eil = r.unpack()
346+
r = await fetch_message_branch_path_messages(db_session, message_id, session_id)
347+
messages, eil = r.unpack()
272348
if eil:
273349
return Result.reject(str(eil))
274-
return await fetch_messages_data_by_ids(
275-
db_session, message_ids, user_kek=user_kek
276-
)
350+
return await hydrate_message_parts(messages, user_kek=user_kek)
277351

278352

279353
async def fetch_session_messages(

src/server/core/tests/service/test_message_data_branch_path.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import json
12
import pytest
23
from uuid import uuid4
4+
from unittest.mock import AsyncMock, patch
35

46
from acontext_core.schema.orm import Message, Project, Session
57
from acontext_core.service.data import message as MD
@@ -167,6 +169,136 @@ async def test_fetch_message_branch_path_rows_returns_statuses(db_client):
167169
]
168170

169171

172+
@pytest.mark.asyncio
173+
async def test_fetch_message_branch_path_messages_returns_ordered_messages(db_client):
174+
_, acontext_session = await _create_project_and_session(db_client)
175+
176+
async with db_client.get_session_context() as session:
177+
root = Message(
178+
session_id=acontext_session.id,
179+
role="user",
180+
parts_asset_meta={
181+
"bucket": "test-bucket",
182+
"s3_key": "parts/root.json",
183+
"etag": "etag-root",
184+
"sha256": "sha-root",
185+
"mime": "application/json",
186+
"size_b": 10,
187+
},
188+
session_task_process_status="success",
189+
)
190+
session.add(root)
191+
await session.flush()
192+
193+
child = Message(
194+
session_id=acontext_session.id,
195+
role="assistant",
196+
parts_asset_meta={
197+
"bucket": "test-bucket",
198+
"s3_key": "parts/child.json",
199+
"etag": "etag-child",
200+
"sha256": "sha-child",
201+
"mime": "application/json",
202+
"size_b": 11,
203+
},
204+
parent_id=root.id,
205+
session_task_process_status="pending",
206+
)
207+
session.add(child)
208+
await session.flush()
209+
210+
leaf = Message(
211+
session_id=acontext_session.id,
212+
role="user",
213+
parts_asset_meta={
214+
"bucket": "test-bucket",
215+
"s3_key": "parts/leaf.json",
216+
"etag": "etag-leaf",
217+
"sha256": "sha-leaf",
218+
"mime": "application/json",
219+
"size_b": 12,
220+
},
221+
parent_id=child.id,
222+
session_task_process_status="running",
223+
)
224+
session.add(leaf)
225+
await session.flush()
226+
227+
r = await MD.fetch_message_branch_path_messages(
228+
session, leaf.id, acontext_session.id
229+
)
230+
messages, eil = r.unpack()
231+
232+
assert eil is None
233+
assert [message.id for message in messages] == [root.id, child.id, leaf.id]
234+
assert [message.session_task_process_status for message in messages] == [
235+
"success",
236+
"pending",
237+
"running",
238+
]
239+
240+
241+
@pytest.mark.asyncio
242+
async def test_hydrate_message_parts_uses_existing_branch_message_order(db_client):
243+
_, acontext_session = await _create_project_and_session(db_client)
244+
245+
async with db_client.get_session_context() as session:
246+
root = Message(
247+
session_id=acontext_session.id,
248+
role="user",
249+
parts_asset_meta={
250+
"bucket": "test-bucket",
251+
"s3_key": "parts/root.json",
252+
"etag": "etag-root",
253+
"sha256": "sha-root",
254+
"mime": "application/json",
255+
"size_b": 10,
256+
},
257+
)
258+
session.add(root)
259+
await session.flush()
260+
261+
leaf = Message(
262+
session_id=acontext_session.id,
263+
role="assistant",
264+
parts_asset_meta={
265+
"bucket": "test-bucket",
266+
"s3_key": "parts/leaf.json",
267+
"etag": "etag-leaf",
268+
"sha256": "sha-leaf",
269+
"mime": "application/json",
270+
"size_b": 11,
271+
},
272+
parent_id=root.id,
273+
)
274+
session.add(leaf)
275+
await session.flush()
276+
277+
branch_result = await MD.fetch_message_branch_path_messages(
278+
session, leaf.id, acontext_session.id
279+
)
280+
messages, eil = branch_result.unpack()
281+
282+
assert eil is None
283+
284+
with patch(
285+
"acontext_core.service.data.message.S3_CLIENT.download_object",
286+
new_callable=AsyncMock,
287+
side_effect=[
288+
json.dumps([{"type": "text", "text": "root"}]).encode("utf-8"),
289+
json.dumps([{"type": "text", "text": "leaf"}]).encode("utf-8"),
290+
],
291+
) as mock_download:
292+
hydrated_result = await MD.hydrate_message_parts(messages)
293+
hydrated_messages, hydrate_error = hydrated_result.unpack()
294+
295+
assert hydrate_error is None
296+
assert [message.id for message in hydrated_messages] == [root.id, leaf.id]
297+
assert hydrated_messages[0].parts[0].text == "root"
298+
assert hydrated_messages[1].parts[0].text == "leaf"
299+
assert mock_download.await_count == 2
300+
301+
170302
@pytest.mark.asyncio
171303
async def test_branch_pending_message_length_ignores_sibling_branches(db_client):
172304
_, acontext_session = await _create_project_and_session(db_client)

0 commit comments

Comments
 (0)