Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions services/doc_agent_chat/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from doc_agent_chat.prompt import build_system_prompt
from doc_agent_chat.tools import TOOL_DEFINITIONS, search_documents, format_search_results_as_documents
from doc_agent_chat.config_loader import ConfigLoader
from models import preferred_chat_model
from models import preferred_chat_model, call_with_model_fallback

logger = create_logger("agent")

Expand Down Expand Up @@ -59,16 +59,19 @@ def run(
for iteration in range(self.max_tool_calls):
logger.info(f"Agentic loop iteration {iteration + 1}")

response = self.client.messages.create(
model=self.model,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=TOOL_DEFINITIONS,
# Per-request timeout (same values as the SDK default):
# required for non-streaming calls with max_tokens > ~21k,
# which the SDK otherwise rejects.
timeout=httpx.Timeout(600.0, connect=5.0),
response = call_with_model_fallback(
lambda m: self.client.messages.create(
model=m,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=TOOL_DEFINITIONS,
# Per-request timeout (same values as the SDK default):
# required for non-streaming calls with max_tokens > ~21k,
# which the SDK otherwise rejects.
timeout=httpx.Timeout(600.0, connect=5.0),
),
preferred=self.model,
)

if hasattr(response, "usage"):
Expand Down
85 changes: 49 additions & 36 deletions services/global_chat/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
STATUS_PLANNING,
)
from global_chat.config_loader import ConfigLoader
from models import preferred_chat_model
from models import (
preferred_chat_model,
call_with_model_fallback,
stream_with_model_fallback,
)
from global_chat.tools.tool_definitions import TOOL_DEFINITIONS
from global_chat.yaml_utils import stitch_job_code, redact_job_bodies, find_job_in_yaml
from tools.search_documentation.search_documentation import search_documentation_tool
Expand Down Expand Up @@ -276,47 +280,56 @@ def _call_api(self, system_prompt, messages, stream):
task-specific status messages sent before each tool execution.
"""
if stream:
buffered_text = []

with self.client.messages.stream(
model=self.model,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=self.tools,
thinking={"type": "adaptive"},
output_config={"effort": "medium"},
) as stream_obj:
def _consume(stream_obj, commit):
buffered_text = []
for event in stream_obj:
if event.type == "content_block_delta":
if event.delta.type == "text_delta":
buffered_text.append(event.delta.text)
commit()
return stream_obj.get_final_message(), buffered_text

return stream_with_model_fallback(
lambda m: self.client.messages.stream(
model=m,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=self.tools,
thinking={"type": "adaptive"},
output_config={"effort": "medium"},
),
_consume,
preferred=self.model,
)
else:
response = self.client.beta.messages.create(
model=self.model,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=self.tools,
thinking={"type": "adaptive"},
output_config={"effort": "medium"},
# Per-request timeout (same values as the SDK default):
# required for non-streaming calls with max_tokens > ~21k,
# which the SDK otherwise rejects.
timeout=httpx.Timeout(600.0, connect=5.0),
betas=["context-management-2025-06-27"],
context_management={
"edits": [
{
"type": "clear_tool_uses_20250919",
"trigger": {"type": "tool_uses", "value": 20},
"keep": {"type": "tool_uses", "value": 10},
"exclude_tools": ["search_documentation"],
"clear_tool_inputs": True,
}
]
},
response = call_with_model_fallback(
lambda m: self.client.beta.messages.create(
model=m,
max_tokens=self.max_tokens,
system=system_prompt,
messages=messages,
tools=self.tools,
thinking={"type": "adaptive"},
output_config={"effort": "medium"},
# Per-request timeout (same values as the SDK default):
# required for non-streaming calls with max_tokens > ~21k,
# which the SDK otherwise rejects.
timeout=httpx.Timeout(600.0, connect=5.0),
betas=["context-management-2025-06-27"],
context_management={
"edits": [
{
"type": "clear_tool_uses_20250919",
"trigger": {"type": "tool_uses", "value": 20},
"keep": {"type": "tool_uses", "value": 10},
"exclude_tools": ["search_documentation"],
"clear_tool_inputs": True,
}
]
},
),
preferred=self.model,
)
return response, []

Expand Down
92 changes: 57 additions & 35 deletions services/job_chat/job_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
STATUS_WORKING,
STATUS_WRITING_CODE,
)
from models import preferred_chat_model
from models import (
preferred_chat_model,
call_with_model_fallback,
stream_with_model_fallback,
)

_MODEL = preferred_chat_model("job_chat")

Expand Down Expand Up @@ -231,26 +235,18 @@ def generate(
with sentry_sdk.start_span(description="anthropic_api_call"):
if stream:
logger.info("Making streaming API call")
text_started = False
sent_length = 0
accumulated_response = ""
self._stream_applied = False
self._stream_suggested_code = None
self._stream_diff = None

original_code = context.get("expression") if context and isinstance(context, dict) else None

stream_kwargs = dict(
max_tokens=self.config.max_tokens,
messages=prompt,
model=self.config.model,
system=system_message,
thinking={"type": "adaptive"},
output_config=output_config,
**tool_kwargs
)
def _consume(stream_obj, commit):
# Reset per attempt so a model fallback never reuses a
# prior (failed) stream's partial state.
text_started = False
sent_length = 0
accumulated_response = ""
self._stream_applied = False
self._stream_suggested_code = None
self._stream_diff = None

with self.client.messages.stream(**stream_kwargs) as stream_obj:
for event in stream_obj:
if event.type == "message_start":
stream_manager.send_thinking(STATUS_WORKING)
Expand All @@ -268,20 +264,40 @@ def generate(
original_code,
content
)
message = stream_obj.get_final_message()
# Once user-facing text has streamed, we can't cleanly
# fall back to another model without re-sending it.
if text_started:
commit()

msg = stream_obj.get_final_message()

# Flush any remaining buffered text, stripping JSON closing chars
if suggest_code and text_started:
if sent_length < len(accumulated_response):
remaining = accumulated_response[sent_length:]
remaining = re.sub(r'"\s*}\s*$', '', remaining)
if remaining:
stream_manager.send_text(self._unescape_json_string(remaining))
return msg

# Flush any remaining buffered text, stripping JSON closing chars
if suggest_code and text_started:
if sent_length < len(accumulated_response):
remaining = accumulated_response[sent_length:]
remaining = re.sub(r'"\s*}\s*$', '', remaining)
if remaining:
stream_manager.send_text(self._unescape_json_string(remaining))
stream_kwargs = dict(
max_tokens=self.config.max_tokens,
messages=prompt,
system=system_message,
thinking={"type": "adaptive"},
output_config=output_config,
**tool_kwargs
)
message = stream_with_model_fallback(
lambda m: self.client.messages.stream(model=m, **stream_kwargs),
_consume,
preferred=self.config.model,
)

else:
logger.info("Making non-streaming API call")
create_kwargs = dict(
max_tokens=self.config.max_tokens, messages=prompt, model=self.config.model, system=system_message,
max_tokens=self.config.max_tokens, messages=prompt, system=system_message,
thinking={"type": "adaptive"},
output_config=output_config,
# Per-request timeout (same values as the SDK default):
Expand All @@ -290,7 +306,10 @@ def generate(
timeout=httpx.Timeout(600.0, connect=5.0),
**tool_kwargs
)
message = self.client.messages.create(**create_kwargs)
message = call_with_model_fallback(
lambda m: self.client.messages.create(model=m, **create_kwargs),
preferred=self.config.model,
)

if hasattr(message, "usage"):
if message.usage.cache_creation_input_tokens:
Expand Down Expand Up @@ -537,13 +556,16 @@ def try_error_correction(self, content: str, error_message: str, old_code: str,
# structured outputs removed here too (see note in generate); the
# correction prompt already instructs the {explanation, corrected_*}
# JSON shape and json.loads below is wrapped in try/except.
message = self.client.messages.create(
max_tokens=16384,
messages=prompt,
model=self.config.model,
system=system_message,
output_config={"effort": "medium"},
thinking={"type": "adaptive"}
message = call_with_model_fallback(
lambda m: self.client.messages.create(
max_tokens=16384,
messages=prompt,
model=m,
system=system_message,
output_config={"effort": "medium"},
thinking={"type": "adaptive"}
),
preferred=self.config.model,
)

response = "\n\n".join([block.text for block in message.content if block.type == "text"])
Expand Down
Loading
Loading