Skip to content

Commit 40df4f6

Browse files
Refactor provider API calls and consensus logic
Refactored provider modules to remove unnecessary chunking logic and process all input in a single request, simplifying API call flows and error handling. Improved retry and error handling for all providers, and standardized result post-processing. Updated consensus logic to use find_agreement() for controversial cluster identification, and improved variable naming and logging for clarity. Minor translation and comment updates for consistency.
1 parent dfb80f3 commit 40df4f6

14 files changed

Lines changed: 533 additions & 712 deletions

File tree

python/mllmcelltype/annotate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def annotate_clusters(
101101
write_log("Using cached results")
102102
return format_results(cached_results, clusters)
103103

104-
# 解析base URL
104+
# Resolve base URL
105105
from .url_utils import resolve_provider_base_url
106106

107107
base_url = resolve_provider_base_url(provider, base_urls)
@@ -234,7 +234,7 @@ def batch_annotate_clusters(
234234
start_idx = end_idx
235235
return result_sets
236236

237-
# 解析base URL
237+
# Resolve base URL
238238
from .url_utils import resolve_provider_base_url
239239

240240
base_url = resolve_provider_base_url(provider, base_urls)

python/mllmcelltype/consensus.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _call_llm_with_retry(
8686
"""
8787
from .url_utils import resolve_provider_base_url
8888

89-
# 解析base URL
89+
# Resolve base URL
9090
primary_base_url = resolve_provider_base_url(provider, base_urls)
9191

9292
# First try with primary provider
@@ -121,7 +121,7 @@ def _call_llm_with_retry(
121121
if api_keys:
122122
fallback_api_key = _get_api_key(fallback_provider, api_keys)
123123
if fallback_api_key:
124-
# 解析fallback provider的base URL
124+
# Resolve base URL for fallback provider
125125
fallback_base_url = resolve_provider_base_url(fallback_provider, base_urls)
126126
try:
127127
response = get_model_response(
@@ -474,7 +474,7 @@ def process_controversial_clusters(
474474
from .prompts import create_consensus_check_prompt
475475
from .url_utils import resolve_provider_base_url
476476

477-
# 解析base URL
477+
# Resolve base URL
478478
base_url = resolve_provider_base_url(provider, base_urls)
479479

480480
results = {}
@@ -486,8 +486,8 @@ def process_controversial_clusters(
486486
write_log(f"Processing controversial cluster {cluster_id}")
487487

488488
# Get marker genes for this cluster
489-
cluster_markers = marker_genes.get(cluster_id, [])
490-
if not cluster_markers:
489+
current_marker_genes = marker_genes.get(cluster_id, [])
490+
if not current_marker_genes:
491491
write_log(
492492
f"Warning: No marker genes found for cluster {cluster_id}",
493493
level="warning",
@@ -541,33 +541,33 @@ def process_controversial_clusters(
541541
lines = consensus_check_response.strip().split("\n")
542542
if len(lines) >= 3:
543543
# Extract consensus proportion
544-
cp = float(lines[1].strip())
544+
cp_value = float(lines[1].strip())
545545

546546
# Extract entropy value
547-
h = float(lines[2].strip())
547+
h_value = float(lines[2].strip())
548548

549549
write_log(
550-
f"Initial metrics for cluster {cluster_id} (LLM calculated): CP={cp:.2f}, H={h:.2f}"
550+
f"Initial metrics for cluster {cluster_id} (LLM calculated): CP={cp_value:.2f}, H={h_value:.2f}"
551551
)
552552
else:
553553
# Fallback if LLM response format is unexpected
554-
cp = 0.25 # Low consensus to ensure discussion happens
555-
h = 2.0 # High entropy to indicate uncertainty
554+
cp_value = 0.25 # Low consensus to ensure discussion happens
555+
h_value = 2.0 # High entropy to indicate uncertainty
556556
write_log(
557-
f"Could not parse LLM consensus check response, using default values: CP={cp:.2f}, H={h:.2f}",
557+
f"Could not parse LLM consensus check response, using default values: CP={cp_value:.2f}, H={h_value:.2f}",
558558
level="warning",
559559
)
560560
except (ValueError, IndexError, AttributeError, TypeError) as e:
561561
# Fallback if parsing fails
562-
cp = 0.25 # Low consensus to ensure discussion happens
563-
h = 2.0 # High entropy to indicate uncertainty
562+
cp_value = 0.25 # Low consensus to ensure discussion happens
563+
h_value = 2.0 # High entropy to indicate uncertainty
564564
write_log(
565-
f"Error parsing LLM consensus check response: {str(e)}, using default values: CP={cp:.2f}, H={h:.2f}",
565+
f"Error parsing LLM consensus check response: {str(e)}, using default values: CP={cp_value:.2f}, H={h_value:.2f}",
566566
level="warning",
567567
)
568568

569569
rounds_history.append(
570-
f"Initial votes: {current_votes}\nConsensus Proportion (CP): {cp:.2f}\nShannon Entropy (H): {h:.2f}"
570+
f"Initial votes: {current_votes}\nConsensus Proportion (CP): {cp_value:.2f}\nShannon Entropy (H): {h_value:.2f}"
571571
)
572572

573573
# Start iterative discussion process
@@ -580,7 +580,7 @@ def process_controversial_clusters(
580580
# Initial discussion round
581581
prompt = create_discussion_prompt(
582582
cluster_id=cluster_id,
583-
marker_genes=cluster_markers,
583+
marker_genes=current_marker_genes,
584584
model_votes=current_votes,
585585
species=species,
586586
tissue=tissue,
@@ -589,7 +589,7 @@ def process_controversial_clusters(
589589
# Follow-up rounds include previous discussion
590590
prompt = create_discussion_prompt(
591591
cluster_id=cluster_id,
592-
marker_genes=cluster_markers,
592+
marker_genes=current_marker_genes,
593593
model_votes=current_votes,
594594
species=species,
595595
tissue=tissue,
@@ -1220,11 +1220,11 @@ def interactive_consensus_annotation(
12201220
)
12211221

12221222
# Update consensus proportion and entropy for resolved clusters
1223-
for cluster_id, cp in updated_cp.items():
1224-
consensus_proportion[cluster_id] = cp
1223+
for cluster_id, cp_value in updated_cp.items():
1224+
consensus_proportion[cluster_id] = cp_value
12251225

1226-
for cluster_id, h in updated_h.items():
1227-
entropy[cluster_id] = h
1226+
for cluster_id, h_value in updated_h.items():
1227+
entropy[cluster_id] = h_value
12281228

12291229
if verbose:
12301230
write_log(f"Successfully resolved {len(resolved)} controversial clusters")
@@ -1314,8 +1314,8 @@ def print_consensus_summary(result: dict[str, Any]) -> None:
13141314

13151315
print("Cluster annotations:")
13161316
for cluster, annotation in sorted(consensus.items(), key=lambda x: x[0]):
1317-
cp = consensus_proportion.get(cluster, 0)
1318-
ent = entropy.get(cluster, 0)
1317+
stored_cp = consensus_proportion.get(cluster, 0)
1318+
stored_entropy = entropy.get(cluster, 0)
13191319
if cluster in resolved:
13201320
# For resolved clusters, show CP and H if available in the discussion logs
13211321
discussion_logs = result.get("discussion_logs", {})
@@ -1343,8 +1343,8 @@ def print_consensus_summary(result: dict[str, Any]) -> None:
13431343
print(f" Cluster {cluster}: {annotation} [Resolved, CP: {cp_value}, H: {h_value}]")
13441344
else:
13451345
# For non-resolved clusters, use the calculated CP and entropy values
1346-
cp_value = cp
1347-
h_value = ent
1346+
cp_value = stored_cp
1347+
h_value = stored_entropy
13481348

13491349
# Display different messages based on agreement level
13501350
# Use the already calculated entropy value

python/mllmcelltype/functions.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
process_stepfun,
1616
process_zhipu,
1717
)
18-
from .utils import clean_annotation
18+
from .utils import clean_annotation, find_agreement
1919

2020
# Global provider function mapping for reuse across modules
2121
PROVIDER_FUNCTIONS = {
@@ -49,7 +49,6 @@
4949
"claude-opus-4-1-20250805",
5050
"claude-opus-4-20250514",
5151
"claude-sonnet-4-20250514",
52-
"claude-sonnet-4-20250514",
5352
"claude-3-5-sonnet-latest",
5453
"claude-3-5-haiku-latest",
5554
"claude-3-opus",
@@ -121,7 +120,6 @@ def get_provider(model: str) -> str:
121120
"claude-opus-4",
122121
"claude-sonnet-4-20250514",
123122
"claude-sonnet-4",
124-
"claude-sonnet-4-20250514",
125123
"claude-3-5-sonnet-20241022",
126124
"claude-3-5-sonnet-20240620",
127125
"claude-3-5-haiku-20241022",
@@ -322,6 +320,9 @@ def identify_controversial_clusters(
322320
) -> list[str]:
323321
"""Identify clusters with inconsistent annotations across models.
324322
323+
This function uses find_agreement() to compute consensus statistics,
324+
then filters clusters where the consensus proportion is below the threshold.
325+
325326
Args:
326327
annotations: Dictionary mapping model names to dictionaries of cluster annotations
327328
threshold: Agreement threshold below which a cluster is considered controversial
@@ -333,36 +334,14 @@ def identify_controversial_clusters(
333334
if not annotations or len(annotations) < 2:
334335
return []
335336

336-
# Get all clusters
337-
all_clusters = set()
338-
for model_results in annotations.values():
339-
all_clusters.update(model_results.keys())
340-
341-
controversial = []
337+
# Use find_agreement() to compute consensus statistics for all clusters
338+
_consensus, consensus_proportion, _entropy = find_agreement(annotations)
342339

343-
# Check each cluster for agreement level
344-
for cluster in all_clusters:
345-
# Get all annotations for this cluster
346-
cluster_annotations = []
347-
for _model, results in annotations.items():
348-
if cluster in results:
349-
annotation = clean_annotation(results[cluster])
350-
if annotation:
351-
cluster_annotations.append(annotation)
352-
353-
# Count occurrences
354-
counts = {}
355-
for anno in cluster_annotations:
356-
counts[anno] = counts.get(anno, 0) + 1
357-
358-
# Find most common annotation and its frequency
359-
if counts:
360-
most_common = max(counts.items(), key=lambda x: x[1])
361-
most_common_count = most_common[1]
362-
agreement = most_common_count / len(cluster_annotations) if cluster_annotations else 0
363-
364-
# Mark as controversial if agreement is below threshold
365-
if agreement < threshold:
366-
controversial.append(cluster)
340+
# Filter clusters where agreement is below threshold
341+
controversial = [
342+
cluster
343+
for cluster, agreement in consensus_proportion.items()
344+
if agreement < threshold
345+
]
367346

368347
return controversial

python/mllmcelltype/providers/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def process_anthropic_direct(
172172

173173
write_log("Falling back to direct API calls for Anthropic")
174174

175-
# 使用自定义URL或默认URL
175+
# Use custom URL or default URL
176176
if base_url:
177177
from ..url_utils import validate_base_url
178178

0 commit comments

Comments
 (0)