|
9 | 9 | import igraph |
10 | 10 | from typing import Tuple, Dict |
11 | 11 | import dspy |
| 12 | +from tqdm import tqdm |
| 13 | +from dspy.teleprompt import LabeledFewShot |
12 | 14 | class PredictionItem(BaseModel): |
13 | 15 | """Individual prediction item with entity name and confidence score.""" |
14 | 16 | entity: str = Field(..., description="Name of the predicted entity") |
@@ -439,9 +441,8 @@ def __init__(self): |
439 | 441 | def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]]: |
440 | 442 | example_str = "" |
441 | 443 | for (s, p), o_list in few_shot_examples.items(): |
442 | | - for o in o_list: |
443 | | - example_str += f"({s}, {p}, {o})\n" |
444 | | - example_str+"\n\n" |
| 444 | + example_str += f"({s}, {p})\n{', '.join(o_list)}\n---\n" |
| 445 | + # @TODO: CD: Also keep track of LLM cost |
445 | 446 | dspy_pred:dspy.primitives.prediction.Prediction=self.predictor(examples=example_str, subject=subject, predicate=predicate) |
446 | 447 | return [ (i["entity"],i["score"])for i in json.loads(dspy_pred.objects_with_scores)] |
447 | 448 |
|
@@ -499,3 +500,189 @@ def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model, |
499 | 500 |
|
500 | 501 | # 4. Instantiate your predictor |
501 | 502 | self.scoring_func = MultiLabelLinkPredictor() |
| 503 | + self.entities:List[str]=list(sorted(self.entity_to_idx.keys())) |
| 504 | + |
| 505 | +class LM_Call_Signature(dspy.Signature): |
| 506 | + source: str = dspy.InputField(description="The source entity") |
| 507 | + relation: str = dspy.InputField(description="The relation") |
| 508 | + target_entities: List[str] = dspy.InputField(description="The list of target entities") |
| 509 | + predictions: List[PredictionItem] = dspy.OutputField(description="The list of predicted entities with scores") |
| 510 | + |
| 511 | +class DSPy_RCL(AbstractBaseLinkPredictorClass): |
| 512 | + |
| 513 | + def __init__(self, knowledge_graph: KG = None, base_url: str = None, api_key: str = None, llm_model: str = None, |
| 514 | + temperature: float = 0.0, seed: int = 42, max_relation_examples: int = 2000, use_val: bool = True, |
| 515 | + exclude_source: bool = False) -> None: |
| 516 | + super().__init__(knowledge_graph, name="DSPy_RCL") |
| 517 | + assert base_url is not None and isinstance(base_url, str) |
| 518 | + self.base_url = base_url |
| 519 | + self.api_key = api_key |
| 520 | + self.llm_model = llm_model |
| 521 | + self.temperature = temperature |
| 522 | + self.seed = seed |
| 523 | + self.max_relation_examples = max_relation_examples |
| 524 | + self.exclude_source = exclude_source |
| 525 | + # hardcoded for now |
| 526 | + self.lm = dspy.LM(model="openai/tentris", api_key=self.api_key, base_url=self.base_url) |
| 527 | + dspy.configure(lm=self.lm) |
| 528 | + self.model = dspy.ChainOfThought(LM_Call_Signature) |
| 529 | + |
| 530 | + # Training dataset |
| 531 | + self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h], |
| 532 | + self.idx_to_relation[idx_r], |
| 533 | + self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in |
| 534 | + self.kg.train_set.tolist()] |
| 535 | + # Validation dataset |
| 536 | + self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h], |
| 537 | + self.idx_to_relation[idx_r], |
| 538 | + self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in |
| 539 | + self.kg.valid_set.tolist()] |
| 540 | + |
| 541 | + triples = self.train_set + self.val_set if use_val else self.train_set |
| 542 | + self.triples = triples |
| 543 | + |
| 544 | + # Create a mapping from relation to all triples using that relation |
| 545 | + self.relation_to_triples = {} |
| 546 | + for s, p, o in triples: |
| 547 | + if p not in self.relation_to_triples: |
| 548 | + self.relation_to_triples[p] = [] |
| 549 | + self.relation_to_triples[p].append((s, p, o)) |
| 550 | + |
| 551 | + self.target_entities = list(sorted(self.entity_to_idx.keys())) |
| 552 | + |
| 553 | + def metric(self, example, pred, trace=None): |
| 554 | + # Calculate MRR |
| 555 | + mrr = 0 |
| 556 | + for i, (h, r, t) in enumerate(example): |
| 557 | + # Check if the target entity is in the list of predicted entities |
| 558 | + if t in [p.entity for p in pred]: |
| 559 | + mrr += 1 / (i + 1) |
| 560 | + mrr /= len(example) |
| 561 | + return mrr |
| 562 | + |
| 563 | + def generate_examples(self): |
| 564 | + """ |
| 565 | + Generate DSPy examples for training the model. |
| 566 | +
|
| 567 | + Returns: |
| 568 | + List[dspy.Example]: A list of DSPy examples for training. |
| 569 | + """ |
| 570 | + examples = [] |
| 571 | + |
| 572 | + # Iterate through each relation |
| 573 | + for relation, triples in self.relation_to_triples.items(): |
| 574 | + # Group triples by head entity |
| 575 | + head_to_tails = {} |
| 576 | + for s, p, o in triples: |
| 577 | + if s not in head_to_tails: |
| 578 | + head_to_tails[s] = [] |
| 579 | + head_to_tails[s].append(o) |
| 580 | + |
| 581 | + # Create examples for each head entity |
| 582 | + for source, targets in head_to_tails.items(): |
| 583 | + # Convert target entities to PredictionItem objects with score 1.0 |
| 584 | + prediction_items = [PredictionItem(entity=target, score=1.0) for target in targets] |
| 585 | + |
| 586 | + # Create a DSPy example with the input being the head entity and relation |
| 587 | + # and the output being all correct tail entities as PredictionItem objects |
| 588 | + example = dspy.Example( |
| 589 | + source=source, |
| 590 | + relation=relation, |
| 591 | + target_entities=self.target_entities, |
| 592 | + predictions=prediction_items |
| 593 | + ).with_inputs("source", "relation", "target_entities") |
| 594 | + examples.append(example) |
| 595 | + |
| 596 | + return examples |
| 597 | + |
| 598 | + def generate_train_test_split(self, examples, test_size=0.2): |
| 599 | + """ |
| 600 | + Split the examples into training and testing sets. |
| 601 | +
|
| 602 | + Args: |
| 603 | + examples (List[dspy.Example]): A list of DSPy examples to split. |
| 604 | + test_size (float): The proportion of examples to include in the test set. |
| 605 | +
|
| 606 | + Returns: |
| 607 | + Tuple[List[dspy.Example], List[dspy.Example]]: A tuple containing the training and testing examples. |
| 608 | + """ |
| 609 | + import random |
| 610 | + random.seed(self.seed) |
| 611 | + |
| 612 | + # Shuffle the examples |
| 613 | + shuffled_examples = examples.copy() |
| 614 | + random.shuffle(shuffled_examples) |
| 615 | + |
| 616 | + # Calculate the split point |
| 617 | + split_idx = int(len(shuffled_examples) * (1 - test_size)) |
| 618 | + |
| 619 | + # Split the examples |
| 620 | + train_examples = shuffled_examples[:split_idx] |
| 621 | + test_examples = shuffled_examples[split_idx:] |
| 622 | + |
| 623 | + return train_examples, test_examples |
| 624 | + |
| 625 | + def manual_evaluation(self, examples): |
| 626 | + """ |
| 627 | + Manually evaluate the model on a list of examples using the metric method. |
| 628 | +
|
| 629 | + Args: |
| 630 | + examples (List[dspy.Example]): A list of DSPy examples to evaluate. |
| 631 | +
|
| 632 | + Returns: |
| 633 | + float: The average metric score across all examples. |
| 634 | + """ |
| 635 | + total_score = 0.0 |
| 636 | + for example in tqdm(examples, desc="Evaluating examples", unit="ex", ncols=100, leave=True): |
| 637 | + # Extract the input values from the example |
| 638 | + source = example.source |
| 639 | + relation = example.relation |
| 640 | + target_entities = example.target_entities |
| 641 | + # Get model predictions |
| 642 | + pred = self.model(source=source, relation=relation, target_entities=target_entities) |
| 643 | + formatted_example = [(source, relation, item.entity) for item in example.predictions] |
| 644 | + score = self.metric(formatted_example, pred.predictions) |
| 645 | + total_score += score |
| 646 | + # Return the average score |
| 647 | + return total_score / len(examples) if examples else 0.0 |
| 648 | + |
| 649 | + def train_labeledFewShot(self, train_set, few_shot_k): |
| 650 | + lfs_optimizer = LabeledFewShot(k=few_shot_k) |
| 651 | + lfs_model = lfs_optimizer.compile(self.model, trainset=train_set) |
| 652 | + self.model = lfs_model |
| 653 | + lfs_model.save("./lfs_model.json") |
| 654 | + return lfs_model |
| 655 | + |
| 656 | + def forward(self, x: torch.LongTensor) -> torch.FloatTensor: |
| 657 | + idx_h, idx_r = x.tolist()[0] |
| 658 | + h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r] |
| 659 | + pred = self.model(source=h, relation=r, target_entities=self.target_entities) |
| 660 | + return pred.predictions |
| 661 | + |
| 662 | + def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor: |
| 663 | + batch_output = [] |
| 664 | + for i in x.tolist(): |
| 665 | + idx_h, idx_r = i |
| 666 | + h, r = self.idx_to_entity[idx_h], self.idx_to_relation[idx_r] |
| 667 | + pred = self.model(source=h, relation=r, target_entities=self.target_entities) |
| 668 | + batch_output.append(pred.predictions) |
| 669 | + return torch.FloatTensor(batch_output) |
| 670 | + |
| 671 | + def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor: |
| 672 | + raise NotImplementedError("DSPy_RCL needs to implement it") |
| 673 | + |
| 674 | +# test the dspy model -> remove later |
| 675 | +if __name__ == "__main__": |
| 676 | + kg = KG(dataset_dir="KGs/Countries-S1", separator="\s+", eval_model="train_value_test", add_reciprocal=False) |
| 677 | + model = DSPy_RCL(knowledge_graph=kg, base_url="http://harebell.cs.upb.de:8501/v1", api_key=":)") |
| 678 | + |
| 679 | + examples = model.generate_examples() |
| 680 | + train_examples, test_examples = model.generate_train_test_split(examples, test_size=0.2) |
| 681 | + |
| 682 | + # Train the model |
| 683 | + model.train_labeledFewShot(train_examples, few_shot_k=3) |
| 684 | + |
| 685 | + # eval model |
| 686 | + print(model.manual_evaluation(test_examples)) |
| 687 | + |
| 688 | + |
0 commit comments