1+ import os
2+ import json
3+ import pytest
4+ import torch
5+ import tempfile
6+ import shutil
7+ from retrieval_aug_predictors .models import KG
8+ from retrieval_aug_predictors .models .demir_ensemble_mipro import DemirEnsembleMPRO
9+ from dicee .evaluator import evaluate_lp_k_vs_all
10+
11+ class TestDemirEnsembleMPRORegression :
12+ @classmethod
13+ def setup_class (cls ):
14+ # Create a temporary directory for model outputs
15+ cls .temp_dir = tempfile .mkdtemp ()
16+
17+ # Configure model parameters
18+ cls .llm_model = "tentris"
19+ cls .api_key = os .getenv ("TENTRIS_TOKEN" )
20+ cls .base_url = os .getenv ("OPENAI_API_BASE" , "http://harebell.cs.upb.de:8501/v1" )
21+ cls .temperature = 0.0
22+ cls .seed = 42
23+
24+ # Define expected benchmark results from the comment at the top of the demir file
25+ cls .expected_results = {
26+ "Countries-S1" : {
27+ "H@1" : 0.75 ,
28+ "H@3" : 0.875 ,
29+ "H@10" : 1.0 ,
30+ "MRR" : 0.8416666666666667
31+ },
32+ "Countries-S2" : {
33+ "H@1" : 0.75 ,
34+ "H@3" : 1.0 ,
35+ "H@10" : 1.0 ,
36+ "MRR" : 0.8680555555555555
37+ },
38+ "Countries-S3" : {
39+ "H@1" : 0.041666666666666664 ,
40+ "H@3" : 0.4583333333333333 ,
41+ "H@10" : 0.625 ,
42+ "MRR" : 0.2626660300405415
43+ }
44+ }
45+
46+ # Dataset directories
47+ cls .dataset_dirs = {
48+ "Countries-S1" : "KGs/Countries-S1" ,
49+ "Countries-S2" : "KGs/Countries-S2" ,
50+ "Countries-S3" : "KGs/Countries-S3"
51+ }
52+
53+ @classmethod
54+ def teardown_class (cls ):
55+ # Clean up temporary directory
56+ shutil .rmtree (cls .temp_dir )
57+
58+ @pytest .mark .parametrize ("dataset_name" , ["Countries-S1" , "Countries-S2" , "Countries-S3" ])
59+ def test_model_performance (self , dataset_name ):
60+ """Test model performance against benchmarks for each dataset."""
61+ dataset_dir = self .dataset_dirs [dataset_name ]
62+ expected_metrics = self .expected_results [dataset_name ]
63+
64+ # Create a dataset-specific save directory
65+ save_dir = os .path .join (self .temp_dir , dataset_name )
66+ os .makedirs (save_dir , exist_ok = True )
67+
68+ kg = KG (dataset_dir = dataset_dir , separator = "\s+" , eval_model = "KvsAll" , add_reciprocal = False )
69+
70+ model = DemirEnsembleMPRO (
71+ knowledge_graph = kg ,
72+ base_url = self .base_url ,
73+ api_key = self .api_key ,
74+ llm_model = self .llm_model ,
75+ temperature = self .temperature ,
76+ seed = self .seed ,
77+ use_val = True ,
78+ ensemble_temperatures = [0.0 ], # Use a single temperature for faster testing
79+ save_dir = save_dir ,
80+ )
81+
82+ # Use the full test set to match the original experiments
83+ test_triples = kg .test_set
84+
85+ # Run evaluation
86+ results = evaluate_lp_k_vs_all (
87+ model = model ,
88+ triple_idx = test_triples ,
89+ er_vocab = kg .er_vocab ,
90+ info = f'Regression Test (DemirEnsembleMIPRO) - { dataset_name } '
91+ )
92+
93+ # Save test results for inspection
94+ results_file = os .path .join (save_dir , f"test_results_{ dataset_name } .json" )
95+ with open (results_file , "w" ) as f :
96+ json .dump (results , f , indent = 2 )
97+
98+ print (f"\n Results for { dataset_name } :" )
99+ print (json .dumps (results , indent = 2 ))
100+ print (f"Expected results:" )
101+ print (json .dumps (expected_metrics , indent = 2 ))
102+
103+ # Check that results have the expected metrics
104+ assert set (results .keys ()) == set (expected_metrics .keys ())
105+
106+ # For regression testing, verify that results are at least as good as the benchmarks
107+ # No tolerance - must be at least as good or better
108+ for metric in expected_metrics :
109+ # For "H@" metrics and MRR, higher is better
110+ if metric .startswith ("H@" ) or metric == "MRR" :
111+ assert results [metric ] >= expected_metrics [metric ], \
112+ f"Performance regression in { dataset_name } - { metric } : " \
113+ f"got { results [metric ]} , expected at least { expected_metrics [metric ]} "
114+
115+
0 commit comments