Skip to content

Commit 8c4838b

Browse files
authored
Merge pull request #144 from Multi-Agent-LLMs/feat/challenge-results
Feat/challenge results
2 parents 12d6ae1 + 15d072e commit 8c4838b

5 files changed

Lines changed: 281 additions & 25 deletions

File tree

mallm/coordinator.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from mallm.utils.types import (
2929
Agreement,
30+
ChallengeResult,
3031
InputExample,
3132
Memory,
3233
VotingResultList,
@@ -200,7 +201,7 @@ def discuss(
200201
float,
201202
bool,
202203
dict[int, Optional[VotingResultList]],
203-
dict[str, Optional[str]],
204+
ChallengeResult,
204205
]:
205206
"""
206207
The routine responsible for the discussion between agents to solve a task.
@@ -285,26 +286,50 @@ def discuss(
285286
)
286287
)
287288

288-
challenged_answers: dict[str, Optional[str]] = {}
289+
challenged_answers: ChallengeResult = ChallengeResult(
290+
answer or "No answer was provided."
291+
)
289292
if config.challenge_final_results:
290293
logger.info("Challenging final results...")
291-
for panelist in self.panelists:
292-
challenge_result = panelist.llm.invoke(
293-
panelist.response_generator.generate_challenge_prompt(
294-
panelist,
295-
input_str,
296-
sample_instruction,
297-
(answer or "No answer was provided."),
298-
)
294+
challenged_answers.additional_information = (
295+
worker_functions.worker_context_function(input_str)
296+
)
297+
challenged_answers.wrong_answer = self.llm.invoke(
298+
self.response_generator.generate_wrong_answer_prompt(
299+
sample_instruction, input_str
299300
)
300-
if "agree" in challenge_result.lower():
301-
logger.info(f"{panelist.persona} agrees with the final result.")
302-
challenged_answers[panelist.id] = None
303-
else:
304-
logger.info(
305-
f"{panelist.persona} disagrees with the final result and proposes a new solution:\n{challenge_result}"
306-
)
307-
challenged_answers[panelist.id] = challenge_result
301+
)
302+
challenged_answers.irrelevant_answer = "I) I don't know."
303+
304+
challenged_answers.challenged_answers = self.challenge_solution(
305+
answer, input_str, sample_instruction, None, False
306+
)
307+
challenged_answers.challenged_answers_wrong = self.challenge_solution(
308+
challenged_answers.wrong_answer,
309+
input_str,
310+
sample_instruction,
311+
None,
312+
False,
313+
)
314+
challenged_answers.challenged_answers_irrelevant = self.challenge_solution(
315+
challenged_answers.irrelevant_answer,
316+
input_str,
317+
sample_instruction,
318+
None,
319+
False,
320+
)
321+
challenged_answers.challenged_answers_history = self.challenge_solution(
322+
answer, input_str, sample_instruction, None, True
323+
)
324+
challenged_answers.challenged_answers_additional_information = (
325+
self.challenge_solution(
326+
answer,
327+
input_str,
328+
sample_instruction,
329+
challenged_answers.additional_information,
330+
False,
331+
)
332+
)
308333

309334
discussion_time = timedelta(
310335
seconds=time.perf_counter() - start_time
@@ -326,6 +351,49 @@ def discuss(
326351
challenged_answers,
327352
)
328353

354+
def challenge_solution(
355+
self,
356+
answer: Optional[str],
357+
input_str: str,
358+
sample_instruction: str,
359+
additional_information: Optional[str],
360+
history: bool,
361+
) -> dict[str, Optional[str]]:
362+
challenged_answers: dict[str, Optional[str]] = {}
363+
for panelist in self.panelists:
364+
agreement = panelist.llm.invoke(
365+
panelist.response_generator.generate_challenge_prompt(
366+
panelist,
367+
input_str,
368+
sample_instruction,
369+
(answer or "No answer was provided."),
370+
history,
371+
additional_information,
372+
)
373+
)
374+
if "disagree" in agreement.lower():
375+
challenge_result = panelist.llm.invoke(
376+
panelist.response_generator.generate_challenge_new_answer_prompt(
377+
panelist,
378+
input_str,
379+
sample_instruction,
380+
(answer or "No answer was provided."),
381+
history,
382+
additional_information,
383+
)
384+
)
385+
logger.info(
386+
f"{panelist.persona} disagrees with the final result and proposes a new solution:\n{challenge_result}"
387+
)
388+
challenged_answers[panelist.id] = challenge_result
389+
elif "agree" in agreement.lower():
390+
logger.info(f"{panelist.persona} agrees with the final result.")
391+
challenged_answers[panelist.id] = None
392+
else:
393+
logger.info(f"{panelist.persona} failed to challenge the final result.")
394+
challenged_answers[panelist.id] = None
395+
return challenged_answers
396+
329397
def get_memories(
330398
self,
331399
context_length: Optional[int] = None,

mallm/evaluation/evaluator.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def calculate_scores(
114114

115115
def add_scores(self) -> None:
116116
for item in tqdm(self.data, desc=f"Calculating scores of {self.input_file_path}: "):
117-
answer = item.get("finalAnswer", "")
117+
main_answer = item.get("finalAnswer", "")
118118
references = item.get("references", [])
119119
dataset_id = item.get("datasetId", None)
120-
if answer:
121-
item["scores"] = self.calculate_scores(answer, references, "", dataset_id)
120+
if main_answer:
121+
item["scores"] = self.calculate_scores(main_answer, references, "", dataset_id)
122+
122123
votes_each_turn = item.get("votesEachTurn", None)
123124
if votes_each_turn:
124125
alterations: dict[str, Any] = votes_each_turn[
@@ -136,6 +137,90 @@ def add_scores(self) -> None:
136137
self.calculate_scores(answer, references, alteration)
137138
)
138139

140+
challenged_answers: Any = item.get("challengedAnswers", None)
141+
if challenged_answers:
142+
if "scores" not in item:
143+
continue
144+
if "correct" not in item["scores"] and "f1" not in item["scores"]:
145+
continue
146+
if challenged_answers["challenged_answers"]:
147+
self.analyze_challenged_answers(
148+
"normal",
149+
challenged_answers["challenged_answers"],
150+
item,
151+
references,
152+
item["scores"],
153+
)
154+
if challenged_answers["challenged_answers_wrong"]:
155+
self.analyze_challenged_answers(
156+
"wrong",
157+
challenged_answers["challenged_answers_wrong"],
158+
item,
159+
references,
160+
self.calculate_scores(
161+
challenged_answers["wrong_answer"], references
162+
),
163+
)
164+
if challenged_answers["challenged_answers_irrelevant"]:
165+
self.analyze_challenged_answers(
166+
"irrelevant",
167+
challenged_answers["challenged_answers_irrelevant"],
168+
item,
169+
references,
170+
self.calculate_scores(
171+
challenged_answers["irrelevant_answer"], references
172+
),
173+
)
174+
if challenged_answers["challenged_answers_history"]:
175+
self.analyze_challenged_answers(
176+
"history",
177+
challenged_answers["challenged_answers_history"],
178+
item,
179+
references,
180+
item["scores"],
181+
)
182+
if challenged_answers["challenged_answers_additional_information"]:
183+
self.analyze_challenged_answers(
184+
"information",
185+
challenged_answers["challenged_answers_additional_information"],
186+
item,
187+
references,
188+
item["scores"],
189+
)
190+
191+
def analyze_challenged_answers(
192+
self,
193+
name: str,
194+
challenged_answers: dict[str, Optional[str]],
195+
item: Any,
196+
references: list[str],
197+
previous_score: Any,
198+
) -> None:
199+
new_answer = {
200+
f"{name}_no_challenge": True,
201+
f"{name}_challenge_failed": False,
202+
f"{name}_challenge_higher": False,
203+
f"{name}_challenge_lower": False,
204+
f"{name}_challenge_same": False,
205+
}
206+
previous_score = previous_score.get("f1", None) or previous_score.get(
207+
"correct", None
208+
)
209+
answer = next(iter(challenged_answers.values()))
210+
if answer:
211+
score = self.calculate_scores(answer, references)
212+
current_score = score.get("f1", None) or score.get("correct", None)
213+
if current_score is None or previous_score is None:
214+
new_answer[f"{name}_challenge_failed"] = True
215+
elif current_score > previous_score:
216+
new_answer[f"{name}_challenge_higher"] = True
217+
elif current_score < previous_score:
218+
new_answer[f"{name}_challenge_lower"] = True
219+
elif current_score == previous_score:
220+
new_answer[f"{name}_challenge_same"] = True
221+
new_answer[f"{name}_no_challenge"] = False
222+
item["scores"].update(new_answer)
223+
139224
def add_scores_extensive(self) -> None:
140225
for item in tqdm(self.data, desc="Extensive scores: "):
141226
references = item.get("references", [])

mallm/models/discussion/ResponseGenerator.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,21 +253,85 @@ def generate_challenge_prompt(
253253
question: str,
254254
task: str,
255255
final_answer: str,
256+
history: bool = False,
257+
facts: Optional[str] = None,
256258
) -> list[dict[str, str]]:
257-
return [
259+
agent_history = panelist.get_discussion_history()[0] if history else []
260+
prompts = [
258261
{
259262
"role": "system",
260263
"content": f"You are a participant in a group discussion. Your role: {panelist.persona} ({panelist.persona_description})",
261-
},
264+
}
265+
]
266+
if history and agent_history:
267+
prompts.append(
268+
{
269+
"role": "system",
270+
"content": "This is the discussion to the current point:",
271+
}
272+
)
273+
prompts.extend(agent_history)
274+
if facts:
275+
prompts.append(
276+
{
277+
"role": "system",
278+
"content": f"Here is some helpful additional information to improve your answer quality: {facts}",
279+
}
280+
)
281+
prompts.append(
262282
{
263283
"role": "user",
264284
"content": (
265285
f"The task is: {task}. The question is: {question}. "
266286
f"This is the final answer generated by the discussion: '{final_answer}'. "
267-
"Please critically evaluate this answer. If you do not agree, provide a new solution based on the task and question. If you agree with the final answer, respond with the exact word 'AGREE' to confirm."
287+
"Please critically evaluate this answer. If you agree with the final answer, respond with the exact word 'AGREE' to confirm. If you do not agree, respond with the exact word 'DISAGREE' to challenge the answer."
268288
),
269289
},
290+
)
291+
return prompts
292+
293+
@staticmethod
294+
def generate_challenge_new_answer_prompt(
295+
panelist: Panelist,
296+
question: str,
297+
task: str,
298+
final_answer: str,
299+
history: bool = False,
300+
facts: Optional[str] = None,
301+
) -> list[dict[str, str]]:
302+
agent_history = panelist.get_discussion_history()[0] if history else []
303+
prompts = [
304+
{
305+
"role": "system",
306+
"content": f"You are a participant in a group discussion. Your role: {panelist.persona} ({panelist.persona_description})",
307+
}
270308
]
309+
if history and agent_history:
310+
prompts.append(
311+
{
312+
"role": "system",
313+
"content": "This is the discussion to the current point:",
314+
}
315+
)
316+
prompts.extend(agent_history)
317+
if facts:
318+
prompts.append(
319+
{
320+
"role": "system",
321+
"content": f"Here is some helpful additional information to improve your answer quality: {facts}",
322+
}
323+
)
324+
prompts.append(
325+
{
326+
"role": "user",
327+
"content": (
328+
f"The task is: {task}. The question is: {question}. "
329+
f"This is the final answer generated by the discussion: '{final_answer}'. "
330+
"You dont agree with the final answer. Please provide a new answer to the question. Include the letter corresponding to your answer in the solution."
331+
),
332+
},
333+
)
334+
return prompts
271335

272336
@staticmethod
273337
def voting_base_prompt(
@@ -520,3 +584,29 @@ def generate_summary_prompt(
520584

521585
# Return the prompts list
522586
return prompts
587+
588+
@staticmethod
589+
def generate_wrong_answer_prompt(task: str, question: str) -> list[dict[str, str]]:
590+
return [
591+
{
592+
"role": "system",
593+
"content": "You are tasked with providing an incorrect or wrong response to the given task and question.",
594+
},
595+
{
596+
"role": "user",
597+
"content": f"Task: {task}\nQuestion: {question}. Please provide an answer that is deliberately incorrect or inaccurate. Only answer with the incorrect response.",
598+
},
599+
]
600+
601+
@staticmethod
602+
def generate_irrelevant_answer_prompt(question: str) -> list[dict[str, str]]:
603+
return [
604+
{
605+
"role": "system",
606+
"content": "You are tasked with providing a completely unrelated response to the given question.",
607+
},
608+
{
609+
"role": "user",
610+
"content": f"Question: {question} \n\nPlease provide an answer that is irrelevant to the question. Only answer with the irrelevant response.",
611+
},
612+
]

mallm/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def run_discussion(
233233
for voting_round, voting_result in voting_results_per_turn.items()
234234
if voting_result is not None
235235
},
236-
"challengedAnswers": challenged_answers,
236+
"challengedAnswers": dataclasses.asdict(challenged_answers),
237237
"references": sample.references,
238238
"metadata": sample.metadata,
239239
"decisionSuccess": decision_success,

mallm/utils/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ class VotingResultList:
6262
alterations: dict[str, VotingResult]
6363

6464

65+
@dataclass
66+
class ChallengeResult:
67+
answer: str
68+
additional_information: Optional[str] = None
69+
wrong_answer: Optional[str] = None
70+
irrelevant_answer: Optional[str] = None
71+
challenged_answers: Optional[dict[str, Optional[str]]] = None
72+
challenged_answers_history: Optional[dict[str, Optional[str]]] = None
73+
challenged_answers_wrong: Optional[dict[str, Optional[str]]] = None
74+
challenged_answers_irrelevant: Optional[dict[str, Optional[str]]] = None
75+
challenged_answers_additional_information: Optional[dict[str, Optional[str]]] = None
76+
77+
6578
@dataclass
6679
class InputExample:
6780
example_id: str

0 commit comments

Comments
 (0)