Skip to content

Commit 5f90bde

Browse files
Resolve consensus provider and parsing fixes
Add _resolve_consensus_provider helper to centralize resolution of provider, model, and API key (prefers explicit consensus_model then falls back to first available api_keys). Refactor functions to accept consensus_model (previously consensus_check_model) and use the new resolver. Fix numeric parsing regexes to accept 0/1 values and improve robustness of metric extraction. Update resolve_provider_base_url to handle a None provider. Improve format_results: make cluster prefix matching case-insensitive, tighten line-by-line parsing to strip "Cluster X:" prefixes, log when fewer lines than clusters and mark missing entries as "Unknown". Update tests to reflect stripped prefixes behavior.
1 parent 6d5dd9b commit 5f90bde

File tree

4 files changed

+71
-86
lines changed

4 files changed

+71
-86
lines changed

python/mllmcelltype/consensus.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,42 @@
3838
}
3939

4040

41+
def _resolve_consensus_provider(
42+
consensus_model: dict[str, str] | None,
43+
api_keys: dict[str, str],
44+
) -> tuple[str | None, str | None, str | None]:
45+
"""Resolve provider, model, and API key for consensus checking.
46+
47+
Resolution order: explicit ``consensus_model`` dict → first available
48+
key in ``api_keys``.
49+
50+
Args:
51+
consensus_model: Optional dict with 'provider' and/or 'model' keys
52+
api_keys: Dictionary mapping provider names to API keys
53+
54+
Returns:
55+
tuple of (provider, model, api_key) — any element may be None
56+
"""
57+
if consensus_model:
58+
provider = consensus_model.get("provider")
59+
model = consensus_model.get("model")
60+
if not provider and model:
61+
provider = get_provider(model)
62+
if provider and not model:
63+
model = get_default_model(provider)
64+
else:
65+
provider = None
66+
model = None
67+
for p, key in api_keys.items():
68+
if key:
69+
provider = p
70+
model = get_default_model(p)
71+
break
72+
73+
api_key = api_keys.get(provider) if provider else None
74+
return provider, model, api_key
75+
76+
4177
def _call_llm_with_retry(
4278
prompt: str,
4379
provider: str,
@@ -188,7 +224,7 @@ def _extract_metrics_from_text(
188224

189225
# Regex patterns (mirroring R's .CONSENSUS_CONSTANTS)
190226
consensus_indicator_pattern = r"^\s*[01]\s*$"
191-
proportion_pattern = r"^\s*(0\.\d+|1\.0*|1)\s*$"
227+
proportion_pattern = r"^\s*(0\.\d+|1\.0*|[01])\s*$"
192228
entropy_pattern = r"^\s*(\d+\.\d+|\d+)\s*$"
193229
general_numeric_pattern = r"^\s*\d+(\.\d+)?\s*$"
194230

@@ -224,7 +260,7 @@ def _extract_metrics_from_text(
224260
parts = line.split("=")
225261
if len(parts) > 1:
226262
last_part = parts[-1].strip()
227-
value_match = re.search(r"(0\.\d+|1\.0*|1)", last_part)
263+
value_match = re.search(r"(0\.\d+|1\.0*|[01])", last_part)
228264
if value_match:
229265
with contextlib.suppress(ValueError):
230266
potential_cp = float(value_match.group(1))
@@ -407,26 +443,9 @@ def check_consensus(
407443
# Use LLM to check consensus among annotations
408444
prompt = create_consensus_check_prompt(cluster_annotations)
409445

410-
# Determine which model to use: explicit consensus_model → api_keys
411-
if consensus_model:
412-
primary_provider = consensus_model.get("provider")
413-
primary_model = consensus_model.get("model")
414-
if not primary_provider and primary_model:
415-
primary_provider = get_provider(primary_model)
416-
if primary_provider and not primary_model:
417-
primary_model = get_default_model(primary_provider)
418-
else:
419-
# Pick from user's available api_keys
420-
primary_provider = None
421-
primary_model = None
422-
if api_keys:
423-
for provider, key in api_keys.items():
424-
if key:
425-
primary_provider = provider
426-
primary_model = get_default_model(provider)
427-
break
428-
429-
primary_api_key = api_keys.get(primary_provider) if primary_provider else None
446+
primary_provider, primary_model, primary_api_key = _resolve_consensus_provider(
447+
consensus_model, api_keys
448+
)
430449

431450
llm_response = _call_llm_with_retry(
432451
prompt=prompt,
@@ -528,7 +547,7 @@ def check_consensus_for_discussion_round(
528547
consensus_threshold: float = 0.7,
529548
entropy_threshold: float = 1.0,
530549
api_keys: dict[str, str] | None = None,
531-
consensus_check_model: dict[str, str] | None = None,
550+
consensus_model: dict[str, str] | None = None,
532551
base_urls: str | dict[str, str] | None = None,
533552
) -> dict[str, Any]:
534553
"""Check consensus among model responses for a single discussion round.
@@ -547,7 +566,7 @@ def check_consensus_for_discussion_round(
547566
consensus_threshold: Agreement threshold (default: 0.7)
548567
entropy_threshold: Entropy threshold (default: 1.0)
549568
api_keys: Dictionary mapping provider names to API keys
550-
consensus_check_model: Optional dict with 'provider' and 'model' keys
569+
consensus_model: Optional dict with 'provider' and 'model' keys
551570
base_urls: Custom base URLs for API endpoints
552571
553572
Returns:
@@ -562,28 +581,12 @@ def check_consensus_for_discussion_round(
562581
write_log("No responses to check consensus", level="warning")
563582
return DEFAULT_CONSENSUS_RESULT.copy()
564583

565-
# Resolve LLM parameters: explicit consensus_check_model → api_keys → give up
566584
if api_keys is None:
567585
api_keys = {}
568586

569-
if consensus_check_model:
570-
primary_provider = consensus_check_model.get("provider")
571-
primary_model = consensus_check_model.get("model")
572-
if not primary_provider and primary_model:
573-
primary_provider = get_provider(primary_model)
574-
if primary_provider and not primary_model:
575-
primary_model = get_default_model(primary_provider)
576-
else:
577-
# Pick from user's available api_keys
578-
primary_provider = None
579-
primary_model = None
580-
for provider, key in api_keys.items():
581-
if key:
582-
primary_provider = provider
583-
primary_model = get_default_model(provider)
584-
break
585-
586-
primary_api_key = api_keys.get(primary_provider) if primary_provider else None
587+
primary_provider, primary_model, primary_api_key = _resolve_consensus_provider(
588+
consensus_model, api_keys
589+
)
587590

588591
# Single response — extract label but cannot establish consensus
589592
if len(round_responses) == 1:
@@ -669,7 +672,7 @@ def process_controversial_clusters(
669672
cache_dir: str | None = None,
670673
base_urls: str | dict[str, str] | None = None,
671674
force_rerun: bool = False,
672-
consensus_check_model: dict[str, str] | None = None,
675+
consensus_model: dict[str, str] | None = None,
673676
) -> tuple[dict[str, str], dict[str, list[dict]], dict[str, float], dict[str, float]]:
674677
"""Process controversial clusters through multi-model discussion.
675678
@@ -694,7 +697,7 @@ def process_controversial_clusters(
694697
cache_dir: Directory to store cache files
695698
base_urls: Custom base URLs for API endpoints
696699
force_rerun: If True, ignore cached results
697-
consensus_check_model: Optional dict with 'provider' and 'model' keys
700+
consensus_model: Optional dict with 'provider' and 'model' keys
698701
to specify which model to use for consensus checking with LLM.
699702
If not provided, picks from the caller's api_keys.
700703
@@ -872,7 +875,7 @@ def process_controversial_clusters(
872875
consensus_threshold=consensus_threshold,
873876
entropy_threshold=entropy_threshold,
874877
api_keys=api_keys,
875-
consensus_check_model=consensus_check_model,
878+
consensus_model=consensus_model,
876879
base_urls=base_urls,
877880
)
878881

@@ -906,7 +909,7 @@ def process_controversial_clusters(
906909
consensus_threshold=consensus_threshold,
907910
entropy_threshold=entropy_threshold,
908911
api_keys=api_keys,
909-
consensus_check_model=consensus_check_model,
912+
consensus_model=consensus_model,
910913
base_urls=base_urls,
911914
)
912915
final_decision = last_consensus["majority_prediction"]
@@ -926,7 +929,7 @@ def process_controversial_clusters(
926929
consensus_threshold=consensus_threshold,
927930
entropy_threshold=entropy_threshold,
928931
api_keys=api_keys,
929-
consensus_check_model=consensus_check_model,
932+
consensus_model=consensus_model,
930933
base_urls=base_urls,
931934
)
932935
cell_type = last_consensus["majority_prediction"]
@@ -1241,7 +1244,7 @@ def interactive_consensus_annotation(
12411244
cache_dir=cache_dir,
12421245
base_urls=base_urls,
12431246
force_rerun=force_rerun,
1244-
consensus_check_model=consensus_model_dict, # Pass consensus model for LLM verification
1247+
consensus_model=consensus_model_dict,
12451248
)
12461249

12471250
# Update consensus proportion and entropy for resolved clusters

python/mllmcelltype/url_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def resolve_provider_base_url(provider: str, base_urls: str | dict | None) -> st
2222
Returns:
2323
Resolved base URL or None
2424
"""
25-
if base_urls is None:
25+
if base_urls is None or provider is None:
2626
return None
2727

2828
if isinstance(base_urls, str):

python/mllmcelltype/utils.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def format_results(results: list[str], clusters: list[str]) -> dict[str, str]:
281281

282282
# Case 1: Try to parse the format "Cluster X: Annotation" (most common format from our prompts)
283283
result = {}
284-
cluster_pattern = r"Cluster\s+(.+?):\s*(.*)"
284+
cluster_pattern = r"(?i)Cluster\s+(.+?):\s*(.*)"
285285

286286
# First pass: try to find annotations for each cluster by ID
287287
for cluster in clusters:
@@ -354,45 +354,27 @@ def format_results(results: list[str], clusters: list[str]) -> dict[str, str]:
354354
except (json.JSONDecodeError, ValueError, KeyError, TypeError, AttributeError) as e:
355355
write_log(f"Failed to parse JSON response: {e!s}", level="debug")
356356

357-
# Case 3: Check if this is a simple response where each line corresponds to a cluster
358-
# This is the expected format from the R version
359-
if len(clean_results) >= len(clusters):
360-
# Simple case: one result per cluster
361-
simple_result = {}
362-
for i, cluster in enumerate(clusters):
363-
if i < len(clean_results):
364-
# Check if this line contains a cluster prefix and remove it
365-
line = clean_results[i]
366-
match = re.match(cluster_pattern, line)
367-
if match:
368-
simple_result[str(cluster)] = match.group(2).strip()
369-
else:
370-
simple_result[str(cluster)] = line.strip()
371-
else:
372-
simple_result[str(cluster)] = "Unknown"
373-
374-
write_log("Successfully parsed response as simple line-by-line format", level="info")
375-
return simple_result
357+
# Case 3: Line-by-line mapping — each line corresponds to a cluster
358+
if len(clean_results) < len(clusters):
359+
write_log(
360+
f"Fewer result lines ({len(clean_results)}) than clusters ({len(clusters)}), "
361+
"remaining clusters will be marked Unknown",
362+
level="warning",
363+
)
376364

377-
# Case 4: Fall back to the original method
378-
write_log(
379-
"Could not parse complex LLM response, falling back to simple mapping",
380-
level="warning",
381-
)
382365
result = {}
383366
for i, cluster in enumerate(clusters):
384367
if i < len(clean_results):
385-
result[str(cluster)] = clean_results[i]
368+
line = clean_results[i]
369+
match = re.match(cluster_pattern, line)
370+
if match:
371+
result[str(cluster)] = match.group(2).strip()
372+
else:
373+
result[str(cluster)] = line.strip()
386374
else:
387375
result[str(cluster)] = "Unknown"
388376

389-
# Check if number of results matches number of clusters
390-
if len(result) != len(clusters):
391-
write_log(
392-
f"Number of results ({len(result)}) does not match number of clusters ({len(clusters)})",
393-
level="warning",
394-
)
395-
377+
write_log("Parsed response as line-by-line format", level="info")
396378
return result
397379

398380

python/tests/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ def test_format_results_mismatched():
175175
assert "2" in formatted
176176
# The function adds "Unknown" for missing clusters
177177
assert "3" in formatted
178-
# In simple mapping mode, it doesn't clean prefixes
179-
assert "Cluster 1: T cells" in formatted["1"]
180-
assert "Cluster 2: B cells" in formatted["2"]
178+
# Line-by-line mapping strips "Cluster X:" prefix when present
179+
assert formatted["1"] == "T cells"
180+
assert formatted["2"] == "B cells"
181181
assert formatted["3"] == "Unknown"
182182

183183

0 commit comments

Comments
 (0)