Skip to content

Commit a1e3187

Browse files
committed
Ensemble of Chain of Thought Models leading to the best link prediction result on Countries
1 parent 149bae5 commit a1e3187

4 files changed

Lines changed: 39 additions & 28 deletions

File tree

dicee/knowledge_graph.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
import pandas as pd
55
import polars as pl
6+
import numpy as np
7+
68
class KG:
79
""" Knowledge Graph """
810

@@ -137,8 +139,22 @@ def exists(self,h:str,r:str,t:str):
137139
return ((self.raw_train_set == pd.Series(row_to_check)).all(axis=1)).any()
138140

139141
def __iter__(self):
140-
for h, r, t in self.raw_train_set.to_numpy().tolist():
142+
if self.raw_train_set is not None:
143+
graph=self.raw_train_set.to_numpy()
144+
elif self.train_set is not None:
145+
assert isinstance(self.train_set,np.ndarray)
146+
graph=self.train_set
147+
else:
148+
raise RuntimeError(f"Dataset {self.dataset_dir} and {self.raw_train_set} & {self.train_set} are None")
149+
assert graph.shape[0]>=0 and graph.shape[1]==3, "Invalid graph shape!"
150+
151+
if hasattr(self,"idx_to_entity") is False:
152+
self.idx_to_entity = self.entity_to_idx.set_index(self.entity_to_idx.index)['entity'].to_dict()
153+
self.idx_to_relations = self.relation_to_idx.set_index(self.relation_to_idx.index)['relation'].to_dict()
154+
155+
for h, r, t in graph.tolist():
141156
yield self.idx_to_entity[h], self.idx_to_relations[r], self.idx_to_entity[t]
157+
142158
def __len__(self):
143159
return len(self.raw_train_set)
144160

dicee/read_preprocess_save_load_kg/preprocess.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def preprocess_with_pandas(self) -> None:
311311
self.kg.raw_test_set = apply_reciprical_or_noise(add_reciprical=self.kg.add_reciprocal,
312312
eval_model=self.kg.eval_model,
313313
df=self.kg.raw_test_set, info="Test")
314-
315314
# (2) Construct integer indexing for entities and relations.
316315
self.sequential_vocabulary_construction()
317316
self.kg.num_entities, self.kg.num_relations = len(self.kg.entity_to_idx), len(self.kg.relation_to_idx)

retrieval_aug_predictors/models/Demir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"H@10": 0.9583333333333334,
1515
"MRR": 0.7921296296296297
1616
}
17-
python -m retrieval_aug_predictors.models.Demir --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
17+
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
1818
{
1919
"H@1": 0.7083333333333334,
2020
"H@3": 0.9583333333333334,

retrieval_aug_predictors/models/demir_ensemble.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,20 @@
66
"H@10": 1.0,
77
"MRR": 1.0
88
}
9-
109
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
1110
{
12-
"H@1": 0.9583333333333334,
13-
"H@3": 0.9583333333333334,
11+
"H@1": 1.0,
12+
"H@3": 1.0,
1413
"H@10": 1.0,
15-
"MRR": 0.9666666666666667
14+
"MRR": 1.0
1615
}
1716
python -m retrieval_aug_predictors.models.demir_ensemble --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
1817
{
19-
"H@1": 0.875,
20-
"H@3": 0.9583333333333334,
18+
"H@1": 0.9166666666666666,
19+
"H@3": 1.0,
2120
"H@10": 1.0,
22-
"MRR": 0.9249999999999999
23-
}
21+
"MRR": 0.951388888888889
22+
}(
2423
"""
2524

2625
import dspy
@@ -49,7 +48,7 @@ class MultiLabelLinkPredictionWithScores(dspy.Signature):
4948
class MultiLabelLinkPredictor(dspy.Module):
5049
def __init__(self):
5150
super().__init__()
52-
self.predictor = dspy.Predict(MultiLabelLinkPredictionWithScores)
51+
self.predictor = dspy.ChainOfThought(MultiLabelLinkPredictionWithScores)
5352

5453
def forward(self, subject, predicate, few_shot_examples) -> List[Tuple[str, float]]:
5554
# Format examples more structured with clearer JSON expectations
@@ -96,16 +95,13 @@ def __init__(self, knowledge_graph, base_url, api_key, temperature, seed, llm_mo
9695
super().__init__(knowledge_graph, name="DemirEnsemble")
9796
self.temperature = temperature
9897
self.seed = seed
99-
100-
# Create multiple LLM models with different parameters
101-
self.lm_high_temp = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
102-
api_base=base_url, seed=seed, temperature=0.7,
103-
cache=True, cache_in_memory=True)
104-
105-
self.lm_low_temp = dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
106-
api_base=base_url, seed=seed, temperature=0.1,
107-
cache=True, cache_in_memory=True)
108-
98+
# () Initialize ensemble.
99+
self.ensemble=[]
100+
for i in range(0, 9):
101+
temperature_coefficient=i*0.1
102+
self.ensemble.append(dspy.LM(model=f"openai/{llm_model}", api_key=api_key,
103+
api_base=base_url, seed=seed, temperature=temperature_coefficient,
104+
cache=True, cache_in_memory=True))
109105
# Initialize data same as original
110106
self.train_set = [(self.idx_to_entity[idx_h],
111107
self.idx_to_relation[idx_r],
@@ -136,14 +132,14 @@ def __init__(self, knowledge_graph, base_url, api_key, temperature, seed, llm_mo
136132
def _create_ensemble_predictors(self):
137133
"""Create multiple predictors with different configurations"""
138134
predictors = []
139-
140-
# Standard predictor
141-
dspy.configure(lm=self.lm_low_temp)
142-
predictors.append(MultiLabelLinkPredictor())
135+
for i in self.ensemble:
136+
# Standard predictor
137+
dspy.configure(lm=i)
138+
predictors.append(MultiLabelLinkPredictor())
143139

144140
# Diverse predictor (high temp)
145-
dspy.configure(lm=self.lm_high_temp)
146-
predictors.append(MultiLabelLinkPredictor())
141+
#dspy.configure(lm=self.lm_high_temp)
142+
#predictors.append(MultiLabelLinkPredictor())
147143

148144
return predictors
149145

0 commit comments

Comments
 (0)