Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions verifiers/v1/interception/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,59 @@ def reached(self, trace: Trace) -> str | None:
"""The name of the first limit `trace` has reached, or None if within all caps."""
if self.max_turns is not None and trace.num_turns >= self.max_turns:
return "max_turns"
if (
self.max_input_tokens is None
and self.max_output_tokens is None
and self.max_total_tokens is None
):
return None

nodes = trace.nodes
# Only canonical append order safely proves a single root-to-leaf path.
if all(
node.parent == (index - 1 if index else None)
for index, node in enumerate(nodes)
):
total_tokens = (
sum(len(node.token_ids) for node in nodes)
if self.max_input_tokens is not None
or self.max_total_tokens is not None
else 0
)
if self.max_input_tokens is not None:
last_completion = next(
(sum(node.mask) for node in reversed(nodes) if any(node.mask)), 0
)
if total_tokens - last_completion >= self.max_input_tokens:
return "max_input_tokens"
if (
self.max_output_tokens is not None
and sum(sum(node.mask) for node in nodes) >= self.max_output_tokens
):
return "max_output_tokens"
if (
self.max_total_tokens is not None
and total_tokens >= self.max_total_tokens
):
return "max_total_tokens"
return None

# Reuse one branch view so enabled token caps do not repeat the graph walk.
branches = trace.branches
if (
self.max_input_tokens is not None
and trace.prompt_len >= self.max_input_tokens
and sum(branch.prompt_len for branch in branches) >= self.max_input_tokens
):
return "max_input_tokens"
if (
self.max_output_tokens is not None
and trace.completion_len >= self.max_output_tokens
and sum(branch.completion_len for branch in branches)
>= self.max_output_tokens
):
return "max_output_tokens"
if (
self.max_total_tokens is not None
and trace.total_tokens >= self.max_total_tokens
and sum(branch.total_tokens for branch in branches) >= self.max_total_tokens
):
return "max_total_tokens"
return None
Expand Down
Loading