77from typing import Optional
88
99import httpx
10- from rich .progress import Console # type: ignore
10+ from rich .progress import Console
1111
1212from mallm .agents .agent import Agent
1313from mallm .agents .draftProposer import DraftProposer
14+ from mallm .agents .judge import Judge
1415from mallm .agents .panelist import Panelist
15- from mallm .agents .policyFeedback import PolicyFeedback
1616from mallm .decision_protocol .protocol import DecisionProtocol
1717from mallm .discourse_policy .policy import DiscoursePolicy
1818from mallm .models .Chat import Chat
@@ -50,9 +50,9 @@ def __init__(
5050 model : Chat ,
5151 client : httpx .Client ,
5252 agent_generators : Optional [list [str ]] = None ,
53- policy : Optional [str ] = None ,
5453 num_neutral_agents : int = 0 ,
5554 console : Optional [Console ] = None ,
55+ judge_model : Optional [Chat ] = None ,
5656 ):
5757 if agent_generators is None :
5858 agent_generators = ["expert" , "expert" , "expert" ]
@@ -68,9 +68,9 @@ def __init__(
6868 self .response_generator : ResponseGenerator = SimpleResponseGenerator (self .llm )
6969 self .client = client
7070 self .agent_generators = agent_generators
71- self .policy = policy
7271 self .memory : list [Memory ] = []
7372 self .console = console or Console ()
73+ self .judge_llm = judge_model
7474
7575 def init_agents (
7676 self ,
@@ -80,14 +80,16 @@ def init_agents(
8080 num_agents : int ,
8181 chain_of_thought : bool ,
8282 sample : InputExample ,
83+ judge_intervention : Optional [str ] = None ,
84+ judge_metric : Optional [str ] = None ,
8385 ) -> None :
8486 """
8587 Instantiates the agents by
8688 1) identify helpful personas depending on the agent_generator
8789 2) create agents with the personas
8890 """
8991 logger .debug (
90- f"Coordinator { self .id } creates { num_agents } agents ({ self .agent_generators } ). Policy: { self . policy } "
92+ f"Coordinator { self .id } creates { num_agents } agents ({ self .agent_generators } )."
9193 )
9294 self .panelists = []
9395 self .agents = []
@@ -142,16 +144,22 @@ def init_agents(
142144 "Created only 1 agent. The discussion will be replaced by a self-improvement mechanism."
143145 )
144146
145- if self .policy :
146- policyFeedback = PolicyFeedback (
147- self .llm ,
147+ self .judge = None
148+ if judge_intervention and self .judge_llm :
149+ self .judge = Judge (
150+ self .judge_llm ,
148151 self .client ,
149152 self ,
150153 response_generator = self .response_generator ,
151- persona = "Policy Moderator" ,
152- policy = self .policy ,
154+ persona = "Judge" ,
155+ persona_description = "Responsible for evaluating the solutions and providing feedback to the agents." ,
156+ metric = str (judge_metric ),
157+ chain_of_thought = False ,
158+ drafting_agent = False ,
159+ intervention_type = judge_intervention ,
160+ references = sample .references ,
153161 )
154- self .agents .append (policyFeedback )
162+ self .agents .append (self . judge )
155163
156164 def get_agents (
157165 self , config : Config , worker_functions : WorkerFunctions
@@ -202,6 +210,8 @@ def discuss(
202210 bool ,
203211 dict [int , Optional [VotingResultList ]],
204212 ChallengeResult ,
213+ Optional [list [Optional [bool ]]],
214+ Optional [list [str ]],
205215 ]:
206216 """
207217 The routine responsible for the discussion between agents to solve a task.
@@ -241,6 +251,8 @@ def discuss(
241251 num_agents = config .num_agents ,
242252 chain_of_thought = config .use_chain_of_thought ,
243253 sample = sample ,
254+ judge_intervention = config .judge_intervention ,
255+ judge_metric = config .judge_metric ,
244256 )
245257
246258 if config .decision_protocol not in DECISION_PROTOCOLS :
@@ -349,6 +361,8 @@ def discuss(
349361 decision_success ,
350362 voting_results_per_turn ,
351363 challenged_answers ,
364+ self .judge .judgements if self .judge else None ,
365+ self .judge .judged_solutions if self .judge else None ,
352366 )
353367
354368 def challenge_solution (
@@ -432,6 +446,10 @@ def get_memories(
432446
433447 return context_memory , memory_ids , current_draft
434448
449+ def forget_memories (self , turn : int ) -> None :
450+ self .memory = [memory for memory in self .memory if memory .turn != turn ]
451+ logger .debug (f"Memories from turn { turn } have been removed from global memory." )
452+
435453 def get_discussion_history (
436454 self ,
437455 context_length : Optional [int ] = None ,
0 commit comments