Skip to content

Commit 771b99f

Browse files
committed
RCL model added (context generated based on relations only)
1 parent 0fb586d commit 771b99f

1 file changed

Lines changed: 126 additions & 2 deletions

File tree

retrieval_augmented_link_predictor.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,124 @@ def forward_triples(self, indexed_triples: torch.LongTensor):
647647
return torch.FloatTensor(scores)
648648

649649

650+
class RCL(AbstractBaseLinkPredictorClass):
651+
""" Relation-based Context Learning to predict missing entities.
652+
653+
(h, r, t) ∈ G_test
654+
655+
1. Use all triples from G_train involving relation r to create context.
656+
2. Generate a prompt based on these triples and (h,r) to assign scores for all e ∈ E.
657+
"""
658+
def __init__(self, knowledge_graph: KG = None, base_url:str=None, api_key:str=None, llm_model:str=None,
659+
temperature:float=0.0, seed:int=42, max_relation_examples:int=50, use_val:bool=True,
660+
exclude_source:bool=True) -> None:
661+
super().__init__(knowledge_graph, name="RCL")
662+
assert base_url is not None and isinstance(base_url, str)
663+
self.base_url = base_url
664+
self.api_key = api_key
665+
self.llm_model = llm_model
666+
self.temperature = temperature
667+
self.seed = seed
668+
self.max_relation_examples = max_relation_examples
669+
self.exclude_source = exclude_source
670+
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
671+
672+
# Training dataset
673+
self.train_set:List[Tuple[str]] = [(self.idx_to_entity[idx_h],
674+
self.idx_to_relation[idx_r],
675+
self.idx_to_entity[idx_t]) for idx_h,idx_r,idx_t in self.kg.train_set.tolist()]
676+
# Validation dataset
677+
self.val_set:List[Tuple[str]] = [(self.idx_to_entity[idx_h],
678+
self.idx_to_relation[idx_r],
679+
self.idx_to_entity[idx_t]) for idx_h,idx_r,idx_t in self.kg.valid_set.tolist()]
680+
681+
triples = self.train_set + self.val_set if use_val else self.train_set
682+
683+
# Create a mapping from relation to all triples using that relation
684+
self.relation_to_triples = {}
685+
for s, p, o in triples:
686+
if p not in self.relation_to_triples:
687+
self.relation_to_triples[p] = []
688+
self.relation_to_triples[p].append((s, p, o))
689+
690+
self.target_entities = list(sorted(self.entity_to_idx.keys()))
691+
692+
def _create_prompt_based_on_relation(self, source: str, relation: str) -> str:
693+
# Get all triples with the current relation
694+
relation_triples = []
695+
if relation in self.relation_to_triples:
696+
relation_triples = self.relation_to_triples[relation]
697+
698+
# Exclude triples where the source entity is the current one if flag is set
699+
if self.exclude_source:
700+
relation_triples = [triple for triple in relation_triples if triple[0] != source]
701+
702+
# Limit examples if too many
703+
if len(relation_triples) > self.max_relation_examples:
704+
relation_triples = relation_triples[:self.max_relation_examples]
705+
706+
relation_context = "Here are examples of how the relation is used in the knowledge base:\n"
707+
for s, p, o in sorted(relation_triples):
708+
relation_context += f"- {s} {p} {o}\n"
709+
relation_context += "\n"
710+
711+
base_prompt = f"""
712+
I'm trying to predict the most likely target entities for the following query:
713+
Source entity: {source}
714+
Relation: {relation}
715+
Query: ({source}, {relation}, ?)
716+
717+
{relation_context}
718+
719+
Please provide a ranked list of at most {min(len(self.target_entities),15)} likely target entities from the following list, along with likelihoods for each: {self.target_entities}
720+
721+
Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": float_number}}]}}
722+
723+
Notes:
724+
1. Use the provided knowledge about how the relation is used to inform your predictions.
725+
2. Only include entities that are plausible targets for this relation.
726+
3. For geographic entities, consider geographic location, regional classifications, and political associations.
727+
4. Rank the entities by likelihood of being the correct target.
728+
5. ONLY INCLUDE entities from the provided list in your predictions.
729+
6. If certain entities are not suitable for this relation, don't include them.
730+
7. Return a valid JSON output.
731+
8. Make sure scores are floating point numbers between 0 and 1, not strings.
732+
9. A score can only be between 0 and 1, i.e. score ∈ [0, 1]. They can never be negative or greater than 1!
733+
"""
734+
return base_prompt
735+
736+
def forward_triples(self, x: torch.LongTensor):
737+
raise NotImplementedError("RCL needs to implement it")
738+
739+
def forward_k_vs_all(self,x: torch.LongTensor) -> torch.FloatTensor:
740+
batch_output = []
741+
# Iterate over batch of subject and relation pairs
742+
for i in x.tolist():
743+
# index of an entity and index of a relation.
744+
idx_h, idx_r = i
745+
# String representations of an entity and a relation, respectively.
746+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
747+
llm_response = self.client.chat.completions.create(
748+
model=self.llm_model, temperature=self.temperature, seed=self.seed,
749+
messages=[{"role": "user",
750+
"content": "You are a knowledgeable assistant that helps with link prediction tasks.\n" +
751+
self._create_prompt_based_on_relation(source=h, relation=r)}],
752+
extra_body={"guided_json": PredictionResponse.model_json_schema(),
753+
"truncate_prompt_tokens": 30_000,
754+
}).choices[0].message.content
755+
756+
prediction_response = PredictionResponse(**json.loads(llm_response))
757+
# Initialize scores for all entities
758+
scores_for_all_entities = [ -1.0 for _ in range(len(self.idx_to_entity))]
759+
for pred in prediction_response.predictions:
760+
try:
761+
scores_for_all_entities[self.entity_to_idx[pred.entity]]=pred.score
762+
except KeyError:
763+
print(f"For {h},{r}, {pred} not found\tPrediction Size: {len(prediction_response.predictions)}")
764+
continue
765+
batch_output.append(scores_for_all_entities)
766+
return torch.FloatTensor(batch_output)
767+
650768
def sanity_checking(args,kg):
651769
if args.eval_size is not None:
652770
assert len(kg.test_set) >= args.eval_size, (f"Evaluation size cant be greater than the "
@@ -664,6 +782,10 @@ def get_model(args,kg)->AbstractBaseLinkPredictorClass:
664782
elif args.model == "GCL":
665783
model = GCL(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
666784
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed,num_of_hops=args.num_of_hops)
785+
elif args.model == "RCL":
786+
model = RCL(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
787+
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed,
788+
max_relation_examples=args.max_relation_examples, exclude_source=args.exclude_source)
667789
else:
668790
raise KeyError(f"{args.model} is not a valid model")
669791
assert model is not None, f"Couldn't assign a model named: {args.model}"
@@ -689,7 +811,7 @@ def run(args):
689811
if __name__ == "__main__":
690812
parser = argparse.ArgumentParser()
691813
parser.add_argument("--dataset_dir", type=str, default="KGs/Countries-S1", help="Path to dataset.")
692-
parser.add_argument("--model", type=str, default="GCL", help="Model name to use for link prediction.", choices=["RALP",'GCL'])
814+
parser.add_argument("--model", type=str, default="GCL", help="Model name to use for link prediction.", choices=["RALP", "GCL", "RCL"])
693815
parser.add_argument("--base_url", type=str, default="http://harebell.cs.upb.de:8501/v1",
694816
choices=["http://harebell.cs.upb.de:8501/v1", "http://tentris-ml.cs.upb.de:8502/v1"],
695817
help="Base URL for the OpenAI client.")
@@ -706,5 +828,7 @@ def run(args):
706828
parser.add_argument("--batch_size", type=int, default=1)
707829
parser.add_argument("--chunk_size", type=int, default=1)
708830
parser.add_argument("--seed", type=int, default=42)
709-
parser.add_argument("--num_of_hops", type=int, default=1, help="Number of hops to use to extract a subgraph around an entity-")
831+
parser.add_argument("--num_of_hops", type=int, default=1, help="Number of hops to use to extract a subgraph around an entity.")
832+
parser.add_argument("--max_relation_examples", type=int, default=50, help="Maximum number of relation examples to include in RCL context.")
833+
parser.add_argument("--exclude_source", default=True, help="Exclude triples with the same source entity in RCL context.")
710834
run(parser.parse_args())

0 commit comments

Comments
 (0)