Skip to content

Commit 9198ffb

Browse files
committed
mypy fixes
1 parent d0a018c commit 9198ffb

4 files changed

Lines changed: 16 additions & 16 deletions

File tree

mallm/coordinator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def init_agents(
146146
self.judge = None
147147
if judge_intervention and self.judge_llm:
148148
# Lazy import to avoid heavy evaluation dependencies when judge is not used
149-
from mallm.agents.judge import Judge # type: ignore # noqa: PLC0415
149+
from mallm.agents.judge import Judge # noqa: PLC0415
150150

151151
self.judge = Judge(
152152
self.judge_llm,

mallm/evaluation/evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from typing import Any, Optional
55

66
try:
7-
import fire # type: ignore
7+
import fire
88
except Exception: # pragma: no cover - optional dependency
9-
fire = None # type: ignore[assignment]
9+
fire = None
1010
import json_repair
1111
from tqdm import tqdm
1212

@@ -433,7 +433,7 @@ def run_evaluator(
433433

434434
def main() -> None:
435435
if fire is not None:
436-
fire.Fire(run_evaluator) # type: ignore[attr-defined]
436+
fire.Fire(run_evaluator)
437437
else:
438438
print("Fire is not available. Please call run_evaluator(...) programmatically.")
439439

mallm/evaluation/plotting/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping:
123123
colors = [color_mapping[option] for option in grouped_data['option']]
124124

125125
# Expand the grouped data back to individual rows for violin plot
126-
expanded_data = []
126+
expanded_data: list[dict[str, Any]] = []
127127
for _i, row in grouped_data.iterrows():
128128
expanded_data.extend({
129129
'option': row['option'],

mallm/scheduler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from multiprocessing.pool import ThreadPool
1313
from pathlib import Path
1414
from threading import Lock
15-
from typing import Any, Optional
15+
from typing import Any, Optional, cast
1616

1717
try:
18-
import fire # type: ignore
18+
import fire
1919
except Exception: # pragma: no cover - optional dependency
20-
fire = None # type: ignore[assignment]
20+
fire = None
2121
import httpx
2222
import langchain
2323
import langchain_core
@@ -29,7 +29,7 @@
2929
try:
3030
from datasets import load_dataset
3131
except Exception: # pragma: no cover - optional for file-only runs
32-
load_dataset = None # type: ignore[assignment]
32+
load_dataset = None
3333
from openai import OpenAI
3434
from rich import print # noqa: A004
3535
from rich.logging import RichHandler
@@ -39,9 +39,9 @@
3939
from sentence_transformers import SentenceTransformer
4040
from sentence_transformers.util import cos_sim
4141
except Exception: # pragma: no cover - fallback for lightweight/mock runs
42-
SentenceTransformer = None # type: ignore[assignment]
42+
SentenceTransformer = None
4343

44-
def cos_sim(a: Any, b: Any) -> Any: # type: ignore[no-redef]
44+
def cos_sim(a: Any, b: Any) -> Any:
4545
return [[1.0]]
4646

4747
try:
@@ -385,7 +385,7 @@ def worker_paraphrase_function(
385385
# Acquire the lock before using the model
386386
if paraphrase_model is not None:
387387
with paraphrase_lock:
388-
return paraphrase_model.encode(input_data) # type: ignore[attr-defined]
388+
return cast(list[Any], paraphrase_model.encode(input_data))
389389
# Fallback lightweight deterministic embeddings
390390
return [[1.0, 0.0] for _ in input_data]
391391

@@ -402,14 +402,14 @@ def worker_persona_diversity_function(
402402
# Acquire the lock before using the model
403403
if all_model is not None and SentenceTransformer is not None:
404404
with persona_diversity_lock:
405-
embeddings = all_model.encode(input_data, convert_to_tensor=True) # type: ignore[attr-defined]
405+
embeddings = all_model.encode(input_data, convert_to_tensor=True)
406406
cos_sims = cos_sim(embeddings, embeddings)
407-
similarities = [
407+
similarities: list[float] = [
408408
cos_sims[i][j].item()
409409
for i in range(len(input_data))
410410
for j in range(i)
411411
]
412-
persona_diversity = sum(similarities) / len(similarities)
412+
persona_diversity: float = float(sum(similarities) / len(similarities)) if similarities else 0.0
413413
return round(persona_diversity, 4)
414414
# Fallback: neutral value
415415
return 0.0
@@ -720,7 +720,7 @@ def main() -> None:
720720
if fire is None:
721721
print("Fire is not available. Please run via batch_mallm.py or provide a Config programmatically.")
722722
return
723-
config = fire.Fire(Config, serialize=print) # type: ignore[attr-defined]
723+
config = fire.Fire(Config, serialize=print)
724724
print("\n" + "=" * width)
725725
print("END OF CONFIGURATION PARAMETERS".center(width))
726726
print("=" * width + "\n")

0 commit comments

Comments
 (0)