1- """
2- python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S1 --out "countries_s1_results.json" && cat countries_s1_results.json
3- {
4- "H@1": 0.75,
5- "H@3": 0.875,
6- "H@10": 1.0,
7- "MRR": 0.8416666666666667
8- }
9- python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S2 --out "countries_s2_results.json" && cat countries_s2_results.json
10- {
11- "H@1": 0.75,
12- "H@3": 1.0,
13- "H@10": 1.0,
14- "MRR": 0.8680555555555555
15- }(
16-
17-
18- python -m retrieval_aug_predictors.models.demir_ensemble_mipro --dataset_dir KGs/Countries-S3 --out "countries_s3_results.json" && cat countries_s3_results.json
19- {
20- "H@1": 0.041666666666666664,
21- "H@3": 0.4583333333333333,
22- "H@10": 0.625,
23- "MRR": 0.2626660300405415
24- }
25-
26- """
271import dspy
282import torch
293import json
3913import math
4014import pandas as pd
4115import random
42- from typing import TypeAlias , Union
16+ from typing import TypeAlias , Union , Literal , Optional
4317# --- Constants ---
4418load_dotenv ()
4519pd .set_option ('display.max_columns' , None )
@@ -267,10 +241,10 @@ def __init__(self, entities: List[str],g:GraphType):
267241 self .g = g
268242 self .finder = dspy .ChainOfThought (EntityFinder )
269243 self .scorer = dspy .ChainOfThought (Scorer )
270-
271- def _graph_based_content_builder ( self , subject : str ):
244+ def _graph_based_content_builder ( self , subject : str , hops : int = 5 ):
245+ assert hops >= 0
272246 hop_to_triples = dict ()
273- graph_report = traverse_beam_by_hop (graph = self .g , start_entities = subject , hops = 5 ,
247+ graph_report = traverse_beam_by_hop (graph = self .g , start_entities = subject , hops = hops ,
274248 beam_width = len (self .entities ), return_triples_only = True )
275249 # Accumulate triples over hops: assert hop_to_triples[i].issubset(hop_to_triples[i+1])
276250 for k ,v in graph_report .items ():
@@ -291,7 +265,7 @@ def forward(self, subject: str, predicate: str) -> dspy.Prediction:
291265 for idx , score in enumerate (scores .target ):
292266 entity = retrieved_entities .target [idx ]
293267 intermediate_predictions .setdefault (entity , []).append (score )
294- # Avg scores of intermediate predictions. @TODO: CD: Sum of average ?!
268+ # Avg scores of intermediate predictions. @TODO: CD: Is there any better way ?!
295269 predictions = dict ()
296270 for k ,v in intermediate_predictions .items ():
297271 predictions [k ]= sum (v )/ len (v )
@@ -306,18 +280,20 @@ class DemirEnsembleMPRO(AbstractBaseLinkPredictorClass):
306280 """
307281 def __init__ (self , knowledge_graph : KG , base_url : str , api_key : str , llm_model : str ,
308282 temperature : float , seed : int , use_val : bool = True , ensemble_temperatures = None ,
309- save_dir : str = SAVE_DIR_BASE ):
283+ save_dir : str = SAVE_DIR_BASE , auto : Optional [ Literal [ "light" , "medium" , "heavy" ]] = "light" ):
310284 super ().__init__ (knowledge_graph , name = "DemirEnsembleMIPRO" )
311285
312286 # Configuration
313287 self .base_url = base_url
314288 self .api_key = api_key
315289 self .llm_model = llm_model
290+ self .auto = auto
316291 self .seed = seed
317292 self .use_val = use_val
318293 self .save_dir = save_dir
319294 self .ensemble_temperatures = [i * 0.1 for i in range (1 )] if (ensemble_temperatures
320295 is None ) else ensemble_temperatures
296+ assert isinstance (self .ensemble_temperatures , list ) and isinstance (self .ensemble_temperatures [0 ], float ) and 1.0 > self .ensemble_temperatures [0 ] >= 0.0
321297 self .mipro_optimizer_temperature = temperature
322298 # Seed random for reproducibility.
323299 random .seed (self .seed )
@@ -366,7 +342,7 @@ def _create_and_optimize_predictors(self) -> List[MultiLabelLinkPredictor]:
366342 predictor_to_optimize = base_predictor ,
367343 train_examples = train_examples ,
368344 test_examples = test_examples ,
369- temperature = temp , save_filename = save_filename )
345+ temperature = temp , save_filename = save_filename , auto = self . auto )
370346 # --- Optional: Evaluate after optimization ---
371347 # print(f"Evaluating optimized predictor for temp {temp:.1f}...")
372348 # evaluator = Evaluate(devset=self.test_examples[:50], # Limit evaluation size
@@ -397,11 +373,15 @@ def _create_and_optimize_predictors(self) -> List[MultiLabelLinkPredictor]:
397373 dspy .configure (lm = None ) # Reset global LM config after use
398374 return predictors
399375 def _compile_predictor_for_temperature (self , predictor_to_optimize : MultiLabelLinkPredictor ,
400- train_examples :List [dspy .Example ],test_examples ,temperature : float ,save_filename :str ) -> MultiLabelLinkPredictor :
376+ train_examples :List [dspy .Example ],test_examples ,temperature : float ,
377+ save_filename :str ,
378+ auto : Optional [Literal ["light" , "medium" , "heavy" ]] = "light" ) -> MultiLabelLinkPredictor :
401379 """Configures LM and runs MIPROv2 compilation."""
380+
402381 assert isinstance (predictor_to_optimize , MultiLabelLinkPredictor )
403382 assert isinstance (train_examples , list ) and isinstance (train_examples [0 ],dspy .Example )
404383 assert isinstance (test_examples , list ) and isinstance (test_examples [0 ],dspy .Example )
384+ assert auto in ("light" , "medium" , "heavy" )
405385 # Configure DSPy LM specifically for this optimization run
406386 lm = dspy .LM (model = f"openai/{ self .llm_model } " , api_key = self .api_key ,api_base = self .base_url ,
407387 seed = self .seed , temperature = temperature ,
@@ -418,7 +398,7 @@ def _compile_predictor_for_temperature(self, predictor_to_optimize: MultiLabelLi
418398 num_threads = 6 , display_table = False ,
419399 display_progress = True , provide_traceback = True )(predictor_to_optimize )
420400 # Generate examples needed for optimization outside the loop
421- optim = dspy .MIPROv2 (metric = dspy_quality_score_closeness , auto = "light" )
401+ optim = dspy .MIPROv2 (metric = dspy_quality_score_closeness , auto = auto )
422402 optimized_predictor = optim .compile (predictor_to_optimize .deepcopy (), trainset = train_examples [:],
423403 valset = test_examples [:],
424404 requires_permission_to_run = False )
0 commit comments