Skip to content

Commit 48530be

Browse files
committed
comparison tests added for gcl and rcl
1 parent 771b99f commit 48530be

1 file changed

Lines changed: 323 additions & 0 deletions

File tree

tests/test_compare_gcl_rcl.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import pytest
2+
import os
3+
import time
4+
from dicee.evaluator import evaluate_lp_k_vs_all
5+
from dicee.knowledge_graph import KG
6+
from retrieval_augmented_link_predictor import GCL, RCL
7+
import pandas as pd
8+
import json
9+
from datetime import datetime
10+
11+
class TestCompareGCLRCL:
12+
"""
13+
Regression tests comparing GCL (with 3 hops) and RCL on various datasets.
14+
This test suite measures both model performance and runtime.
15+
"""
16+
17+
# Store results for all test runs to create a comparison report at the end
18+
results_data = []
19+
20+
def setup_method(self):
21+
"""Setup for each test method: check for API key"""
22+
# Get API key from environment variable
23+
self.api_key = os.environ.get("TENTRIS_TOKEN")
24+
assert self.api_key is not None, "TENTRIS_TOKEN environment variable not set"
25+
26+
# Common API settings
27+
self.base_url = "http://harebell.cs.upb.de:8501/v1"
28+
self.llm_model = "tentris"
29+
self.temperature = 0.0
30+
self.seed = 42
31+
32+
# Test settings
33+
self.batch_size = 1
34+
35+
# Ensure the temp directory exists for saving results
36+
os.makedirs("temp", exist_ok=True)
37+
38+
def run_model_eval(self, model_name, model, kg, test_size=None, dataset_name="Unknown"):
39+
"""Run evaluation for a model and record performance and runtime"""
40+
test_triples = kg.test_set if test_size is None else kg.test_set[:test_size]
41+
42+
# Start timer
43+
start_time = time.time()
44+
45+
# Run evaluation
46+
results = evaluate_lp_k_vs_all(
47+
model=model,
48+
triple_idx=test_triples,
49+
er_vocab=kg.er_vocab,
50+
info=f'Testing {model_name} on {dataset_name}',
51+
batch_size=self.batch_size
52+
)
53+
54+
# End timer
55+
end_time = time.time()
56+
runtime = end_time - start_time
57+
58+
# Add runtime to results
59+
results['Runtime'] = runtime
60+
61+
# Store results for comparison
62+
self.results_data.append({
63+
'Dataset': dataset_name,
64+
'Model': model_name,
65+
'H@1': results['H@1'],
66+
'H@3': results['H@3'],
67+
'H@10': results['H@10'],
68+
'MRR': results['MRR'],
69+
'Runtime (s)': runtime,
70+
'Test Size': len(test_triples)
71+
})
72+
73+
return results
74+
75+
def test_countries_s1(self):
76+
"""Compare GCL and RCL on Countries-S1 dataset"""
77+
# Setup
78+
dataset_name = "Countries-S1"
79+
kg = KG(dataset_dir=f"KGs/{dataset_name}", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
80+
81+
# Test GCL with 3 hops
82+
gcl_model = GCL(
83+
knowledge_graph=kg,
84+
base_url=self.base_url,
85+
api_key=self.api_key,
86+
llm_model=self.llm_model,
87+
temperature=self.temperature,
88+
seed=self.seed,
89+
num_of_hops=3
90+
)
91+
92+
gcl_results = self.run_model_eval("GCL-3hops", gcl_model, kg, dataset_name=dataset_name)
93+
94+
# Test RCL with exclude_source=True (default)
95+
rcl_model = RCL(
96+
knowledge_graph=kg,
97+
base_url=self.base_url,
98+
api_key=self.api_key,
99+
llm_model=self.llm_model,
100+
temperature=self.temperature,
101+
seed=self.seed,
102+
max_relation_examples=50,
103+
exclude_source=True
104+
)
105+
106+
rcl_results = self.run_model_eval("RCL-exclude", rcl_model, kg, dataset_name=dataset_name)
107+
108+
# Test RCL with exclude_source=False
109+
rcl_include_model = RCL(
110+
knowledge_graph=kg,
111+
base_url=self.base_url,
112+
api_key=self.api_key,
113+
llm_model=self.llm_model,
114+
temperature=self.temperature,
115+
seed=self.seed,
116+
max_relation_examples=50,
117+
exclude_source=False
118+
)
119+
120+
rcl_include_results = self.run_model_eval("RCL-include", rcl_include_model, kg, dataset_name=dataset_name)
121+
122+
# Check expectations
123+
assert gcl_results['H@1'] >= 0.9, f"GCL H@1 score too low: {gcl_results['H@1']}"
124+
assert rcl_results['H@1'] >= 0.9, f"RCL H@1 score too low: {rcl_results['H@1']}"
125+
assert rcl_include_results['H@1'] >= 0.9, f"RCL-include H@1 score too low: {rcl_include_results['H@1']}"
126+
127+
def test_countries_s2(self):
128+
"""Compare GCL and RCL on Countries-S2 dataset"""
129+
# Setup
130+
dataset_name = "Countries-S2"
131+
kg = KG(dataset_dir=f"KGs/{dataset_name}", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
132+
133+
# Test GCL with 3 hops
134+
gcl_model = GCL(
135+
knowledge_graph=kg,
136+
base_url=self.base_url,
137+
api_key=self.api_key,
138+
llm_model=self.llm_model,
139+
temperature=self.temperature,
140+
seed=self.seed,
141+
num_of_hops=3
142+
)
143+
144+
gcl_results = self.run_model_eval("GCL-3hops", gcl_model, kg, dataset_name=dataset_name)
145+
146+
# Test RCL with exclude_source=True (default)
147+
rcl_model = RCL(
148+
knowledge_graph=kg,
149+
base_url=self.base_url,
150+
api_key=self.api_key,
151+
llm_model=self.llm_model,
152+
temperature=self.temperature,
153+
seed=self.seed,
154+
max_relation_examples=50,
155+
exclude_source=True
156+
)
157+
158+
rcl_results = self.run_model_eval("RCL-exclude", rcl_model, kg, dataset_name=dataset_name)
159+
160+
# Test RCL with exclude_source=False
161+
rcl_include_model = RCL(
162+
knowledge_graph=kg,
163+
base_url=self.base_url,
164+
api_key=self.api_key,
165+
llm_model=self.llm_model,
166+
temperature=self.temperature,
167+
seed=self.seed,
168+
max_relation_examples=50,
169+
exclude_source=False
170+
)
171+
172+
rcl_include_results = self.run_model_eval("RCL-include", rcl_include_model, kg, dataset_name=dataset_name)
173+
174+
# Check expectations for S2 (slightly more challenging than S1)
175+
assert gcl_results['H@1'] >= 0.85, f"GCL H@1 score too low: {gcl_results['H@1']}"
176+
assert rcl_results['H@1'] >= 0.85, f"RCL H@1 score too low: {rcl_results['H@1']}"
177+
assert rcl_include_results['H@1'] >= 0.85, f"RCL-include H@1 score too low: {rcl_include_results['H@1']}"
178+
179+
def test_umls(self):
180+
"""Compare GCL and RCL on UMLS dataset with limited test set"""
181+
# Setup
182+
dataset_name = "UMLS"
183+
kg = KG(dataset_dir=f"KGs/{dataset_name}", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
184+
185+
# Limit test size for UMLS to save time
186+
test_size = 20
187+
188+
# Test GCL with 3 hops
189+
gcl_model = GCL(
190+
knowledge_graph=kg,
191+
base_url=self.base_url,
192+
api_key=self.api_key,
193+
llm_model=self.llm_model,
194+
temperature=self.temperature,
195+
seed=self.seed,
196+
num_of_hops=3
197+
)
198+
199+
gcl_results = self.run_model_eval("GCL-3hops", gcl_model, kg, test_size=test_size, dataset_name=dataset_name)
200+
201+
# Test RCL with exclude_source=True (default)
202+
rcl_model = RCL(
203+
knowledge_graph=kg,
204+
base_url=self.base_url,
205+
api_key=self.api_key,
206+
llm_model=self.llm_model,
207+
temperature=self.temperature,
208+
seed=self.seed,
209+
max_relation_examples=50,
210+
exclude_source=True
211+
)
212+
213+
rcl_results = self.run_model_eval("RCL-exclude", rcl_model, kg, test_size=test_size, dataset_name=dataset_name)
214+
215+
# Test RCL with exclude_source=False
216+
rcl_include_model = RCL(
217+
knowledge_graph=kg,
218+
base_url=self.base_url,
219+
api_key=self.api_key,
220+
llm_model=self.llm_model,
221+
temperature=self.temperature,
222+
seed=self.seed,
223+
max_relation_examples=50,
224+
exclude_source=False
225+
)
226+
227+
rcl_include_results = self.run_model_eval("RCL-include", rcl_include_model, kg, test_size=test_size, dataset_name=dataset_name)
228+
229+
# Check expectations for UMLS
230+
assert gcl_results['H@1'] >= 0.45, f"GCL H@1 score too low: {gcl_results['H@1']}"
231+
assert rcl_results['H@1'] >= 0.45, f"RCL H@1 score too low: {rcl_results['H@1']}"
232+
assert rcl_include_results['H@1'] >= 0.45, f"RCL-include H@1 score too low: {rcl_include_results['H@1']}"
233+
234+
def test_kinship(self):
235+
"""Compare GCL and RCL on KINSHIP dataset with limited test set"""
236+
# Setup
237+
dataset_name = "KINSHIP"
238+
kg = KG(dataset_dir=f"KGs/{dataset_name}", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
239+
240+
# Limit test size for KINSHIP to save time
241+
test_size = 20
242+
243+
# Test GCL with 3 hops
244+
gcl_model = GCL(
245+
knowledge_graph=kg,
246+
base_url=self.base_url,
247+
api_key=self.api_key,
248+
llm_model=self.llm_model,
249+
temperature=self.temperature,
250+
seed=self.seed,
251+
num_of_hops=3
252+
)
253+
254+
gcl_results = self.run_model_eval("GCL-3hops", gcl_model, kg, test_size=test_size, dataset_name=dataset_name)
255+
256+
# Test RCL with exclude_source=True (default)
257+
rcl_model = RCL(
258+
knowledge_graph=kg,
259+
base_url=self.base_url,
260+
api_key=self.api_key,
261+
llm_model=self.llm_model,
262+
temperature=self.temperature,
263+
seed=self.seed,
264+
max_relation_examples=50,
265+
exclude_source=True
266+
)
267+
268+
rcl_results = self.run_model_eval("RCL-exclude", rcl_model, kg, test_size=test_size, dataset_name=dataset_name)
269+
270+
# Test RCL with exclude_source=False
271+
rcl_include_model = RCL(
272+
knowledge_graph=kg,
273+
base_url=self.base_url,
274+
api_key=self.api_key,
275+
llm_model=self.llm_model,
276+
temperature=self.temperature,
277+
seed=self.seed,
278+
max_relation_examples=50,
279+
exclude_source=False
280+
)
281+
282+
rcl_include_results = self.run_model_eval("RCL-include", rcl_include_model, kg, test_size=test_size, dataset_name=dataset_name)
283+
284+
# Check expectations for KINSHIP
285+
assert gcl_results['H@1'] >= 0.05, f"GCL H@1 score too low: {gcl_results['H@1']}"
286+
assert rcl_results['H@1'] >= 0.05, f"RCL H@1 score too low: {rcl_results['H@1']}"
287+
assert rcl_include_results['H@1'] >= 0.05, f"RCL-include H@1 score too low: {rcl_include_results['H@1']}"
288+
289+
def teardown_method(self):
290+
"""After each test method, print current results"""
291+
if len(self.results_data) % 3 == 0: # Print after each dataset's tests complete
292+
current_df = pd.DataFrame(self.results_data[-3:])
293+
print(f"\nLatest results:")
294+
print(current_df.to_string(index=False))
295+
296+
@pytest.fixture(scope="session", autouse=True)
297+
def save_comparison_results(self, request):
298+
"""Save all results at the end of the test session"""
299+
def finalize():
300+
if TestCompareGCLRCL.results_data:
301+
# Create a timestamp for the results file
302+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
303+
results_df = pd.DataFrame(TestCompareGCLRCL.results_data)
304+
305+
# Save to CSV in temp directory
306+
output_file = f"temp/compare_gcl_rcl_results_{timestamp}.csv"
307+
results_df.to_csv(output_file, index=False)
308+
print(f"\nComplete comparison results saved to {output_file}")
309+
310+
# Also save as JSON for easier programmatic access
311+
json_file = f"temp/compare_gcl_rcl_results_{timestamp}.json"
312+
results_df.to_json(json_file, orient="records", indent=2)
313+
314+
# Print the final comparison table
315+
print("\nFinal Comparison Results:")
316+
print(results_df.to_string(index=False))
317+
318+
# Create summary by dataset/model
319+
summary = results_df.groupby(['Dataset', 'Model']).mean().reset_index()
320+
print("\nAverage Metrics by Dataset and Model:")
321+
print(summary.to_string(index=False))
322+
323+
request.addfinalizer(finalize)

0 commit comments

Comments
 (0)