diff --git a/verifiers/v1/trace.py b/verifiers/v1/trace.py index 9944483a9..1e3d44dbb 100644 --- a/verifiers/v1/trace.py +++ b/verifiers/v1/trace.py @@ -227,6 +227,8 @@ class Trace(StrictBaseModel, Generic[TaskT, StateT]): _head_index: dict = PrivateAttr(default_factory=dict) """`(parent, msg_hash) -> node_id` for the graph builder (`graph.prepare_turn` / `commit`); rebuilt lazily from `nodes` after deserialization.""" + _num_turns_cache: tuple[int, int] = PrivateAttr(default=(0, 0)) + """`(counted nodes, sampled turns)` for the append-only message graph.""" @property def reward(self) -> float: @@ -300,7 +302,19 @@ def num_branches(self) -> int: def num_turns(self) -> int: """Total model turns (sampled responses) across all branches — prompt-supplied assistant messages don't count.""" - return sum(1 for n in self.nodes if n.sampled) + counted_nodes, num_turns = self._num_turns_cache + node_count = len(self.nodes) + # Graph commits append nodes, so count only the unseen suffix between reads. + # If a caller shrinks the list, rebuild the count from the remaining nodes. + if node_count == counted_nodes: + return num_turns + if node_count < counted_nodes: + counted_nodes, num_turns = 0, 0 + while counted_nodes < node_count: + num_turns += self.nodes[counted_nodes].sampled + counted_nodes += 1 + self._num_turns_cache = (node_count, num_turns) + return num_turns @property def is_truncated(self) -> bool: