Skip to content

Commit 2e6c6eb

Browse files
committed
refactoring
1 parent 24809b1 commit 2e6c6eb

1 file changed

Lines changed: 60 additions & 127 deletions

File tree

retrieval_augmented_link_predictor.py

Lines changed: 60 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
5050
5151
"""
52+
import argparse
5253

5354
#### NOTE: LF: First implementation approach
5455

@@ -391,19 +392,20 @@ def eval(self):
391392
def __call__(self,*args,**kwargs):
392393
"""Predicting missing triples"""
393394

395+
394396
class Dummy(AbstractBaseLinkPredictorClass):
395-
def __init__(self, knowledge_graph:KG=None, name="dummy") -> None:
396-
super().__init__(knowledge_graph,name)
397+
def __init__(self, knowledge_graph: KG = None, name="dummy") -> None:
398+
super().__init__(knowledge_graph, name)
397399

398400
def __call__(self,indexed_triples:torch.LongTensor):
399401
n,d=indexed_triples.shape
400402
# For the time being
401-
assert d==3
402-
assert n==1
403-
scores=[]
403+
assert d == 3
404+
assert n == 1
405+
scores = []
404406
for triple in indexed_triples.tolist():
405407
idx_h, idx_r, idx_t = triple
406-
h,r,t=self.idx_to_entity[idx_h], self.idx_to_relation[idx_r], self.idx_to_entity[idx_t]
408+
h, r, t = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r], self.idx_to_entity[idx_t]
407409
# Given this triple, we need to assign a score
408410
scores.append([0.0])
409411
return torch.FloatTensor(scores)
@@ -414,10 +416,12 @@ def __init__(self, knowledge_graph: KG = None,
414416
name="ralp-1.0",
415417
base_url="http://tentris-ml.cs.upb.de:8501/v1",
416418
api_key=None,
417-
model="tentris")-> None:
419+
llm_model="tentris",
420+
temperature=1) -> None:
418421
super().__init__(knowledge_graph, name)
419422
self.client = OpenAI(base_url=base_url, api_key=api_key)
420-
self.model = model
423+
self.llm_model = llm_model
424+
self.temperature = temperature
421425

422426
def extract_float(self, text):
423427
"""Extract the float number from a string. Used mainly to filter the LLM-output for the scoring task."""
@@ -457,11 +461,13 @@ def get_score(self, triple: tuple, triples_h: str) -> float:
457461
Assign a score to the given triple based on the provided training triples.
458462
"""
459463
response = self.client.chat.completions.create(
460-
model=self.model,
464+
model=self.llm_model,
461465
messages=[
462466
{"role": "system", "content": system_prompt},
463467
{"role": "user", "content": user_prompt},
464468
],
469+
seed=42,
470+
temperature=self.temperature
465471
)
466472

467473
# Extract the response content
@@ -487,126 +493,53 @@ def __call__(self, indexed_triples: torch.LongTensor):
487493
triples_h_str += f'- ("{self.ru(self.idx_to_entity[trp[0]])}", "{self.ru(self.idx_to_relation[trp[1]])}", "{self.ru(self.idx_to_entity[trp[2]])}") \n'
488494

489495
# Get the score from the LLM
490-
score = self.get_score((h, r, t), triples_h)
496+
score = self.get_score((h, r, t), triples_h_str)
491497
scores.append([score])
492498
return torch.FloatTensor(scores)
493499

494500

