|
1 | 1 | """ |
2 | | -python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json |
| 2 | +python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json |
3 | 3 | { |
4 | 4 | "H@1": 1.0, |
5 | 5 | "H@3": 1.0, |
6 | 6 | "H@10": 1.0, |
7 | 7 | "MRR": 1.0 |
8 | 8 | } |
9 | 9 |
|
10 | | -python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json |
| 10 | +python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json |
11 | 11 | { |
12 | 12 | "H@1": 0.9583333333333334, |
13 | 13 | "H@3": 0.9583333333333334, |
14 | 14 | "H@10": 1.0, |
15 | 15 | "MRR": 0.9666666666666667 |
16 | 16 | } |
17 | | -python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json |
| 17 | +python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json |
18 | 18 | { |
19 | 19 | "H@1": 0.875, |
20 | 20 | "H@3": 0.9583333333333334, |
@@ -231,14 +231,17 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor: |
231 | 231 | return torch.FloatTensor(batch_predictions) |
232 | 232 | def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor: |
233 | 233 | raise NotImplementedError("RCL needs to implement it") |
| 234 | + |
| 235 | + |
234 | 236 | # test the dspy model -> remove later |
235 | 237 | if __name__ == "__main__": |
236 | 238 | args=parser.parse_args() |
237 | 239 | # Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced. |
238 | 240 | # Therefore, The link prediction results are based on the missing tail rankings only! |
239 | 241 | kg = KG(dataset_dir=args.dataset_dir, separator="\s+", eval_model=args.eval_model, add_reciprocal=False) |
240 | 242 | sanity_checking(args,kg) |
241 | | - model = DemirEnsemble(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key, llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed) |
| 243 | + model = DemirEnsemble(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key, |
| 244 | + llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed) |
242 | 245 | results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size], |
243 | 246 | er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size) |
244 | 247 | if args.out and results: |
|
0 commit comments