Skip to content

Commit 24809b1

Browse files
committed
added RALP model for link prediction
1 parent 7e9720f commit 24809b1

2 files changed

Lines changed: 88 additions & 2 deletions

File tree

dicee/knowledge_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, dataset_dir: str = None,
3333
sample_triples_ratio
3434
:param training_technique
3535
"""
36-
assert dataset_dir is not None, f"dataset_dir cannot be None"
36+
assert dataset_dir is not None, "dataset_dir cannot be None"
3737
self.dataset_dir = dataset_dir
3838
self.sparql_endpoint = sparql_endpoint
3939
self.path_single_kg = path_single_kg

retrieval_augmented_link_predictor.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from dicee.evaluator import evaluate_lp
6464
from abc import ABC, abstractmethod
6565
import torch
66+
import re
6667

6768
class KnowledgeGraphPredictor:
6869
"""
@@ -407,11 +408,96 @@ def __call__(self,indexed_triples:torch.LongTensor):
407408
scores.append([0.0])
408409
return torch.FloatTensor(scores)
409410

411+
412+
class RALP(AbstractBaseLinkPredictorClass):
413+
def __init__(self, knowledge_graph: KG = None,
414+
name="ralp-1.0",
415+
base_url="http://tentris-ml.cs.upb.de:8501/v1",
416+
api_key=None,
417+
model="tentris")-> None:
418+
super().__init__(knowledge_graph, name)
419+
self.client = OpenAI(base_url=base_url, api_key=api_key)
420+
self.model = model
421+
422+
def extract_float(self, text):
423+
"""Extract the float number from a string. Used mainly to filter the LLM-output for the scoring task."""
424+
pattern = r"-?\d*\.\d+|-?\d+\.\d*"
425+
match = re.search(pattern, text)
426+
return float(match.group()) if match else 0.0
427+
428+
def ru(self, entity):
429+
"""Remove underscore from the entity (as str)."""
430+
return entity.replace("_", " ")
431+
432+
def get_score(self, triple: tuple, triples_h: str) -> float:
433+
system_prompt = """You are an expert in knowledge graphs and link prediction. Your task is to assign a plausibility score (from 0 to 1) to a given triple (subject, predicate, object) based on a set of known training triples for the same subject.
434+
435+
- A score of 1.0 means the triple is highly likely to be true.
436+
- A score of 0.0 means the triple is highly unlikely to be true.
437+
- Intermediate values (e.g., 0.4, 0.7) reflect varying levels of plausibility.
438+
439+
440+
**Guidelines for scoring:**
441+
1. **Exact Match:** If the triple already exists in the training set or if the facts clearly state that the triple must be true assign a score close to 1.0.
442+
2. **Pattern Matching:** If the predicate-object pair frequently occurs for the given subject, assign a high score.
443+
3. **Semantic Similarity:** If the object is semantically close to known objects for the subject-predicate pair, assign a moderate to high score.
444+
4. **Rare or Unseen Combinations:** If the triple does not follow the learned patterns, assign a low score.
445+
5. **Contradictions:** If the triple contradicts existing facts (perform your own reasoning), assign a very low score.
446+
447+
You must analyze the given triple and the training triples, apply the reasoning above, and output only a single **floating-point score** between **0.0 and 1.0**, without any explanation or additional text.
448+
Do not depend only on triples provided to you, also use your own knowledge as an AI assistant to reason about the truthness of the given triple as a fact.
449+
You are strictly required to provide only the score as an answer and do not explain it."""
450+
451+
user_prompt = f"""Here is the triple we want to evaluate:
452+
(subject: {triple[0]}, predicate: {triple[1]}, object: {triple[2]})
453+
454+
Here are the known training triples for the subject "{triple[0]}":
455+
{triples_h}
456+
457+
Assign a score to the given triple based on the provided training triples.
458+
"""
459+
response = self.client.chat.completions.create(
460+
model=self.model,
461+
messages=[
462+
{"role": "system", "content": system_prompt},
463+
{"role": "user", "content": user_prompt},
464+
],
465+
)
466+
467+
# Extract the response content
468+
content = response.choices[0].message.content
469+
return self.extract_float(content)
470+
471+
def __call__(self, indexed_triples: torch.LongTensor):
472+
n, d = indexed_triples.shape
473+
# For the time being
474+
assert d == 3
475+
assert n == 1
476+
scores = []
477+
for triple in indexed_triples.tolist():
478+
idx_h, idx_r, idx_t = triple
479+
h, r, t = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r], self.idx_to_entity[idx_t]
480+
481+
# Retrieve triples where 'h' is a subject or an object
482+
triples_h = [trp for trp in self.kg.train_set if (trp[0] == idx_h or trp[2] == idx_h)]
483+
484+
# Format the triples into structured string output that will be used in the prompt.
485+
triples_h_str = ""
486+
for trp in triples_h:
487+
triples_h_str += f'- ("{self.ru(self.idx_to_entity[trp[0]])}", "{self.ru(self.idx_to_relation[trp[1]])}", "{self.ru(self.idx_to_entity[trp[2]])}") \n'
488+
489+
# Get the score from the LLM
490+
score = self.get_score((h, r, t), triples_h)
491+
scores.append([score])
492+
return torch.FloatTensor(scores)
493+
494+
410495
if __name__ == "__main__":
411496
# () Read / Preprocess KG
412497
kg = KG(dataset_dir="KGs/Countries-S1",separator="\s+",eval_model="train_val_test")
413498

414-
evaluate_lp(model=Dummy(knowledge_graph=kg), triple_idx=kg.train_set, num_entities=len(kg.entity_to_idx), er_vocab=kg.er_vocab,
499+
# It takes ~14 h to evaluate this model :/
500+
evaluate_lp(model=RALP(knowledge_graph=kg, api_key="API_KEY"), triple_idx=kg.train_set, num_entities=len(kg.entity_to_idx), er_vocab=kg.er_vocab,
415501
re_vocab=kg.re_vocab, info='Eval LP Starts', batch_size=1, chunk_size=1)
416502

417503
# @TODO: Create classes inherits from AbstractBaseLinkPredictorClass and improve the link prediction results

0 commit comments

Comments
 (0)