501+
def run(args):
502+
503+
# () Read KG
504+
kg = KG(dataset_dir=args.dataset_dir, separator="\s+", eval_model=args.eval_model)
505+
if args.eval_size is not None:
506+
assert len(kg.test_set) >= args.eval_size, (f"Evaluation size cant be greater than the "
507+
f"total amount of triples in the test set: {len(kg.test_set)}")
508+
else:
509+
args.eval_size = len(kg.test_set)
510+
model = None
511+
512+
# () Initialize the link prediction model
513+
if args.model == "RALP":
514+
model = RALP(knowledge_graph=kg,
515+
base_url=args.base_url,
516+
api_key=args.api_key,
517+
llm_model=args.llm_model_name,
518+
temperature=args.temperature)
519+
520+
assert model is not None, f"Couldn't assign a model named: {args.model}"
521+
522+
# () Start evaluation
523+
evaluate_lp(model=model, triple_idx=kg.test_set[:args.eval_size], num_entities=len(kg.entity_to_idx),
524+
er_vocab=kg.er_vocab, re_vocab=kg.re_vocab, info='Eval LP Starts', batch_size=args.batch_size,
525+
chunk_size=args.chunk_size)
526+
527+
495528
if __name__ == "__main__":
496-
# () Read / Preprocess KG
497-
kg = KG(dataset_dir="KGs/Countries-S1",separator="\s+",eval_model="train_val_test")
498-
499-
# It takes ~14 h to evaluate this model :/
500-
evaluate_lp(model=RALP(knowledge_graph=kg, api_key="API_KEY"), triple_idx=kg.train_set, num_entities=len(kg.entity_to_idx), er_vocab=kg.er_vocab,
501-
re_vocab=kg.re_vocab, info='Eval LP Starts', batch_size=1, chunk_size=1)
502-
503-
# @TODO: Create classes inherits from AbstractBaseLinkPredictorClass and improve the link prediction results
504-
exit(1)
505-
# @TODO:CD -> Luke: Please refactor the below code to work with the above code.
506-
# Create predictor (uses Tentris model by default)
507-
predictor = KnowledgeGraphPredictor()
508-
509-
print("\nExample: Countries that border Italy")
510-
head = "Italy"
511-
relation = "borders"
512-
candidates = ["France", "Austria", "Switzerland", "Slovenia", "Vatican", "San Marino"]
513-
514-
# Predict missing tails
515-
ranked_candidates = predictor.predict_missing_tails(head, relation, candidates, "data/countries.ttl")
516-
517-
# Print results
518-
print(f"\nPredicting missing tails for ({head}, {relation}, ?)")
519-
print("\nRanked candidates with scores:")
520-
for candidate, score in ranked_candidates:
521-
print(f"{candidate}: {score:.2f}")
522-
523-
print("\ndone!")
524-
525-
526-
'''
527-
Data used:
528-
529-
@prefix ex: <http://example.org/> .
530-
@prefix dbo: <http://dbpedia.org/ontology/> .
531-
@prefix dbr: <http://dbpedia.org/resource/> .
532-
@prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
533-
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
534-
535-
# Germany and its borders
536-
ex:Germany a dbo:Country ;
537-
rdfs:label "Germany" ;
538-
dbo:capital dbr:Berlin ;
539-
dbo:continent dbr:Europe ;
540-
dbo:borders ex:France, ex:Poland, ex:Netherlands, ex:Austria, ex:Czech_Republic, ex:Denmark .
541-
542-
# France with partial missing borders
543-
ex:France a dbo:Country ;
544-
rdfs:label "France" ;
545-
dbo:capital dbr:Paris ;
546-
dbo:continent dbr:Europe ;
547-
dbo:borders ex:Germany, ex:Belgium, ex:Spain, ex:Italy .
548-
# Missing Luxembourg, Switzerland, Monaco
549-
550-
# Poland with missing some eastern borders
551-
ex:Poland a dbo:Country ;
552-
rdfs:label "Poland" ;
553-
dbo:capital dbr:Warsaw ;
554-
dbo:continent dbr:Europe ;
555-
dbo:borders ex:Germany, ex:Czech_Republic, ex:Slovakia, ex:Lithuania .
556-
# Missing Ukraine, Belarus
557-
558-
# Netherlands with incomplete neighbors
559-
ex:Netherlands a dbo:Country ;
560-
rdfs:label "Netherlands" ;
561-
dbo:capital dbr:Amsterdam ;
562-
dbo:continent dbr:Europe ;
563-
dbo:borders ex:Germany, ex:Belgium .
564-
# Missing North Sea relation
565-
566-
# Belgium with missing some minor relations
567-
ex:Belgium a dbo:Country ;
568-
rdfs:label "Belgium" ;
569-
dbo:capital dbr:Brussels ;
570-
dbo:borders ex:France, ex:Netherlands, ex:Germany .
571-
# Missing Luxembourg
572-
573-
# Austria with some missing borders
574-
ex:Austria a dbo:Country ;
575-
rdfs:label "Austria" ;
576-
dbo:capital dbr:Vienna ;
577-
dbo:continent dbr:Europe ;
578-
dbo:borders ex:Germany, ex:Czech_Republic, ex:Slovakia, ex:Italy .
579-
# Missing Switzerland, Slovenia, Hungary
580-
581-
# Czech Republic with missing some eastern borders
582-
ex:Czech_Republic a dbo:Country ;
583-
rdfs:label "Czech Republic" ;
584-
dbo:capital dbr:Prague ;
585-
dbo:continent dbr:Europe ;
586-
dbo:borders ex:Germany, ex:Poland, ex:Austria .
587-
# Missing Slovakia
588-
589-
# Denmark with only one neighbor
590-
ex:Denmark a dbo:Country ;
591-
rdfs:label "Denmark" ;
592-
dbo:capital dbr:Copenhagen ;
593-
dbo:continent dbr:Europe ;
594-
dbo:borders ex:Germany .
595-
# Missing maritime neighbors (Sweden, Norway via sea)
596-
597-
# Italy with missing eastern borders
598-
ex:Italy a dbo:Country ;
599-
rdfs:label "Italy" ;
600-
dbo:capital dbr:Rome ;
601-
dbo:continent dbr:Europe ;
602-
dbo:borders ex:France, ex:Austria .
603-
# Missing Slovenia, Switzerland, Vatican, San Marino
604-
605-
# Slovakia with partial information
606-
ex:Slovakia a dbo:Country ;
607-
rdfs:label "Slovakia" ;
608-
dbo:capital dbr:Bratislava ;
609-
dbo:continent dbr:Europe ;
610-
dbo:borders ex:Poland, ex:Czech_Republic, ex:Austria .
611-
# Missing Hungary, Ukraine
612-
'''
529+
530+
parser = argparse.ArgumentParser()
531+
parser.add_argument("--dataset_dir", type=str, default="KGs/Countries-S1", help="Path to dataset.")
532+
parser.add_argument("--model", type=str, default="RALP", help="Model name to use for link prediction.", choices=["RALP"]) # add new models in 'choices'
533+
parser.add_argument("--base_url", type=str, default="http://tentris-ml.cs.upb.de:8501/v1",
534+
help="Base URL for the OpenAI client.")
535+
parser.add_argument("--llm_model_name", type=str, default="tentris", help="Model name of the LLM to use.")
536+
parser.add_argument("--api_key", type=str, default="INSERT_API_KEY", help="API key for the OpenAI client.")
537+
parser.add_argument("--temperature", type=float, default=1, help="Temperature hyperparameter for LLM calls.")
538+
parser.add_argument("--eval_size", type=int, default=None,
539+
help="Amount of triples from the test set to evaluate. "
540+
"Leave it None to include all triples on the test set.")
541+
parser.add_argument("--eval_model", type=str, default="train_value_test", help="Type of evaluation model.")
542+
parser.add_argument("--batch_size", type=int, default=1)
543+
parser.add_argument("--chunk_size", type=int, default=1)
544+
545+
run(parser.parse_args())

0 commit comments

Comments
 (0)