Skip to content

Commit bf91ecd

Browse files
committed
Refactoring
1 parent 2d1f2fd commit bf91ecd

1 file changed

Lines changed: 52 additions & 58 deletions

File tree

retrieval_augmented_link_predictor.py

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,20 @@
22
Additional dependencies
33
pip install openai==1.66.3
44
pip install igraph==0.11.8
5+
pip install jellyfish==1.1.3
56
7+
# @TODO:CD:@Luke I guess writing few regression tests would help us to ensure that our modifications would not break the model
8+
python retrieval_augmented_link_predictor.py --dataset_dir "KGs/Countries-S1" --model "GCL" --base_url "http://harebell.cs.upb.de:8501/v1" --num_of_hops 1
9+
@TODO: CD: There is some randomness on this setup. I dunno whay
10+
{'H@1': 0.7916666666666666, 'H@3': 0.875, 'H@10': 0.9583333333333334, 'MRR': 0.8472644080996884}
11+
12+
python retrieval_augmented_link_predictor.py --dataset_dir "KGs/Countries-S1" --model "GCL" --base_url "http://harebell.cs.upb.de:8501/v1" --num_of_hops 2
13+
{'H@1': 1.0, 'H@3': 1.0, 'H@10': 1.0, 'MRR': 1.0}
14+
15+
python retrieval_augmented_link_predictor.py --dataset_dir "KGs/Countries-S2" --model "GCL" --base_url "http://harebell.cs.upb.de:8501/v1" --num_of_hops 2
16+
@TODO: CD: There is some randomness on this setup. I dunno whay
17+
{'H@1': 0.875, 'H@3': 1.0, 'H@10': 1.0, 'MRR': 0.9305555555555555}
18+
{'H@1': 0.9166666666666666, 'H@3': 1.0, 'H@10': 1.0, 'MRR': 0.9583333333333334}
619
720
"""
821
import argparse
@@ -29,7 +42,6 @@
2942

3043
load_dotenv()
3144

32-
3345
class KnowledgeGraphPredictor:
3446
"""
3547
A class for predicting missing relations in knowledge graphs using LLMs.
@@ -324,7 +336,6 @@ def predict_missing_tails(self, head: str, relation: str, candidates: List[str],
324336
ranked_candidates.sort(key=lambda x: x[1], reverse=True)
325337

326338
return ranked_candidates
327-
328339
class AbstractBaseLinkPredictorClass(ABC):
329340
def __init__(self, knowledge_graph: KG = None, name="dummy"):
330341
assert knowledge_graph is not None
@@ -372,7 +383,6 @@ def __call__(self, x: torch.LongTensor | Tuple[torch.LongTensor, torch.LongTenso
372383
else:
373384
raise RuntimeError("Unsupported shape: {}".format(shape_info))
374385

375-
376386
class PredictionItem(BaseModel):
377387
"""Individual prediction item with entity name and confidence score."""
378388
entity: str = Field(..., description="Name of the predicted entity")
@@ -382,7 +392,7 @@ class PredictionResponse(BaseModel):
382392
predictions: List[PredictionItem] = Field(..., description="List of predicted entities with scores")
383393

384394
class GCL(AbstractBaseLinkPredictorClass):
385-
""" In context Learning on neighbouring triples to predict missing entities.
395+
""" in-context Learning on neighbouring triples to predict missing entities.
386396
387397
(h, r, t) \in G_test
388398
@@ -406,98 +416,80 @@ def __init__(self, knowledge_graph: KG = None,base_url:str=None, api_key:str=Non
406416
self.seed = seed
407417
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
408418
# Training dataset
409-
self.train_set:List[Tuple[str]] = [(self.idx_to_entity[idx_h],self.idx_to_relation[idx_r],self.idx_to_entity[idx_t]) for idx_h,idx_r,idx_t in self.kg.train_set.tolist()]
419+
self.train_set:List[Tuple[str]] = [(self.idx_to_entity[idx_h],
420+
self.idx_to_relation[idx_r],
421+
self.idx_to_entity[idx_t]) for idx_h,idx_r,idx_t in self.kg.train_set.tolist()]
410422
# Validation dataset
411423
self.val_set:List[Tuple[str]] = [(self.idx_to_entity[idx_h],
412424
self.idx_to_relation[idx_r],
413425
self.idx_to_entity[idx_t]) for idx_h,idx_r,idx_t in self.kg.valid_set.tolist()]
414426

415-
self.igraph = self.build_igraph(self.train_set + self.val_set if use_val else self.train_set)
427+
triples = self.train_set + self.val_set if use_val else self.train_set
428+
self.igraph = self.build_igraph(triples)
416429
self.str_entity_to_igraph_vertice={i["name"]:i for i in self.igraph.vs}
417430
self.str_rel_to_igraph_edges={i["label"]:i for i in self.igraph.es}
418431

432+
# Mapping from an entity to relevant triples.
433+
# A relevant triple contains a entity that is num_of_hops around of a given entity
419434
self.node_to_relevant_triples=dict()
420435
for entity, entity_node_object in self.str_entity_to_igraph_vertice.items():
421436
neighboring_nodes = self.igraph.neighborhood(entity_node_object, order=num_of_hops)
422437
subgraph = self.igraph.subgraph(neighboring_nodes)
423-
triples = {(subgraph.vs[edge.source]["name"], edge["label"], subgraph.vs[edge.target]["name"]) for edge in subgraph.es}
424-
self.node_to_relevant_triples[entity] = triples
438+
self.node_to_relevant_triples[entity] = {(subgraph.vs[edge.source]["name"], edge["label"], subgraph.vs[edge.target]["name"]) for edge in subgraph.es}
425439

426440
self.target_entities = list(sorted(self.entity_to_idx.keys()))
427441

428-
429-
def _create_prompt(self, source: str, relation: str) -> str:
430-
# neighbouring nodes of source
431-
base_prompt = f"""
432-
I'm trying to predict the most likely target entities for the following query:
433-
Source entity: {source}
434-
Relation: {relation}
435-
Query: ({source}, {relation}, ?)
436-
437-
Please provide a ranked list of the 10 most likely target entities from the following list, along with likelihoods for each: {self.target_entities}
438-
439-
Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": "float number"}}]}}
440-
441-
Notes:
442-
1. Only include entities that are plausible targets for this relation.
443-
2. For geographic entities, consider geographic location, regional classifications, and political associations.
444-
3. Rank the entities by likelihood of being the correct target.
445-
4. Only include entities from the provided list in your predictions.
446-
5. If certain entities are not suitable for this relation, don't include them.
447-
6. Return a valid JSON output.
448-
"""
449-
return base_prompt
450-
451442
def _create_prompt_based_on_neighbours(self, source: str, relation: str) -> str:
452443
# Get relevant triples for the source entity
453444
relevant_triples = []
454445
if source in self.node_to_relevant_triples:
455446
relevant_triples = list(self.node_to_relevant_triples[source])
456-
# Limit to top 20 triples to avoid overwhelming the prompt
457-
relevant_triples = relevant_triples
458447

459-
# Format the relevant triples for display
460-
triples_context = ""
461-
if relevant_triples:
462-
triples_context = "Here are some known facts about the source entity that might be relevant:\n"
463-
for s, p, o in relevant_triples:
464-
triples_context += f"- {s} {p} {o}\n"
465-
triples_context += "\n"
448+
assert len(relevant_triples) > 0
449+
# @TODO:CD:Potential improvement by trade offing the test runtime:
450+
# @TODO: Finding an some triples from relevant_triples while the prediction is being invariant to it
451+
# @TODO: Prediction does not change but the input size decreases
452+
# @TODO: The removed triples can be seen as noise
453+
triples_context = "Here are some known facts about the source entity that might be relevant:\n"
454+
for s, p, o in sorted(relevant_triples):
455+
triples_context += f"- {s} {p} {o}\n"
456+
triples_context += "\n"
466457

467458
# Important: Grouping relations is important to reach MRR 1.0
468459
similar_relations = []
469460
for s, p, o in relevant_triples:
470461
if p == relation and s != source:
471462
similar_relations.append((s, p, o))
472-
similar_relations_context = ""
473-
if similar_relations:
474-
similar_relations_context = "Here are examples of similar relations in the knowledge base:\n"
475-
for s, p, o in similar_relations:
476-
similar_relations_context += f"- {s} {p} {o}\n"
477-
similar_relations_context += "\n"
463+
464+
similar_relations_context = "Here are examples of similar relations in the knowledge base:\n"
465+
for s, p, o in similar_relations:
466+
similar_relations_context += f"- {s} {p} {o}\n"
467+
similar_relations_context += "\n"
468+
478469

479470
base_prompt = f"""
480471
I'm trying to predict the most likely target entities for the following query:
481472
Source entity: {source}
482473
Relation: {relation}
483474
Query: ({source}, {relation}, ?)
484475
476+
Subgraph Graph:
485477
{triples_context}
486478
{similar_relations_context}
487479
488-
Please provide a ranked list of the 10 most likely target entities from the following list, along with likelihoods for each: {self.target_entities}
480+
Please provide a ranked list of at most {min(len(self.target_entities),15)} likely target entities from the following list, along with likelihoods for each: {self.target_entities}
489481
490482
Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": float_number}}]}}
491483
492484
Notes:
485+
1. Use the provided knowledge about the source entity and similar relations to inform your predictions.
493486
1. Only include entities that are plausible targets for this relation.
494487
2. For geographic entities, consider geographic location, regional classifications, and political associations.
495488
3. Rank the entities by likelihood of being the correct target.
496-
4. Only include entities from the provided list in your predictions.
489+
4. ONLY INCLUDE entities from the provided list in your predictions.
497490
5. If certain entities are not suitable for this relation, don't include them.
498491
6. Return a valid JSON output.
499492
7. Make sure scores are floating point numbers between 0 and 1, not strings.
500-
8. Use the provided knowledge about the source entity and similar relations to inform your predictions.
501493
"""
502494
return base_prompt
503495

@@ -543,16 +535,18 @@ def forward_k_vs_all(self,x: torch.LongTensor) -> torch.FloatTensor:
543535
messages=[{"role": "user",
544536
"content": "You are a knowledgeable assistant that helps with link prediction tasks.\n" +
545537
self._create_prompt_based_on_neighbours(source=h, relation=r)}],
546-
extra_body={"guided_json": PredictionResponse.model_json_schema()}).choices[0].message.content
538+
extra_body={"guided_json": PredictionResponse.model_json_schema(),
539+
"truncate_prompt_tokens": 30_000,
540+
}).choices[0].message.content
547541

548542
prediction_response = PredictionResponse(**json.loads(llm_response))
549543
# Initialize scores for all entities
550-
scores_for_all_entities = [ 0.0 for _ in range(len(self.idx_to_entity))]
544+
scores_for_all_entities = [ -1.0 for _ in range(len(self.idx_to_entity))]
551545
for pred in prediction_response.predictions:
552546
try:
553547
scores_for_all_entities[self.entity_to_idx[pred.entity]]=pred.score
554548
except KeyError:
555-
print(f"For {h},{r}, {pred} not found")
549+
print(f"For {h},{r}, {pred} not found\tPrediction Size: {len(prediction_response.predictions)}")
556550
continue
557551
batch_output.append(scores_for_all_entities)
558552
return torch.FloatTensor(batch_output)
@@ -668,11 +662,12 @@ def get_model(args,kg)->AbstractBaseLinkPredictorClass:
668662
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
669663
elif args.model == "GCL":
670664
model = GCL(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
671-
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
665+
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed,num_of_hops=args.num_of_hops)
672666
else:
673667
raise KeyError(f"{args.model} is not a valid model")
674668
assert model is not None, f"Couldn't assign a model named: {args.model}"
675669
return model
670+
676671
def run(args):
677672
# Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
678673
# Therefore, The link prediction results are based on the missing tail rankings only!
@@ -682,10 +677,9 @@ def run(args):
682677

683678
model = get_model(args,kg)
684679

685-
evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
680+
results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
686681
er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size)
687-
688-
# @TODO:CD: We need to introduce a flag to use negative sampling eval or kvsall eval
682+
print(results)
689683
#evaluate_lp(model=model, triple_idx=kg.test_set[:args.eval_size], num_entities=len(kg.entity_to_idx),
690684
# er_vocab=kg.er_vocab, re_vocab=kg.re_vocab, info='Eval LP Starts', batch_size=args.batch_size,
691685
# chunk_size=args.chunk_size)
@@ -694,8 +688,7 @@ def run(args):
694688
if __name__ == "__main__":
695689
parser = argparse.ArgumentParser()
696690
parser.add_argument("--dataset_dir", type=str, default="KGs/Countries-S1", help="Path to dataset.")
697-
parser.add_argument("--model", type=str, default="GCL", help="Model name to use for link prediction.",
698-
choices=["RALP","GCL"])
691+
parser.add_argument("--model", type=str, default="GCL", help="Model name to use for link prediction.", choices=["RALP",'GCL'])
699692
parser.add_argument("--base_url", type=str, default="http://harebell.cs.upb.de:8501/v1",
700693
choices=["http://harebell.cs.upb.de:8501/v1", "http://tentris-ml.cs.upb.de:8502/v1"],
701694
help="Base URL for the OpenAI client.")
@@ -712,4 +705,5 @@ def run(args):
712705
parser.add_argument("--batch_size", type=int, default=1)
713706
parser.add_argument("--chunk_size", type=int, default=1)
714707
parser.add_argument("--seed", type=int, default=42)
708+
parser.add_argument("--num_of_hops", type=int, default=1, help="Number of hops to use to extract a subgraph around an entity-")
715709
run(parser.parse_args())

0 commit comments

Comments
 (0)