Skip to content

Commit bc33305

Browse files
committed
chainOfThought + regression test
1 parent aa4abc1 commit bc33305

2 files changed

Lines changed: 123 additions & 2 deletions

File tree

retrieval_aug_predictors/models/demir_ensemble_mipro.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def __init__(self, entities: List[str],g:GraphType):
265265
super().__init__()
266266
self.entities = sorted(list(set(entities)))
267267
self.g = g
268-
self.finder = dspy.Predict(EntityFinder)
269-
self.scorer = dspy.Predict(Scorer)
268+
self.finder = dspy.ChainOfThought(EntityFinder)
269+
self.scorer = dspy.ChainOfThought(Scorer)
270270

271271
def _graph_based_content_builder(self,subject:str):
272272
hop_to_triples = dict()
@@ -435,6 +435,12 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
435435
batch_predictions = []
436436
num_entities = len(self.idx_to_entity)
437437

438+
# Configure LM for prediction
439+
lm = dspy.LM(model=f"openai/{self.llm_model}", api_key=self.api_key, api_base=self.base_url,
440+
seed=self.seed, temperature=self.mipro_optimizer_temperature,
441+
cache=True, cache_in_memory=True)
442+
dspy.configure(lm=lm)
443+
438444
# Use tqdm for progress visualization
439445
for hr in tqdm(x.tolist(), desc="Predicting Batches (K vs All)"):
440446
idx_h, idx_r = hr
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import json
3+
import pytest
4+
import torch
5+
import tempfile
6+
import shutil
7+
from retrieval_aug_predictors.models import KG
8+
from retrieval_aug_predictors.models.demir_ensemble_mipro import DemirEnsembleMPRO
9+
from dicee.evaluator import evaluate_lp_k_vs_all
10+
11+
class TestDemirEnsembleMPRORegression:
12+
@classmethod
13+
def setup_class(cls):
14+
# Create a temporary directory for model outputs
15+
cls.temp_dir = tempfile.mkdtemp()
16+
17+
# Configure model parameters
18+
cls.llm_model = "tentris"
19+
cls.api_key = os.getenv("TENTRIS_TOKEN")
20+
cls.base_url = os.getenv("OPENAI_API_BASE", "http://harebell.cs.upb.de:8501/v1")
21+
cls.temperature = 0.0
22+
cls.seed = 42
23+
24+
# Define expected benchmark results from the comment at the top of the demir file
25+
cls.expected_results = {
26+
"Countries-S1": {
27+
"H@1": 0.75,
28+
"H@3": 0.875,
29+
"H@10": 1.0,
30+
"MRR": 0.8416666666666667
31+
},
32+
"Countries-S2": {
33+
"H@1": 0.75,
34+
"H@3": 1.0,
35+
"H@10": 1.0,
36+
"MRR": 0.8680555555555555
37+
},
38+
"Countries-S3": {
39+
"H@1": 0.041666666666666664,
40+
"H@3": 0.4583333333333333,
41+
"H@10": 0.625,
42+
"MRR": 0.2626660300405415
43+
}
44+
}
45+
46+
# Dataset directories
47+
cls.dataset_dirs = {
48+
"Countries-S1": "KGs/Countries-S1",
49+
"Countries-S2": "KGs/Countries-S2",
50+
"Countries-S3": "KGs/Countries-S3"
51+
}
52+
53+
@classmethod
54+
def teardown_class(cls):
55+
# Clean up temporary directory
56+
shutil.rmtree(cls.temp_dir)
57+
58+
@pytest.mark.parametrize("dataset_name", ["Countries-S1", "Countries-S2", "Countries-S3"])
59+
def test_model_performance(self, dataset_name):
60+
"""Test model performance against benchmarks for each dataset."""
61+
dataset_dir = self.dataset_dirs[dataset_name]
62+
expected_metrics = self.expected_results[dataset_name]
63+
64+
# Create a dataset-specific save directory
65+
save_dir = os.path.join(self.temp_dir, dataset_name)
66+
os.makedirs(save_dir, exist_ok=True)
67+
68+
kg = KG(dataset_dir=dataset_dir, separator="\s+", eval_model="KvsAll", add_reciprocal=False)
69+
70+
model = DemirEnsembleMPRO(
71+
knowledge_graph=kg,
72+
base_url=self.base_url,
73+
api_key=self.api_key,
74+
llm_model=self.llm_model,
75+
temperature=self.temperature,
76+
seed=self.seed,
77+
use_val=True,
78+
ensemble_temperatures=[0.0], # Use a single temperature for faster testing
79+
save_dir=save_dir,
80+
)
81+
82+
# Use the full test set to match the original experiments
83+
test_triples = kg.test_set
84+
85+
# Run evaluation
86+
results = evaluate_lp_k_vs_all(
87+
model=model,
88+
triple_idx=test_triples,
89+
er_vocab=kg.er_vocab,
90+
info=f'Regression Test (DemirEnsembleMIPRO) - {dataset_name}'
91+
)
92+
93+
# Save test results for inspection
94+
results_file = os.path.join(save_dir, f"test_results_{dataset_name}.json")
95+
with open(results_file, "w") as f:
96+
json.dump(results, f, indent=2)
97+
98+
print(f"\nResults for {dataset_name}:")
99+
print(json.dumps(results, indent=2))
100+
print(f"Expected results:")
101+
print(json.dumps(expected_metrics, indent=2))
102+
103+
# Check that results have the expected metrics
104+
assert set(results.keys()) == set(expected_metrics.keys())
105+
106+
# For regression testing, verify that results are at least as good as the benchmarks
107+
# No tolerance - must be at least as good or better
108+
for metric in expected_metrics:
109+
# For "H@" metrics and MRR, higher is better
110+
if metric.startswith("H@") or metric == "MRR":
111+
assert results[metric] >= expected_metrics[metric], \
112+
f"Performance regression in {dataset_name} - {metric}: " \
113+
f"got {results[metric]}, expected at least {expected_metrics[metric]}"
114+
115+

0 commit comments

Comments
 (0)