Skip to content

Commit 149bae5

Browse files
committed
Ensemble updated
1 parent df69229 commit 149bae5

2 files changed

Lines changed: 8 additions & 5 deletions

File tree

retrieval_aug_predictors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .models.RALP import RALP
55
from .models.RCL import RCL
66
from .models.Demir import Demir
7-
7+
from .models.demir_ensemble import DemirEnsemble

retrieval_aug_predictors/models/demir_ensemble.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""
2-
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
2+
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
33
{
44
"H@1": 1.0,
55
"H@3": 1.0,
66
"H@10": 1.0,
77
"MRR": 1.0
88
}
99
10-
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
10+
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
1111
{
1212
"H@1": 0.9583333333333334,
1313
"H@3": 0.9583333333333334,
1414
"H@10": 1.0,
1515
"MRR": 0.9666666666666667
1616
}
17-
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
17+
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
1818
{
1919
"H@1": 0.875,
2020
"H@3": 0.9583333333333334,
@@ -231,14 +231,17 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
231231
return torch.FloatTensor(batch_predictions)
232232
def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
233233
raise NotImplementedError("RCL needs to implement it")
234+
235+
234236
# test the dspy model -> remove later
235237
if __name__ == "__main__":
236238
args=parser.parse_args()
237239
# Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
238240
# Therefore, The link prediction results are based on the missing tail rankings only!
239241
kg = KG(dataset_dir=args.dataset_dir, separator="\s+", eval_model=args.eval_model, add_reciprocal=False)
240242
sanity_checking(args,kg)
241-
model = DemirEnsemble(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key, llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
243+
model = DemirEnsemble(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key,
244+
llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
242245
results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
243246
er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size)
244247
if args.out and results:

0 commit comments

Comments
 (0)