Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ venv.bak/

.vscode
.idea
.cursor

# custom
*.pkl
Expand Down
247 changes: 233 additions & 14 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping
from ms_agent.memory.memory_manager import SharedMemoryManager
from ms_agent.rag.base import RAG
from ms_agent.session import ContextAssembler, SessionLog
from ms_agent.session.strategies import SummaryCompactor, ToolOutputPruner
from ms_agent.rag.utils import rag_mapping
from ms_agent.tools import ToolManager
from ms_agent.utils import async_retry, read_history, save_history
Expand Down Expand Up @@ -107,9 +109,11 @@ def __init__(
self.tool_manager: Optional[ToolManager] = None
self.memory_tools: List[Memory] = []
self.rag: Optional[RAG] = None
self.knowledge_search: Optional[SirschmunkSearch] = None
self.knowledge_search: Optional[SirchmunkSearch] = None
self.llm: Optional[LLM] = None
self.runtime: Optional[Runtime] = None
self.session_log: Optional[SessionLog] = None
self.context_assembler: Optional[ContextAssembler] = None
self.max_chat_round: int = 0
self.load_cache = kwargs.get('load_cache', False)
self.config.load_cache = self.load_cache
Expand Down Expand Up @@ -733,6 +737,11 @@ async def do_skill(self,
async def load_memory(self):
"""Initialize and append memory tool instances based on the configuration provided in the global config.

For ``unified_memory``, this also:
- Passes the agent's LLM instance to the orchestrator
- Registers the ``memory`` / ``memory_read`` tools into ToolManager
- Injects memory-usage guidance into the system prompt

Raises:
AssertionError: If a specified memory type in the config does not exist in memory_mapping.
"""
Expand All @@ -747,6 +756,40 @@ async def load_memory(self):
self.config, mem_instance_type)
self.memory_tools.append(shared_memory)

# Wire unified_memory into the tool system
if mem_instance_type == 'unified_memory':
await self._register_memory_tool(shared_memory)

async def _register_memory_tool(self, orchestrator):
"""Register the memory tool into ToolManager and inject prompt guidance."""
from ms_agent.memory.unified.memory_tool import MemoryTool, MEMORY_USAGE_PROMPT

if not hasattr(orchestrator, 'get_tool_schemas'):
return

# Pass LLM and session_log to orchestrator for consolidation / extraction
if self.llm is not None:
orchestrator.set_llm(self.llm)
orchestrator.init_update_queue()
if self.session_log is not None and hasattr(orchestrator, '_session_log'):
orchestrator._session_log = self.session_log

# Register memory tool into the agent's tool system
if self.tool_manager is not None:
mem_tool = MemoryTool(self.config, orchestrator)
self.tool_manager.register_tool(mem_tool)
await self.tool_manager.reindex_tool()
logger.info('[unified_memory] Memory tool registered')

# Inject usage guidance into system prompt
if hasattr(self.config, 'prompt') and hasattr(self.config.prompt, 'system'):
current_prompt = self.config.prompt.system or ''
if 'Long-term Memory' not in current_prompt:
OmegaConf.update(
self.config, 'prompt.system',
current_prompt + '\n\n' + MEMORY_USAGE_PROMPT,
merge=True)

async def prepare_rag(self):
"""Load and initialize the RAG component from the config."""
if hasattr(self.config, 'rag'):
Expand All @@ -770,19 +813,135 @@ async def prepare_knowledge_search(self):
self.config)

async def condense_memory(self, messages: List[Message]) -> List[Message]:
"""Inject long-term memory context into messages.

.. deprecated::
Historically this also ran context compressors. Compression is
now handled by :class:`ContextAssembler` before this method is
called. This method only performs memory *injection* (adding
``<long-term-memory>`` blocks, etc.).
"""
Update memory using the current conversation history.
for memory_tool in self.memory_tools:
messages = await memory_tool.run(messages)
return messages

Args:
messages (List[Message]): Current message history.
async def inject_memory(self, messages: List[Message]) -> List[Message]:
"""Inject long-term memory context into the message list.

Returns:
List[Message]: Possibly updated message history after memory refinement.
Unlike ``condense_memory`` this only runs ``unified_memory`` style
tools that *inject* context (MEMORY.md snapshot, facts, etc.) — it
never trims or compresses messages.
"""
for memory_tool in self.memory_tools:
messages = await memory_tool.run(messages)
return messages

def _init_session_log(self) -> None:
"""Create SessionLog and ContextAssembler if session logging is enabled."""
session_cfg = getattr(self.config, 'session_log', None)
enabled = getattr(session_cfg, 'enabled', True) if session_cfg else True
if not enabled:
return

session_dir = getattr(
session_cfg, 'dir', None
) if session_cfg else None
if session_dir is None:
session_dir = os.path.join(
getattr(self.config, 'output_dir', 'output'),
'sessions',
)

