@@ -647,6 +647,124 @@ def forward_triples(self, indexed_triples: torch.LongTensor):
647647 return torch .FloatTensor (scores )
648648
649649
650+ class RCL (AbstractBaseLinkPredictorClass ):
651+ """ Relation-based Context Learning to predict missing entities.
652+
653+ (h, r, t) ∈ G_test
654+
655+ 1. Use all triples from G_train involving relation r to create context.
656+ 2. Generate a prompt based on these triples and (h,r) to assign scores for all e ∈ E.
657+ """
658+ def __init__ (self , knowledge_graph : KG = None , base_url :str = None , api_key :str = None , llm_model :str = None ,
659+ temperature :float = 0.0 , seed :int = 42 , max_relation_examples :int = 50 , use_val :bool = True ,
660+ exclude_source :bool = True ) -> None :
661+ super ().__init__ (knowledge_graph , name = "RCL" )
662+ assert base_url is not None and isinstance (base_url , str )
663+ self .base_url = base_url
664+ self .api_key = api_key
665+ self .llm_model = llm_model
666+ self .temperature = temperature
667+ self .seed = seed
668+ self .max_relation_examples = max_relation_examples
669+ self .exclude_source = exclude_source
670+ self .client = OpenAI (base_url = self .base_url , api_key = self .api_key )
671+
672+ # Training dataset
673+ self .train_set :List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
674+ self .idx_to_relation [idx_r ],
675+ self .idx_to_entity [idx_t ]) for idx_h ,idx_r ,idx_t in self .kg .train_set .tolist ()]
676+ # Validation dataset
677+ self .val_set :List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
678+ self .idx_to_relation [idx_r ],
679+ self .idx_to_entity [idx_t ]) for idx_h ,idx_r ,idx_t in self .kg .valid_set .tolist ()]
680+
681+ triples = self .train_set + self .val_set if use_val else self .train_set
682+
683+ # Create a mapping from relation to all triples using that relation
684+ self .relation_to_triples = {}
685+ for s , p , o in triples :
686+ if p not in self .relation_to_triples :
687+ self .relation_to_triples [p ] = []
688+ self .relation_to_triples [p ].append ((s , p , o ))
689+
690+ self .target_entities = list (sorted (self .entity_to_idx .keys ()))
691+
692+ def _create_prompt_based_on_relation (self , source : str , relation : str ) -> str :
693+ # Get all triples with the current relation
694+ relation_triples = []
695+ if relation in self .relation_to_triples :
696+ relation_triples = self .relation_to_triples [relation ]
697+
698+ # Exclude triples where the source entity is the current one if flag is set
699+ if self .exclude_source :
700+ relation_triples = [triple for triple in relation_triples if triple [0 ] != source ]
701+
702+ # Limit examples if too many
703+ if len (relation_triples ) > self .max_relation_examples :
704+ relation_triples = relation_triples [:self .max_relation_examples ]
705+
706+ relation_context = "Here are examples of how the relation is used in the knowledge base:\n "
707+ for s , p , o in sorted (relation_triples ):
708+ relation_context += f"- { s } { p } { o } \n "
709+ relation_context += "\n "
710+
711+ base_prompt = f"""
712+ I'm trying to predict the most likely target entities for the following query:
713+ Source entity: { source }
714+ Relation: { relation }
715+ Query: ({ source } , { relation } , ?)
716+
717+ { relation_context }
718+
719+ Please provide a ranked list of at most { min (len (self .target_entities ),15 )} likely target entities from the following list, along with likelihoods for each: { self .target_entities }
720+
721+ Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": float_number}}]}}
722+
723+ Notes:
724+ 1. Use the provided knowledge about how the relation is used to inform your predictions.
725+ 2. Only include entities that are plausible targets for this relation.
726+ 3. For geographic entities, consider geographic location, regional classifications, and political associations.
727+ 4. Rank the entities by likelihood of being the correct target.
728+ 5. ONLY INCLUDE entities from the provided list in your predictions.
729+ 6. If certain entities are not suitable for this relation, don't include them.
730+ 7. Return a valid JSON output.
731+ 8. Make sure scores are floating point numbers between 0 and 1, not strings.
732+ 9. A score can only be between 0 and 1, i.e. score ∈ [0, 1]. They can never be negative or greater than 1!
733+ """
734+ return base_prompt
735+
736+ def forward_triples (self , x : torch .LongTensor ):
737+ raise NotImplementedError ("RCL needs to implement it" )
738+
739+ def forward_k_vs_all (self ,x : torch .LongTensor ) -> torch .FloatTensor :
740+ batch_output = []
741+ # Iterate over batch of subject and relation pairs
742+ for i in x .tolist ():
743+ # index of an entity and index of a relation.
744+ idx_h , idx_r = i
745+ # String representations of an entity and a relation, respectively.
746+ h , r = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ]
747+ llm_response = self .client .chat .completions .create (
748+ model = self .llm_model , temperature = self .temperature , seed = self .seed ,
749+ messages = [{"role" : "user" ,
750+ "content" : "You are a knowledgeable assistant that helps with link prediction tasks.\n " +
751+ self ._create_prompt_based_on_relation (source = h , relation = r )}],
752+ extra_body = {"guided_json" : PredictionResponse .model_json_schema (),
753+ "truncate_prompt_tokens" : 30_000 ,
754+ }).choices [0 ].message .content
755+
756+ prediction_response = PredictionResponse (** json .loads (llm_response ))
757+ # Initialize scores for all entities
758+ scores_for_all_entities = [ - 1.0 for _ in range (len (self .idx_to_entity ))]
759+ for pred in prediction_response .predictions :
760+ try :
761+ scores_for_all_entities [self .entity_to_idx [pred .entity ]]= pred .score
762+ except KeyError :
763+ print (f"For { h } ,{ r } , { pred } not found\t Prediction Size: { len (prediction_response .predictions )} " )
764+ continue
765+ batch_output .append (scores_for_all_entities )
766+ return torch .FloatTensor (batch_output )
767+
650768def sanity_checking (args ,kg ):
651769 if args .eval_size is not None :
652770 assert len (kg .test_set ) >= args .eval_size , (f"Evaluation size cant be greater than the "
@@ -664,6 +782,10 @@ def get_model(args,kg)->AbstractBaseLinkPredictorClass:
664782 elif args .model == "GCL" :
665783 model = GCL (knowledge_graph = kg , base_url = args .base_url , api_key = args .api_key ,
666784 llm_model = args .llm_model_name , temperature = args .temperature , seed = args .seed ,num_of_hops = args .num_of_hops )
785+ elif args .model == "RCL" :
786+ model = RCL (knowledge_graph = kg , base_url = args .base_url , api_key = args .api_key ,
787+ llm_model = args .llm_model_name , temperature = args .temperature , seed = args .seed ,
788+ max_relation_examples = args .max_relation_examples , exclude_source = args .exclude_source )
667789 else :
668790 raise KeyError (f"{ args .model } is not a valid model" )
669791 assert model is not None , f"Couldn't assign a model named: { args .model } "
@@ -689,7 +811,7 @@ def run(args):
689811if __name__ == "__main__" :
690812 parser = argparse .ArgumentParser ()
691813 parser .add_argument ("--dataset_dir" , type = str , default = "KGs/Countries-S1" , help = "Path to dataset." )
692- parser .add_argument ("--model" , type = str , default = "GCL" , help = "Model name to use for link prediction." , choices = ["RALP" ,' GCL' ])
814+ parser .add_argument ("--model" , type = str , default = "GCL" , help = "Model name to use for link prediction." , choices = ["RALP" , " GCL" , "RCL" ])
693815 parser .add_argument ("--base_url" , type = str , default = "http://harebell.cs.upb.de:8501/v1" ,
694816 choices = ["http://harebell.cs.upb.de:8501/v1" , "http://tentris-ml.cs.upb.de:8502/v1" ],
695817 help = "Base URL for the OpenAI client." )
@@ -706,5 +828,7 @@ def run(args):
706828 parser .add_argument ("--batch_size" , type = int , default = 1 )
707829 parser .add_argument ("--chunk_size" , type = int , default = 1 )
708830 parser .add_argument ("--seed" , type = int , default = 42 )
709- parser .add_argument ("--num_of_hops" , type = int , default = 1 , help = "Number of hops to use to extract a subgraph around an entity-" )
831+ parser .add_argument ("--num_of_hops" , type = int , default = 1 , help = "Number of hops to use to extract a subgraph around an entity." )
832+ parser .add_argument ("--max_relation_examples" , type = int , default = 50 , help = "Maximum number of relation examples to include in RCL context." )
833+ parser .add_argument ("--exclude_source" , default = True , help = "Exclude triples with the same source entity in RCL context." )
710834 run (parser .parse_args ())
0 commit comments