Skip to content

Commit ba8dc4e

Browse files
committed
auto param for opt can be given to model
1 parent bc33305 commit ba8dc4e

1 file changed

Lines changed: 15 additions & 35 deletions

File tree

retrieval_aug_predictors/models/demir_ensemble_mipro.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,3 @@
1-
"""
2-
python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
3-
{
4-
"H@1": 0.75,
5-
"H@3": 0.875,
6-
"H@10": 1.0,
7-
"MRR": 0.8416666666666667
8-
}
9-
python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
10-
{
11-
"H@1": 0.75,
12-
"H@3": 1.0,
13-
"H@10": 1.0,
14-
"MRR": 0.8680555555555555
15-
}(
16-
17-
18-
python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
19-
{
20-
"H@1": 0.041666666666666664,
21-
"H@3": 0.4583333333333333,
22-
"H@10": 0.625,
23-
"MRR": 0.2626660300405415
24-
}
25-
26-
"""
271
import dspy
282
import torch
293
import json
@@ -39,7 +13,7 @@
3913
import math
4014
import pandas as pd
4115
import random
42-
from typing import TypeAlias, Union
16+
from typing import TypeAlias, Union, Literal, Optional
4317
# --- Constants ---
4418
load_dotenv()
4519
pd.set_option('display.max_columns', None)
@@ -267,10 +241,10 @@ def __init__(self, entities: List[str],g:GraphType):
267241
self.g = g
268242
self.finder = dspy.ChainOfThought(EntityFinder)
269243
self.scorer = dspy.ChainOfThought(Scorer)
270-
271-
def _graph_based_content_builder(self,subject:str):
244+
def _graph_based_content_builder(self,subject:str,hops:int=5):
245+
assert hops>=0
272246
hop_to_triples = dict()
273-
graph_report = traverse_beam_by_hop(graph=self.g, start_entities=subject, hops=5,
247+
graph_report = traverse_beam_by_hop(graph=self.g, start_entities=subject, hops=hops,
274248
beam_width=len(self.entities), return_triples_only=True)
275249
# Accumulate triples over hops: assert hop_to_triples[i].issubset(hop_to_triples[i+1])
276250
for k,v in graph_report.items():
@@ -291,7 +265,7 @@ def forward(self, subject: str, predicate: str) -> dspy.Prediction:
291265
for idx, score in enumerate(scores.target):
292266
entity = retrieved_entities.target[idx]
293267
intermediate_predictions.setdefault(entity, []).append(score)
294-
# Avg scores of intermediate predictions. @TODO: CD: Sum of average ?!
268+
# Avg scores of intermediate predictions. @TODO: CD: Is there any better way ?!
295269
predictions=dict()
296270
for k,v in intermediate_predictions.items():
297271
predictions[k]=sum(v)/len(v)
@@ -306,18 +280,20 @@ class DemirEnsembleMPRO(AbstractBaseLinkPredictorClass):
306280
"""
307281
def __init__(self, knowledge_graph: KG, base_url: str, api_key: str, llm_model: str,
308282
temperature: float, seed: int, use_val: bool = True, ensemble_temperatures=None,
309-
save_dir: str = SAVE_DIR_BASE):
283+
save_dir: str = SAVE_DIR_BASE,auto: Optional[Literal["light", "medium", "heavy"]] = "light"):
310284
super().__init__(knowledge_graph, name="DemirEnsembleMIPRO")
311285

312286
# Configuration
313287
self.base_url = base_url
314288
self.api_key = api_key
315289
self.llm_model = llm_model
290+
self.auto=auto
316291
self.seed = seed
317292
self.use_val = use_val
318293
self.save_dir = save_dir
319294
self.ensemble_temperatures = [i * 0.1 for i in range(1)] if (ensemble_temperatures
320295
is None) else ensemble_temperatures
296+
assert isinstance(self.ensemble_temperatures, list) and isinstance(self.ensemble_temperatures[0], float) and 1.0 > self.ensemble_temperatures[0] >= 0.0
321297
self.mipro_optimizer_temperature = temperature
322298
# Seed random for reproducibility.
323299
random.seed(self.seed)
@@ -366,7 +342,7 @@ def _create_and_optimize_predictors(self) -> List[MultiLabelLinkPredictor]:
366342
predictor_to_optimize=base_predictor,
367343
train_examples=train_examples,
368344
test_examples=test_examples,
369-
temperature=temp, save_filename=save_filename)
345+
temperature=temp, save_filename=save_filename,auto=self.auto)
370346
# --- Optional: Evaluate after optimization ---
371347
# print(f"Evaluating optimized predictor for temp {temp:.1f}...")
372348
# evaluator = Evaluate(devset=self.test_examples[:50], # Limit evaluation size
@@ -397,11 +373,15 @@ def _create_and_optimize_predictors(self) -> List[MultiLabelLinkPredictor]:
397373
dspy.configure(lm=None) # Reset global LM config after use
398374
return predictors
399375
def _compile_predictor_for_temperature(self, predictor_to_optimize: MultiLabelLinkPredictor,
400-
train_examples:List[dspy.Example],test_examples,temperature: float,save_filename:str) -> MultiLabelLinkPredictor:
376+
train_examples:List[dspy.Example],test_examples,temperature: float,
377+
save_filename:str,
378+
auto: Optional[Literal["light", "medium", "heavy"]] = "light") -> MultiLabelLinkPredictor:
401379
"""Configures LM and runs MIPROv2 compilation."""
380+
402381
assert isinstance(predictor_to_optimize, MultiLabelLinkPredictor)
403382
assert isinstance(train_examples, list) and isinstance(train_examples[0],dspy.Example)
404383
assert isinstance(test_examples, list) and isinstance(test_examples[0],dspy.Example)
384+
assert auto in ("light", "medium", "heavy")
405385
# Configure DSPy LM specifically for this optimization run
406386
lm = dspy.LM(model=f"openai/{self.llm_model}", api_key=self.api_key,api_base=self.base_url,
407387
seed=self.seed, temperature=temperature,
@@ -418,7 +398,7 @@ def _compile_predictor_for_temperature(self, predictor_to_optimize: MultiLabelLi
418398
num_threads=6, display_table=False,
419399
display_progress=True, provide_traceback=True)(predictor_to_optimize)
420400
# Generate examples needed for optimization outside the loop
421-
optim = dspy.MIPROv2(metric=dspy_quality_score_closeness, auto="light")
401+
optim = dspy.MIPROv2(metric=dspy_quality_score_closeness, auto=auto)
422402
optimized_predictor = optim.compile(predictor_to_optimize.deepcopy(), trainset=train_examples[:],
423403
valset=test_examples[:],
424404
requires_permission_to_run=False)

0 commit comments

Comments
 (0)