session_key = getattr(session_cfg, 'session_key', None) if session_cfg else None
self.session_log = SessionLog(session_dir, session_key=session_key)

compaction_cfg = getattr(self.config, 'compaction', None)
compaction_enabled = (
getattr(compaction_cfg, 'enabled', True) if compaction_cfg else True
)

if not compaction_enabled:
self.context_assembler = ContextAssembler(
session_log=self.session_log, strategies=[], config={},
)
return

strategies = self._build_compaction_strategies(compaction_cfg)
assembler_config = self._build_assembler_config(compaction_cfg, session_cfg)
flush_callback = self._make_memory_flush_callback()

self.context_assembler = ContextAssembler(
session_log=self.session_log,
strategies=strategies,
config=assembler_config,
memory_flush_callback=flush_callback,
)

def _build_compaction_strategies(self, compaction_cfg):
"""Build the strategy list from YAML ``compaction.strategies``."""
if compaction_cfg and hasattr(compaction_cfg, 'strategies'):
strategies = []
for s_cfg in compaction_cfg.strategies:
name = getattr(s_cfg, 'name', '')
if not getattr(s_cfg, 'enabled', True):
continue
if name == 'tool_output_pruner':
strategies.append(ToolOutputPruner())
elif name == 'summary_compactor':
strategies.append(SummaryCompactor(llm=self.llm))
else:
logger.warning(f"Unknown compaction strategy: {name}")
return strategies

return [ToolOutputPruner(), SummaryCompactor(llm=self.llm)]

def _build_assembler_config(self, compaction_cfg, session_cfg):
"""Merge compaction params from ``compaction`` and ``session_log``."""
config: Dict[str, Any] = {}

if session_cfg:
for key in ('context_limit', 'reserved_buffer', 'prune_protect'):
val = getattr(session_cfg, key, None)
if val is not None:
config[key] = val

if compaction_cfg:
for key in ('context_limit', 'reserved_buffer'):
val = getattr(compaction_cfg, key, None)
if val is not None:
config[key] = val
if hasattr(compaction_cfg, 'strategies'):
for s_cfg in compaction_cfg.strategies:
if getattr(s_cfg, 'name', '') == 'tool_output_pruner':
pp = getattr(s_cfg, 'prune_protect', None)
if pp is not None:
config['prune_protect'] = pp

config.setdefault('context_limit', 128000)
config.setdefault('reserved_buffer', 20000)
config.setdefault('prune_protect', 40000)
return config

def _make_memory_flush_callback(self):
"""Create a callback that flushes memory before context compaction."""
def _flush(discarded_messages):
for memory_tool in self.memory_tools:
orchestrator = memory_tool
if hasattr(orchestrator, 'flush'):
import asyncio
from ms_agent.llm.utils import Message as _Msg
msgs = [_Msg(
role=m.get('role', 'user'),
content=m.get('content', ''),
tool_calls=m.get('tool_calls'),
) for m in discarded_messages]
try:
loop = asyncio.get_running_loop()
loop.create_task(orchestrator.flush(msgs))
except RuntimeError:
asyncio.run(orchestrator.flush(msgs))
return _flush

def log_output(self, content: Union[str, list]):
"""
Log formatted output with a tag prefix.
Expand Down Expand Up @@ -1089,6 +1248,31 @@ def save_history(self, messages: List[Message], **kwargs):
save_history(
self.output_dir, task=self.tag, config=config, messages=messages)

@staticmethod
def _msg_to_dict(msg: Message) -> Dict[str, Any]:
"""Convert a Message to a plain dict for SessionLog.

