Skip to content

Commit 18f48ef

Browse files
committed
argparse included
1 parent 5bd8920 commit 18f48ef

3 files changed

Lines changed: 83 additions & 50 deletions

File tree

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import argparse
2+
parser = argparse.ArgumentParser()
3+
parser.add_argument("--dataset_dir", type=str, default="/home/cdemir/Desktop/Softwares/dice-embeddings/KGs/Countries-S1", help="Path to dataset.")
4+
parser.add_argument("--model", type=str, default="Demir", help="Model name to use for link prediction.",
5+
choices=["Demir", "GCL", "RCL", "RALP"])
6+
parser.add_argument("--base_url", type=str, default="http://harebell.cs.upb.de:8501/v1",
7+
choices=["http://harebell.cs.upb.de:8501/v1", "http://tentris-ml.cs.upb.de:8502/v1"],
8+
help="Base URL for the OpenAI client.")
9+
parser.add_argument("--llm_model_name", type=str, default="tentris", help="Model name of the LLM to use.")
10+
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature hyperparameter for LLM calls.")
11+
parser.add_argument("--api_key", type=str, default=None, help="API key for the OpenAI client. If left to None, "
12+
"it will look at the environment variable named "
13+
"TENTRIS_TOKEN from a local .env file.")
14+
parser.add_argument("--eval_size", type=int, default=None,
15+
help="Amount of triples from the test set to evaluate. "
16+
"Leave it None to include all triples on the test set.")
17+
parser.add_argument("--eval_model", type=str, default="train_value_test",
18+
help="Type of evaluation model.")
19+
parser.add_argument("--batch_size", type=int, default=1)
20+
parser.add_argument("--chunk_size", type=int, default=1)
21+
parser.add_argument("--seed", type=int, default=42)
22+
parser.add_argument("--num_of_hops", type=int, default=1,
23+
help="Number of hops to use to extract a subgraph around an entity.")
24+
parser.add_argument("--max_relation_examples", type=int, default=2000,
25+
help="Maximum number of relation examples to include in RCL context.")
26+
parser.add_argument("--exclude_source", action="store_true",
27+
help="Exclude triples with the same source entity in RCL context.")

retrieval_aug_predictors/models/Demir.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,18 @@
44
from typing import List, Tuple
55
from retrieval_aug_predictors.models import KG, AbstractBaseLinkPredictorClass
66
from openai import OpenAI
7-
8-
# 1. Define the Signature
9-
class KGLikelihood(dspy.Signature):
10-
"""Assess the likelihood that a triple (subject, predicate, candidate_object) is true,
11-
given some context triples. Output a score between 0.0 and 1.0."""
12-
13-
context = dspy.InputField(desc="Known knowledge graph triples.")
14-
subject = dspy.InputField(desc="The subject entity.")
15-
predicate = dspy.InputField(desc="The relationship type.")
16-
candidate_object = dspy.InputField(desc="The candidate object entity to score.")
17-
18-
score = dspy.OutputField(desc="A likelihood score between 0.0 and 1.0.")
19-
7+
from collections import OrderedDict
8+
from retrieval_aug_predictors.arguments import parser
9+
from retrieval_aug_predictors.utils import sanity_checking
10+
from dicee.evaluator import evaluate_lp, evaluate_lp_k_vs_all
11+
from dotenv import load_dotenv
12+
load_dotenv()
2013

2114
class MultiLabelLinkPredictionWithScores(dspy.Signature):
22-
"""Given a subject entity and a predicate (relation), predict a list of
23-
object entities that satisfy the relation, along with a likelihood score for each.
24-
Use the provided examples as a guide.
25-
Output a JSON formatted list of objects, where each object has an 'entity' (string)
26-
and a 'score' (float between 0.0 and 1.0) key."""
27-
2815
examples = dspy.InputField(
2916
desc="Few-shot examples of (subject, predicate) -> [{'entity': entity1, 'score': score1}, ...].")
30-
subject = dspy.InputField(desc="The subject entity.")
31-
predicate = dspy.InputField(desc="The relationship type.")
17+
subject:str = dspy.InputField(desc="The subject entity.")
18+
predicate:str = dspy.InputField(desc="The relationship type.")
3219

