99import igraph
1010from typing import Tuple , Dict
1111import dspy
12+ from tqdm import tqdm
13+ from dspy .teleprompt import LabeledFewShot
1214class PredictionItem (BaseModel ):
1315 """Individual prediction item with entity name and confidence score."""
1416 entity : str = Field (..., description = "Name of the predicted entity" )
@@ -537,6 +539,7 @@ def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: st
537539 self .kg .valid_set .tolist ()]
538540
539541 triples = self .train_set + self .val_set if use_val else self .train_set
542+ self .triples = triples
540543
541544 # Create a mapping from relation to all triples using that relation
542545 self .relation_to_triples = {}
@@ -551,16 +554,109 @@ def metric(self, example, pred, trace=None):
551554 # Calculate MRR
552555 mrr = 0
553556 for i , (h , r , t ) in enumerate (example ):
554- if t in pred :
557+ # Check if the target entity is in the list of predicted entities
558+ if t in [p .entity for p in pred ]:
555559 mrr += 1 / (i + 1 )
556560 mrr /= len (example )
557561 return mrr
562+
563+ def generate_examples (self ):
564+ """
565+ Generate DSPy examples for training the model.
566+
567+ Returns:
568+ List[dspy.Example]: A list of DSPy examples for training.
569+ """
570+ examples = []
571+
572+ # Iterate through each relation
573+ for relation , triples in self .relation_to_triples .items ():
574+ # Group triples by head entity
575+ head_to_tails = {}
576+ for s , p , o in triples :
577+ if s not in head_to_tails :
578+ head_to_tails [s ] = []
579+ head_to_tails [s ].append (o )
580+
581+ # Create examples for each head entity
582+ for source , targets in head_to_tails .items ():
583+ # Convert target entities to PredictionItem objects with score 1.0
584+ prediction_items = [PredictionItem (entity = target , score = 1.0 ) for target in targets ]
585+
586+ # Create a DSPy example with the input being the head entity and relation
587+ # and the output being all correct tail entities as PredictionItem objects
588+ example = dspy .Example (
589+ source = source ,
590+ relation = relation ,
591+ target_entities = self .target_entities ,
592+ predictions = prediction_items
593+ ).with_inputs ("source" , "relation" , "target_entities" )
594+ examples .append (example )
595+
596+ return examples
597+
598+ def generate_train_test_split (self , examples , test_size = 0.2 ):
599+ """
600+ Split the examples into training and testing sets.
601+
602+ Args:
603+ examples (List[dspy.Example]): A list of DSPy examples to split.
604+ test_size (float): The proportion of examples to include in the test set.
605+
606+ Returns:
607+ Tuple[List[dspy.Example], List[dspy.Example]]: A tuple containing the training and testing examples.
608+ """
609+ import random
610+ random .seed (self .seed )
611+
612+ # Shuffle the examples
613+ shuffled_examples = examples .copy ()
614+ random .shuffle (shuffled_examples )
615+
616+ # Calculate the split point
617+ split_idx = int (len (shuffled_examples ) * (1 - test_size ))
618+
619+ # Split the examples
620+ train_examples = shuffled_examples [:split_idx ]
621+ test_examples = shuffled_examples [split_idx :]
622+
623+ return train_examples , test_examples
624+
625+ def manual_evaluation (self , examples ):
626+ """
627+ Manually evaluate the model on a list of examples using the metric method.
628+
629+ Args:
630+ examples (List[dspy.Example]): A list of DSPy examples to evaluate.
631+
632+ Returns:
633+ float: The average metric score across all examples.
634+ """
635+ total_score = 0.0
636+ for example in tqdm (examples , desc = "Evaluating examples" , unit = "ex" , ncols = 100 , leave = True ):
637+ # Extract the input values from the example
638+ source = example .source
639+ relation = example .relation
640+ target_entities = example .target_entities
641+ # Get model predictions
642+ pred = self .model (source = source , relation = relation , target_entities = target_entities )
643+ formatted_example = [(source , relation , item .entity ) for item in example .predictions ]
644+ score = self .metric (formatted_example , pred .predictions )
645+ total_score += score
646+ # Return the average score
647+ return total_score / len (examples ) if examples else 0.0
648+
649+ def train_labeledFewShot (self , train_set , few_shot_k ):
650+ lfs_optimizer = LabeledFewShot (k = few_shot_k )
651+ lfs_model = lfs_optimizer .compile (self .model , trainset = train_set )
652+ self .model = lfs_model
653+ lfs_model .save ("./lfs_model.json" )
654+ return lfs_model
558655
559656 def forward (self , x : torch .LongTensor ) -> torch .FloatTensor :
560657 idx_h , idx_r = x .tolist ()[0 ]
561658 h , r = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ]
562659 pred = self .model (source = h , relation = r , target_entities = self .target_entities )
563- print (self .lm .inspect_history ())
564660 return pred .predictions
565661
566662 def forward_k_vs_all (self , x : torch .LongTensor ) -> torch .FloatTensor :
@@ -569,7 +665,6 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
569665 idx_h , idx_r = i
570666 h , r = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ]
571667 pred = self .model (source = h , relation = r , target_entities = self .target_entities )
572- print (self .lm .inspect_history ())
573668 batch_output .append (pred .predictions )
574669 return torch .FloatTensor (batch_output )
575670
@@ -579,5 +674,14 @@ def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
579674# test the dspy model -> remove later
580675if __name__ == "__main__" :
581676 kg = KG (dataset_dir = "KGs/Countries-S1" , separator = "\s+" , eval_model = "train_value_test" , add_reciprocal = False )
582- model = DSPy_RCL (knowledge_graph = kg , base_url = "http://harebell.cs.upb.de:8501/v1" , api_key = "secure-key-123" )
583- print (model .forward (torch .tensor ([(1 , 1 )])))
677+ model = DSPy_RCL (knowledge_graph = kg , base_url = "http://harebell.cs.upb.de:8501/v1" , api_key = ":)" )
678+
679+ examples = model .generate_examples ()
680+ train_examples , test_examples = model .generate_train_test_split (examples , test_size = 0.2 )
681+
682+ # Train the model
683+ model .train_labeledFewShot (train_examples , few_shot_k = 3 )
684+
685+ # eval model
686+ print (model .manual_evaluation (test_examples ))
687+
0 commit comments