Skip to content

Commit 1ba5cd8

Browse files
committed
Refactored
1 parent 2bcaa0c commit 1ba5cd8

3 files changed

Lines changed: 108 additions & 7 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .abstract import AbstractBaseLinkPredictorClass
2-
from .models import RCL,RALP, GCL
2+
from .models import RCL,RALP, GCL, Demir

retrieval_aug_predictors/models.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import json
88
import re
99
import igraph
10+
from typing import Tuple, Dict
11+
import dspy
1012
class PredictionItem(BaseModel):
1113
"""Individual prediction item with entity name and confidence score."""
1214
entity: str = Field(..., description="Name of the predicted entity")
@@ -190,12 +192,9 @@ def __init__(self, knowledge_graph: KG = None,
190192
llm_model="tentris",
191193
temperature: float = 1, seed: int = 42) -> None:
192194
super().__init__(knowledge_graph, name)
193-
# @TODO: CD: input arguments should be passed onto the abstract class
194-
195195
self.client = OpenAI(base_url=base_url, api_key=api_key)
196196
self.llm_model = llm_model
197197
self.temperature = temperature
198-
# @TODO:CD: Use the seed
199198
self.seed = seed
200199

201200
def extract_float(self, text):
@@ -401,4 +400,102 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
401400
print(f"For {h},{r}, {pred} not found\tPrediction Size: {len(prediction_response.predictions)}")
402401
continue
403402
batch_output.append(scores_for_all_entities)
404-
return torch.FloatTensor(batch_output)
403+
return torch.FloatTensor(batch_output)
404+
405+
# 1. Define the Signature
406+
class KGLikelihood(dspy.Signature):
407+
"""Assess the likelihood that a triple (subject, predicate, candidate_object) is true,
408+
given some context triples. Output a score between 0.0 and 1.0."""
409+
410+
context = dspy.InputField(desc="Known knowledge graph triples.")
411+
subject = dspy.InputField(desc="The subject entity.")
412+
predicate = dspy.InputField(desc="The relationship type.")
413+
candidate_object = dspy.InputField(desc="The candidate object entity to score.")
414+
415+
score = dspy.OutputField(desc="A likelihood score between 0.0 and 1.0.")
416+
417+
418+
class MultiLabelLinkPredictionWithScores(dspy.Signature):
419+
"""Given a subject entity and a predicate (relation), predict a list of
420+
object entities that satisfy the relation, along with a likelihood score for each.
421+
Use the provided examples as a guide.
422+
Output a JSON formatted list of objects, where each object has an 'entity' (string)
423+
and a 'score' (float between 0.0 and 1.0) key."""
424+
425+
examples = dspy.InputField(
426+
desc="Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...].")
427+
subject = dspy.InputField(desc="The subject entity.")
428+
predicate = dspy.InputField(desc="The relationship type.")
429+
430+
# Updated OutputField requesting JSON
431+
objects_with_scores = dspy.OutputField(
432+
desc="A JSON string representing a list of objects. "
433+
"Each object in the list should be a dictionary with 'entity' (string) and 'score' (float, 0.0-1.0) keys.")
434+
435+
class MultiLabelLinkPredictor(dspy.Module):
436+
def __init__(self):
437+
super().__init__()
438+
self.predictor = dspy.Predict(MultiLabelLinkPredictionWithScores)
439+
def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]]:
440+
example_str = ""
441+
for (s, p), o_list in few_shot_examples.items():
442+
example_str += f"({s}, {p})\n{', '.join(o_list)}\n---\n"
443+
# @TODO: CD: Also keep track of LLM cost
444+
dspy_pred:dspy.primitives.prediction.Prediction=self.predictor(examples=example_str, subject=subject, predicate=predicate)
445+
return [ (i["entity"],i["score"])for i in json.loads(dspy_pred.objects_with_scores)]
446+
447+
class Demir(AbstractBaseLinkPredictorClass):
448+
def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
449+
raise NotImplementedError("RCL needs to implement it")
450+
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
451+
batch_predictions=[]
452+
for hr in x.tolist():
453+
idx_h, idx_r = hr
454+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
455+
predictions = self.scoring_func.forward(
456+
subject=h,
457+
predicate=r,
458+
few_shot_examples=self.entity_relation_to_entities)
459+
scores=[-100]*len(self.idx_to_entity)
460+
for entity,score in predictions:
461+
try:
462+
idx_entity=self.entity_to_idx[entity]
463+
except KeyError:
464+
print(f"Entity:{entity} not found")
465+
continue
466+
scores[idx_entity]=score
467+
batch_predictions.append(scores)
468+
return torch.FloatTensor(batch_predictions)
469+
470+
def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model,use_val:bool=False):
471+
super().__init__(knowledge_graph,name="Demir")
472+
self.client = OpenAI(base_url=base_url, api_key=api_key)
473+
self.temperature = temperature
474+
self.seed = seed
475+
476+
self.lm = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
477+
api_base=base_url,
478+
seed=seed,
479+
temperature=temperature,
480+
cache=True,cache_in_memory=True,
481+
kwargs={"extra_body":{"truncate_prompt_tokens": 32_000}})
482+
dspy.configure(lm=self.lm)
483+
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
484+
self.idx_to_relation[idx_r],
485+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
486+
self.kg.train_set.tolist()]
487+
# Validation dataset
488+
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
489+
self.idx_to_relation[idx_r],
490+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
491+
self.kg.valid_set.tolist()]
492+
self.triples = self.train_set + self.val_set if use_val else self.train_set
493+
494+
self.entity_relation_to_entities=dict()
495+
from collections import OrderedDict
496+
for s,p,o in self.triples:
497+
self.entity_relation_to_entities.setdefault((s,p),[]).append(o)
498+
499+
# 4. Instantiate your predictor
500+
self.scoring_func = MultiLabelLinkPredictor()
501+
self.entities:List[str]=list(sorted(self.entity_to_idx.keys()))

