Skip to content

Commit 7040de9

Browse files
committed
Solving conflict
2 parents f0e2cae + acc5748 commit 7040de9

1 file changed

Lines changed: 190 additions & 3 deletions

File tree

retrieval_aug_predictors/models.py

Lines changed: 190 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import igraph
1010
from typing import Tuple, Dict
1111
import dspy
12+
from tqdm import tqdm
13+
from dspy.teleprompt import LabeledFewShot
1214
class PredictionItem(BaseModel):
1315
"""Individual prediction item with entity name and confidence score."""
1416
entity: str = Field(..., description="Name of the predicted entity")
@@ -439,9 +441,8 @@ def __init__(self):
439441
def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]]:
440442
example_str = ""
441443
for (s, p), o_list in few_shot_examples.items():
442-
for o in o_list:
443-
example_str += f"({s}, {p}, {o})\n"
444-
example_str+"\n\n"
444+
example_str += f"({s}, {p})\n{', '.join(o_list)}\n---\n"
445+
# @TODO: CD: Also keep track of LLM cost
445446
dspy_pred:dspy.primitives.prediction.Prediction=self.predictor(examples=example_str, subject=subject, predicate=predicate)
446447
return [ (i["entity"],i["score"])for i in json.loads(dspy_pred.objects_with_scores)]
447448

@@ -499,3 +500,189 @@ def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model,
499500

