Skip to content

Commit da450a3

Browse files
committed
fix: preserve .tool attr on error-wrapped runnables for confirmation detection
1 parent 5700f6b commit da450a3

File tree

3 files changed

+15
-23
lines changed

3 files changed

+15
-23
lines changed

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,11 @@ async def _afunc(state: AgentGraphState) -> OutputType:
274274
raise
275275
return result
276276

277-
return RunnableCallable(func=_func, afunc=_afunc, name=tool_name)
277+
wrapped = RunnableCallable(func=_func, afunc=_afunc, name=tool_name)
278+
# Preserve .tool so _get_tool_confirmation_info can find metadata
279+
if hasattr(tool_node, "tool"):
280+
wrapped.tool = tool_node.tool # type: ignore[attr-defined]
281+
return wrapped
278282

279283

280284
class ToolWrapperMixin:

src/uipath_langchain/runtime/messages.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,7 @@ async def map_current_message_to_start_tool_call_events(self):
435435
self.map_tool_call_to_tool_call_start_event(
436436
self.current_message.id,
437437
tool_call,
438-
require_confirmation=require_confirmation
439-
if require_confirmation
440-
else None,
438+
require_confirmation=require_confirmation or None,
441439
input_schema=input_schema,
442440
)
443441
)
@@ -671,7 +669,7 @@ def _map_langchain_human_message_to_uipath_message_data(
671669
)
672670

673671
return UiPathConversationMessageData(
674-
role="user", content_parts=content_parts, tool_calls=[], interrupts=[]
672+
role="user", content_parts=content_parts, tool_calls=[]
675673
)
676674

677675
@staticmethod
@@ -721,7 +719,6 @@ def _map_langchain_ai_message_to_uipath_message_data(
721719
role="assistant",
722720
content_parts=content_parts,
723721
tool_calls=uipath_tool_calls,
724-
interrupts=[],
725722
)
726723

727724

src/uipath_langchain/runtime/runtime.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@ def __init__(
6565
self.entrypoint: str | None = entrypoint
6666
self.callbacks: list[BaseCallbackHandler] = callbacks or []
6767
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
68-
self.chat.tool_names_requiring_confirmation = (
69-
self._get_tool_names_requiring_confirmation()
70-
)
71-
self.chat.tool_confirmation_schemas = self._get_tool_confirmation_schemas()
68+
confirmation_names, confirmation_schemas = self._get_tool_confirmation_info()
69+
self.chat.tool_names_requiring_confirmation = confirmation_names
70+
self.chat.tool_confirmation_schemas = confirmation_schemas
7271
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
7372

7473
async def execute(
@@ -491,33 +490,25 @@ def _detect_middleware_nodes(self) -> set[str]:
491490

492491
return middleware_nodes
493492

494-
def _get_tool_names_requiring_confirmation(self) -> set[str]:
493+
def _get_tool_confirmation_info(self) -> tuple[set[str], dict[str, Any]]:
494+
"""Single pass over graph nodes to collect confirmation tool names and schemas."""
495495
names: set[str] = set()
496-
for node_name, node_spec in self.graph.nodes.items():
497-
# langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node)
498-
tool = getattr(getattr(node_spec, "bound", None), "tool", None)
499-
if tool is None:
500-
continue
501-
metadata = getattr(tool, "metadata", None) or {}
502-
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
503-
names.add(getattr(tool, "name", node_name))
504-
return names
505-
506-
def _get_tool_confirmation_schemas(self) -> dict[str, Any]:
507496
schemas: dict[str, Any] = {}
508497
for node_name, node_spec in self.graph.nodes.items():
498+
# langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node)
509499
tool = getattr(getattr(node_spec, "bound", None), "tool", None)
510500
if tool is None:
511501
continue
512502
metadata = getattr(tool, "metadata", None) or {}
513503
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
514504
tool_name = getattr(tool, "name", node_name)
505+
names.add(tool_name)
515506
tool_call_schema = getattr(tool, "tool_call_schema", None)
516507
if tool_call_schema is not None:
517508
schemas[tool_name] = tool_call_schema.model_json_schema()
518509
else:
519510
schemas[tool_name] = {}
520-
return schemas
511+
return names, schemas
521512

522513
def _is_middleware_node(self, node_name: str) -> bool:
523514
"""Check if a node name represents a middleware node."""

0 commit comments

Comments
 (0)