|
| 1 | +import igraph |
| 2 | +import torch |
| 3 | +import json |
| 4 | +from typing import List, Tuple |
| 5 | +from dicee.knowledge_graph import KG |
| 6 | +from ..abstract import AbstractBaseLinkPredictorClass |
| 7 | +from ..schemas import PredictionResponse |
| 8 | +from openai import OpenAI |
| 9 | + |
| 10 | +class GCL(AbstractBaseLinkPredictorClass): |
| 11 | + """ in-context Learning on neighbouring triples to predict missing entities. |
| 12 | +
|
| 13 | + (h, r, t) \in G_test |
| 14 | +
|
| 15 | + 1. Get all nodes that are n=3 hop around h. |
| 16 | + 2. Get all triples from G_train involving (1). |
| 17 | + 3. Generate a prompt based on (2) and (h,r) to assign scores for all e \in E. |
| 18 | +
|
| 19 | + @TODO:CD: We should write a regression test on the Countries S1 dataset. |
| 20 | + @TODO:CD: We should ensure that the input tokens do not exceed the allowed limit. |
| 21 | +
|
| 22 | + """ |
| 23 | + |
| 24 | + def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: str = None, llm_model: str = None, |
| 25 | + temperature: float = 0.0, seed: int = 42, num_of_hops: int = 3, use_val: bool = True) -> None: |
| 26 | + super().__init__(knowledge_graph, name="GCL") |
| 27 | + # @TODO: CD: input arguments should be passed onto the abstract class |
| 28 | + assert base_url is not None and isinstance(base_url, str) |
| 29 | + self.base_url = base_url |
| 30 | + self.api_key = api_key |
| 31 | + self.llm_model = llm_model |
| 32 | + self.temperature = temperature |
| 33 | + self.seed = seed |
| 34 | + self.client = OpenAI(base_url=self.base_url, api_key=self.api_key) |
| 35 | + # Training dataset |
| 36 | + self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h], |
| 37 | + self.idx_to_relation[idx_r], |
| 38 | + self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in |
| 39 | + self.kg.train_set.tolist()] |
| 40 | + # Validation dataset |
| 41 | + self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h], |
| 42 | + self.idx_to_relation[idx_r], |
| 43 | + self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in |
| 44 | + self.kg.valid_set.tolist()] |
| 45 | + |
| 46 | + triples = self.train_set + self.val_set if use_val else self.train_set |
| 47 | + self.igraph = self.build_igraph(triples) |
| 48 | + self.str_entity_to_igraph_vertice = {i["name"]: i for i in self.igraph.vs} |
| 49 | + self.str_rel_to_igraph_edges = {i["label"]: i for i in self.igraph.es} |
| 50 | + |
| 51 | + # Mapping from an entity to relevant triples. |
| 52 | + # A relevant triple contains a entity that is num_of_hops around of a given entity |
| 53 | + self.node_to_relevant_triples = dict() |
| 54 | + for entity, entity_node_object in self.str_entity_to_igraph_vertice.items(): |
| 55 | + neighboring_nodes = self.igraph.neighborhood(entity_node_object, order=num_of_hops) |
| 56 | + subgraph = self.igraph.subgraph(neighboring_nodes) |
| 57 | + self.node_to_relevant_triples[entity] = { |
| 58 | + (subgraph.vs[edge.source]["name"], edge["label"], subgraph.vs[edge.target]["name"]) for edge in |
| 59 | + subgraph.es} |
| 60 | + |
| 61 | + self.target_entities = list(sorted(self.entity_to_idx.keys())) |
| 62 | + |
| 63 | + def _create_prompt_based_on_neighbours(self, source: str, relation: str) -> str: |
| 64 | + # Get relevant triples for the source entity |
| 65 | + relevant_triples = [] |
| 66 | + if source in self.node_to_relevant_triples: |
| 67 | + relevant_triples = list(self.node_to_relevant_triples[source]) |
| 68 | + |
| 69 | + assert len(relevant_triples) > 0 |
| 70 | + # @TODO:CD:Potential improvement by trade offing the test runtime: |
| 71 | + # @TODO: Finding an some triples from relevant_triples while the prediction is being invariant to it |
| 72 | + # @TODO: Prediction does not change but the input size decreases |
| 73 | + # @TODO: The removed triples can be seen as noise |
| 74 | + triples_context = "Here are some known facts about the source entity that might be relevant:\n" |
| 75 | + for s, p, o in sorted(relevant_triples): |
| 76 | + triples_context += f"- {s} {p} {o}\n" |
| 77 | + triples_context += "\n" |
| 78 | + |
| 79 | + # Important: Grouping relations is important to reach MRR 1.0 |
| 80 | + similar_relations = [] |
| 81 | + for s, p, o in relevant_triples: |
| 82 | + if p == relation and s != source: |
| 83 | + similar_relations.append((s, p, o)) |
| 84 | + |
| 85 | + similar_relations_context = "Here are examples of similar relations in the knowledge base:\n" |
| 86 | + for s, p, o in similar_relations: |
| 87 | + similar_relations_context += f"- {s} {p} {o}\n" |
| 88 | + similar_relations_context += "\n" |
| 89 | + |
| 90 | + base_prompt = f""" |
| 91 | + I'm trying to predict the most likely target entities for the following query: |
| 92 | + Source entity: {source} |
| 93 | + Relation: {relation} |
| 94 | + Query: ({source}, {relation}, ?) |
| 95 | +
|
| 96 | + Subgraph Graph: |
| 97 | + {triples_context} |
| 98 | + {similar_relations_context} |
| 99 | +
|
| 100 | + Please provide a ranked list of at most {min(len(self.target_entities), 15)} likely target entities from the following list, along with likelihoods for each: {self.target_entities} |
| 101 | +
|
| 102 | + Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": float_number}}]}} |
| 103 | +
|
| 104 | + Notes: |
| 105 | + 1. Use the provided knowledge about the source entity and similar relations to inform your predictions. |
| 106 | + 1. Only include entities that are plausible targets for this relation. |
| 107 | + 2. For geographic entities, consider geographic location, regional classifications, and political associations. |
| 108 | + 3. Rank the entities by likelihood of being the correct target. |
| 109 | + 4. ONLY INCLUDE entities from the provided list in your predictions. |
| 110 | + 5. If certain entities are not suitable for this relation, don't include them. |
| 111 | + 6. Return a valid JSON output. |
| 112 | + 7. Make sure scores are floating point numbers between 0 and 1, not strings. |
| 113 | + 8. A score can only be between 0 and 1, i.e. score ∈ [0, 1]. They can never be negative or greater than 1! |
| 114 | + """ |
| 115 | + return base_prompt |
| 116 | + |
| 117 | + @staticmethod |
| 118 | + def build_igraph(graph: List[Tuple[str, str, str]]): |
| 119 | + ig_graph = igraph.Graph(directed=True) |
| 120 | + # Extract unique vertices from all quadruples |
| 121 | + vertices = set() |
| 122 | + edges = [] |
| 123 | + labels = [] |
| 124 | + for s, p, o in graph: |
| 125 | + vertices.add(s) |
| 126 | + vertices.add(o) |
| 127 | + # ORDER MATTERS! |
| 128 | + edges.append((s, o)) |
| 129 | + labels.append(p) |
| 130 | + |
| 131 | + # Add all unique vertices at once |
| 132 | + ig_graph.add_vertices(list(vertices)) |
| 133 | + # Add edges with labels |
| 134 | + ig_graph.add_edges(edges) |
| 135 | + ig_graph.es["label"] = labels |
| 136 | + # Validate edge count |
| 137 | + assert len(edges) == len(ig_graph.es), "Edge mismatch after graph construction!" |
| 138 | + extracted_triples = [(ig_graph.vs[edge.source]["name"], edge["label"], ig_graph.vs[edge.target]["name"]) for |
| 139 | + edge in ig_graph.es] |
| 140 | + # Not only the number but even the order must match |
| 141 | + assert extracted_triples == [triple for triple in graph] |
| 142 | + return ig_graph |
| 143 | + |
| 144 | + def forward_triples(self, x: torch.LongTensor): |
| 145 | + raise NotImplementedError("GraphContextLearner needs to implement it") |
| 146 | + |
| 147 | + def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor: |
| 148 | + batch_output = [] |
| 149 | + # Iterate over batch of subject and relation pairs |
| 150 | + for i in x.tolist(): |
| 151 | + # index of an entity and index of a relation. |
| 152 | + idx_h, idx_r = i |
| 153 | + # String representations of an entity and a relation, respectively. |
| 154 | + h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r] |
| 155 | + llm_response = self.client.chat.completions.create( |
| 156 | + model=self.llm_model, temperature=self.temperature, seed=self.seed, |
| 157 | + messages=[{"role": "user", |
| 158 | + "content": "You are a knowledgeable assistant that helps with link prediction tasks.\n" + |
| 159 | + self._create_prompt_based_on_neighbours(source=h, relation=r)}], |
| 160 | + extra_body={"guided_json": PredictionResponse.model_json_schema(), |
| 161 | + "truncate_prompt_tokens": 30_000, |
| 162 | + }).choices[0].message.content |
| 163 | + |
| 164 | + prediction_response = PredictionResponse(**json.loads(llm_response)) |
| 165 | + # Initialize scores for all entities |
| 166 | + scores_for_all_entities = [-1.0 for _ in range(len(self.idx_to_entity))] |
| 167 | + for pred in prediction_response.predictions: |
| 168 | + try: |
| 169 | + scores_for_all_entities[self.entity_to_idx[pred.entity]] = pred.score |
| 170 | + except KeyError: |
| 171 | + print(f"For {h},{r}, {pred} not found\tPrediction Size: {len(prediction_response.predictions)}") |
| 172 | + continue |
| 173 | + batch_output.append(scores_for_all_entities) |
| 174 | + return torch.FloatTensor(batch_output) |
0 commit comments