Skip to content

Commit 13d8088

Browse files
committed
fix: add types for plots
1 parent f6c8006 commit 13d8088

1 file changed

Lines changed: 27 additions & 23 deletions

File tree

mallm/evaluation/plotting/plots.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
from pathlib import Path
5+
from typing import Optional, Any, Union
56

67
import matplotlib.pyplot as plt
78
import pandas as pd
@@ -16,16 +17,16 @@
1617
# Define a beautiful color palette
1718
COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8']
1819

19-
def get_colors(n_colors):
20+
def get_colors(n_colors: int) -> Union[list[str], np.ndarray[Any, Any]]:
2021
"""Generate enough colors for n_colors by cycling or using colormap"""
2122
if n_colors <= len(COLORS):
2223
return COLORS[:n_colors]
2324
else:
2425
# Use a colormap for more colors
25-
return plt.cm.Set3(np.linspace(0, 1, n_colors))
26+
return plt.cm.Set3(np.linspace(0, 1, n_colors)) # type: ignore
2627

2728

28-
def get_consistent_color_mapping(options):
29+
def get_consistent_color_mapping(options: list[str]) -> dict[str, Any]:
2930
"""Create consistent color mapping based on option names"""
3031
# Sort options to ensure consistent assignment
3132
unique_options = sorted(set(options))
@@ -93,7 +94,7 @@ def aggregate_data(
9394
return eval_df, stats_df
9495

9596

96-
def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
97+
def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: Optional[dict[str, Any]] = None) -> None:
9798
"""Create a beautiful violin plot for turns distribution"""
9899
# Filter out rows with missing or invalid turns data
99100
df = df.dropna(subset=['turns'])
@@ -114,7 +115,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping:
114115

115116
# Use global color mapping if provided, otherwise create local one
116117
if global_color_mapping is None:
117-
color_mapping = get_consistent_color_mapping(grouped_data['option'].unique())
118+
color_mapping = get_consistent_color_mapping(grouped_data['option'].unique().tolist())
118119
else:
119120
color_mapping = global_color_mapping
120121

@@ -175,7 +176,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping:
175176
plt.close()
176177

177178

178-
def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
179+
def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: Optional[dict[str, Any]] = None) -> None:
179180
"""Create a beautiful horizontal lollipop chart for clock seconds"""
180181
grouped = (
181182
df.groupby(["option", "dataset"])["clockSeconds"]
@@ -187,7 +188,7 @@ def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str, global_color_
187188
grouped['label'] = unique_labels
188189

189190
# Sort data: baselines first, then others by shortest time
190-
def sort_key(row):
191+
def sort_key(row: pd.Series) -> tuple[int, float]:
191192
option = row['option'].lower()
192193
if option.startswith('baseline'):
193194
return (0, row['mean']) # Baselines first, sorted by time
@@ -206,7 +207,7 @@ def sort_key(row):
206207

207208
# Use global color mapping if provided, otherwise create local one
208209
if global_color_mapping is None:
209-
color_mapping = get_consistent_color_mapping(grouped['option'].unique())
210+
color_mapping = get_consistent_color_mapping(grouped['option'].unique().tolist())
210211
else:
211212
color_mapping = global_color_mapping
212213
colors = [color_mapping[option] for option in grouped['option']]
@@ -252,7 +253,7 @@ def sort_key(row):
252253
plt.close()
253254

254255

255-
def plot_decision_success_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
256+
def plot_decision_success_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: Optional[dict[str, Any]] = None) -> None:
256257
"""Create a beautiful horizontal bar chart for decision success rates"""
257258
if "decisionSuccess" not in df.columns:
258259
print(
@@ -283,7 +284,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str, global_col
283284

284285
# Use global color mapping if provided, otherwise create local one
285286
if global_color_mapping is None:
286-
color_mapping = get_consistent_color_mapping(grouped['option'].unique())
287+
color_mapping = get_consistent_color_mapping(grouped['option'].unique().tolist())
287288
else:
288289
color_mapping = global_color_mapping
289290
colors = [color_mapping[option] for option in grouped['option']]
@@ -364,38 +365,41 @@ def get_unique_labels(df: pd.DataFrame) -> list[str]:
364365
return unique_labels
365366

366367

367-
def get_unique_labels_from_conditions(conditions) -> list[str]:
368+
def get_unique_labels_from_conditions(conditions: Union[list[str], np.ndarray[Any, Any]]) -> list[str]:
368369
"""Helper function to get unique labels from condition strings"""
369370
# Convert to list if it's a numpy array
371+
condition_list: list[str]
370372
if hasattr(conditions, 'tolist'):
371-
conditions = conditions.tolist()
373+
condition_list = conditions.tolist()
374+
else:
375+
condition_list = conditions
372376

373-
if len(conditions) == 0:
377+
if len(condition_list) == 0:
374378
return []
375379

376380
# Find the longest common prefix
377381
common_prefix = ""
378-
if len(conditions) > 0:
379-
first_condition = conditions[0]
382+
if len(condition_list) > 0:
383+
first_condition = condition_list[0]
380384
for i in range(len(first_condition)):
381-
if all(condition.startswith(first_condition[:i + 1]) for condition in conditions):
385+
if all(condition.startswith(first_condition[:i + 1]) for condition in condition_list):
382386
common_prefix = first_condition[:i + 1]
383387
else:
384388
break
385389

386390
# Find the longest common suffix
387391
common_suffix = ""
388-
if len(conditions) > 0:
389-
first_condition = conditions[0]
392+
if len(condition_list) > 0:
393+
first_condition = condition_list[0]
390394
for i in range(len(first_condition)):
391-
if all(condition.endswith(first_condition[-(i + 1):]) for condition in conditions):
395+
if all(condition.endswith(first_condition[-(i + 1):]) for condition in condition_list):
392396
common_suffix = first_condition[-(i + 1):]
393397
else:
394398
break
395399

396400
# Extract unique parts by removing common prefix and suffix
397401
unique_labels = []
398-
for condition in conditions:
402+
for condition in condition_list:
399403
unique_part = condition
400404
if common_prefix and condition.startswith(common_prefix):
401405
unique_part = unique_part[len(common_prefix):]
@@ -406,7 +410,7 @@ def get_unique_labels_from_conditions(conditions) -> list[str]:
406410
return unique_labels
407411

408412

409-
def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: dict = None) -> None:
413+
def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str, global_color_mapping: Optional[dict[str, Any]] = None) -> None:
410414
"""Create beautiful enhanced bar charts for score distributions"""
411415
print("Shape of stats_df:", df.shape)
412416
print("Columns in stats_df:", df.columns)
@@ -445,7 +449,7 @@ def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str, global_
445449
score_data = grouped[grouped["Score Type"] == score_type].copy()
446450

447451
# Sort data: baselines first, then alphabetically
448-
def sort_key(row):
452+
def sort_key(row: pd.Series) -> tuple[int, str]:
449453
option = row['option'].lower()
450454
if option.startswith('baseline'):
451455
return (0, option) # Baselines first
@@ -465,7 +469,7 @@ def sort_key(row):
465469

466470
# Use global color mapping if provided, otherwise create local one
467471
if global_color_mapping is None:
468-
color_mapping = get_consistent_color_mapping(score_data['option'].unique())
472+
color_mapping = get_consistent_color_mapping(score_data['option'].unique().tolist())
469473
else:
470474
color_mapping = global_color_mapping
471475
colors = [color_mapping[option] for option in score_data['option']]

0 commit comments

Comments
 (0)