4949
5050
5151"""
52+ import argparse
5253
5354#### NOTE: LF: First implementation approach
5455
@@ -391,19 +392,20 @@ def eval(self):
391392 def __call__ (self ,* args ,** kwargs ):
392393 """Predicting missing triples"""
393394
395+
394396class Dummy (AbstractBaseLinkPredictorClass ):
395- def __init__ (self , knowledge_graph :KG = None , name = "dummy" ) -> None :
396- super ().__init__ (knowledge_graph ,name )
397+ def __init__ (self , knowledge_graph : KG = None , name = "dummy" ) -> None :
398+ super ().__init__ (knowledge_graph , name )
397399
398400 def __call__ (self ,indexed_triples :torch .LongTensor ):
399401 n ,d = indexed_triples .shape
400402 # For the time being
401- assert d == 3
402- assert n == 1
403- scores = []
403+ assert d == 3
404+ assert n == 1
405+ scores = []
404406 for triple in indexed_triples .tolist ():
405407 idx_h , idx_r , idx_t = triple
406- h ,r , t = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ], self .idx_to_entity [idx_t ]
408+ h , r , t = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ], self .idx_to_entity [idx_t ]
407409 # Given this triple, we need to assign a score
408410 scores .append ([0.0 ])
409411 return torch .FloatTensor (scores )
@@ -414,10 +416,12 @@ def __init__(self, knowledge_graph: KG = None,
414416 name = "ralp-1.0" ,
415417 base_url = "http://tentris-ml.cs.upb.de:8501/v1" ,
416418 api_key = None ,
417- model = "tentris" )-> None :
419+ llm_model = "tentris" ,
420+ temperature = 1 ) -> None :
418421 super ().__init__ (knowledge_graph , name )
419422 self .client = OpenAI (base_url = base_url , api_key = api_key )
420- self .model = model
423+ self .llm_model = llm_model
424+ self .temperature = temperature
421425
422426 def extract_float (self , text ):
423427 """Extract the float number from a string. Used mainly to filter the LLM-output for the scoring task."""
@@ -457,11 +461,13 @@ def get_score(self, triple: tuple, triples_h: str) -> float:
457461 Assign a score to the given triple based on the provided training triples.
458462 """
459463 response = self .client .chat .completions .create (
460- model = self .model ,
464+ model = self .llm_model ,
461465 messages = [
462466 {"role" : "system" , "content" : system_prompt },
463467 {"role" : "user" , "content" : user_prompt },
464468 ],
469+ seed = 42 ,
470+ temperature = self .temperature
465471 )
466472
467473 # Extract the response content
@@ -487,126 +493,53 @@ def __call__(self, indexed_triples: torch.LongTensor):
487493 triples_h_str += f'- ("{ self .ru (self .idx_to_entity [trp [0 ]])} ", "{ self .ru (self .idx_to_relation [trp [1 ]])} ", "{ self .ru (self .idx_to_entity [trp [2 ]])} ") \n '
488494
489495 # Get the score from the LLM
490- score = self .get_score ((h , r , t ), triples_h )
496+ score = self .get_score ((h , r , t ), triples_h_str )
491497 scores .append ([score ])
492498 return torch .FloatTensor (scores )
493499
494500
501+ def run (args ):
502+
503+ # () Read KG
504+ kg = KG (dataset_dir = args .dataset_dir , separator = "\s+" , eval_model = args .eval_model )
505+ if args .eval_size is not None :
506+ assert len (kg .test_set ) >= args .eval_size , (f"Evaluation size cant be greater than the "
507+ f"total amount of triples in the test set: { len (kg .test_set )} " )
508+ else :
509+ args .eval_size = len (kg .test_set )
510+ model = None
511+
512+ # () Initialize the link prediction model
513+ if args .model == "RALP" :
514+ model = RALP (knowledge_graph = kg ,
515+ base_url = args .base_url ,
516+ api_key = args .api_key ,
517+ llm_model = args .llm_model_name ,
518+ temperature = args .temperature )
519+
520+ assert model is not None , f"Couldn't assign a model named: { args .model } "
521+
522+ # () Start evaluation
523+ evaluate_lp (model = model , triple_idx = kg .test_set [:args .eval_size ], num_entities = len (kg .entity_to_idx ),
524+ er_vocab = kg .er_vocab , re_vocab = kg .re_vocab , info = 'Eval LP Starts' , batch_size = args .batch_size ,
525+ chunk_size = args .chunk_size )
526+
527+
495528if __name__ == "__main__" :
496- # () Read / Preprocess KG
497- kg = KG (dataset_dir = "KGs/Countries-S1" ,separator = "\s+" ,eval_model = "train_val_test" )
498-
499- # It takes ~14 h to evaluate this model :/
500- evaluate_lp (model = RALP (knowledge_graph = kg , api_key = "API_KEY" ), triple_idx = kg .train_set , num_entities = len (kg .entity_to_idx ), er_vocab = kg .er_vocab ,
501- re_vocab = kg .re_vocab , info = 'Eval LP Starts' , batch_size = 1 , chunk_size = 1 )
502-
503- # @TODO: Create classes inherits from AbstractBaseLinkPredictorClass and improve the link prediction results
504- exit (1 )
505- # @TODO:CD -> Luke: Please refactor the below code to work with the above code.
506- # Create predictor (uses Tentris model by default)
507- predictor = KnowledgeGraphPredictor ()
508-
509- print ("\n Example: Countries that border Italy" )
510- head = "Italy"
511- relation = "borders"
512- candidates = ["France" , "Austria" , "Switzerland" , "Slovenia" , "Vatican" , "San Marino" ]
513-
514- # Predict missing tails
515- ranked_candidates = predictor .predict_missing_tails (head , relation , candidates , "data/countries.ttl" )
516-
517- # Print results
518- print (f"\n Predicting missing tails for ({ head } , { relation } , ?)" )
519- print ("\n Ranked candidates with scores:" )
520- for candidate , score in ranked_candidates :
521- print (f"{ candidate } : { score :.2f} " )
522-
523- print ("\n done!" )
524-
525-
526- '''
527- Data used:
528-
529- @prefix ex: <http://example.org/> .
530- @prefix dbo: <http://dbpedia.org/ontology/> .
531- @prefix dbr: <http://dbpedia.org/resource/> .
532- @prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
533- @prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
534-
535- # Germany and its borders
536- ex:Germany a dbo:Country ;
537- rdfs:label "Germany" ;
538- dbo:capital dbr:Berlin ;
539- dbo:continent dbr:Europe ;
540- dbo:borders ex:France, ex:Poland, ex:Netherlands, ex:Austria, ex:Czech_Republic, ex:Denmark .
541-
542- # France with partial missing borders
543- ex:France a dbo:Country ;
544- rdfs:label "France" ;
545- dbo:capital dbr:Paris ;
546- dbo:continent dbr:Europe ;
547- dbo:borders ex:Germany, ex:Belgium, ex:Spain, ex:Italy .
548- # Missing Luxembourg, Switzerland, Monaco
549-
550- # Poland with missing some eastern borders
551- ex:Poland a dbo:Country ;
552- rdfs:label "Poland" ;
553- dbo:capital dbr:Warsaw ;
554- dbo:continent dbr:Europe ;
555- dbo:borders ex:Germany, ex:Czech_Republic, ex:Slovakia, ex:Lithuania .
556- # Missing Ukraine, Belarus
557-
558- # Netherlands with incomplete neighbors
559- ex:Netherlands a dbo:Country ;
560- rdfs:label "Netherlands" ;
561- dbo:capital dbr:Amsterdam ;
562- dbo:continent dbr:Europe ;
563- dbo:borders ex:Germany, ex:Belgium .
564- # Missing North Sea relation
565-
566- # Belgium with missing some minor relations
567- ex:Belgium a dbo:Country ;
568- rdfs:label "Belgium" ;
569- dbo:capital dbr:Brussels ;
570- dbo:borders ex:France, ex:Netherlands, ex:Germany .
571- # Missing Luxembourg
572-
573- # Austria with some missing borders
574- ex:Austria a dbo:Country ;
575- rdfs:label "Austria" ;
576- dbo:capital dbr:Vienna ;
577- dbo:continent dbr:Europe ;
578- dbo:borders ex:Germany, ex:Czech_Republic, ex:Slovakia, ex:Italy .
579- # Missing Switzerland, Slovenia, Hungary
580-
581- # Czech Republic with missing some eastern borders
582- ex:Czech_Republic a dbo:Country ;
583- rdfs:label "Czech Republic" ;
584- dbo:capital dbr:Prague ;
585- dbo:continent dbr:Europe ;
586- dbo:borders ex:Germany, ex:Poland, ex:Austria .
587- # Missing Slovakia
588-
589- # Denmark with only one neighbor
590- ex:Denmark a dbo:Country ;
591- rdfs:label "Denmark" ;
592- dbo:capital dbr:Copenhagen ;
593- dbo:continent dbr:Europe ;
594- dbo:borders ex:Germany .
595- # Missing maritime neighbors (Sweden, Norway via sea)
596-
597- # Italy with missing eastern borders
598- ex:Italy a dbo:Country ;
599- rdfs:label "Italy" ;
600- dbo:capital dbr:Rome ;
601- dbo:continent dbr:Europe ;
602- dbo:borders ex:France, ex:Austria .
603- # Missing Slovenia, Switzerland, Vatican, San Marino
604-
605- # Slovakia with partial information
606- ex:Slovakia a dbo:Country ;
607- rdfs:label "Slovakia" ;
608- dbo:capital dbr:Bratislava ;
609- dbo:continent dbr:Europe ;
610- dbo:borders ex:Poland, ex:Czech_Republic, ex:Austria .
611- # Missing Hungary, Ukraine
612- '''
529+
530+ parser = argparse .ArgumentParser ()
531+ parser .add_argument ("--dataset_dir" , type = str , default = "KGs/Countries-S1" , help = "Path to dataset." )
532+ parser .add_argument ("--model" , type = str , default = "RALP" , help = "Model name to use for link prediction." , choices = ["RALP" ]) # add new models in 'choices'
533+ parser .add_argument ("--base_url" , type = str , default = "http://tentris-ml.cs.upb.de:8501/v1" ,
534+ help = "Base URL for the OpenAI client." )
535+ parser .add_argument ("--llm_model_name" , type = str , default = "tentris" , help = "Model name of the LLM to use." )
536+ parser .add_argument ("--api_key" , type = str , default = "INSERT_API_KEY" , help = "API key for the OpenAI client." )
537+ parser .add_argument ("--temperature" , type = float , default = 1 , help = "Temperature hyperparameter for LLM calls." )
538+ parser .add_argument ("--eval_size" , type = int , default = None ,
539+ help = "Amount of triples from the test set to evaluate. "
540+ "Leave it None to include all triples on the test set." )
541+ parser .add_argument ("--eval_model" , type = str , default = "train_value_test" , help = "Type of evaluation model." )
542+ parser .add_argument ("--batch_size" , type = int , default = 1 )
543+ parser .add_argument ("--chunk_size" , type = int , default = 1 )
544+
545+ run (parser .parse_args ())
0 commit comments