500501
# 4. Instantiate your predictor
501502
self.scoring_func = MultiLabelLinkPredictor()
503+
self.entities:List[str]=list(sorted(self.entity_to_idx.keys()))
504+
505+
class LM_Call_Signature(dspy.Signature):
506+
source: str = dspy.InputField(description="The source entity")
507+
relation: str = dspy.InputField(description="The relation")
508+
target_entities: List[str] = dspy.InputField(description="The list of target entities")
509+
predictions: List[PredictionItem] = dspy.OutputField(description="The list of predicted entities with scores")
510+
511+
class DSPy_RCL(AbstractBaseLinkPredictorClass):
512+
513+
def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: str = None, llm_model: str = None,
514+
temperature: float = 0.0, seed: int = 42, max_relation_examples: int = 2000, use_val: bool = True,
515+
exclude_source: bool = False) -> None:
516+
super().__init__(knowledge_graph, name="DSPy_RCL")
517+
assert base_url is not None and isinstance(base_url, str)
518+
self.base_url = base_url
519+
self.api_key = api_key
520+
self.llm_model = llm_model
521+
self.temperature = temperature
522+
self.seed = seed
523+
self.max_relation_examples = max_relation_examples
524+
self.exclude_source = exclude_source
525+
# hardcoded for now
526+
self.lm = dspy.LM(model="openai/tentris", api_key=self.api_key, base_url=self.base_url)
527+
dspy.configure(lm=self.lm)
528+
self.model = dspy.ChainOfThought(LM_Call_Signature)
529+
530+
# Training dataset
531+
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
532+
self.idx_to_relation[idx_r],
533+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
534+
self.kg.train_set.tolist()]
535+
# Validation dataset
536+
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
537+
self.idx_to_relation[idx_r],
538+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
539+
self.kg.valid_set.tolist()]
540+
541+
triples = self.train_set + self.val_set if use_val else self.train_set
542+
self.triples = triples
543+
544+
# Create a mapping from relation to all triples using that relation
545+
self.relation_to_triples = {}
546+
for s, p, o in triples:
547+
if p not in self.relation_to_triples:
548+
self.relation_to_triples[p] = []
549+
self.relation_to_triples[p].append((s, p, o))
550+
551+
self.target_entities = list(sorted(self.entity_to_idx.keys()))
552+
553+
def metric(self, example, pred, trace=None):
554+
# Calculate MRR
555+
mrr = 0
556+
for i, (h, r, t) in enumerate(example):
557+
# Check if the target entity is in the list of predicted entities
558+
if t in [p.entity for p in pred]:
559+
mrr += 1 / (i + 1)
560+
mrr /= len(example)
561+
return mrr
562+
563+
def generate_examples(self):
564+
"""
565+
Generate DSPy examples for training the model.
566+
567+
Returns:
568+
List[dspy.Example]: A list of DSPy examples for training.
569+
"""
570+
examples = []
571+
572+
# Iterate through each relation
573+
for relation, triples in self.relation_to_triples.items():
574+
# Group triples by head entity
575+
head_to_tails = {}
576+
for s, p, o in triples:
577+
if s not in head_to_tails:
578+
head_to_tails[s] = []
579+
head_to_tails[s].append(o)
580+
581+
# Create examples for each head entity
582+
for source, targets in head_to_tails.items():
583+
# Convert target entities to PredictionItem objects with score 1.0
584+
prediction_items = [PredictionItem(entity=target, score=1.0) for target in targets]
585+
586+
# Create a DSPy example with the input being the head entity and relation
587+
# and the output being all correct tail entities as PredictionItem objects
588+
example = dspy.Example(
589+
source=source,
590+
relation=relation,
591+
target_entities=self.target_entities,
592+
predictions=prediction_items
593+
).with_inputs("source", "relation", "target_entities")
594+
examples.append(example)
595+
596+
return examples
597+
598+
def generate_train_test_split(self, examples, test_size=0.2):
599+
"""
600+
Split the examples into training and testing sets.
601+
602+
Args:
603+
examples (List[dspy.Example]): A list of DSPy examples to split.
604+
test_size (float): The proportion of examples to include in the test set.
605+
606+
Returns:
607+
Tuple[List[dspy.Example], List[dspy.Example]]: A tuple containing the training and testing examples.
608+
"""
609+
import random
610+
random.seed(self.seed)
611+
612+
# Shuffle the examples
613+
shuffled_examples = examples.copy()
614+
random.shuffle(shuffled_examples)
615+
616+
# Calculate the split point
617+
split_idx = int(len(shuffled_examples) * (1 - test_size))
618+
619+
# Split the examples
620+
train_examples = shuffled_examples[:split_idx]
621+
test_examples = shuffled_examples[split_idx:]
622+
623+
return train_examples, test_examples
624+
625+
def manual_evaluation(self, examples):
626+
"""
627+
Manually evaluate the model on a list of examples using the metric method.
628+
629+
Args:
630+
examples (List[dspy.Example]): A list of DSPy examples to evaluate.
631+
632+
Returns:
633+
float: The average metric score across all examples.
634+
"""
635+
total_score = 0.0
636+
for example in tqdm(examples, desc="Evaluating examples", unit="ex", ncols=100, leave=True):
637+
# Extract the input values from the example
638+
source = example.source
639+
relation = example.relation
640+
target_entities = example.target_entities
641+
# Get model predictions
642+
pred = self.model(source=source, relation=relation, target_entities=target_entities)
643+
formatted_example = [(source, relation, item.entity) for item in example.predictions]
644+
score = self.metric(formatted_example, pred.predictions)
645+
total_score += score
646+
# Return the average score
647+
return total_score / len(examples) if examples else 0.0
648+
649+
def train_labeledFewShot(self, train_set, few_shot_k):
650+
lfs_optimizer = LabeledFewShot(k=few_shot_k)
651+
lfs_model = lfs_optimizer.compile(self.model, trainset=train_set)
652+
self.model = lfs_model
653+
lfs_model.save("./lfs_model.json")
654+
return lfs_model
655+
656+
def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
657+
idx_h, idx_r = x.tolist()[0]
658+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
659+
pred = self.model(source=h, relation=r, target_entities=self.target_entities)
660+
return pred.predictions
661+
662+
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
663+
batch_output = []
664+
for i in x.tolist():
665+
idx_h, idx_r = i
666+
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
667+
pred = self.model(source=h, relation=r, target_entities=self.target_entities)
668+
batch_output.append(pred.predictions)
669+
return torch.FloatTensor(batch_output)
670+
671+
def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
672+
raise NotImplementedError("DSPy_RCL needs to implement it")
673+
674+
# test the dspy model -> remove later
675+
if __name__ == "__main__":
676+
kg = KG(dataset_dir="KGs/Countries-S1", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
677+
model = DSPy_RCL(knowledge_graph=kg, base_url="http://harebell.cs.upb.de:8501/v1", api_key=":)")
678+
679+
examples = model.generate_examples()
680+
train_examples, test_examples = model.generate_train_test_split(examples, test_size=0.2)
681+
682+
# Train the model
683+
model.train_labeledFewShot(train_examples, few_shot_k=3)
684+
685+
# eval model
686+
print(model.manual_evaluation(test_examples))
687+
688+

0 commit comments

Comments
 (0)