Skip to content

Commit 5bd8920

Browse files
committed
models added
1 parent 44a4861 commit 5bd8920

6 files changed

Lines changed: 866 additions & 0 deletions

File tree

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import dspy
2+
import torch
3+
import json
4+
from typing import List, Tuple
5+
from retrieval_aug_predictors.models import KG, AbstractBaseLinkPredictorClass
6+
from openai import OpenAI
7+
8+
# 1. Define the Signature
9+
class KGLikelihood(dspy.Signature):
10+
"""Assess the likelihood that a triple (subject, predicate, candidate_object) is true,
11+
given some context triples. Output a score between 0.0 and 1.0."""
12+
13+
context = dspy.InputField(desc="Known knowledge graph triples.")
14+
subject = dspy.InputField(desc="The subject entity.")
15+
predicate = dspy.InputField(desc="The relationship type.")
16+
candidate_object = dspy.InputField(desc="The candidate object entity to score.")
17+
18+
score = dspy.OutputField(desc="A likelihood score between 0.0 and 1.0.")
19+
20+
21+
class MultiLabelLinkPredictionWithScores(dspy.Signature):
22+
"""Given a subject entity and a predicate (relation), predict a list of
23+
object entities that satisfy the relation, along with a likelihood score for each.
24+
Use the provided examples as a guide.
25+
Output a JSON formatted list of objects, where each object has an 'entity' (string)
26+
and a 'score' (float between 0.0 and 1.0) key."""
27+
28+
examples = dspy.InputField(
29+
desc="Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...].")
30+
subject = dspy.InputField(desc="The subject entity.")
31+
predicate = dspy.InputField(desc="The relationship type.")
32+
33+
# Updated OutputField requesting JSON
34+
objects_with_scores = dspy.OutputField(
35+
desc="A JSON string representing a list of objects. "
36+
"Each object in the list should be a dictionary with 'entity' (string) and 'score' (float, 0.0-1.0) keys.")
37+
38+
class MultiLabelLinkPredictor(dspy.Module):
39+
def __init__(self):
40+
super().__init__()
41+
self.predictor = dspy.Predict(MultiLabelLinkPredictionWithScores)
42+
def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]]:
43+
example_str = ""
44+
for (s, p), o_list in few_shot_examples.items():
45+
example_str += f"({s}, {p})\n{', '.join(o_list)}\n---\n"
46+
# @TODO: CD: Also keep track of LLM cost
47+
dspy_pred:dspy.primitives.prediction.Prediction=self.predictor(examples=example_str, subject=subject, predicate=predicate)
48+
return [ (i["entity"],i["score"])for i in json.loads(dspy_pred.objects_with_scores)]
49+
50+
class Demir(AbstractBaseLinkPredictorClass):
51+
def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
52+
raise NotImplementedError("RCL needs to implement it")
53+
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
54+
batch_predictions=[]
55+
for hr in x.tolist():
56+
idx_h, idx_r = hr
57+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
58+
predictions = self.scoring_func.forward(
59+
subject=h,
60+
predicate=r,
61+
few_shot_examples=self.entity_relation_to_entities)
62+
scores=[-100]*len(self.idx_to_entity)
63+
for entity,score in predictions:
64+
try:
65+
idx_entity=self.entity_to_idx[entity]
66+
except KeyError:
67+
print(f"Entity:{entity} not found")
68+
continue
69+
scores[idx_entity]=score
70+
batch_predictions.append(scores)
71+
return torch.FloatTensor(batch_predictions)
72+
73+
def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model,use_val:bool=False):
74+
super().__init__(knowledge_graph,name="Demir")
75+
self.client = OpenAI(base_url=base_url, api_key=api_key)
76+
self.temperature = temperature
77+
self.seed = seed
78+
79+
self.lm = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
80+
api_base=base_url,
81+
seed=seed,
82+
temperature=temperature,
83+
cache=True,cache_in_memory=True,
84+
kwargs={"extra_body":{"truncate_prompt_tokens": 32_000}})
85+
dspy.configure(lm=self.lm)
86+
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
87+
self.idx_to_relation[idx_r],
88+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
89+
self.kg.train_set.tolist()]
90+
# Validation dataset
91+
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
92+
self.idx_to_relation[idx_r],
93+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
94+
self.kg.valid_set.tolist()]
95+
self.triples = self.train_set + self.val_set if use_val else self.train_set
96+
97+
self.entity_relation_to_entities=dict()
98+
from collections import OrderedDict
99+
for s,p,o in self.triples:
100+
self.entity_relation_to_entities.setdefault((s,p),[]).append(o)
101+
102+
# 4. Instantiate your predictor
103+
self.scoring_func = MultiLabelLinkPredictor()
104+
self.entities:List[str]=list(sorted(self.entity_to_idx.keys()))
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import igraph
2+
import torch
3+
import json
4+
from typing import List, Tuple
5+
from dicee.knowledge_graph import KG
6+
from ..abstract import AbstractBaseLinkPredictorClass
7+
from ..schemas import PredictionResponse
8+
from openai import OpenAI
9+
10+
class GCL(AbstractBaseLinkPredictorClass):
11+
""" in-context Learning on neighbouring triples to predict missing entities.
12+
13+
(h, r, t) \in G_test
14+
15+
1. Get all nodes that are n=3 hop around h.
16+
2. Get all triples from G_train involving (1).
17+
3. Generate a prompt based on (2) and (h,r) to assign scores for all e \in E.
18+
19+
@TODO:CD: We should write a regression test on the Countries S1 dataset.
20+
@TODO:CD: We should ensure that the input tokens do not exceed the allowed limit.
21+
22+
"""
23+
24+
def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: str = None, llm_model: str = None,
25+
temperature: float = 0.0, seed: int = 42, num_of_hops: int = 3, use_val: bool = True) -> None:
26+
super().__init__(knowledge_graph, name="GCL")
27+
# @TODO: CD: input arguments should be passed onto the abstract class
28+
assert base_url is not None and isinstance(base_url, str)
29+
self.base_url = base_url
30+
self.api_key = api_key
31+
self.llm_model = llm_model
32+
self.temperature = temperature
33+
self.seed = seed
34+
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
35+
# Training dataset
36+
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
37+
self.idx_to_relation[idx_r],
38+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
39+
self.kg.train_set.tolist()]
40+
# Validation dataset
41+
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
42+
self.idx_to_relation[idx_r],
43+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
44+
self.kg.valid_set.tolist()]
45+
46+
triples = self.train_set + self.val_set if use_val else self.train_set
47+
self.igraph = self.build_igraph(triples)
48+
self.str_entity_to_igraph_vertice = {i["name"]: i for i in self.igraph.vs}
49+
self.str_rel_to_igraph_edges = {i["label"]: i for i in self.igraph.es}
50+
51+
# Mapping from an entity to relevant triples.
52+
# A relevant triple contains a entity that is num_of_hops around of a given entity
53+
self.node_to_relevant_triples = dict()
54+
for entity, entity_node_object in self.str_entity_to_igraph_vertice.items():
55+
neighboring_nodes = self.igraph.neighborhood(entity_node_object, order=num_of_hops)
56+
subgraph = self.igraph.subgraph(neighboring_nodes)
57+
self.node_to_relevant_triples[entity] = {
58+
(subgraph.vs[edge.source]["name"], edge["label"], subgraph.vs[edge.target]["name"]) for edge in
59+
subgraph.es}
60+
61+
self.target_entities = list(sorted(self.entity_to_idx.keys()))
62+
63+
def _create_prompt_based_on_neighbours(self, source: str, relation: str) -> str:
64+
# Get relevant triples for the source entity
65+
relevant_triples = []
66+
if source in self.node_to_relevant_triples:
67+
relevant_triples = list(self.node_to_relevant_triples[source])
68+
69+
assert len(relevant_triples) > 0
70+
# @TODO:CD:Potential improvement by trade offing the test runtime:
71+
# @TODO: Finding an some triples from relevant_triples while the prediction is being invariant to it
72+
# @TODO: Prediction does not change but the input size decreases
73+
# @TODO: The removed triples can be seen as noise
74+
triples_context = "Here are some known facts about the source entity that might be relevant:\n"
75+
for s, p, o in sorted(relevant_triples):
76+
triples_context += f"- {s} {p} {o}\n"
77+
triples_context += "\n"
78+
79+
# Important: Grouping relations is important to reach MRR 1.0
80+
similar_relations = []
81+
for s, p, o in relevant_triples:
82+
if p == relation and s != source:
83+
similar_relations.append((s, p, o))
84+
85+
similar_relations_context = "Here are examples of similar relations in the knowledge base:\n"
86+
for s, p, o in similar_relations:
87+
similar_relations_context += f"- {s} {p} {o}\n"
88+
similar_relations_context += "\n"
89+
90+
base_prompt = f"""
91+
I'm trying to predict the most likely target entities for the following query:
92+
Source entity: {source}
93+
Relation: {relation}
94+
Query: ({source}, {relation}, ?)
95+
96+
Subgraph Graph:
97+
{triples_context}
98+
{similar_relations_context}
99+
100+
Please provide a ranked list of at most {min(len(self.target_entities), 15)} likely target entities from the following list, along with likelihoods for each: {self.target_entities}
101+
102+
Provide your answer in the following JSON format: {{"predictions": [{{"entity": "entity_name", "score": float_number}}]}}
103+
104+
Notes:
105+
1. Use the provided knowledge about the source entity and similar relations to inform your predictions.
106+
1. Only include entities that are plausible targets for this relation.
107+
2. For geographic entities, consider geographic location, regional classifications, and political associations.
108+
3. Rank the entities by likelihood of being the correct target.
109+
4. ONLY INCLUDE entities from the provided list in your predictions.
110+
5. If certain entities are not suitable for this relation, don't include them.
111+
6. Return a valid JSON output.
112+
7. Make sure scores are floating point numbers between 0 and 1, not strings.
113+
8. A score can only be between 0 and 1, i.e. score ∈ [0, 1]. They can never be negative or greater than 1!
114+
"""
115+
return base_prompt
116+
117+
@staticmethod
118+
def build_igraph(graph: List[Tuple[str, str, str]]):
119+
ig_graph = igraph.Graph(directed=True)
120+
# Extract unique vertices from all quadruples
121+
vertices = set()
122+
edges = []
123+
labels = []
124+
for s, p, o in graph:
125+
vertices.add(s)
126+
vertices.add(o)
127+
# ORDER MATTERS!
128+
edges.append((s, o))
129+
labels.append(p)
130+
131+
# Add all unique vertices at once
132+
ig_graph.add_vertices(list(vertices))
133+
# Add edges with labels
134+
ig_graph.add_edges(edges)
135+
ig_graph.es["label"] = labels
136+
# Validate edge count
137+
assert len(edges) == len(ig_graph.es), "Edge mismatch after graph construction!"
138+
extracted_triples = [(ig_graph.vs[edge.source]["name"], edge["label"], ig_graph.vs[edge.target]["name"]) for
139+
edge in ig_graph.es]
140+
# Not only the number but even the order must match
141+
assert extracted_triples == [triple for triple in graph]
142+
return ig_graph
143+
144+
def forward_triples(self, x: torch.LongTensor):
145+
raise NotImplementedError("GraphContextLearner needs to implement it")
146+
147+
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
148+
batch_output = []
149+
# Iterate over batch of subject and relation pairs
150+
for i in x.tolist():
151+
# index of an entity and index of a relation.
152+
idx_h, idx_r = i
153+
# String representations of an entity and a relation, respectively.
154+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
155+
llm_response = self.client.chat.completions.create(
156+
model=self.llm_model, temperature=self.temperature, seed=self.seed,
157+
messages=[{"role": "user",
158+
"content": "You are a knowledgeable assistant that helps with link prediction tasks.\n" +
159+
self._create_prompt_based_on_neighbours(source=h, relation=r)}],
160+
extra_body={"guided_json": PredictionResponse.model_json_schema(),
161+
"truncate_prompt_tokens": 30_000,
162+
}).choices[0].message.content
163+
164+
prediction_response = PredictionResponse(**json.loads(llm_response))
165+
# Initialize scores for all entities
166+
scores_for_all_entities = [-1.0 for _ in range(len(self.idx_to_entity))]
167+
for pred in prediction_response.predictions:
168+
try:
169+
scores_for_all_entities[self.entity_to_idx[pred.entity]] = pred.score
170+
except KeyError:
171+
print(f"For {h},{r}, {pred} not found\tPrediction Size: {len(prediction_response.predictions)}")
172+
continue
173+
batch_output.append(scores_for_all_entities)
174+
return torch.FloatTensor(batch_output)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import re
2+
import torch
3+
from openai import OpenAI
4+
from retrieval_aug_predictors.models import KG, AbstractBaseLinkPredictorClass
5+
6+
7+
class RALP(AbstractBaseLinkPredictorClass):
8+
def __init__(self, knowledge_graph: KG = None,
9+
name="ralp-1.0",
10+
base_url="http://tentris-ml.cs.upb.de:8501/v1",
11+
api_key=None,
12+
llm_model="tentris",
13+
temperature: float = 1, seed: int = 42) -> None:
14+
super().__init__(knowledge_graph, name)
15+
self.client = OpenAI(base_url=base_url, api_key=api_key)
16+
self.llm_model = llm_model
17+
self.temperature = temperature
18+
self.seed = seed
19+
20+
def extract_float(self, text):
21+
"""Extract the float number from a string. Used mainly to filter the LLM-output for the scoring task."""
22+
pattern = r"-?\d*\.\d+|-?\d+\.\d*"
23+
match = re.search(pattern, text)
24+
return float(match.group()) if match else 0.0
25+
26+
def ru(self, entity):
27+
"""Remove underscore from the entity (as str)."""
28+
return entity.replace("_", " ")
29+
30+
def get_score(self, triple: tuple, triples_h: str) -> float:
31+
system_prompt = """You are an expert in knowledge graphs and link prediction. Your task is to assign a plausibility score (from 0 to 1) to a given triple (subject, predicate, object) based on a set of known training triples for the same subject.
32+
33+
- A score of 1.0 means the triple is highly likely to be true.
34+
- A score of 0.0 means the triple is highly unlikely to be true.
35+
- Intermediate values (e.g., 0.4, 0.7) reflect varying levels of plausibility.
36+
37+
38+
**Guidelines for scoring:**
39+
1. **Exact Match:** If the triple already exists in the training set or if the facts clearly state that the triple must be true assign a score close to 1.0.
40+
2. **Pattern Matching:** If the predicate-object pair frequently occurs for the given subject, assign a high score.
41+
3. **Semantic Similarity:** If the object is semantically close to known objects for the subject-predicate pair, assign a moderate to high score.
42+
4. **Rare or Unseen Combinations:** If the triple does not follow the learned patterns, assign a low score.
43+
5. **Contradictions:** If the triple contradicts existing facts (perform your own reasoning), assign a very low score.
44+
45+
You must analyze the given triple and the training triples, apply the reasoning above, and output only a single **floating-point score** between **0.0 and 1.0**, without any explanation or additional text.
46+
Do not depend only on triples provided to you, also use your own knowledge as an AI assistant to reason about the truthness of the given triple as a fact.
47+
You are strictly required to provide only the score as an answer and do not explain it."""
48+
49+
user_prompt = f"""Here is the triple we want to evaluate:
50+
(subject: {triple[0]}, predicate: {triple[1]}, object: {triple[2]})
51+
52+
Here are the known training triples for the subject "{triple[0]}":
53+
{triples_h}
54+
55+
Assign a score to the given triple based on the provided training triples.
56+
"""
57+
response = self.client.chat.completions.create(
58+
model=self.llm_model,
59+
messages=[
60+
{"role": "system", "content": system_prompt},
61+
{"role": "user", "content": user_prompt},
62+
],
63+
seed=42,
64+
temperature=self.temperature
65+
)
66+
67+
# Extract the response content
68+
content = response.choices[0].message.content
69+
return self.extract_float(content)
70+
71+
def forward_k_vs_all(self, x):
72+
raise NotImplementedError("RALP needs to implement it")
73+
74+
def forward_triples(self, indexed_triples: torch.LongTensor):
75+
n, d = indexed_triples.shape
76+
# For the time being
77+
assert d == 3
78+
assert n == 1
79+
scores = []
80+
for triple in indexed_triples.tolist():
81+
idx_h, idx_r, idx_t = triple
82+
h, r, t = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r], self.idx_to_entity[idx_t]
83+
84+
# Retrieve triples where 'h' is a subject or an object
85+
triples_h = [trp for trp in self.kg.train_set if (trp[0] == idx_h or trp[2] == idx_h)]
86+
87+
# Format the triples into structured string output that will be used in the prompt.
88+
triples_h_str = ""
89+
for trp in triples_h:
90+
triples_h_str += f'- ("{self.ru(self.idx_to_entity[trp[0]])}", "{self.ru(self.idx_to_relation[trp[1]])}", "{self.ru(self.idx_to_entity[trp[2]])}") \n'
91+
92+
# Get the score from the LLM
93+
score = self.get_score((h, r, t), triples_h_str)
94+
scores.append([score])
95+
return torch.FloatTensor(scores)

0 commit comments

Comments
 (0)