Skip to content

Commit 0fb586d

Browse files
committed
prompt update + few regression tests
1 parent bf91ecd commit 0fb586d

2 files changed

Lines changed: 115 additions & 0 deletions

File tree

retrieval_augmented_link_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def _create_prompt_based_on_neighbours(self, source: str, relation: str) -> str:
490490
5. If certain entities are not suitable for this relation, don't include them.
491491
6. Return a valid JSON output.
492492
7. Make sure scores are floating point numbers between 0 and 1, not strings.
493+
8. A score can only be between 0 and 1, i.e. score ∈ [0, 1]. They can never be negative or greater than 1!
493494
"""
494495
return base_prompt
495496

tests/test_regression_gcl.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import argparse
2+
import pytest
3+
import os
4+
from dicee.evaluator import evaluate_lp_k_vs_all
5+
from dicee.knowledge_graph import KG
6+
from retrieval_augmented_link_predictor import GCL
7+
8+
class TestRegressionGCL:
9+
"""Regression tests for the GCL (Graph Context Learning) model"""
10+
#@pytest.mark.filterwarnings('ignore::UserWarning')
11+
def test_countries_s1_hop3(self):
12+
"""Test GCL on Countries-S1 dataset with 3 hops"""
13+
# Setup
14+
kg = KG(dataset_dir="KGs/Countries-S1", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
15+
16+
# Get API key from environment variable
17+
api_key = os.environ.get("TENTRIS_TOKEN")
18+
assert api_key is not None, "TENTRIS_TOKEN environment variable not set"
19+
20+
# Initialize model
21+
model = GCL(
22+
knowledge_graph=kg,
23+
base_url="http://harebell.cs.upb.de:8501/v1",
24+
api_key=api_key,
25+
llm_model="tentris",
26+
temperature=0.0,
27+
seed=42,
28+
num_of_hops=3
29+
)
30+
31+
# Run evaluation
32+
results = evaluate_lp_k_vs_all(
33+
model=model,
34+
triple_idx=kg.test_set,
35+
er_vocab=kg.er_vocab,
36+
info='Testing Countries-S1 with 3 hops',
37+
batch_size=1
38+
)
39+
40+
# Check results - we expect perfect or near-perfect scores with 3 hops
41+
assert results['H@1'] >= 1.0, f"H@1 score too low: {results['H@1']}"
42+
assert results['H@3'] >= 1.0, f"H@3 score too low: {results['H@3']}"
43+
assert results['H@10'] >= 1.0, f"H@10 score too low: {results['H@10']}"
44+
assert results['MRR'] >= 1.0, f"MRR score too low: {results['MRR']}"
45+
46+
def test_umls_hop3(self):
47+
"""Test GCL on UMLS dataset with 2 hops"""
48+
# Setup
49+
kg = KG(dataset_dir="KGs/UMLS", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
50+
51+
# Get API key from environment variable
52+
api_key = os.environ.get("TENTRIS_TOKEN")
53+
assert api_key is not None, "TENTRIS_TOKEN environment variable not set"
54+
55+
# Initialize model
56+
model = GCL(
57+
knowledge_graph=kg,
58+
base_url="http://harebell.cs.upb.de:8501/v1",
59+
api_key=api_key,
60+
llm_model="tentris",
61+
temperature=0.0,
62+
seed=42,
63+
num_of_hops=3
64+
)
65+
66+
# Run evaluation
67+
results = evaluate_lp_k_vs_all(
68+
model=model,
69+
triple_idx=kg.test_set[:24],
70+
er_vocab=kg.er_vocab,
71+
info='Testing UMLS with 2 hops',
72+
batch_size=1
73+
)
74+
75+
# Check results - based on typical performance for UMLS
76+
assert results['H@1'] >= 0.1, f"H@1 score too low: {results['H@1']}"
77+
assert results['H@3'] >= 0.1, f"H@3 score too low: {results['H@3']}"
78+
assert results['H@10'] >= 0.1, f"H@10 score too low: {results['H@10']}"
79+
assert results['MRR'] >= 0.1, f"MRR score too low: {results['MRR']}"
80+
81+
def test_kinship_hop3(self):
82+
"""Test GCL on KINSHIP dataset with 3 hops"""
83+
# Setup
84+
kg = KG(dataset_dir="KGs/KINSHIP", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
85+
86+
# Get API key from environment variable
87+
api_key = os.environ.get("TENTRIS_TOKEN")
88+
assert api_key is not None, "TENTRIS_TOKEN environment variable not set"
89+
90+
# Initialize model
91+
model = GCL(
92+
knowledge_graph=kg,
93+
base_url="http://harebell.cs.upb.de:8501/v1",
94+
api_key=api_key,
95+
llm_model="tentris",
96+
temperature=0.0,
97+
seed=42,
98+
num_of_hops=3
99+
)
100+
101+
# Run evaluation
102+
results = evaluate_lp_k_vs_all(
103+
model=model,
104+
triple_idx=kg.test_set[:24],
105+
er_vocab=kg.er_vocab,
106+
info='Testing KINSHIP with 2 hops',
107+
batch_size=1
108+
)
109+
110+
# Check results - based on typical performance for KINSHIP
111+
assert results['H@1'] >= 0.1, f"H@1 score too low: {results['H@1']}"
112+
assert results['H@3'] >= 0.08, f"H@3 score too low: {results['H@3']}"
113+
assert results['H@10'] >= 0.08, f"H@10 score too low: {results['H@10']}"
114+
assert results['MRR'] >= 0.1, f"MRR score too low: {results['MRR']}"

0 commit comments

Comments
 (0)