Skip to content

Commit 47ec162

Browse files
authored
Merge pull request #148 from Multi-Agent-LLMs/judge
Judge Agent
2 parents ebe07a9 + fbe666e commit 47ec162

25 files changed

Lines changed: 366 additions & 506 deletions

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,18 @@ use_ablation: bool = False
120120
shuffle_input_samples: bool = False
121121
all_agents_generate_first_draft: bool = False
122122
all_agents_generate_draft: bool = False
123-
policy: Optional[str] = None
124123
voting_protocols_with_alterations: bool = False
125124
calculate_persona_diversity: bool = False
126125
challenge_final_results: bool = False
126+
judge_intervention: Optional[str] = None
127+
judge_metric: Optional[str] = None
128+
judge_endpoint_url: Optional[str] = None
129+
judge_api_key: str = "-"
130+
judge_always_intervene: bool = False
127131
```
128132

129133
### Discussion Parameters:
130-
Response Generators: `freetext`, `json`, `simple`, `splitfreetext`
134+
Response Generators: `freetext`, `simple`, `splitfreetext`
131135

132136
Decision Protocols: `approval_voting`, `consensus_voting`, `cumulative_voting`, `hybrid_consensus`, `majority_consensus`, `ranked_voting`, `simple_voting`, `summary`, `supermajority_consensus`, `unanimity_consensus`
133137

mallm/agents/agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ def get_memories(
232232

233233
return context_memory, memory_ids, current_draft
234234

235+
def forget_memories(self, turn: int) -> None:
236+
keys_to_delete = [key for key, memory in self.memory.items() if memory.turn == turn]
237+
for key in keys_to_delete:
238+
del self.memory[key]
239+
logger.debug(f"Forgot memories {keys_to_delete} from turn {turn} from agent {self.id}")
240+
235241
def get_own_messages(self, context_length: Optional[int] = None) -> list[str]:
236242
"""
237243
Retrieves memory from the agents memory bucket as a string

mallm/agents/judge.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import TYPE_CHECKING, Optional
5+
6+
import httpx
7+
8+
from mallm.models.Chat import Chat
9+
from mallm.models.discussion.ResponseGenerator import ResponseGenerator
10+
11+
if TYPE_CHECKING:
12+
from mallm.coordinator import Coordinator
13+
14+
from mallm.agents.agent import Agent
15+
from mallm.evaluation.evaluator import Evaluator
16+
from mallm.utils.types import Memory, TemplateFilling
17+
18+
logger = logging.getLogger("mallm")
19+
20+
21+
class Judge(Agent):
22+
def __init__(
23+
self,
24+
llm: Chat,
25+
client: httpx.Client,
26+
coordinator: Coordinator,
27+
response_generator: ResponseGenerator,
28+
persona: str,
29+
persona_description: str,
30+
metric: str,
31+
chain_of_thought: bool = False,
32+
drafting_agent: bool = False,
33+
intervention_type: str = "regenerate",
34+
references: Optional[list[str]] = None,
35+
):
36+
if references is None:
37+
references = []
38+
super().__init__(
39+
llm,
40+
client,
41+
coordinator,
42+
response_generator,
43+
persona,
44+
persona_description,
45+
chain_of_thought,
46+
drafting_agent,
47+
)
48+
self.metric = Evaluator._initialize_metrics([metric])[0]
49+
self.judgements: list[Optional[bool]] = []
50+
self.performances: list[float] = []
51+
self.judged_solutions: list[str] = []
52+
self.intervention_type = intervention_type
53+
self.coordinator = coordinator
54+
self.references = references
55+
56+
def llm_as_a_judge(self, template_filling: TemplateFilling) -> Optional[bool]:
57+
repeats = 0
58+
while repeats < 3:
59+
# check for drift
60+
response = self.response_generator.generate_judgement(
61+
template_filling, self.judged_solutions[-2], self.judged_solutions[-1]
62+
)
63+
if "[[A]]" in response.message:
64+
return True # answer_before is better
65+
if "[[B]]" in response.message:
66+
return False # answer_after is better (problem drift)
67+
logger.warning(f"Judge verdict is not valid: {response.message}. Retry number {repeats + 1}.")
68+
repeats += 1
69+
logger.warning(f"Judge verdict is not valid: {response.message}. All retries failed. The verdict will be saved as None.")
70+
return None
71+
72+
def intervention(self,
73+
unique_id: int,
74+
turn: int,
75+
memory_ids: list[int],
76+
template_filling: TemplateFilling,
77+
answer: str,
78+
threshold: float = 0,
79+
always_intervene: bool = False,
80+
) -> tuple[int, int]:
81+
self.judged_solutions.append(answer)
82+
83+
if self.coordinator.judge_llm is not None:
84+
if len(self.judged_solutions) < 2:
85+
logger.debug("Judge skipped this turn because there are not enough solutions to judge.")
86+
return unique_id, turn
87+
on_track = self.llm_as_a_judge(template_filling)
88+
else:
89+
self.performances.append(Evaluator.calculate_score(answer, self.references, self.metric)["value"])
90+
on_track = len(self.performances) > 1 and self.performances[-1] + threshold < self.performances[-2]
91+
self.judgements.append(on_track)
92+
93+
if on_track is False or always_intervene: # regenerates at most once per turn
94+
if self.intervention_type == "regenerate":
95+
# delete and restart the turn
96+
logger.debug("Judge decided to regenerate the turn.")
97+
self.coordinator.forget_memories(turn)
98+
return unique_id - len(self.coordinator.agents) + 1, turn - 1
99+
if self.intervention_type == "policy":
100+
# Give the agents tips on how to improve their policy
101+
logger.debug("Judge decided to give policy feedback.")
102+
response = self.response_generator.generate_policy_intervention(
103+
template_filling,
104+
provide_labels=False
105+
)
106+
memory = Memory(
107+
message_id=unique_id,
108+
turn=turn,
109+
agent_id=self.id,
110+
persona=self.persona,
111+
contribution="judge",
112+
message=response.message,
113+
agreement=None,
114+
solution=None,
115+
memory_ids=memory_ids,
116+
additional_args={},
117+
)
118+
self.coordinator.update_memories([memory], self.coordinator.agents)
119+
self.coordinator.memory.append(memory)
120+
return unique_id + 1, turn
121+
return unique_id, turn

mallm/agents/policyFeedback.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

mallm/coordinator.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from typing import Optional
88

99
import httpx
10-
from rich.progress import Console # type: ignore
10+
from rich.progress import Console
1111

1212
from mallm.agents.agent import Agent
1313
from mallm.agents.draftProposer import DraftProposer
14+
from mallm.agents.judge import Judge
1415
from mallm.agents.panelist import Panelist
15-
from mallm.agents.policyFeedback import PolicyFeedback
1616
from mallm.decision_protocol.protocol import DecisionProtocol
1717
from mallm.discourse_policy.policy import DiscoursePolicy
1818
from mallm.models.Chat import Chat
@@ -50,9 +50,9 @@ def __init__(
5050
model: Chat,
5151
client: httpx.Client,
5252
agent_generators: Optional[list[str]] = None,
53-
policy: Optional[str] = None,
5453
num_neutral_agents: int = 0,
5554
console: Optional[Console] = None,
55+
judge_model: Optional[Chat] = None,
5656
):
5757
if agent_generators is None:
5858
agent_generators = ["expert", "expert", "expert"]
@@ -68,9 +68,9 @@ def __init__(
6868
self.response_generator: ResponseGenerator = SimpleResponseGenerator(self.llm)
6969
self.client = client
7070
self.agent_generators = agent_generators
71-
self.policy = policy
7271
self.memory: list[Memory] = []
7372
self.console = console or Console()
73+
self.judge_llm = judge_model
7474

7575
def init_agents(
7676
self,
@@ -80,14 +80,16 @@ def init_agents(
8080
num_agents: int,
8181
chain_of_thought: bool,
8282
sample: InputExample,
83+
judge_intervention: Optional[str] = None,
84+
judge_metric: Optional[str] = None,
8385
) -> None:
8486
"""
8587
Instantiates the agents by
8688
1) identify helpful personas depending on the agent_generator
8789
2) create agents with the personas
8890
"""
8991
logger.debug(
90-
f"Coordinator {self.id} creates {num_agents} agents ({self.agent_generators}). Policy: {self.policy}"
92+
f"Coordinator {self.id} creates {num_agents} agents ({self.agent_generators})."
9193
)
9294
self.panelists = []
9395
self.agents = []
@@ -142,16 +144,22 @@ def init_agents(
142144
"Created only 1 agent. The discussion will be replaced by a self-improvement mechanism."
143145
)
144146

145-
if self.policy:
146-
policyFeedback = PolicyFeedback(
147-
self.llm,
147+
self.judge = None
148+
if judge_intervention and self.judge_llm:
149+
self.judge = Judge(
150+
self.judge_llm,
148151
self.client,
149152
self,
150153
response_generator=self.response_generator,
151-
persona="Policy Moderator",
152-
policy=self.policy,
154+
persona="Judge",
155+
persona_description="Responsible for evaluating the solutions and providing feedback to the agents.",
156+
metric=str(judge_metric),
157+
chain_of_thought=False,
158+
drafting_agent=False,
159+
intervention_type=judge_intervention,
160+
references=sample.references,
153161
)
154-
self.agents.append(policyFeedback)
162+
self.agents.append(self.judge)
155163

156164
def get_agents(
157165
self, config: Config, worker_functions: WorkerFunctions
@@ -202,6 +210,8 @@ def discuss(
202210
bool,
203211
dict[int, Optional[VotingResultList]],
204212
ChallengeResult,
213+
Optional[list[Optional[bool]]],
214+
Optional[list[str]],
205215
]:
206216
"""
207217
The routine responsible for the discussion between agents to solve a task.
@@ -241,6 +251,8 @@ def discuss(
241251
num_agents=config.num_agents,
242252
chain_of_thought=config.use_chain_of_thought,
243253
sample=sample,
254+
judge_intervention=config.judge_intervention,
255+
judge_metric=config.judge_metric,
244256
)
245257

246258
if config.decision_protocol not in DECISION_PROTOCOLS:
@@ -349,6 +361,8 @@ def discuss(
349361
decision_success,
350362
voting_results_per_turn,
351363
challenged_answers,
364+
self.judge.judgements if self.judge else None,
365+
self.judge.judged_solutions if self.judge else None,
352366
)
353367

354368
def challenge_solution(
@@ -432,6 +446,10 @@ def get_memories(
432446

433447
return context_memory, memory_ids, current_draft
434448

449+
def forget_memories(self, turn: int) -> None:
450+
self.memory = [memory for memory in self.memory if memory.turn != turn]
451+
logger.debug(f"Memories from turn {turn} have been removed from global memory.")
452+
435453
def get_discussion_history(
436454
self,
437455
context_length: Optional[int] = None,

mallm/decision_protocol/consensus_voting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def make_decision(
4242
) -> tuple[str, bool, list[Agreement], str, Optional[VotingResultList]]:
4343
if len(agreements) > self.total_agents:
4444
agreements = agreements[-self.total_agents :]
45+
4546
if agent_index != self.total_agents - 1:
4647
return "", False, agreements, "", None
4748

@@ -166,6 +167,10 @@ def process_votes(
166167
) -> tuple[str, Any, bool, str]:
167168
success = False
168169
vote_int = int("".join([x for x in vote_str if x.isnumeric()]))
170+
171+
# if len(final_answers) == 1: # TODO: Add this in a future PR
172+
# vote_int = 0 # If there is only one answer, the agent must vote for it
173+
169174
if 0 <= vote_int < len(final_answers):
170175
vote.append(vote_int)
171176
logger.info(

0 commit comments

Comments
 (0)