Preserves ``prompt_tokens`` and ``completion_tokens`` individually
so that :class:`ContextAssembler` strategies can leverage API-reported
usage data for accurate overflow detection.
"""
d: Dict[str, Any] = {'role': msg.role, 'content': msg.content or ''}
if msg.tool_calls:
d['tool_calls'] = msg.tool_calls
if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
d['tool_call_id'] = msg.tool_call_id
if hasattr(msg, 'name') and msg.name:
d['name'] = msg.name
prompt_tokens = int(getattr(msg, 'prompt_tokens', 0) or 0)
completion_tokens = int(getattr(msg, 'completion_tokens', 0) or 0)
if prompt_tokens:
d['prompt_tokens'] = prompt_tokens
if completion_tokens:
d['completion_tokens'] = completion_tokens
if prompt_tokens or completion_tokens:
d['tokens'] = prompt_tokens + completion_tokens
return d

async def run_loop(self, messages: Union[List[Message], str],
**kwargs) -> AsyncGenerator[Any, Any]:
"""
Expand All @@ -1112,13 +1296,28 @@ async def run_loop(self, messages: Union[List[Message], str],
await self.load_memory()
await self.prepare_rag()
await self.prepare_knowledge_search()
self._init_session_log()
self.runtime.tag = self.tag

if messages is None:
messages = self.query

# Load history and restore state
self.config, self.runtime, messages = self.read_history(messages)
if self.session_log is not None:
restored = self.session_log.get_all_messages()
if restored and self.load_cache:
from ms_agent.llm.utils import Message as _Msg
messages = [_Msg(
role=m.get('role', 'user'),
content=m.get('content', ''),
tool_calls=m.get('tool_calls'),
) for m in restored]
Comment on lines +1310 to +1314
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When restoring messages from the session log, the tool_call_id and name fields are omitted. This will break the conversation history for tool-use interactions, as many LLM providers require these fields to correctly associate tool outputs with their corresponding calls. These fields should be included in the restoration logic, similar to how they are handled in _dicts_to_messages in the ContextAssembler.

                    messages = [_Msg(
                        role=m.get('role', 'user'),
                        content=m.get('content', ''),
                        tool_calls=m.get('tool_calls'),
                        tool_call_id=m.get('tool_call_id'),
                        name=m.get('name'),
                    ) for m in restored]

else:
self.config, self.runtime, messages = self.read_history(
messages)
else:
self.config, self.runtime, messages = self.read_history(
messages)

if self.runtime.round == 0:
# New task: create standardized messages first
Expand All @@ -1137,14 +1336,30 @@ async def run_loop(self, messages: Union[List[Message], str],
await self.do_rag(messages)
await self.on_task_begin(messages)

# Seed SessionLog with initial messages
if self.session_log is not None:
for msg in messages:
self.session_log.append(self._msg_to_dict(msg))

for message in messages:
if message.role != 'system':
self.log_output('[' + message.role + ']:')
self.log_output(message.content)
while not self.runtime.should_stop:
# Rebuild context view from SessionLog (non-destructive compression)
if self.context_assembler is not None and self.runtime.round > 0:
messages = self.context_assembler.assemble()

pre_step_len = len(messages)
async for messages in self.step(messages):
yield messages
self.runtime.round += 1

# Append new messages to SessionLog
if self.session_log is not None:
for msg in messages[pre_step_len:]:
self.session_log.append(self._msg_to_dict(msg))

# save memory and history
await self.add_memory(
messages, add_type='add_after_step', **kwargs)
Expand All @@ -1153,13 +1368,17 @@ async def run_loop(self, messages: Union[List[Message], str],
# +1 means the next round the assistant may give a conclusion
if self.runtime.round >= self.max_chat_round + 1:
if not self.runtime.should_stop:
messages.append(
Message(
role='assistant',
content=
f'Task {messages[1].content} was cutted off, because '
f'max round({self.max_chat_round}) exceeded.',
))
cutoff_msg = Message(
role='assistant',
content=
f'Task {messages[1].content} was cutted off, because '
f'max round({self.max_chat_round}) exceeded.',
)
messages.append(cutoff_msg)
if self.session_log is not None:
self.session_log.append(
self._msg_to_dict(cutoff_msg))
self.save_history(messages)
self.runtime.should_stop = True
yield messages

Expand Down
2 changes: 2 additions & 0 deletions ms_agent/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def generate(self,
args = self.args.copy()
args.update(kwargs)
stream = args.get('stream', False)
if not stream:
args.pop('stream_options', None)

args = {key: value for key, value in args.items() if key in parameters}
completion = self._call_llm(messages, self.format_tools(tools), **args)
Expand Down
46 changes: 46 additions & 0 deletions ms_agent/memory/unified/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Unified Memory — a protocol-driven, backend-pluggable memory system.

Register ``unified_memory`` in ``memory_mapping`` to use this system.

Architecture::

Orchestrator --delegates-to--> MemoryBackend (Protocol)
|
+----------------+----------------+
v v v
FileBasedBackend ReMeBackend MempalaceBackend ...
(built-in) (adapter) (adapter)

Switch backends via YAML config::

storage:
backend: "file" # or "reme", "mempalace", "mem0", "byterover", "supermemory"
"""
from .config import MemoryConfig
from .orchestrator import MemoryOrchestrator
from .protocols import (
BaseMemoryBackend,
MemoryBackend,
MemoryEntry,
MemoryEvent,
MemoryEventBus,
MemoryNamespace,
)
from .registry import backend_registry

# Import backends so they self-register
from .backends import file_based as _fb # noqa: F401

__all__ = [
"MemoryConfig",
"MemoryOrchestrator",
# Layer 2 — primary contract
"MemoryBackend",
"BaseMemoryBackend",
"backend_registry",
# Layer 1 — data structures
"MemoryEntry",
"MemoryEvent",
"MemoryEventBus",
"MemoryNamespace",
]
Loading
Loading