3838}
3939
4040
41+ def _resolve_consensus_provider (
42+ consensus_model : dict [str , str ] | None ,
43+ api_keys : dict [str , str ],
44+ ) -> tuple [str | None , str | None , str | None ]:
45+ """Resolve provider, model, and API key for consensus checking.
46+
47+ Resolution order: explicit ``consensus_model`` dict → first available
48+ key in ``api_keys``.
49+
50+ Args:
51+ consensus_model: Optional dict with 'provider' and/or 'model' keys
52+ api_keys: Dictionary mapping provider names to API keys
53+
54+ Returns:
55+ tuple of (provider, model, api_key) — any element may be None
56+ """
57+ if consensus_model :
58+ provider = consensus_model .get ("provider" )
59+ model = consensus_model .get ("model" )
60+ if not provider and model :
61+ provider = get_provider (model )
62+ if provider and not model :
63+ model = get_default_model (provider )
64+ else :
65+ provider = None
66+ model = None
67+ for p , key in api_keys .items ():
68+ if key :
69+ provider = p
70+ model = get_default_model (p )
71+ break
72+
73+ api_key = api_keys .get (provider ) if provider else None
74+ return provider , model , api_key
75+
76+
4177def _call_llm_with_retry (
4278 prompt : str ,
4379 provider : str ,
@@ -188,7 +224,7 @@ def _extract_metrics_from_text(
188224
189225 # Regex patterns (mirroring R's .CONSENSUS_CONSTANTS)
190226 consensus_indicator_pattern = r"^\s*[01]\s*$"
191- proportion_pattern = r"^\s*(0\.\d+|1\.0*|1 )\s*$"
227+ proportion_pattern = r"^\s*(0\.\d+|1\.0*|[01] )\s*$"
192228 entropy_pattern = r"^\s*(\d+\.\d+|\d+)\s*$"
193229 general_numeric_pattern = r"^\s*\d+(\.\d+)?\s*$"
194230
@@ -224,7 +260,7 @@ def _extract_metrics_from_text(
224260 parts = line .split ("=" )
225261 if len (parts ) > 1 :
226262 last_part = parts [- 1 ].strip ()
227- value_match = re .search (r"(0\.\d+|1\.0*|1 )" , last_part )
263+ value_match = re .search (r"(0\.\d+|1\.0*|[01] )" , last_part )
228264 if value_match :
229265 with contextlib .suppress (ValueError ):
230266 potential_cp = float (value_match .group (1 ))
@@ -407,26 +443,9 @@ def check_consensus(
407443 # Use LLM to check consensus among annotations
408444 prompt = create_consensus_check_prompt (cluster_annotations )
409445
410- # Determine which model to use: explicit consensus_model → api_keys
411- if consensus_model :
412- primary_provider = consensus_model .get ("provider" )
413- primary_model = consensus_model .get ("model" )
414- if not primary_provider and primary_model :
415- primary_provider = get_provider (primary_model )
416- if primary_provider and not primary_model :
417- primary_model = get_default_model (primary_provider )
418- else :
419- # Pick from user's available api_keys
420- primary_provider = None
421- primary_model = None
422- if api_keys :
423- for provider , key in api_keys .items ():
424- if key :
425- primary_provider = provider
426- primary_model = get_default_model (provider )
427- break
428-
429- primary_api_key = api_keys .get (primary_provider ) if primary_provider else None
446+ primary_provider , primary_model , primary_api_key = _resolve_consensus_provider (
447+ consensus_model , api_keys
448+ )
430449
431450 llm_response = _call_llm_with_retry (
432451 prompt = prompt ,
@@ -528,7 +547,7 @@ def check_consensus_for_discussion_round(
528547 consensus_threshold : float = 0.7 ,
529548 entropy_threshold : float = 1.0 ,
530549 api_keys : dict [str , str ] | None = None ,
531- consensus_check_model : dict [str , str ] | None = None ,
550+ consensus_model : dict [str , str ] | None = None ,
532551 base_urls : str | dict [str , str ] | None = None ,
533552) -> dict [str , Any ]:
534553 """Check consensus among model responses for a single discussion round.
@@ -547,7 +566,7 @@ def check_consensus_for_discussion_round(
547566 consensus_threshold: Agreement threshold (default: 0.7)
548567 entropy_threshold: Entropy threshold (default: 1.0)
549568 api_keys: Dictionary mapping provider names to API keys
550- consensus_check_model : Optional dict with 'provider' and 'model' keys
569+ consensus_model : Optional dict with 'provider' and 'model' keys
551570 base_urls: Custom base URLs for API endpoints
552571
553572 Returns:
@@ -562,28 +581,12 @@ def check_consensus_for_discussion_round(
562581 write_log ("No responses to check consensus" , level = "warning" )
563582 return DEFAULT_CONSENSUS_RESULT .copy ()
564583
565- # Resolve LLM parameters: explicit consensus_check_model → api_keys → give up
566584 if api_keys is None :
567585 api_keys = {}
568586
569- if consensus_check_model :
570- primary_provider = consensus_check_model .get ("provider" )
571- primary_model = consensus_check_model .get ("model" )
572- if not primary_provider and primary_model :
573- primary_provider = get_provider (primary_model )
574- if primary_provider and not primary_model :
575- primary_model = get_default_model (primary_provider )
576- else :
577- # Pick from user's available api_keys
578- primary_provider = None
579- primary_model = None
580- for provider , key in api_keys .items ():
581- if key :
582- primary_provider = provider
583- primary_model = get_default_model (provider )
584- break
585-
586- primary_api_key = api_keys .get (primary_provider ) if primary_provider else None
587+ primary_provider , primary_model , primary_api_key = _resolve_consensus_provider (
588+ consensus_model , api_keys
589+ )
587590
588591 # Single response — extract label but cannot establish consensus
589592 if len (round_responses ) == 1 :
@@ -669,7 +672,7 @@ def process_controversial_clusters(
669672 cache_dir : str | None = None ,
670673 base_urls : str | dict [str , str ] | None = None ,
671674 force_rerun : bool = False ,
672- consensus_check_model : dict [str , str ] | None = None ,
675+ consensus_model : dict [str , str ] | None = None ,
673676) -> tuple [dict [str , str ], dict [str , list [dict ]], dict [str , float ], dict [str , float ]]:
674677 """Process controversial clusters through multi-model discussion.
675678
@@ -694,7 +697,7 @@ def process_controversial_clusters(
694697 cache_dir: Directory to store cache files
695698 base_urls: Custom base URLs for API endpoints
696699 force_rerun: If True, ignore cached results
697- consensus_check_model : Optional dict with 'provider' and 'model' keys
700+ consensus_model : Optional dict with 'provider' and 'model' keys
698701 to specify which model to use for consensus checking with LLM.
699702 If not provided, picks from the caller's api_keys.
700703
@@ -872,7 +875,7 @@ def process_controversial_clusters(
872875 consensus_threshold = consensus_threshold ,
873876 entropy_threshold = entropy_threshold ,
874877 api_keys = api_keys ,
875- consensus_check_model = consensus_check_model ,
878+ consensus_model = consensus_model ,
876879 base_urls = base_urls ,
877880 )
878881
@@ -906,7 +909,7 @@ def process_controversial_clusters(
906909 consensus_threshold = consensus_threshold ,
907910 entropy_threshold = entropy_threshold ,
908911 api_keys = api_keys ,
909- consensus_check_model = consensus_check_model ,
912+ consensus_model = consensus_model ,
910913 base_urls = base_urls ,
911914 )
912915 final_decision = last_consensus ["majority_prediction" ]
@@ -926,7 +929,7 @@ def process_controversial_clusters(
926929 consensus_threshold = consensus_threshold ,
927930 entropy_threshold = entropy_threshold ,
928931 api_keys = api_keys ,
929- consensus_check_model = consensus_check_model ,
932+ consensus_model = consensus_model ,
930933 base_urls = base_urls ,
931934 )
932935 cell_type = last_consensus ["majority_prediction" ]
@@ -1241,7 +1244,7 @@ def interactive_consensus_annotation(
12411244 cache_dir = cache_dir ,
12421245 base_urls = base_urls ,
12431246 force_rerun = force_rerun ,
1244- consensus_check_model = consensus_model_dict , # Pass consensus model for LLM verification
1247+ consensus_model = consensus_model_dict ,
12451248 )
12461249
12471250 # Update consensus proportion and entropy for resolved clusters
0 commit comments