From f63ba5d390195ce9ae4d299a7aae0ee25b57e70e Mon Sep 17 00:00:00 2001 From: Kovbo Date: Thu, 4 Jun 2026 01:14:24 +0000 Subject: [PATCH] fix: honor langgraph chat model kwargs Closes #474 --- src/art/langgraph/llm_wrapper.py | 117 ++++++++--- src/art/rewards/ruler.py | 3 + tests/unit/test_langgraph_llm_wrapper.py | 242 +++++++++++++++++++++++ 3 files changed, 331 insertions(+), 31 deletions(-) create mode 100644 tests/unit/test_langgraph_llm_wrapper.py diff --git a/src/art/langgraph/llm_wrapper.py b/src/art/langgraph/llm_wrapper.py index 36b5314b3..8a40a7097 100644 --- a/src/art/langgraph/llm_wrapper.py +++ b/src/art/langgraph/llm_wrapper.py @@ -1,6 +1,7 @@ """LLM wrapper with logging functionality.""" import asyncio +from collections.abc import Callable import contextvars import json import os @@ -22,6 +23,9 @@ mappings = {} +DEFAULT_INVOKE_TIMEOUT = 10 * 60 +OPENAI_COMPATIBLE_PROVIDERS = {None, "openai", "openai-compatible", "openai_compatible"} + def add_thread(thread_id, base_url, api_key, model): log_path = f".art/langgraph/{thread_id}" @@ -108,31 +112,82 @@ async def wrapper(*args, **kwargs): def init_chat_model( - model: Literal[None] = None, + model: str | Runnable | None = None, *, model_provider: str | None = None, configurable_fields: Literal[None] = None, config_prefix: str | None = None, + invoke_timeout: float | None = DEFAULT_INVOKE_TIMEOUT, **kwargs: Any, ): + """Create a logged LangChain chat model for ART LangGraph rollouts. + + By default ART constructs a ChatOpenAI client pointed at the + OpenAI-compatible endpoint from the active rollout context. For other + LangChain providers, pass an already constructed chat model instance as + ``model``. Provider kwargs such as ``temperature`` and ``timeout`` are + forwarded to ChatOpenAI; ``invoke_timeout`` controls only ART's outer + ``asyncio.wait_for`` timeout. + """ config = CURRENT_CONFIG.get() + + if configurable_fields is not None: + raise ValueError( + "configurable_fields is not supported by ART's init_chat_model" + ) + if config_prefix is not None: + raise ValueError("config_prefix is not supported by ART's init_chat_model") + + if model is not None and not isinstance(model, str): + return LoggingLLM( + model, + config["logger"], + invoke_timeout=invoke_timeout, + ) + + if model_provider not in OPENAI_COMPATIBLE_PROVIDERS: + raise ValueError( + "ART's init_chat_model can construct only OpenAI-compatible chat " + "models. Pass a LangChain chat model instance as `model` to use " + f"provider {model_provider!r}." + ) + + model_name = model + + def chat_openai_factory(art_config: dict[str, Any]): + chat_model_kwargs: dict[str, Any] = { + "base_url": art_config["base_url"], + "api_key": art_config["api_key"], + "model": model_name or art_config["model"], + "temperature": 1.0, + } + chat_model_kwargs.update(kwargs) + return ChatOpenAI(**chat_model_kwargs) + return LoggingLLM( - ChatOpenAI( - base_url=config["base_url"], # ty:ignore[unknown-argument] - api_key=config["api_key"], # ty:ignore[unknown-argument] - model=config["model"], # ty:ignore[unknown-argument] - temperature=1.0, - ), + chat_openai_factory(config), config["logger"], + invoke_timeout=invoke_timeout, + chat_model_factory=chat_openai_factory, ) class LoggingLLM(Runnable): - def __init__(self, llm, logger, structured_output=None, tools=None): + def __init__( + self, + llm, + logger, + structured_output=None, + tools=None, + invoke_timeout: float | None = DEFAULT_INVOKE_TIMEOUT, + chat_model_factory: Callable[[dict[str, Any]], Any] | None = None, + ): self.llm = llm self.logger = logger self.structured_output = structured_output self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None + self.invoke_timeout = invoke_timeout + self.chat_model_factory = chat_model_factory def _log(self, completion_id, input, output): if self.logger: @@ -143,7 +198,7 @@ def invoke(self, input, config=None, **kwargs): completion_id = str(uuid.uuid4()) def execute(): - result = self.llm.invoke(input, config=config) + result = self.llm.invoke(input, config=config, **kwargs) self._log(completion_id, input, result) return result @@ -166,9 +221,11 @@ async def ainvoke(self, input, config=None, **kwargs): async def execute(): try: - result = await asyncio.wait_for( - self.llm.ainvoke(input, config=config), timeout=10 * 60 - ) + call = self.llm.ainvoke(input, config=config, **kwargs) + if self.invoke_timeout is None: + result = await call + else: + result = await asyncio.wait_for(call, timeout=self.invoke_timeout) self._log(completion_id, input, result) except asyncio.TimeoutError as e: raise e @@ -194,10 +251,18 @@ def with_structured_output(self, tools): self.logger, structured_output=tools, tools=[tools], + invoke_timeout=self.invoke_timeout, + chat_model_factory=self.chat_model_factory, ) def bind_tools(self, tools): - return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools) + return LoggingLLM( + self.llm.bind_tools(tools), + self.logger, + tools=tools, + invoke_timeout=self.invoke_timeout, + chat_model_factory=self.chat_model_factory, + ) def with_retry( self, @@ -217,23 +282,13 @@ def with_config( art_config = CURRENT_CONFIG.get() self.logger = art_config["logger"] - if hasattr(self.llm, "bound"): - setattr( - self.llm, - "bound", - ChatOpenAI( - base_url=art_config["base_url"], # ty:ignore[unknown-argument] - api_key=art_config["api_key"], # ty:ignore[unknown-argument] - model=art_config["model"], # ty:ignore[unknown-argument] - temperature=1.0, - ), - ) - else: - self.llm = ChatOpenAI( - base_url=art_config["base_url"], # ty:ignore[unknown-argument] - api_key=art_config["api_key"], # ty:ignore[unknown-argument] - model=art_config["model"], # ty:ignore[unknown-argument] - temperature=1.0, - ) + if self.chat_model_factory is not None: + configured_llm = self.chat_model_factory(art_config) + if hasattr(self.llm, "bound"): + setattr(self.llm, "bound", configured_llm) + else: + self.llm = configured_llm + elif hasattr(self.llm, "with_config"): + self.llm = self.llm.with_config(config=config, **kwargs) return self diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 75a3c120f..7c56dcb7d 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -112,6 +112,9 @@ async def ruler( - "openai/gpt-4o-mini" - Fast and cost-effective - "openai/o3" - Most capable but expensive (default) - "anthropic/claude-3-opus-20240229" - Alternative judge + - "ollama/qwen3:32b" - Local Ollama judge via LiteLLM + The default calls OpenAI through LiteLLM. Set this explicitly for + local or custom judge backends. extra_litellm_params: Additional parameters to pass to LiteLLM completion. Can include temperature, max_tokens, etc. rubric: The grading rubric. The default rubric works well for most tasks. diff --git a/tests/unit/test_langgraph_llm_wrapper.py b/tests/unit/test_langgraph_llm_wrapper.py new file mode 100644 index 000000000..4cdaacf3d --- /dev/null +++ b/tests/unit/test_langgraph_llm_wrapper.py @@ -0,0 +1,242 @@ +import importlib +import sys +import types +from typing import Any + +import pytest + + +class _FakeRunnable: + pass + + +class _FakeMessage: + pass + + +class _FakePromptValue: + pass + + +class _FakeBoundLLM: + def __init__(self, bound: Any, tools: list[Any]) -> None: + self.bound = bound + self.tools = tools + + def bind_tools(self, tools: list[Any]) -> "_FakeBoundLLM": + return _FakeBoundLLM(self.bound, tools) + + +class _FakeChatOpenAI: + instances: list["_FakeChatOpenAI"] = [] + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + self.calls: list[tuple[str, Any, dict[str, Any]]] = [] + self.instances.append(self) + + def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: + self.calls.append(("invoke", input, {"config": config, **kwargs})) + return types.SimpleNamespace(tool_calls=None) + + async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: + self.calls.append(("ainvoke", input, {"config": config, **kwargs})) + return types.SimpleNamespace(tool_calls=None) + + def bind_tools(self, tools: list[Any]) -> _FakeBoundLLM: + return _FakeBoundLLM(self, tools) + + +class _FakeLogger: + def __init__(self) -> None: + self.entries: list[tuple[str, Any]] = [] + + def log(self, key: str, entry: Any) -> None: + self.entries.append((key, entry)) + + +@pytest.fixture +def llm_wrapper(monkeypatch: pytest.MonkeyPatch): + _FakeChatOpenAI.instances.clear() + + messages_module = types.ModuleType("langchain_core.messages") + setattr(messages_module, "AIMessage", _FakeMessage) + setattr(messages_module, "BaseMessage", _FakeMessage) + setattr(messages_module, "FunctionMessage", _FakeMessage) + setattr(messages_module, "HumanMessage", _FakeMessage) + setattr(messages_module, "SystemMessage", _FakeMessage) + setattr(messages_module, "ToolMessage", _FakeMessage) + + prompt_values_module = types.ModuleType("langchain_core.prompt_values") + setattr(prompt_values_module, "ChatPromptValue", _FakePromptValue) + + runnables_module = types.ModuleType("langchain_core.runnables") + setattr(runnables_module, "Runnable", _FakeRunnable) + + function_calling_module = types.ModuleType("langchain_core.utils.function_calling") + setattr(function_calling_module, "convert_to_openai_tool", lambda tool: tool) + + utils_module = types.ModuleType("langchain_core.utils") + core_module = types.ModuleType("langchain_core") + openai_module = types.ModuleType("langchain_openai") + setattr(openai_module, "ChatOpenAI", _FakeChatOpenAI) + + for module_name, module in { + "langchain_core": core_module, + "langchain_core.messages": messages_module, + "langchain_core.prompt_values": prompt_values_module, + "langchain_core.runnables": runnables_module, + "langchain_core.utils": utils_module, + "langchain_core.utils.function_calling": function_calling_module, + "langchain_openai": openai_module, + }.items(): + monkeypatch.setitem(sys.modules, module_name, module) + + for module_name in [ + "art.langgraph", + "art.langgraph.llm_wrapper", + "art.langgraph.message_utils", + ]: + sys.modules.pop(module_name, None) + + return importlib.import_module("art.langgraph.llm_wrapper") + + +def _set_current_config(module: Any, **overrides: Any) -> _FakeLogger: + logger = overrides.pop("logger", _FakeLogger()) + module.CURRENT_CONFIG.set( + { + "logger": logger, + "base_url": overrides.pop("base_url", "http://rollout.test/v1"), + "api_key": overrides.pop("api_key", "test-key"), + "model": overrides.pop("model", "context-model"), + **overrides, + } + ) + return logger + + +def test_init_chat_model_forwards_model_and_provider_kwargs(llm_wrapper: Any) -> None: + _set_current_config(llm_wrapper) + + logged_llm = llm_wrapper.init_chat_model( + "explicit-model", + temperature=0.2, + timeout=123, + max_tokens=42, + invoke_timeout=5, + ) + + assert logged_llm.invoke_timeout == 5 + assert logged_llm.llm.kwargs == { + "base_url": "http://rollout.test/v1", + "api_key": "test-key", + "model": "explicit-model", + "temperature": 0.2, + "timeout": 123, + "max_tokens": 42, + } + + +def test_with_config_preserves_kwargs_and_uses_new_context( + llm_wrapper: Any, +) -> None: + first_logger = _set_current_config( + llm_wrapper, + base_url="http://first.test/v1", + api_key="first-key", + model="first-model", + ) + + logged_llm = llm_wrapper.init_chat_model(temperature=0.3, invoke_timeout=None) + assert logged_llm.logger is first_logger + assert logged_llm.llm.kwargs["model"] == "first-model" + + second_logger = _set_current_config( + llm_wrapper, + base_url="http://second.test/v1", + api_key="second-key", + model="second-model", + ) + + assert logged_llm.with_config() is logged_llm + assert logged_llm.logger is second_logger + assert logged_llm.invoke_timeout is None + assert logged_llm.llm.kwargs == { + "base_url": "http://second.test/v1", + "api_key": "second-key", + "model": "second-model", + "temperature": 0.3, + } + + +def test_with_config_updates_bound_chat_model(llm_wrapper: Any) -> None: + _set_current_config(llm_wrapper, model="first-model") + logged_llm = llm_wrapper.init_chat_model("explicit-model", temperature=0.4) + bound_llm = logged_llm.bind_tools(["tool"]) + + _set_current_config(llm_wrapper, model="second-context-model") + + bound_llm.with_config() + + assert bound_llm.llm.tools == ["tool"] + assert bound_llm.llm.bound.kwargs == { + "base_url": "http://rollout.test/v1", + "api_key": "test-key", + "model": "explicit-model", + "temperature": 0.4, + } + + +def test_custom_chat_model_is_wrapped_without_chat_openai(llm_wrapper: Any) -> None: + class CustomChatModel: + pass + + _set_current_config(llm_wrapper) + custom_model = CustomChatModel() + + logged_llm = llm_wrapper.init_chat_model( + custom_model, + model_provider="ollama", + invoke_timeout=7, + ) + + assert logged_llm.llm is custom_model + assert logged_llm.invoke_timeout == 7 + assert _FakeChatOpenAI.instances == [] + + +def test_unsupported_provider_raises_instead_of_using_chat_openai( + llm_wrapper: Any, +) -> None: + _set_current_config(llm_wrapper) + + with pytest.raises(ValueError, match="OpenAI-compatible"): + llm_wrapper.init_chat_model("llama3", model_provider="ollama") + + assert _FakeChatOpenAI.instances == [] + + +@pytest.mark.asyncio +async def test_ainvoke_uses_wrapper_timeout_and_forwards_kwargs( + llm_wrapper: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + seen: dict[str, Any] = {} + + async def fake_wait_for(awaitable: Any, timeout: float | None) -> Any: + seen["timeout"] = timeout + return await awaitable + + monkeypatch.setattr(llm_wrapper.asyncio, "wait_for", fake_wait_for) + + logger = _set_current_config(llm_wrapper) + logged_llm = llm_wrapper.init_chat_model(invoke_timeout=2) + + await logged_llm.ainvoke("hello", config={"run": 1}, stop=["END"]) + + assert seen["timeout"] == 2 + assert logged_llm.llm.calls == [ + ("ainvoke", "hello", {"config": {"run": 1}, "stop": ["END"]}) + ] + assert len(logger.entries) == 1