22import json
33import os
44from pathlib import Path
5+ from typing import Optional , Any , Union
56
67import matplotlib .pyplot as plt
78import pandas as pd
1617# Define a beautiful color palette
1718COLORS = ['#FF6B6B' , '#4ECDC4' , '#45B7D1' , '#96CEB4' , '#FFEAA7' , '#DDA0DD' , '#98D8C8' ]
1819
19- def get_colors (n_colors ) :
20+ def get_colors (n_colors : int ) -> Union [ list [ str ], np . ndarray [ Any , Any ]] :
2021 """Generate enough colors for n_colors by cycling or using colormap"""
2122 if n_colors <= len (COLORS ):
2223 return COLORS [:n_colors ]
2324 else :
2425 # Use a colormap for more colors
25- return plt .cm .Set3 (np .linspace (0 , 1 , n_colors ))
26+ return plt .cm .Set3 (np .linspace (0 , 1 , n_colors )) # type: ignore
2627
2728
28- def get_consistent_color_mapping (options ) :
29+ def get_consistent_color_mapping (options : list [ str ]) -> dict [ str , Any ] :
2930 """Create consistent color mapping based on option names"""
3031 # Sort options to ensure consistent assignment
3132 unique_options = sorted (set (options ))
@@ -93,7 +94,7 @@ def aggregate_data(
9394 return eval_df , stats_df
9495
9596
96- def plot_turns_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
97+ def plot_turns_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : Optional [ dict [ str , Any ]] = None ) -> None :
9798 """Create a beautiful violin plot for turns distribution"""
9899 # Filter out rows with missing or invalid turns data
99100 df = df .dropna (subset = ['turns' ])
@@ -114,7 +115,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping:
114115
115116 # Use global color mapping if provided, otherwise create local one
116117 if global_color_mapping is None :
117- color_mapping = get_consistent_color_mapping (grouped_data ['option' ].unique ())
118+ color_mapping = get_consistent_color_mapping (grouped_data ['option' ].unique (). tolist () )
118119 else :
119120 color_mapping = global_color_mapping
120121
@@ -175,7 +176,7 @@ def plot_turns_with_std(df: pd.DataFrame, input_path: str, global_color_mapping:
175176 plt .close ()
176177
177178
178- def plot_clock_seconds_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
179+ def plot_clock_seconds_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : Optional [ dict [ str , Any ]] = None ) -> None :
179180 """Create a beautiful horizontal lollipop chart for clock seconds"""
180181 grouped = (
181182 df .groupby (["option" , "dataset" ])["clockSeconds" ]
@@ -187,7 +188,7 @@ def plot_clock_seconds_with_std(df: pd.DataFrame, input_path: str, global_color_
187188 grouped ['label' ] = unique_labels
188189
189190 # Sort data: baselines first, then others by shortest time
190- def sort_key (row ) :
191+ def sort_key (row : pd . Series ) -> tuple [ int , float ] :
191192 option = row ['option' ].lower ()
192193 if option .startswith ('baseline' ):
193194 return (0 , row ['mean' ]) # Baselines first, sorted by time
@@ -206,7 +207,7 @@ def sort_key(row):
206207
207208 # Use global color mapping if provided, otherwise create local one
208209 if global_color_mapping is None :
209- color_mapping = get_consistent_color_mapping (grouped ['option' ].unique ())
210+ color_mapping = get_consistent_color_mapping (grouped ['option' ].unique (). tolist () )
210211 else :
211212 color_mapping = global_color_mapping
212213 colors = [color_mapping [option ] for option in grouped ['option' ]]
@@ -252,7 +253,7 @@ def sort_key(row):
252253 plt .close ()
253254
254255
255- def plot_decision_success_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
256+ def plot_decision_success_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : Optional [ dict [ str , Any ]] = None ) -> None :
256257 """Create a beautiful horizontal bar chart for decision success rates"""
257258 if "decisionSuccess" not in df .columns :
258259 print (
@@ -283,7 +284,7 @@ def plot_decision_success_with_std(df: pd.DataFrame, input_path: str, global_col
283284
284285 # Use global color mapping if provided, otherwise create local one
285286 if global_color_mapping is None :
286- color_mapping = get_consistent_color_mapping (grouped ['option' ].unique ())
287+ color_mapping = get_consistent_color_mapping (grouped ['option' ].unique (). tolist () )
287288 else :
288289 color_mapping = global_color_mapping
289290 colors = [color_mapping [option ] for option in grouped ['option' ]]
@@ -364,38 +365,41 @@ def get_unique_labels(df: pd.DataFrame) -> list[str]:
364365 return unique_labels
365366
366367
367- def get_unique_labels_from_conditions (conditions ) -> list [str ]:
368+ def get_unique_labels_from_conditions (conditions : Union [ list [ str ], np . ndarray [ Any , Any ]] ) -> list [str ]:
368369 """Helper function to get unique labels from condition strings"""
369370 # Convert to list if it's a numpy array
371+ condition_list : list [str ]
370372 if hasattr (conditions , 'tolist' ):
371- conditions = conditions .tolist ()
373+ condition_list = conditions .tolist ()
374+ else :
375+ condition_list = conditions
372376
373- if len (conditions ) == 0 :
377+ if len (condition_list ) == 0 :
374378 return []
375379
376380 # Find the longest common prefix
377381 common_prefix = ""
378- if len (conditions ) > 0 :
379- first_condition = conditions [0 ]
382+ if len (condition_list ) > 0 :
383+ first_condition = condition_list [0 ]
380384 for i in range (len (first_condition )):
381- if all (condition .startswith (first_condition [:i + 1 ]) for condition in conditions ):
385+ if all (condition .startswith (first_condition [:i + 1 ]) for condition in condition_list ):
382386 common_prefix = first_condition [:i + 1 ]
383387 else :
384388 break
385389
386390 # Find the longest common suffix
387391 common_suffix = ""
388- if len (conditions ) > 0 :
389- first_condition = conditions [0 ]
392+ if len (condition_list ) > 0 :
393+ first_condition = condition_list [0 ]
390394 for i in range (len (first_condition )):
391- if all (condition .endswith (first_condition [- (i + 1 ):]) for condition in conditions ):
395+ if all (condition .endswith (first_condition [- (i + 1 ):]) for condition in condition_list ):
392396 common_suffix = first_condition [- (i + 1 ):]
393397 else :
394398 break
395399
396400 # Extract unique parts by removing common prefix and suffix
397401 unique_labels = []
398- for condition in conditions :
402+ for condition in condition_list :
399403 unique_part = condition
400404 if common_prefix and condition .startswith (common_prefix ):
401405 unique_part = unique_part [len (common_prefix ):]
@@ -406,7 +410,7 @@ def get_unique_labels_from_conditions(conditions) -> list[str]:
406410 return unique_labels
407411
408412
409- def plot_score_distributions_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : dict = None ) -> None :
413+ def plot_score_distributions_with_std (df : pd .DataFrame , input_path : str , global_color_mapping : Optional [ dict [ str , Any ]] = None ) -> None :
410414 """Create beautiful enhanced bar charts for score distributions"""
411415 print ("Shape of stats_df:" , df .shape )
412416 print ("Columns in stats_df:" , df .columns )
@@ -445,7 +449,7 @@ def plot_score_distributions_with_std(df: pd.DataFrame, input_path: str, global_
445449 score_data = grouped [grouped ["Score Type" ] == score_type ].copy ()
446450
447451 # Sort data: baselines first, then alphabetically
448- def sort_key (row ) :
452+ def sort_key (row : pd . Series ) -> tuple [ int , str ] :
449453 option = row ['option' ].lower ()
450454 if option .startswith ('baseline' ):
451455 return (0 , option ) # Baselines first
@@ -465,7 +469,7 @@ def sort_key(row):
465469
466470 # Use global color mapping if provided, otherwise create local one
467471 if global_color_mapping is None :
468- color_mapping = get_consistent_color_mapping (score_data ['option' ].unique ())
472+ color_mapping = get_consistent_color_mapping (score_data ['option' ].unique (). tolist () )
469473 else :
470474 color_mapping = global_color_mapping
471475 colors = [color_mapping [option ] for option in score_data ['option' ]]
0 commit comments