Skip to content

Commit a82338e

Browse files
committed
Initial experiments on Countries done
1 parent 18f48ef commit a82338e

2 files changed

Lines changed: 36 additions & 6 deletions

File tree

retrieval_aug_predictors/arguments.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@
2424
parser.add_argument("--max_relation_examples", type=int, default=2000,
2525
help="Maximum number of relation examples to include in RCL context.")
2626
parser.add_argument("--exclude_source", action="store_true",
27-
help="Exclude triples with the same source entity in RCL context.")
27+
help="Exclude triples with the same source entity in RCL context.")
28+
parser.add_argument("--out", type=str, default=None,
29+
help="A path of a json file reporting the link prediction results.")

retrieval_aug_predictors/models/Demir.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,29 @@
1+
"""
2+
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
3+
{
4+
"H@1": 1.0,
5+
"H@3": 1.0,
6+
"H@10": 1.0,
7+
"MRR": 1.0
8+
}
9+
10+
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
11+
{
12+
"H@1": 0.7083333333333334,
13+
"H@3": 1.0,
14+
"H@10": 1.0,
15+
"MRR": 0.8541666666666666
16+
}
17+
18+
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
19+
{
20+
"H@1": 0.3333333333333333,
21+
"H@3": 1.0,
22+
"H@10": 1.0,
23+
"MRR": 0.6666666666666666
24+
}
25+
"""
26+
127
import dspy
228
import torch
329
import json
@@ -11,6 +37,7 @@
1137
from dotenv import load_dotenv
1238
load_dotenv()
1339

40+
1441
class MultiLabelLinkPredictionWithScores(dspy.Signature):
1542
examples = dspy.InputField(
1643
desc="Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...].")
@@ -89,13 +116,14 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
89116
args=parser.parse_args()
90117
# Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
91118
# Therefore, The link prediction results are based on the missing tail rankings only!
92-
print(args)
93119
kg = KG(dataset_dir=args.dataset_dir, separator="\s+", eval_model=args.eval_model, add_reciprocal=False)
94-
95120
sanity_checking(args,kg)
96-
97121
model = Demir(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key, llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
98-
99122
results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
100123
er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size)
101-
print(results)
124+
if args.out and results:
125+
# Writing the dictionary to a JSON file
126+
print(results)
127+
with open(args.out, 'w') as json_file:
128+
json.dump(results, json_file, indent=4)
129+
print(f"Results has been saved to {args.out}")

0 commit comments

Comments
 (0)