1+ import argparse
2+ import pytest
3+ import os
4+ from dicee .evaluator import evaluate_lp_k_vs_all
5+ from dicee .knowledge_graph import KG
6+ from retrieval_augmented_link_predictor import GCL
7+
8+ class TestRegressionGCL :
9+ """Regression tests for the GCL (Graph Context Learning) model"""
10+ #@pytest.mark.filterwarnings('ignore::UserWarning')
11+ def test_countries_s1_hop3 (self ):
12+ """Test GCL on Countries-S1 dataset with 3 hops"""
13+ # Setup
14+ kg = KG (dataset_dir = "KGs/Countries-S1" , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
15+
16+ # Get API key from environment variable
17+ api_key = os .environ .get ("TENTRIS_TOKEN" )
18+ assert api_key is not None , "TENTRIS_TOKEN environment variable not set"
19+
20+ # Initialize model
21+ model = GCL (
22+ knowledge_graph = kg ,
23+ base_url = "http://harebell.cs.upb.de:8501/v1" ,
24+ api_key = api_key ,
25+ llm_model = "tentris" ,
26+ temperature = 0.0 ,
27+ seed = 42 ,
28+ num_of_hops = 3
29+ )
30+
31+ # Run evaluation
32+ results = evaluate_lp_k_vs_all (
33+ model = model ,
34+ triple_idx = kg .test_set ,
35+ er_vocab = kg .er_vocab ,
36+ info = 'Testing Countries-S1 with 3 hops' ,
37+ batch_size = 1
38+ )
39+
40+ # Check results - we expect perfect or near-perfect scores with 3 hops
41+ assert results ['H@1' ] >= 1.0 , f"H@1 score too low: { results ['H@1' ]} "
42+ assert results ['H@3' ] >= 1.0 , f"H@3 score too low: { results ['H@3' ]} "
43+ assert results ['H@10' ] >= 1.0 , f"H@10 score too low: { results ['H@10' ]} "
44+ assert results ['MRR' ] >= 1.0 , f"MRR score too low: { results ['MRR' ]} "
45+
46+ def test_umls_hop3 (self ):
47+ """Test GCL on UMLS dataset with 2 hops"""
48+ # Setup
49+ kg = KG (dataset_dir = "KGs/UMLS" , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
50+
51+ # Get API key from environment variable
52+ api_key = os .environ .get ("TENTRIS_TOKEN" )
53+ assert api_key is not None , "TENTRIS_TOKEN environment variable not set"
54+
55+ # Initialize model
56+ model = GCL (
57+ knowledge_graph = kg ,
58+ base_url = "http://harebell.cs.upb.de:8501/v1" ,
59+ api_key = api_key ,
60+ llm_model = "tentris" ,
61+ temperature = 0.0 ,
62+ seed = 42 ,
63+ num_of_hops = 3
64+ )
65+
66+ # Run evaluation
67+ results = evaluate_lp_k_vs_all (
68+ model = model ,
69+ triple_idx = kg .test_set [:24 ],
70+ er_vocab = kg .er_vocab ,
71+ info = 'Testing UMLS with 2 hops' ,
72+ batch_size = 1
73+ )
74+
75+ # Check results - based on typical performance for UMLS
76+ assert results ['H@1' ] >= 0.1 , f"H@1 score too low: { results ['H@1' ]} "
77+ assert results ['H@3' ] >= 0.1 , f"H@3 score too low: { results ['H@3' ]} "
78+ assert results ['H@10' ] >= 0.1 , f"H@10 score too low: { results ['H@10' ]} "
79+ assert results ['MRR' ] >= 0.1 , f"MRR score too low: { results ['MRR' ]} "
80+
81+ def test_kinship_hop3 (self ):
82+ """Test GCL on KINSHIP dataset with 3 hops"""
83+ # Setup
84+ kg = KG (dataset_dir = "KGs/KINSHIP" , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
85+
86+ # Get API key from environment variable
87+ api_key = os .environ .get ("TENTRIS_TOKEN" )
88+ assert api_key is not None , "TENTRIS_TOKEN environment variable not set"
89+
90+ # Initialize model
91+ model = GCL (
92+ knowledge_graph = kg ,
93+ base_url = "http://harebell.cs.upb.de:8501/v1" ,
94+ api_key = api_key ,
95+ llm_model = "tentris" ,
96+ temperature = 0.0 ,
97+ seed = 42 ,
98+ num_of_hops = 3
99+ )
100+
101+ # Run evaluation
102+ results = evaluate_lp_k_vs_all (
103+ model = model ,
104+ triple_idx = kg .test_set [:24 ],
105+ er_vocab = kg .er_vocab ,
106+ info = 'Testing KINSHIP with 2 hops' ,
107+ batch_size = 1
108+ )
109+
110+ # Check results - based on typical performance for KINSHIP
111+ assert results ['H@1' ] >= 0.1 , f"H@1 score too low: { results ['H@1' ]} "
112+ assert results ['H@3' ] >= 0.08 , f"H@3 score too low: { results ['H@3' ]} "
113+ assert results ['H@10' ] >= 0.08 , f"H@10 score too low: { results ['H@10' ]} "
114+ assert results ['MRR' ] >= 0.1 , f"MRR score too low: { results ['MRR' ]} "
0 commit comments