44from typing import List , Tuple
55from retrieval_aug_predictors .models import KG , AbstractBaseLinkPredictorClass
66from openai import OpenAI
7-
8- # 1. Define the Signature
9- class KGLikelihood (dspy .Signature ):
10- """Assess the likelihood that a triple (subject, predicate, candidate_object) is true,
11- given some context triples. Output a score between 0.0 and 1.0."""
12-
13- context = dspy .InputField (desc = "Known knowledge graph triples." )
14- subject = dspy .InputField (desc = "The subject entity." )
15- predicate = dspy .InputField (desc = "The relationship type." )
16- candidate_object = dspy .InputField (desc = "The candidate object entity to score." )
17-
18- score = dspy .OutputField (desc = "A likelihood score between 0.0 and 1.0." )
19-
7+ from collections import OrderedDict
8+ from retrieval_aug_predictors .arguments import parser
9+ from retrieval_aug_predictors .utils import sanity_checking
10+ from dicee .evaluator import evaluate_lp , evaluate_lp_k_vs_all
11+ from dotenv import load_dotenv
12+ load_dotenv ()
2013
2114class MultiLabelLinkPredictionWithScores (dspy .Signature ):
22- """Given a subject entity and a predicate (relation), predict a list of
23- object entities that satisfy the relation, along with a likelihood score for each.
24- Use the provided examples as a guide.
25- Output a JSON formatted list of objects, where each object has an 'entity' (string)
26- and a 'score' (float between 0.0 and 1.0) key."""
27-
2815 examples = dspy .InputField (
2916 desc = "Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...]." )
30- subject = dspy .InputField (desc = "The subject entity." )
31- predicate = dspy .InputField (desc = "The relationship type." )
17+ subject : str = dspy .InputField (desc = "The subject entity." )
18+ predicate : str = dspy .InputField (desc = "The relationship type." )
3219
3320 # Updated OutputField requesting JSON
3421 objects_with_scores = dspy .OutputField (
@@ -48,6 +35,32 @@ def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]
4835 return [ (i ["entity" ],i ["score" ])for i in json .loads (dspy_pred .objects_with_scores )]
4936
5037class Demir (AbstractBaseLinkPredictorClass ):
38+ def __init__ (self ,knowledge_graph , base_url ,api_key ,temperature , seed ,llm_model ,use_val :bool = False ):
39+ super ().__init__ (knowledge_graph ,name = "Demir" )
40+ self .temperature = temperature
41+ self .seed = seed
42+ self .lm = dspy .LM (model = f"openai/{ llm_model } " , api_key = api_key ,
43+ api_base = base_url ,
44+ seed = seed ,
45+ temperature = temperature ,
46+ cache = True ,cache_in_memory = True )
47+ dspy .configure (lm = self .lm )
48+ self .train_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
49+ self .idx_to_relation [idx_r ],
50+ self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
51+ self .kg .train_set .tolist ()]
52+ # Validation dataset
53+ self .val_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
54+ self .idx_to_relation [idx_r ],
55+ self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
56+ self .kg .valid_set .tolist ()]
57+ self .triples = self .train_set + self .val_set if use_val else self .train_set
58+
59+ self .entity_relation_to_entities = dict ()
60+ for s ,p ,o in self .triples :
61+ self .entity_relation_to_entities .setdefault ((s ,p ),[]).append (o )
62+ self .scoring_func = MultiLabelLinkPredictor ()
63+
5164 def forward_triples (self , x : torch .LongTensor ) -> torch .FloatTensor :
5265 raise NotImplementedError ("RCL needs to implement it" )
5366 def forward_k_vs_all (self , x : torch .LongTensor ) -> torch .FloatTensor :
@@ -70,35 +83,19 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
7083 batch_predictions .append (scores )
7184 return torch .FloatTensor (batch_predictions )
7285
73- def __init__ (self ,knowledge_graph , base_url ,api_key ,temperature , seed ,llm_model ,use_val :bool = False ):
74- super ().__init__ (knowledge_graph ,name = "Demir" )
75- self .client = OpenAI (base_url = base_url , api_key = api_key )
76- self .temperature = temperature
77- self .seed = seed
7886
79- self .lm = dspy .LM (model = f"openai/{ llm_model } " , api_key = api_key ,
80- api_base = base_url ,
81- seed = seed ,
82- temperature = temperature ,
83- cache = True ,cache_in_memory = True ,
84- kwargs = {"extra_body" :{"truncate_prompt_tokens" : 32_000 }})
85- dspy .configure (lm = self .lm )
86- self .train_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
87- self .idx_to_relation [idx_r ],
88- self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
89- self .kg .train_set .tolist ()]
90- # Validation dataset
91- self .val_set : List [Tuple [str ]] = [(self .idx_to_entity [idx_h ],
92- self .idx_to_relation [idx_r ],
93- self .idx_to_entity [idx_t ]) for idx_h , idx_r , idx_t in
94- self .kg .valid_set .tolist ()]
95- self .triples = self .train_set + self .val_set if use_val else self .train_set
87+ # test the dspy model -> remove later
88+ if __name__ == "__main__" :
89+ args = parser .parse_args ()
90+ # Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
91+ # Therefore, The link prediction results are based on the missing tail rankings only!
92+ print (args )
93+ kg = KG (dataset_dir = args .dataset_dir , separator = "\s+" , eval_model = args .eval_model , add_reciprocal = False )
9694
97- self .entity_relation_to_entities = dict ()
98- from collections import OrderedDict
99- for s ,p ,o in self .triples :
100- self .entity_relation_to_entities .setdefault ((s ,p ),[]).append (o )
95+ sanity_checking (args ,kg )
10196
102- # 4. Instantiate your predictor
103- self .scoring_func = MultiLabelLinkPredictor ()
104- self .entities :List [str ]= list (sorted (self .entity_to_idx .keys ()))
97+ model = Demir (knowledge_graph = kg , base_url = args .base_url , api_key = args .api_key , llm_model = args .llm_model_name , temperature = args .temperature , seed = args .seed )
98+
99+ results :dict = evaluate_lp_k_vs_all (model = model , triple_idx = kg .test_set [:args .eval_size ],
100+ er_vocab = kg .er_vocab , info = 'Eval KvsAll Starts' , batch_size = args .batch_size )
101+ print (results )
0 commit comments