3320
# Updated OutputField requesting JSON
3421
objects_with_scores = dspy.OutputField(
@@ -48,6 +35,32 @@ def forward(self, subject, predicate, few_shot_examples)->List[Tuple[str, float]
4835
return [ (i["entity"],i["score"])for i in json.loads(dspy_pred.objects_with_scores)]
4936

5037
class Demir(AbstractBaseLinkPredictorClass):
38+
def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model,use_val:bool=False):
39+
super().__init__(knowledge_graph,name="Demir")
40+
self.temperature = temperature
41+
self.seed = seed
42+
self.lm = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
43+
api_base=base_url,
44+
seed=seed,
45+
temperature=temperature,
46+
cache=True,cache_in_memory=True)
47+
dspy.configure(lm=self.lm)
48+
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
49+
self.idx_to_relation[idx_r],
50+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
51+
self.kg.train_set.tolist()]
52+
# Validation dataset
53+
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
54+
self.idx_to_relation[idx_r],
55+
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
56+
self.kg.valid_set.tolist()]
57+
self.triples = self.train_set + self.val_set if use_val else self.train_set
58+
59+
self.entity_relation_to_entities=dict()
60+
for s,p,o in self.triples:
61+
self.entity_relation_to_entities.setdefault((s,p),[]).append(o)
62+
self.scoring_func = MultiLabelLinkPredictor()
63+
5164
def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
5265
raise NotImplementedError("RCL needs to implement it")
5366
def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
@@ -70,35 +83,19 @@ def forward_k_vs_all(self, x: torch.LongTensor) -> torch.FloatTensor:
7083
batch_predictions.append(scores)
7184
return torch.FloatTensor(batch_predictions)
7285

73-
def __init__(self,knowledge_graph, base_url,api_key,temperature, seed,llm_model,use_val:bool=False):
74-
super().__init__(knowledge_graph,name="Demir")
75-
self.client = OpenAI(base_url=base_url, api_key=api_key)
76-
self.temperature = temperature
77-
self.seed = seed
7886

79-
self.lm = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
80-
api_base=base_url,
81-
seed=seed,
82-
temperature=temperature,
83-
cache=True,cache_in_memory=True,
84-
kwargs={"extra_body":{"truncate_prompt_tokens": 32_000}})
85-
dspy.configure(lm=self.lm)
86-
self.train_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
87-
self.idx_to_relation[idx_r],
88-
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
89-
self.kg.train_set.tolist()]
90-
# Validation dataset
91-
self.val_set: List[Tuple[str]] = [(self.idx_to_entity[idx_h],
92-
self.idx_to_relation[idx_r],
93-
self.idx_to_entity[idx_t]) for idx_h, idx_r, idx_t in
94-
self.kg.valid_set.tolist()]
95-
self.triples = self.train_set + self.val_set if use_val else self.train_set
87+
# test the dspy model -> remove later
88+
if __name__ == "__main__":
89+
args=parser.parse_args()
90+
# Important: add_reciprocal=False in KvsAll implies that inverse relation has been introduced.
91+
# Therefore, The link prediction results are based on the missing tail rankings only!
92+
print(args)
93+
kg = KG(dataset_dir=args.dataset_dir, separator="\s+", eval_model=args.eval_model, add_reciprocal=False)
9694

97-
self.entity_relation_to_entities=dict()
98-
from collections import OrderedDict
99-
for s,p,o in self.triples:
100-
self.entity_relation_to_entities.setdefault((s,p),[]).append(o)
95+
sanity_checking(args,kg)
10196

102-
# 4. Instantiate your predictor
103-
self.scoring_func = MultiLabelLinkPredictor()
104-
self.entities:List[str]=list(sorted(self.entity_to_idx.keys()))
97+
model = Demir(knowledge_graph=kg, base_url=args.base_url, api_key=args.api_key, llm_model=args.llm_model_name, temperature=args.temperature, seed=args.seed)
98+
99+
results:dict = evaluate_lp_k_vs_all(model=model, triple_idx=kg.test_set[:args.eval_size],
100+
er_vocab=kg.er_vocab, info='Eval KvsAll Starts', batch_size=args.batch_size)
101+
print(results)

retrieval_aug_predictors/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import os
2+
def sanity_checking(args,kg):
3+
if args.eval_size is not None:
4+
assert len(kg.test_set) >= args.eval_size, (f"Evaluation size cant be greater than the "
5+
f"total amount of triples in the test set: {len(kg.test_set)}")
6+
else:
7+
args.eval_size = len(kg.test_set)
8+
if args.api_key is None:
9+
args.api_key = os.environ.get("TENTRIS_TOKEN")

0 commit comments

Comments
 (0)