retrieval_augmented_link_predictor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import numpy as np
3838
from typing import List, Optional
3939
from dotenv import load_dotenv
40-
from retrieval_aug_predictors import AbstractBaseLinkPredictorClass, RALP, GCL, RCL
40+
from retrieval_aug_predictors import AbstractBaseLinkPredictorClass, RALP, GCL, RCL, Demir
4141

4242
load_dotenv()
4343

@@ -63,6 +63,10 @@ def get_model(args,kg)->AbstractBaseLinkPredictorClass:
6363
model = RCL(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
6464
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed,
6565
max_relation_examples=args.max_relation_examples, exclude_source=args.exclude_source)
66+
elif args.model == "Demir":
67+
model = Demir(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
68+
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
69+
6670
else:
6771
raise KeyError(f"{args.model} is not a valid model")
6872
assert model is not None, f"Couldn't assign a model named: {args.model}"
@@ -85,7 +89,7 @@ def run(args):
8589
if __name__ == "__main__":
8690
parser = argparse.ArgumentParser()
8791
parser.add_argument("--dataset_dir", type=str, default="KGs/Countries-S1", help="Path to dataset.")
88-
parser.add_argument("--model", type=str, default="GCL", help="Model name to use for link prediction.", choices=["RALP", "GCL", "RCL"])
92+
parser.add_argument("--model", type=str, default="Demir", help="Model name to use for link prediction.", choices=["Demir", "GCL", "RCL","RALP"])
8993
parser.add_argument("--base_url", type=str, default="http://harebell.cs.upb.de:8501/v1",
9094
choices=["http://harebell.cs.upb.de:8501/v1", "http://tentris-ml.cs.upb.de:8502/v1"],
9195
help="Base URL for the OpenAI client.")

0 commit comments

Comments
 (0)