5555 RayleighMeasurement , SpectralEquivalenceResult , PartitionData ,
5656 measure_egi_fixed_point , print_egi_report
5757)
58+ from collections import deque
59+
60+
61+ # ============================================================================
62+ # LOSS-SGC MONITOR: Real-time grokking detector (Sprint C+1 critical fix)
63+ # ============================================================================
64+
65+ class LossSGCMonitor :
66+ """
67+ Measures trans_rate of APPROX_EQUAL on the running loss sequence.
68+
69+ This is the DIRECT grokking detector from Sprint 2 — when loss values
70+ cluster near zero, APPROX_EQUAL transitivity spikes and trans_rate → 1.
71+
72+ Key insight: Prior successful experiments measured trans_rate on the
73+ system's own dynamics (loss), not on the weight matrix. The ε measurement
74+ on weights is a POST-HOC verifier, not a real-time detector.
75+
76+ ZERO-PARAMETER: tolerance derived from std of loss window.
77+ """
78+
79+ def __init__ (self , window : int = 20 , k_sustained : int = 2 ,
80+ trans_threshold : float = 0.83 ):
81+ """
82+ Args:
83+ window: Size of loss history window
84+ k_sustained: Consecutive windows above threshold needed for detection
85+ trans_threshold: trans_rate threshold for grokking (default 0.83)
86+ """
87+ self .window = window
88+ self .k_sustained = k_sustained
89+ self .trans_threshold = trans_threshold
90+
91+ self .loss_history = deque (maxlen = window )
92+ self .trans_rate_history = []
93+ self .consecutive_above = 0
94+ self .grokking_detected = False
95+ self .grokking_step = - 1
96+
97+ def _measure_approx_equal_transitivity (self , values : list , tol : float ) -> float :
98+ """
99+ Measure transitivity of APPROX_EQUAL relation on values.
100+
101+ trans_rate = P(|a-c| < tol | |a-b| < tol AND |b-c| < tol)
102+
103+ When values cluster (all near zero), transitivity → 1.
104+ When values spread, transitivity drops.
105+ """
106+ n = len (values )
107+ if n < 3 :
108+ return 0.0
109+
110+ # Sample triplets
111+ n_chains = 0
112+ n_transitive = 0
113+
114+ # For efficiency, sample up to 100 triplets
115+ max_triplets = min (100 , n * (n - 1 ) * (n - 2 ) // 6 )
116+
117+ for i in range (n ):
118+ for j in range (i + 1 , n ):
119+ if abs (values [i ] - values [j ]) < tol : # a ~ b
120+ for k in range (j + 1 , n ):
121+ if abs (values [j ] - values [k ]) < tol : # b ~ c
122+ n_chains += 1
123+ if abs (values [i ] - values [k ]) < tol : # a ~ c?
124+ n_transitive += 1
125+ if n_chains >= max_triplets :
126+ break
127+ if n_chains >= max_triplets :
128+ break
129+ if n_chains >= max_triplets :
130+ break
131+
132+ if n_chains == 0 :
133+ return 0.0
134+
135+ return n_transitive / n_chains
136+
137+ def update (self , loss : float , step : int ) -> float :
138+ """
139+ Update monitor with new loss value.
140+
141+ Returns:
142+ Current trans_rate
143+ """
144+ self .loss_history .append (loss )
145+
146+ if len (self .loss_history ) < 3 :
147+ return 0.0
148+
149+ # ZERO-PARAMETER: derive tolerance from std of window
150+ # In pre-grokking phase, losses have high variance
151+ # Post-grokking, losses cluster near zero, std drops
152+ values = list (self .loss_history )
153+ std = np .std (values )
154+ mean = np .mean (values )
155+
156+ # Tolerance = 0.5 * std (adaptive to current dynamics)
157+ # Also cap at mean to handle near-zero case
158+ tol = max (0.5 * std , 0.01 * mean + 1e-6 )
159+
160+ trans_rate = self ._measure_approx_equal_transitivity (values , tol )
161+ self .trans_rate_history .append ((step , trans_rate ))
162+
163+ # Detection logic
164+ if trans_rate >= self .trans_threshold :
165+ self .consecutive_above += 1
166+ if self .consecutive_above >= self .k_sustained and not self .grokking_detected :
167+ self .grokking_detected = True
168+ self .grokking_step = step
169+ print (f"\n *** LOSS-SGC GROKKING DETECTED at step { step } ***" )
170+ print (f" trans_rate = { trans_rate :.4f} (threshold { self .trans_threshold } )" )
171+ print (f" loss_mean = { mean :.6f} , loss_std = { std :.6f} " )
172+ else :
173+ self .consecutive_above = 0
174+
175+ return trans_rate
176+
177+ def get_summary (self ) -> dict :
178+ return {
179+ 'grokking_detected' : self .grokking_detected ,
180+ 'grokking_step' : self .grokking_step ,
181+ 'final_trans_rate' : self .trans_rate_history [- 1 ][1 ] if self .trans_rate_history else 0.0 ,
182+ 'n_measurements' : len (self .trans_rate_history ),
183+ }
58184
59185
60186# ============================================================================
@@ -336,7 +462,9 @@ def __init__(self,
336462 device : str = 'cuda' ,
337463 max_steps_per_task : int = 50000 ,
338464 log_interval : int = 100 ,
339- preservation_check_interval : int = 500 ):
465+ preservation_check_interval : int = 500 ,
466+ weight_decay : float = 1.0 ,
467+ learning_rate : float = 1e-3 ):
340468 """
341469 Initialize Sprint C conductor.
342470
@@ -347,13 +475,21 @@ def __init__(self,
347475 max_steps_per_task: Maximum training steps per task
348476 log_interval: Steps between logging
349477 preservation_check_interval: Steps between preservation checks
478+ weight_decay: Base weight decay (Sprint C+1: use 1.0 per Nanda et al.)
479+ learning_rate: Base learning rate
350480 """
351481 self .tasks = tasks
352482 self .hidden_dim = hidden_dim
353483 self .device = device
354484 self .max_steps_per_task = max_steps_per_task
355485 self .log_interval = log_interval
356486 self .preservation_check_interval = preservation_check_interval
487+ self .base_weight_decay = weight_decay
488+ self .base_learning_rate = learning_rate
489+
490+ # Sprint C+1: Loss-SGC monitor for real-time grokking detection
491+ # This gates the thermal pump quench - don't crystallize before grokking!
492+ self .loss_sgc_monitors = {} # Per-task monitors
357493
358494 # Build task configs
359495 self .task_configs = {}
@@ -593,6 +729,11 @@ def train_task(self, task) -> TaskResult:
593729 Train a single task using the full technology stack.
594730
595731 Implements Phase 1 (accelerated grokking) and Phase 2 (fixed point verification).
732+
733+ Sprint C+1 changes:
734+ - LossSGCMonitor for real-time grokking detection
735+ - Fermi quench gated on loss-SGC detection
736+ - Higher weight decay (1.0 default per Nanda et al.)
596737 """
597738 task_name = task .name
598739 train_loader , test_loader = self .dataloaders [task_name ]
@@ -601,19 +742,28 @@ def train_task(self, task) -> TaskResult:
601742 print (f"PHASE 1: Training task '{ task_name } '" )
602743 print (f"{ '=' * 60 } " )
603744
745+ # Sprint C+1: Initialize loss-SGC monitor for this task
746+ self .loss_sgc_monitors [task_name ] = LossSGCMonitor (
747+ window = 20 , k_sustained = 2 , trans_threshold = 0.83
748+ )
749+ loss_monitor = self .loss_sgc_monitors [task_name ]
750+
604751 # Set up controller
605752 self .controller .start_task (task_name )
606753 self .model .set_task (task_name )
607754
608- # Optimizer with temperature-dependent weight decay
609- base_wd = 0.1
755+ # Sprint C+1: Use higher weight decay (1.0 per Nanda et al. grokking paper)
756+ base_wd = self .base_weight_decay
757+ base_lr = self .base_learning_rate
610758 optimizer = optim .AdamW (
611759 self .model .parameters (),
612- lr = 1e-3 ,
760+ lr = base_lr ,
613761 weight_decay = base_wd
614762 )
615763 criterion = nn .CrossEntropyLoss ()
616764
765+ print (f" weight_decay = { base_wd } , lr = { base_lr } " )
766+
617767 step = 0
618768 grokked_step = - 1
619769 final_metrics = None
@@ -639,23 +789,31 @@ def train_task(self, task) -> TaskResult:
639789 # FIX 2 & 3: Track energy for Cv computation and self-derived Re_crit
640790 self .controller .thermal .update_energy (loss .item ())
641791
792+ # Sprint C+1: Update loss-SGC monitor for real-time grokking detection
793+ trans_rate = loss_monitor .update (loss .item (), step )
794+
642795 # Update weight decay based on temperature
643796 wd = self .controller .get_weight_decay (base_wd )
644797 for param_group in optimizer .param_groups :
645798 param_group ['weight_decay' ] = wd
646799
647- # FIX 3: Modulate learning rate by Fermi quench factor
648- # This replaces hard grokking quench with smooth phase transition
649- # sigma_quench ∈ [0, 1]: higher = more crystallization
650- sigma_quench = self .controller .thermal .get_fermi_quench_factor ()
800+ # Sprint C+1 FIX: Gate Fermi quench on loss-SGC detection
801+ # Don't crystallize BEFORE grokking is detected!
802+ # Prior Sprint C quenched immediately (sigma=0 from step 1)
803+ if not loss_monitor .grokking_detected :
804+ # Before grokking: full exploration, no crystallization
805+ sigma_quench = 0.0
806+ else :
807+ # After loss-SGC detects grokking: engage Fermi quench
808+ sigma_quench = self .controller .thermal .get_fermi_quench_factor ()
651809
652810 # Reduce learning rate smoothly as system crystallizes
653811 # lr_effective = lr_base * (1 - 0.9 * sigma_quench)
654812 # At sigma=0 (exploring): full learning rate
655813 # At sigma=1 (crystallized): 10% of learning rate
656814 lr_scale = 1.0 - 0.9 * sigma_quench
657815 for param_group in optimizer .param_groups :
658- param_group ['lr' ] = 1e-3 * lr_scale
816+ param_group ['lr' ] = base_lr * lr_scale
659817
660818 # Optimizer step
661819 optimizer .step ()
@@ -664,13 +822,13 @@ def train_task(self, task) -> TaskResult:
664822 # Logging
665823 if step % self .log_interval == 0 :
666824 train_acc , test_acc = self .compute_accuracy (task_name )
667- sigma_q = self .controller .thermal .get_fermi_quench_factor ()
668825 thermal = self .controller .thermal
669826
670- # Basic metrics
827+ # Basic metrics + Sprint C+1 trans_rate
671828 log_line = (f"Step { step :5d} | Train: { train_acc :.3f} | Test: { test_acc :.3f} | "
672829 f"eps: { metrics .epsilon :.4f} | R: { metrics .ridge_ratio :.2f} | "
673- f"T: { thermal .temperature :.2f} | sigma_q: { sigma_q :.3f} " )
830+ f"T: { thermal .temperature :.2f} | sigma_q: { sigma_quench :.3f} | "
831+ f"trans: { trans_rate :.3f} " )
674832
675833 # Verbose: add Cv and Re_SGC tracking
676834 if getattr (self , 'verbose' , False ):
@@ -794,12 +952,13 @@ def save_results(self, output_path: str):
794952def main ():
795953 import argparse
796954
797- parser = argparse .ArgumentParser (description = 'Sprint C: EGI Tower Validation' )
955+ parser = argparse .ArgumentParser (description = 'Sprint C+1 : EGI Tower Validation' )
798956 parser .add_argument ('--device' , type = str ,
799957 default = 'cuda' if torch .cuda .is_available () else 'cpu' )
800- parser .add_argument ('--max_steps' , type = int , default = 20000 ,
958+ parser .add_argument ('--max_steps' , type = int , default = 50000 ,
801959 help = 'Maximum steps per task' )
802- parser .add_argument ('--hidden_dim' , type = int , default = 256 )
960+ parser .add_argument ('--hidden_dim' , type = int , default = 128 ,
961+ help = 'Hidden dimension (Sprint C+1: smaller=forces structure)' )
803962 parser .add_argument ('--output' , type = str , default = 'sprint_c_tower_result.json' )
804963 parser .add_argument ('--quick' , action = 'store_true' ,
805964 help = 'Quick test with reduced parameters' )
@@ -809,6 +968,13 @@ def main():
809968 help = 'Run 500-step smoke test to verify Cv peak detection' )
810969 parser .add_argument ('--early_stop_on_grok' , action = 'store_true' ,
811970 help = 'Stop task immediately after grokking (for smoke test)' )
971+ # Sprint C+1 additions
972+ parser .add_argument ('--weight_decay' , type = float , default = 1.0 ,
973+ help = 'Weight decay (Sprint C+1: 1.0 per Nanda et al.)' )
974+ parser .add_argument ('--learning_rate' , type = float , default = 1e-3 ,
975+ help = 'Base learning rate' )
976+ parser .add_argument ('--prime' , type = int , default = 11 ,
977+ help = 'Prime for modular addition task (Sprint C+1: 11 groks in ~5k steps)' )
812978
813979 args = parser .parse_args ()
814980
@@ -826,24 +992,34 @@ def main():
826992 args .max_steps = min (args .max_steps , 2000 )
827993 prime = 17 # Smaller prime for quick/smoke tests
828994 else :
829- prime = 97
995+ prime = args . prime # Sprint C+1: Use CLI-specified prime (default 11)
830996
831997 # Create tasks - ORDERED EASIEST TO HARDEST for strongest preservation test
832998 # Parity (Z_2) groks fastest, Permutation (S_n) medium, Modular Add (Z_p) slowest
999+ # Sprint C+1: Use S_3 instead of S_5 for faster grokking
8331000 tasks = [
8341001 ParityTask (n_bits = 8 ), # Task 1: Simplest invariant, fastest grokking
835- PermutationTask (n_elements = 5 ), # Task 2: Medium complexity
836- ModularAdditionTask (prime = prime ), # Task 3: Hardest (Fourier invariant)
1002+ PermutationTask (n_elements = 3 ), # Task 2: Reduced complexity for faster grokking
1003+ ModularAdditionTask (prime = prime ), # Task 3: p=11 groks in ~5k steps
8371004 ]
8381005
839- # Run Sprint C
1006+ print (f"\n [Sprint C+1] Configuration:" )
1007+ print (f" hidden_dim = { args .hidden_dim } " )
1008+ print (f" weight_decay = { args .weight_decay } " )
1009+ print (f" learning_rate = { args .learning_rate } " )
1010+ print (f" prime = { prime } " )
1011+ print (f" max_steps = { args .max_steps } " )
1012+
1013+ # Run Sprint C+1
8401014 conductor = SprintCConductor (
8411015 tasks = tasks ,
8421016 hidden_dim = args .hidden_dim ,
8431017 device = args .device ,
8441018 max_steps_per_task = args .max_steps ,
8451019 log_interval = 50 if args .verbose else 100 ,
846- preservation_check_interval = 500
1020+ preservation_check_interval = 500 ,
1021+ weight_decay = args .weight_decay ,
1022+ learning_rate = args .learning_rate
8471023 )
8481024
8491025 # Pass verbosity setting
0 commit comments