77import json
88import re
99import igraph
10+ from typing import Tuple , Dict
11+ import dspy
1012class PredictionItem (BaseModel ):
1113 """Individual prediction item with entity name and confidence score."""
1214 entity : str = Field (..., description = "Name of the predicted entity" )
@@ -190,12 +192,9 @@ def __init__(self, knowledge_graph: KG = None,
190192 llm_model = "tentris" ,
191193 temperature : float = 1 , seed : int = 42 ) -> None :
192194 super ().__init__ (knowledge_graph , name )
193- # @TODO: CD: input arguments should be passed onto the abstract class
194-
195195 self .client = OpenAI (base_url = base_url , api_key = api_key )
196196 self .llm_model = llm_model
197197 self .temperature = temperature
198- # @TODO:CD: Use the seed
199198 self .seed = seed
200199
201200 def extract_float (self , text ):
@@ -401,4 +400,102 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
401400 print (f"For { h } ,{ r } , { pred } not found\t Prediction Size: { len (prediction_response .predictions )} " )
402401 continue
403402 batch_output .append (scores_for_all_entities )
404- return torch .FloatTensor (batch_output )
403+ return torch .FloatTensor (batch_output )
404+
405+ # 1. Define the Signature
406+ class KGLikelihood (dspy .Signature ):
407+ """Assess the likelihood that a triple (subject, predicate, candidate_object) is true,
408+ given some context triples. Output a score between 0.0 and 1.0."""
409+
410+ context = dspy .InputField (desc = "Known knowledge graph triples." )
411+ subject = dspy .InputField (desc = "The subject entity." )
412+ predicate = dspy .InputField (desc = "The relationship type." )
413+ candidate_object = dspy .InputField (desc = "The candidate object entity to score." )
414+
415+ score = dspy .OutputField (desc = "A likelihood score between 0.0 and 1.0." )
416+
417+
418+ class MultiLabelLinkPredictionWithScores (dspy .Signature ):
419+ """Given a subject entity and a predicate (relation), predict a list of
420+ object entities that satisfy the relation, along with a likelihood score for each.
421+ Use the provided examples as a guide.
422+ Output a JSON formatted list of objects, where each object has an 'entity' (string)
423+ and a 'score' (float between 0.0 and 1.0) key."""
424+
425+ examples = dspy .InputField (
426+ desc = "Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...]." )
427+ subject = dspy .InputField (desc = "The subject entity." )
428+ predicate = dspy .InputField (desc = "The relationship type." )
429+
430+ # Updated OutputField requesting JSON
431+ objects_with_scores = dspy .OutputField (
432+ desc = "A JSON string representing a list of objects. "
433+ "Each object in the list should be a dictionary with 'entity' (string) and 'score' (float, 0.0-1.0) keys." )
434+
435+ class MultiLabelLinkPredictor (dspy .Module ):
436+ def __init__ (self ):
437+ super ().__init__ ()
438+ self .predictor = dspy .Predict (MultiLabelLinkPredictionWithScores )
439+ def forward (self , subject , predicate , few_shot_examples )-> List [Tuple [str , float ]]:
440+ example_str = ""
441+ for (s , p ), o_list in few_shot_examples .items ():
442+ example_str += f"({ s } , { p } )\n { ', ' .join (o_list )} \n ---\n "
443+ # @TODO: CD: Also keep track of LLM cost
444+ dspy_pred :dspy .primitives .prediction .Prediction = self .predictor (examples = example_str , subject = subject , predicate = predicate )
445+ return [ (i ["entity" ],i ["score" ])for i in json .loads (dspy_pred .objects_with_scores )]
446+
447+ class Demir (AbstractBaseLinkPredictorClass ):
448+ def forward_triples (self , x : torch .LongTensor ) -> torch .FloatTensor :
449+ raise NotImplementedError ("RCL needs to implement it" )
450+ def forward_k_vs_all (self , x : torch .LongTensor ) -> torch .FloatTensor :
451+ batch_predictions = []
452+ for hr in x .tolist ():
453+ idx_h , idx_r = hr
454+ h , r = self .idx_to_entity [idx_h ], self .idx_to_relation [idx_r ]
455+ predictions = self .scoring_func .forward (
456+ subject = h ,
457+ predicate = r ,
458+ few_shot_examples = self .entity_relation_to_entities )
459+ scores = [- 100 ]* len (self .idx_to_entity )
460+ for entity ,score in predictions :
461+ try :
462+ idx_entity = self .entity_to_idx [entity ]
463+ except KeyError :
464+ print (f"Entity:{ entity } not found" )
465+ continue
466+ scores [idx_entity ]= score
467+ batch_predictions .append (scores )
468+ return torch .FloatTensor (batch_predictions )
469+
470+ def __init__ (self ,knowledge_graph , base_url ,api_key ,temperature , seed ,llm_model ,use_val :bool = False ):
471+ super ().__init__ (knowledge_graph ,name = "Demir" )
472+ self .client = OpenAI (base_url = base_url , api_key = api_key )
473+ self .temperature = temperature
474+ self .seed = seed
475+
476+ self .lm = dspy .LM (model = f"openai/{ llm_model } " , api_key = api_key ,
477+ api_base = base_url ,
478+ seed = seed ,
479+ temperature = temperature ,
480+ cache = True ,cache_in_memory = True ,
481+ kwargs = {"extra_body" :{"truncate_prompt_tokens" : 32_000 }})
482+ dspy .configure (lm = self .lm )
483+ self .train_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
484+ self .idx_to_relation [idx_r ],
485+ self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
486+ self .kg .train_set .tolist ()]
487+ # Validation dataset
488+ self .val_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
489+ self .idx_to_relation [idx_r ],
490+ self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
491+ self .kg .valid_set .tolist ()]
492+ self .triples = self .train_set + self .val_set if use_val else self .train_set
493+
494+ self .entity_relation_to_entities = dict ()
495+ from collections import OrderedDict
496+ for s ,p ,o in self .triples :
497+ self .entity_relation_to_entities .setdefault ((s ,p ),[]).append (o )
498+
499+ # 4. Instantiate your predictor
500+ self .scoring_func = MultiLabelLinkPredictor ()
501+ self .entities :List [str ]= list (sorted (self .entity_to_idx .keys ()))
0 commit comments