Skip to content

Commit 6e27ba8

Browse files
committed
renaming and small tweaks
1 parent 216cddb commit 6e27ba8

3 files changed

Lines changed: 10 additions & 7 deletions

File tree

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def run(args):
2525
triples = f.readlines()
2626

2727
for prd in predictions:
28-
triple = prd[0] + "\t" + prd[1] + " \t" + prd[2][0][0]
29-
triples.append(triple + "\n")
28+
triple = prd[0] + "\t" + prd[1] + " \t" + prd[2][0][0]+ "\n"
29+
if triple not in triples:
30+
triples.append(triple)
3031

3132
with open(args.dataset_dir + "/" + args.kg_out, "w") as out:
3233
out.writelines(triples)
@@ -40,7 +41,7 @@ def run(args):
4041
parser.add_argument("--pred_out", type=str, default=None,
4142
help="Name of the output file where the predictions will be saved.")
4243

43-
parser.add_argument("--kg_out", type=str, default="extended_train.txt",
44+
parser.add_argument("--kg_out", type=str, default="enriched_train.txt",
4445
help="Name of the output file where the extended train set will be saved.")
4546

4647
run(parser.parse_args())
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from retrieval_aug_predictors.models import KG, AbstractBaseLinkPredictorClass
55

66

7-
class RALP(AbstractBaseLinkPredictorClass):
7+
class CATC(AbstractBaseLinkPredictorClass):
8+
"""Context-Aware Triple Completion"""
89
def __init__(self, knowledge_graph: KG = None,
910
name="ralp-1.0",
1011
base_url="http://tentris-ml.cs.upb.de:8501/v1",

retrieval_aug_predictors/models/demir_ensemble.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def forward(self, subject, predicate, few_shot_examples) -> List[Tuple[str, floa
8888
return []
8989

9090

91-
class DemirEnsemble(AbstractBaseLinkPredictorClass):
92-
"""Ensemble approach combining multiple prediction strategies"""
91+
class RALP(AbstractBaseLinkPredictorClass):
92+
"""Retrieval-Augmented Link Prediction (ex DemirEnsemble).
93+
Ensemble approach combining multiple prediction strategies"""
9394

9495
def __init__(self, knowledge_graph, base_url, api_key, temperature, seed, llm_model, use_val: bool = False):
9596
super().__init__(knowledge_graph, name="DemirEnsemble")
@@ -288,7 +289,7 @@ def process(scores, storage, **kwargs):
288289
results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
289290
er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size)
290291
else:
291-
x = kg.test_set[:, [0, 1]]
292+
x = kg.train_set[:, [0, 1]]
292293
results = model.get_predicted_triples(x, args.k)
293294

294295
print("Results: {}".format(results))

0 commit comments

Comments
 (0)