1+ import pytest
2+ import os
3+ import time
4+ from dicee .evaluator import evaluate_lp_k_vs_all
5+ from dicee .knowledge_graph import KG
6+ from retrieval_augmented_link_predictor import GCL , RCL
7+ import pandas as pd
8+ import json
9+ from datetime import datetime
10+
11+ class TestCompareGCLRCL :
12+ """
13+ Regression tests comparing GCL (with 3 hops) and RCL on various datasets.
14+ This test suite measures both model performance and runtime.
15+ """
16+
17+ # Store results for all test runs to create a comparison report at the end
18+ results_data = []
19+
20+ def setup_method (self ):
21+ """Setup for each test method: check for API key"""
22+ # Get API key from environment variable
23+ self .api_key = os .environ .get ("TENTRIS_TOKEN" )
24+ assert self .api_key is not None , "TENTRIS_TOKEN environment variable not set"
25+
26+ # Common API settings
27+ self .base_url = "http://harebell.cs.upb.de:8501/v1"
28+ self .llm_model = "tentris"
29+ self .temperature = 0.0
30+ self .seed = 42
31+
32+ # Test settings
33+ self .batch_size = 1
34+
35+ # Ensure the temp directory exists for saving results
36+ os .makedirs ("temp" , exist_ok = True )
37+
38+ def run_model_eval (self , model_name , model , kg , test_size = None , dataset_name = "Unknown" ):
39+ """Run evaluation for a model and record performance and runtime"""
40+ test_triples = kg .test_set if test_size is None else kg .test_set [:test_size ]
41+
42+ # Start timer
43+ start_time = time .time ()
44+
45+ # Run evaluation
46+ results = evaluate_lp_k_vs_all (
47+ model = model ,
48+ triple_idx = test_triples ,
49+ er_vocab = kg .er_vocab ,
50+ info = f'Testing { model_name } on { dataset_name } ' ,
51+ batch_size = self .batch_size
52+ )
53+
54+ # End timer
55+ end_time = time .time ()
56+ runtime = end_time - start_time
57+
58+ # Add runtime to results
59+ results ['Runtime' ] = runtime
60+
61+ # Store results for comparison
62+ self .results_data .append ({
63+ 'Dataset' : dataset_name ,
64+ 'Model' : model_name ,
65+ 'H@1' : results ['H@1' ],
66+ 'H@3' : results ['H@3' ],
67+ 'H@10' : results ['H@10' ],
68+ 'MRR' : results ['MRR' ],
69+ 'Runtime (s)' : runtime ,
70+ 'Test Size' : len (test_triples )
71+ })
72+
73+ return results
74+
75+ def test_countries_s1 (self ):
76+ """Compare GCL and RCL on Countries-S1 dataset"""
77+ # Setup
78+ dataset_name = "Countries-S1"
79+ kg = KG (dataset_dir = f"KGs/{ dataset_name } " , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
80+
81+ # Test GCL with 3 hops
82+ gcl_model = GCL (
83+ knowledge_graph = kg ,
84+ base_url = self .base_url ,
85+ api_key = self .api_key ,
86+ llm_model = self .llm_model ,
87+ temperature = self .temperature ,
88+ seed = self .seed ,
89+ num_of_hops = 3
90+ )
91+
92+ gcl_results = self .run_model_eval ("GCL-3hops" , gcl_model , kg , dataset_name = dataset_name )
93+
94+ # Test RCL with exclude_source=True (default)
95+ rcl_model = RCL (
96+ knowledge_graph = kg ,
97+ base_url = self .base_url ,
98+ api_key = self .api_key ,
99+ llm_model = self .llm_model ,
100+ temperature = self .temperature ,
101+ seed = self .seed ,
102+ max_relation_examples = 50 ,
103+ exclude_source = True
104+ )
105+
106+ rcl_results = self .run_model_eval ("RCL-exclude" , rcl_model , kg , dataset_name = dataset_name )
107+
108+ # Test RCL with exclude_source=False
109+ rcl_include_model = RCL (
110+ knowledge_graph = kg ,
111+ base_url = self .base_url ,
112+ api_key = self .api_key ,
113+ llm_model = self .llm_model ,
114+ temperature = self .temperature ,
115+ seed = self .seed ,
116+ max_relation_examples = 50 ,
117+ exclude_source = False
118+ )
119+
120+ rcl_include_results = self .run_model_eval ("RCL-include" , rcl_include_model , kg , dataset_name = dataset_name )
121+
122+ # Check expectations
123+ assert gcl_results ['H@1' ] >= 0.9 , f"GCL H@1 score too low: { gcl_results ['H@1' ]} "
124+ assert rcl_results ['H@1' ] >= 0.9 , f"RCL H@1 score too low: { rcl_results ['H@1' ]} "
125+ assert rcl_include_results ['H@1' ] >= 0.9 , f"RCL-include H@1 score too low: { rcl_include_results ['H@1' ]} "
126+
127+ def test_countries_s2 (self ):
128+ """Compare GCL and RCL on Countries-S2 dataset"""
129+ # Setup
130+ dataset_name = "Countries-S2"
131+ kg = KG (dataset_dir = f"KGs/{ dataset_name } " , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
132+
133+ # Test GCL with 3 hops
134+ gcl_model = GCL (
135+ knowledge_graph = kg ,
136+ base_url = self .base_url ,
137+ api_key = self .api_key ,
138+ llm_model = self .llm_model ,
139+ temperature = self .temperature ,
140+ seed = self .seed ,
141+ num_of_hops = 3
142+ )
143+
144+ gcl_results = self .run_model_eval ("GCL-3hops" , gcl_model , kg , dataset_name = dataset_name )
145+
146+ # Test RCL with exclude_source=True (default)
147+ rcl_model = RCL (
148+ knowledge_graph = kg ,
149+ base_url = self .base_url ,
150+ api_key = self .api_key ,
151+ llm_model = self .llm_model ,
152+ temperature = self .temperature ,
153+ seed = self .seed ,
154+ max_relation_examples = 50 ,
155+ exclude_source = True
156+ )
157+
158+ rcl_results = self .run_model_eval ("RCL-exclude" , rcl_model , kg , dataset_name = dataset_name )
159+
160+ # Test RCL with exclude_source=False
161+ rcl_include_model = RCL (
162+ knowledge_graph = kg ,
163+ base_url = self .base_url ,
164+ api_key = self .api_key ,
165+ llm_model = self .llm_model ,
166+ temperature = self .temperature ,
167+ seed = self .seed ,
168+ max_relation_examples = 50 ,
169+ exclude_source = False
170+ )
171+
172+ rcl_include_results = self .run_model_eval ("RCL-include" , rcl_include_model , kg , dataset_name = dataset_name )
173+
174+ # Check expectations for S2 (slightly more challenging than S1)
175+ assert gcl_results ['H@1' ] >= 0.85 , f"GCL H@1 score too low: { gcl_results ['H@1' ]} "
176+ assert rcl_results ['H@1' ] >= 0.85 , f"RCL H@1 score too low: { rcl_results ['H@1' ]} "
177+ assert rcl_include_results ['H@1' ] >= 0.85 , f"RCL-include H@1 score too low: { rcl_include_results ['H@1' ]} "
178+
179+ def test_umls (self ):
180+ """Compare GCL and RCL on UMLS dataset with limited test set"""
181+ # Setup
182+ dataset_name = "UMLS"
183+ kg = KG (dataset_dir = f"KGs/{ dataset_name } " , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
184+
185+ # Limit test size for UMLS to save time
186+ test_size = 20
187+
188+ # Test GCL with 3 hops
189+ gcl_model = GCL (
190+ knowledge_graph = kg ,
191+ base_url = self .base_url ,
192+ api_key = self .api_key ,
193+ llm_model = self .llm_model ,
194+ temperature = self .temperature ,
195+ seed = self .seed ,
196+ num_of_hops = 3
197+ )
198+
199+ gcl_results = self .run_model_eval ("GCL-3hops" , gcl_model , kg , test_size = test_size , dataset_name = dataset_name )
200+
201+ # Test RCL with exclude_source=True (default)
202+ rcl_model = RCL (
203+ knowledge_graph = kg ,
204+ base_url = self .base_url ,
205+ api_key = self .api_key ,
206+ llm_model = self .llm_model ,
207+ temperature = self .temperature ,
208+ seed = self .seed ,
209+ max_relation_examples = 50 ,
210+ exclude_source = True
211+ )
212+
213+ rcl_results = self .run_model_eval ("RCL-exclude" , rcl_model , kg , test_size = test_size , dataset_name = dataset_name )
214+
215+ # Test RCL with exclude_source=False
216+ rcl_include_model = RCL (
217+ knowledge_graph = kg ,
218+ base_url = self .base_url ,
219+ api_key = self .api_key ,
220+ llm_model = self .llm_model ,
221+ temperature = self .temperature ,
222+ seed = self .seed ,
223+ max_relation_examples = 50 ,
224+ exclude_source = False
225+ )
226+
227+ rcl_include_results = self .run_model_eval ("RCL-include" , rcl_include_model , kg , test_size = test_size , dataset_name = dataset_name )
228+
229+ # Check expectations for UMLS
230+ assert gcl_results ['H@1' ] >= 0.45 , f"GCL H@1 score too low: { gcl_results ['H@1' ]} "
231+ assert rcl_results ['H@1' ] >= 0.45 , f"RCL H@1 score too low: { rcl_results ['H@1' ]} "
232+ assert rcl_include_results ['H@1' ] >= 0.45 , f"RCL-include H@1 score too low: { rcl_include_results ['H@1' ]} "
233+
234+ def test_kinship (self ):
235+ """Compare GCL and RCL on KINSHIP dataset with limited test set"""
236+ # Setup
237+ dataset_name = "KINSHIP"
238+ kg = KG (dataset_dir = f"KGs/{ dataset_name } " , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
239+
240+ # Limit test size for KINSHIP to save time
241+ test_size = 20
242+
243+ # Test GCL with 3 hops
244+ gcl_model = GCL (
245+ knowledge_graph = kg ,
246+ base_url = self .base_url ,
247+ api_key = self .api_key ,
248+ llm_model = self .llm_model ,
249+ temperature = self .temperature ,
250+ seed = self .seed ,
251+ num_of_hops = 3
252+ )
253+
254+ gcl_results = self .run_model_eval ("GCL-3hops" , gcl_model , kg , test_size = test_size , dataset_name = dataset_name )
255+
256+ # Test RCL with exclude_source=True (default)
257+ rcl_model = RCL (
258+ knowledge_graph = kg ,
259+ base_url = self .base_url ,
260+ api_key = self .api_key ,
261+ llm_model = self .llm_model ,
262+ temperature = self .temperature ,
263+ seed = self .seed ,
264+ max_relation_examples = 50 ,
265+ exclude_source = True
266+ )
267+
268+ rcl_results = self .run_model_eval ("RCL-exclude" , rcl_model , kg , test_size = test_size , dataset_name = dataset_name )
269+
270+ # Test RCL with exclude_source=False
271+ rcl_include_model = RCL (
272+ knowledge_graph = kg ,
273+ base_url = self .base_url ,
274+ api_key = self .api_key ,
275+ llm_model = self .llm_model ,
276+ temperature = self .temperature ,
277+ seed = self .seed ,
278+ max_relation_examples = 50 ,
279+ exclude_source = False
280+ )
281+
282+ rcl_include_results = self .run_model_eval ("RCL-include" , rcl_include_model , kg , test_size = test_size , dataset_name = dataset_name )
283+
284+ # Check expectations for KINSHIP
285+ assert gcl_results ['H@1' ] >= 0.05 , f"GCL H@1 score too low: { gcl_results ['H@1' ]} "
286+ assert rcl_results ['H@1' ] >= 0.05 , f"RCL H@1 score too low: { rcl_results ['H@1' ]} "
287+ assert rcl_include_results ['H@1' ] >= 0.05 , f"RCL-include H@1 score too low: { rcl_include_results ['H@1' ]} "
288+
289+ def teardown_method (self ):
290+ """After each test method, print current results"""
291+ if len (self .results_data ) % 3 == 0 : # Print after each dataset's tests complete
292+ current_df = pd .DataFrame (self .results_data [- 3 :])
293+ print (f"\n Latest results:" )
294+ print (current_df .to_string (index = False ))
295+
296+ @pytest .fixture (scope = "session" , autouse = True )
297+ def save_comparison_results (self , request ):
298+ """Save all results at the end of the test session"""
299+ def finalize ():
300+ if TestCompareGCLRCL .results_data :
301+ # Create a timestamp for the results file
302+ timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
303+ results_df = pd .DataFrame (TestCompareGCLRCL .results_data )
304+
305+ # Save to CSV in temp directory
306+ output_file = f"temp/compare_gcl_rcl_results_{ timestamp } .csv"
307+ results_df .to_csv (output_file , index = False )
308+ print (f"\n Complete comparison results saved to { output_file } " )
309+
310+ # Also save as JSON for easier programmatic access
311+ json_file = f"temp/compare_gcl_rcl_results_{ timestamp } .json"
312+ results_df .to_json (json_file , orient = "records" , indent = 2 )
313+
314+ # Print the final comparison table
315+ print ("\n Final Comparison Results:" )
316+ print (results_df .to_string (index = False ))
317+
318+ # Create summary by dataset/model
319+ summary = results_df .groupby (['Dataset' , 'Model' ]).mean ().reset_index ()
320+ print ("\n Average Metrics by Dataset and Model:" )
321+ print (summary .to_string (index = False ))
322+
323+ request .addfinalizer (finalize )
0 commit comments