|
63 | 63 | from dicee.evaluator import evaluate_lp |
64 | 64 | from abc import ABC, abstractmethod |
65 | 65 | import torch |
| 66 | +import re |
66 | 67 |
|
67 | 68 | class KnowledgeGraphPredictor: |
68 | 69 | """ |
@@ -407,11 +408,96 @@ def __call__(self,indexed_triples:torch.LongTensor): |
407 | 408 | scores.append([0.0]) |
408 | 409 | return torch.FloatTensor(scores) |
409 | 410 |
|
| 411 | + |
| 412 | +class RALP(AbstractBaseLinkPredictorClass): |
| 413 | + def __init__(self, knowledge_graph: KG = None, |
| 414 | + name="ralp-1.0", |
| 415 | + base_url="http://tentris-ml.cs.upb.de:8501/v1", |
| 416 | + api_key=None, |
| 417 | + model="tentris")-> None: |
| 418 | + super().__init__(knowledge_graph, name) |
| 419 | + self.client = OpenAI(base_url=base_url, api_key=api_key) |
| 420 | + self.model = model |
| 421 | + |
| 422 | + def extract_float(self, text): |
| 423 | + """Extract the float number from a string. Used mainly to filter the LLM-output for the scoring task.""" |
| 424 | + pattern = r"-?\d*\.\d+|-?\d+\.\d*" |
| 425 | + match = re.search(pattern, text) |
| 426 | + return float(match.group()) if match else 0.0 |
| 427 | + |
| 428 | + def ru(self, entity): |
| 429 | + """Remove underscore from the entity (as str).""" |
| 430 | + return entity.replace("_", " ") |
| 431 | + |
| 432 | + def get_score(self, triple: tuple, triples_h: str) -> float: |
| 433 | + system_prompt = """You are an expert in knowledge graphs and link prediction. Your task is to assign a plausibility score (from 0 to 1) to a given triple (subject, predicate, object) based on a set of known training triples for the same subject. |
| 434 | +
|
| 435 | + - A score of 1.0 means the triple is highly likely to be true. |
| 436 | + - A score of 0.0 means the triple is highly unlikely to be true. |
| 437 | + - Intermediate values (e.g., 0.4, 0.7) reflect varying levels of plausibility. |
| 438 | + |
| 439 | +
|
| 440 | + **Guidelines for scoring:** |
| 441 | + 1. **Exact Match:** If the triple already exists in the training set or if the facts clearly state that the triple must be true assign a score close to 1.0. |
| 442 | + 2. **Pattern Matching:** If the predicate-object pair frequently occurs for the given subject, assign a high score. |
| 443 | + 3. **Semantic Similarity:** If the object is semantically close to known objects for the subject-predicate pair, assign a moderate to high score. |
| 444 | + 4. **Rare or Unseen Combinations:** If the triple does not follow the learned patterns, assign a low score. |
| 445 | + 5. **Contradictions:** If the triple contradicts existing facts (perform your own reasoning), assign a very low score. |
| 446 | +
|
| 447 | + You must analyze the given triple and the training triples, apply the reasoning above, and output only a single **floating-point score** between **0.0 and 1.0**, without any explanation or additional text. |
| 448 | + Do not depend only on triples provided to you, also use your own knowledge as an AI assistant to reason about the truthness of the given triple as a fact. |
| 449 | + You are strictly required to provide only the score as an answer and do not explain it.""" |
| 450 | + |
| 451 | + user_prompt = f"""Here is the triple we want to evaluate: |
| 452 | + (subject: {triple[0]}, predicate: {triple[1]}, object: {triple[2]}) |
| 453 | +
|
| 454 | + Here are the known training triples for the subject "{triple[0]}": |
| 455 | + {triples_h} |
| 456 | +
|
| 457 | + Assign a score to the given triple based on the provided training triples. |
| 458 | + """ |
| 459 | + response = self.client.chat.completions.create( |
| 460 | + model=self.model, |
| 461 | + messages=[ |
| 462 | + {"role": "system", "content": system_prompt}, |
| 463 | + {"role": "user", "content": user_prompt}, |
| 464 | + ], |
| 465 | + ) |
| 466 | + |
| 467 | + # Extract the response content |
| 468 | + content = response.choices[0].message.content |
| 469 | + return self.extract_float(content) |
| 470 | + |
| 471 | + def __call__(self, indexed_triples: torch.LongTensor): |
| 472 | + n, d = indexed_triples.shape |
| 473 | + # For the time being |
| 474 | + assert d == 3 |
| 475 | + assert n == 1 |
| 476 | + scores = [] |
| 477 | + for triple in indexed_triples.tolist(): |
| 478 | + idx_h, idx_r, idx_t = triple |
| 479 | + h, r, t = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r], self.idx_to_entity[idx_t] |
| 480 | + |
| 481 | + # Retrieve triples where 'h' is a subject or an object |
| 482 | + triples_h = [trp for trp in self.kg.train_set if (trp[0] == idx_h or trp[2] == idx_h)] |
| 483 | + |
| 484 | + # Format the triples into structured string output that will be used in the prompt. |
| 485 | + triples_h_str = "" |
| 486 | + for trp in triples_h: |
| 487 | + triples_h_str += f'- ("{self.ru(self.idx_to_entity[trp[0]])}", "{self.ru(self.idx_to_relation[trp[1]])}", "{self.ru(self.idx_to_entity[trp[2]])}") \n' |
| 488 | + |
| 489 | + # Get the score from the LLM |
| 490 | + score = self.get_score((h, r, t), triples_h) |
| 491 | + scores.append([score]) |
| 492 | + return torch.FloatTensor(scores) |
| 493 | + |
| 494 | + |
410 | 495 | if __name__ == "__main__": |
411 | 496 | # () Read / Preprocess KG |
412 | 497 | kg = KG(dataset_dir="KGs/Countries-S1",separator="\s+",eval_model="train_val_test") |
413 | 498 |
|
414 | | - evaluate_lp(model=Dummy(knowledge_graph=kg), triple_idx=kg.train_set, num_entities=len(kg.entity_to_idx), er_vocab=kg.er_vocab, |
| 499 | + # It takes ~14 h to evaluate this model :/ |
| 500 | + evaluate_lp(model=RALP(knowledge_graph=kg, api_key="API_KEY"), triple_idx=kg.train_set, num_entities=len(kg.entity_to_idx), er_vocab=kg.er_vocab, |
415 | 501 | re_vocab=kg.re_vocab, info='Eval LP Starts', batch_size=1, chunk_size=1) |
416 | 502 |
|
417 | 503 | # @TODO: Create classes inherits from AbstractBaseLinkPredictorClass and improve the link prediction results |
|
0 commit comments