1+ """
2+ python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
3+ {
4+ "H@1": 1.0,
5+ "H@3": 1.0,
6+ "H@10": 1.0,
7+ "MRR": 1.0
8+ }
9+
10+ python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
11+ {
12+ "H@1": 0.7083333333333334,
13+ "H@3": 1.0,
14+ "H@10": 1.0,
15+ "MRR": 0.8541666666666666
16+ }
17+
18+ python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
19+ {
20+ "H@1": 0.3333333333333333,
21+ "H@3": 1.0,
22+ "H@10": 1.0,
23+ "MRR": 0.6666666666666666
24+ }
25+ """
26+
127import dspy
228import torch
329import json
1137from dotenv import load_dotenv
1238load_dotenv ()
1339
40+
1441class MultiLabelLinkPredictionWithScores (dspy .Signature ):
1542 examples = dspy .InputField (
1643 desc = "Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...]." )
@@ -89,13 +116,14 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
89116 args = parser .parse_args ()
90117 # Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
91118 # Therefore, The link prediction results are based on the missing tail rankings only!
92- print (args )
93119 kg = KG (dataset_dir = args .dataset_dir , separator = "\s+" , eval_model = args .eval_model , add_reciprocal = False )
94-
95120 sanity_checking (args ,kg )
96-
97121 model = Demir (knowledge_graph = kg , base_url = args .base_url , api_key = args .api_key , llm_model = args .llm_model_name , temperature = args .temperature , seed = args .seed )
98-
99122 results :dict = evaluate_lp_k_vs_all (model = model , triple_idx = kg .test_set [:args .eval_size ],
100123 er_vocab = kg .er_vocab , info = 'Eval KvsAll Starts' , batch_size = args .batch_size )
101- print (results )
124+ if args .out and results :
125+ # Writing the dictionary to a JSON file
126+ print (results )
127+ with open (args .out , 'w' ) as json_file :
128+ json .dump (results , json_file , indent = 4 )
129+ print (f"Results has been saved to { args .out } " )
0 commit comments