Skip to content

Commit acc5748

Browse files
committed
initial simple labelled few shot implementation
1 parent ef47705 commit acc5748

1 file changed

Lines changed: 109 additions & 5 deletions

File tree

retrieval_aug_predictors/models.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import igraph
1010
from typing import Tuple, Dict
1111
import dspy
12+
from tqdm import tqdm
13+
from dspy.teleprompt import LabeledFewShot
1214
class PredictionItem(BaseModel):
1315
"""Individual prediction item with entity name and confidence score."""
1416
entity: str = Field(..., description="Name of the predicted entity")
@@ -537,6 +539,7 @@ def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: st
537539
self.kg.valid_set.tolist()]
538540

539541
triples = self.train_set + self.val_set if use_val else self.train_set
542+
self.triples = triples
540543

541544
# Create a mapping from relation to all triples using that relation
542545
self.relation_to_triples = {}
@@ -551,16 +554,109 @@ def metric(self, example, pred, trace=None):
551554
# Calculate MRR
552555
mrr = 0
553556
for i, (h, r, t) in enumerate(example):
554-
if t in pred:
557+
# Check if the target entity is in the list of predicted entities
558+
if t in [p.entity for p in pred]:
555559
mrr += 1 / (i + 1)
556560
mrr /= len(example)
557561
return mrr
562+
563+
def generate_examples(self):
564+
"""
565+
Generate DSPy examples for training the model.
566+
567+
Returns:
568+
List[dspy.Example]: A list of DSPy examples for training.
569+
"""
570+
examples = []
571+
572+
# Iterate through each relation
573+
for relation, triples in self.relation_to_triples.items():
574+
# Group triples by head entity
575+
head_to_tails = {}
576+
for s, p, o in triples:
577+
if s not in head_to_tails:
578+
head_to_tails[s] = []
579+
head_to_tails[s].append(o)
580+
581+
# Create examples for each head entity
582+
for source, targets in head_to_tails.items():
583+
# Convert target entities to PredictionItem objects with score 1.0
584+
prediction_items = [PredictionItem(entity=target, score=1.0) for target in targets]
585+
586+
# Create a DSPy example with the input being the head entity and relation
587+
# and the output being all correct tail entities as PredictionItem objects
588+
example = dspy.Example(
589+
source=source,
590+
relation=relation,
591+
target_entities=self.target_entities,
592+
predictions=prediction_items
593+
).with_inputs("source", "relation", "target_entities")
594+
examples.append(example)
595+
596+
return examples
597+
598+
def generate_train_test_split(self, examples, test_size=0.2):
599+
"""
600+
Split the examples into training and testing sets.
601+
602+
Args:
603+
examples (List[dspy.Example]): A list of DSPy examples to split.
604+
test_size (float): The proportion of examples to include in the test set.
605+
606+
Returns:
607+
Tuple[List[dspy.Example], List[dspy.Example]]: A tuple containing the training and testing examples.
608+
"""
609+
import random
610+
random.seed(self.seed)
611+
612+
# Shuffle the examples
613+
shuffled_examples = examples.copy()
614+
random.shuffle(shuffled_examples)
615+
616+
# Calculate the split point
617+
split_idx = int(len(shuffled_examples) * (1 - test_size))
618+
619+
# Split the examples
620+
train_examples = shuffled_examples[:split_idx]
621+
test_examples = shuffled_examples[split_idx:]
622+
623+
return train_examples, test_examples
624+
625+
def manual_evaluation(self, examples):
626+
"""
627+
Manually evaluate the model on a list of examples using the metric method.
628+
629+
Args:
630+
examples (List[dspy.Example]): A list of DSPy examples to evaluate.
631+
632+
Returns:
633+
float: The average metric score across all examples.
634+
"""
635+
total_score = 0.0
636+
for example in tqdm(examples, desc="Evaluating examples", unit="ex", ncols=100, leave=True):
637+
# Extract the input values from the example
638+
source = example.source
639+
relation = example.relation
640+
target_entities = example.target_entities
641+
# Get model predictions
642+
pred = self.model(source=source, relation=relation, target_entities=target_entities)
643+
formatted_example = [(source, relation, item.entity) for item in example.predictions]
644+
score = self.metric(formatted_example, pred.predictions)
645+
total_score += score
646+
# Return the average score
647+
return total_score / len(examples) if examples else 0.0
648+
649+
def train_labeledFewShot(self, train_set, few_shot_k):
650+
lfs_optimizer = LabeledFewShot(k=few_shot_k)
651+
lfs_model = lfs_optimizer.compile(self.model, trainset=train_set)
652+
self.model = lfs_model
653+
lfs_model.save("./lfs_model.json")
654+
return lfs_model
558655

559656
def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
560657
idx_h, idx_r = x.tolist()[0]
561658
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
562659
pred = self.model(source=h, relation=r, target_entities=self.target_entities)
563-
print(self.lm.inspect_history())
564660
return pred.predictions
565661

566662
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
@@ -569,7 +665,6 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
569665
idx_h, idx_r = i
570666
h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r]
571667
pred = self.model(source=h, relation=r, target_entities=self.target_entities)
572-
print(self.lm.inspect_history())
573668
batch_output.append(pred.predictions)
574669
return torch.FloatTensor(batch_output)
575670

@@ -579,5 +674,14 @@ def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
579674
# test the dspy model -> remove later
580675
if __name__ == "__main__":
581676
kg = KG(dataset_dir="KGs/Countries-S1", separator="\s+", eval_model="train_value_test", add_reciprocal=False)
582-
model = DSPy_RCL(knowledge_graph=kg, base_url="http://harebell.cs.upb.de:8501/v1", api_key="secure-key-123")
583-
print(model.forward(torch.tensor([(1, 1)])))
677+
model = DSPy_RCL(knowledge_graph=kg, base_url="http://harebell.cs.upb.de:8501/v1", api_key=":)")
678+
679+
examples = model.generate_examples()
680+
train_examples, test_examples = model.generate_train_test_split(examples, test_size=0.2)
681+
682+
# Train the model
683+
model.train_labeledFewShot(train_examples, few_shot_k=3)
684+
685+
# eval model
686+
print(model.manual_evaluation(test_examples))
687+

0 commit comments

Comments
 (0)