Skip to content

Commit 0752d93

Browse files
committed
feat: Sprint C+1 architecture - LossSGCMonitor and gating fixes
Key changes based on colleague's analysis: 1. LossSGCMonitor: Real-time grokking detector measuring trans_rate of APPROX_EQUAL on loss sequence (like Sprint 2). This fires when loss values cluster, signaling the network found the invariant structure. 2. Fermi quench gating: Don't crystallize BEFORE grokking detected! Prior Sprint C quenched from step 1 (sigma=0). Now sigma_quench=0 until LossSGCMonitor detects grokking, then Fermi engages. 3. Higher weight decay: 1.0 default (per Nanda et al. grokking paper) instead of 0.1. This is the dominant regularizer for spectral collapse. 4. Smaller network: hidden_dim=128 default (was 256). Forces model to find invariant structure instead of memorizing. 5. Smaller prime: p=11 default (was 97). Groks in ~5k steps, well within budget. Also avoids CUDA index errors. 6. Smaller permutation group: S_3 instead of S_5 for faster grokking. The architecture now matches what worked in Sprint 2/5: - Wavelet noise pump (already present) accelerates grokking - Loss-SGC monitor DETECTS grokking in real-time - Rayleigh measurement CONFIRMS the quotient structure CLI: --weight_decay, --learning_rate, --prime added.
1 parent d73a07d commit 0752d93

File tree

1 file changed

+196
-20
lines changed

1 file changed

+196
-20
lines changed

perihelion/experiments/sprint_c_tower_validation.py

Lines changed: 196 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,132 @@
5555
RayleighMeasurement, SpectralEquivalenceResult, PartitionData,
5656
measure_egi_fixed_point, print_egi_report
5757
)
58+
from collections import deque
59+
60+
61+
# ============================================================================
62+
# LOSS-SGC MONITOR: Real-time grokking detector (Sprint C+1 critical fix)
63+
# ============================================================================
64+
65+
class LossSGCMonitor:
66+
"""
67+
Measures trans_rate of APPROX_EQUAL on the running loss sequence.
68+
69+
This is the DIRECT grokking detector from Sprint 2 — when loss values
70+
cluster near zero, APPROX_EQUAL transitivity spikes and trans_rate → 1.
71+
72+
Key insight: Prior successful experiments measured trans_rate on the
73+
system's own dynamics (loss), not on the weight matrix. The ε measurement
74+
on weights is a POST-HOC verifier, not a real-time detector.
75+
76+
ZERO-PARAMETER: tolerance derived from std of loss window.
77+
"""
78+
79+
def __init__(self, window: int = 20, k_sustained: int = 2,
80+
trans_threshold: float = 0.83):
81+
"""
82+
Args:
83+
window: Size of loss history window
84+
k_sustained: Consecutive windows above threshold needed for detection
85+
trans_threshold: trans_rate threshold for grokking (default 0.83)
86+
"""
87+
self.window = window
88+
self.k_sustained = k_sustained
89+
self.trans_threshold = trans_threshold
90+
91+
self.loss_history = deque(maxlen=window)
92+
self.trans_rate_history = []
93+
self.consecutive_above = 0
94+
self.grokking_detected = False
95+
self.grokking_step = -1
96+
97+
def _measure_approx_equal_transitivity(self, values: list, tol: float) -> float:
98+
"""
99+
Measure transitivity of APPROX_EQUAL relation on values.
100+
101+
trans_rate = P(|a-c| < tol | |a-b| < tol AND |b-c| < tol)
102+
103+
When values cluster (all near zero), transitivity → 1.
104+
When values spread, transitivity drops.
105+
"""
106+
n = len(values)
107+
if n < 3:
108+
return 0.0
109+
110+
# Sample triplets
111+
n_chains = 0
112+
n_transitive = 0
113+
114+
# For efficiency, sample up to 100 triplets
115+
max_triplets = min(100, n * (n-1) * (n-2) // 6)
116+
117+
for i in range(n):
118+
for j in range(i+1, n):
119+
if abs(values[i] - values[j]) < tol: # a ~ b
120+
for k in range(j+1, n):
121+
if abs(values[j] - values[k]) < tol: # b ~ c
122+
n_chains += 1
123+
if abs(values[i] - values[k]) < tol: # a ~ c?
124+
n_transitive += 1
125+
if n_chains >= max_triplets:
126+
break
127+
if n_chains >= max_triplets:
128+
break
129+
if n_chains >= max_triplets:
130+
break
131+
132+
if n_chains == 0:
133+
return 0.0
134+
135+
return n_transitive / n_chains
136+
137+
def update(self, loss: float, step: int) -> float:
138+
"""
139+
Update monitor with new loss value.
140+
141+
Returns:
142+
Current trans_rate
143+
"""
144+
self.loss_history.append(loss)
145+
146+
if len(self.loss_history) < 3:
147+
return 0.0
148+
149+
# ZERO-PARAMETER: derive tolerance from std of window
150+
# In pre-grokking phase, losses have high variance
151+
# Post-grokking, losses cluster near zero, std drops
152+
values = list(self.loss_history)
153+
std = np.std(values)
154+
mean = np.mean(values)
155+
156+
# Tolerance = 0.5 * std (adaptive to current dynamics)
157+
# Also cap at mean to handle near-zero case
158+
tol = max(0.5 * std, 0.01 * mean + 1e-6)
159+
160+
trans_rate = self._measure_approx_equal_transitivity(values, tol)
161+
self.trans_rate_history.append((step, trans_rate))
162+
163+
# Detection logic
164+
if trans_rate >= self.trans_threshold:
165+
self.consecutive_above += 1
166+
if self.consecutive_above >= self.k_sustained and not self.grokking_detected:
167+
self.grokking_detected = True
168+
self.grokking_step = step
169+
print(f"\n*** LOSS-SGC GROKKING DETECTED at step {step} ***")
170+
print(f" trans_rate = {trans_rate:.4f} (threshold {self.trans_threshold})")
171+
print(f" loss_mean = {mean:.6f}, loss_std = {std:.6f}")
172+
else:
173+
self.consecutive_above = 0
174+
175+
return trans_rate
176+
177+
def get_summary(self) -> dict:
178+
return {
179+
'grokking_detected': self.grokking_detected,
180+
'grokking_step': self.grokking_step,
181+
'final_trans_rate': self.trans_rate_history[-1][1] if self.trans_rate_history else 0.0,
182+
'n_measurements': len(self.trans_rate_history),
183+
}
58184

59185

60186
# ============================================================================
@@ -336,7 +462,9 @@ def __init__(self,
336462
device: str = 'cuda',
337463
max_steps_per_task: int = 50000,
338464
log_interval: int = 100,
339-
preservation_check_interval: int = 500):
465+
preservation_check_interval: int = 500,
466+
weight_decay: float = 1.0,
467+
learning_rate: float = 1e-3):
340468
"""
341469
Initialize Sprint C conductor.
342470
@@ -347,13 +475,21 @@ def __init__(self,
347475
max_steps_per_task: Maximum training steps per task
348476
log_interval: Steps between logging
349477
preservation_check_interval: Steps between preservation checks
478+
weight_decay: Base weight decay (Sprint C+1: use 1.0 per Nanda et al.)
479+
learning_rate: Base learning rate
350480
"""
351481
self.tasks = tasks
352482
self.hidden_dim = hidden_dim
353483
self.device = device
354484
self.max_steps_per_task = max_steps_per_task
355485
self.log_interval = log_interval
356486
self.preservation_check_interval = preservation_check_interval
487+
self.base_weight_decay = weight_decay
488+
self.base_learning_rate = learning_rate
489+
490+
# Sprint C+1: Loss-SGC monitor for real-time grokking detection
491+
# This gates the thermal pump quench - don't crystallize before grokking!
492+
self.loss_sgc_monitors = {} # Per-task monitors
357493

358494
# Build task configs
359495
self.task_configs = {}
@@ -593,6 +729,11 @@ def train_task(self, task) -> TaskResult:
593729
Train a single task using the full technology stack.
594730
595731
Implements Phase 1 (accelerated grokking) and Phase 2 (fixed point verification).
732+
733+
Sprint C+1 changes:
734+
- LossSGCMonitor for real-time grokking detection
735+
- Fermi quench gated on loss-SGC detection
736+
- Higher weight decay (1.0 default per Nanda et al.)
596737
"""
597738
task_name = task.name
598739
train_loader, test_loader = self.dataloaders[task_name]
@@ -601,19 +742,28 @@ def train_task(self, task) -> TaskResult:
601742
print(f"PHASE 1: Training task '{task_name}'")
602743
print(f"{'='*60}")
603744

745+
# Sprint C+1: Initialize loss-SGC monitor for this task
746+
self.loss_sgc_monitors[task_name] = LossSGCMonitor(
747+
window=20, k_sustained=2, trans_threshold=0.83
748+
)
749+
loss_monitor = self.loss_sgc_monitors[task_name]
750+
604751
# Set up controller
605752
self.controller.start_task(task_name)
606753
self.model.set_task(task_name)
607754

608-
# Optimizer with temperature-dependent weight decay
609-
base_wd = 0.1
755+
# Sprint C+1: Use higher weight decay (1.0 per Nanda et al. grokking paper)
756+
base_wd = self.base_weight_decay
757+
base_lr = self.base_learning_rate
610758
optimizer = optim.AdamW(
611759
self.model.parameters(),
612-
lr=1e-3,
760+
lr=base_lr,
613761
weight_decay=base_wd
614762
)
615763
criterion = nn.CrossEntropyLoss()
616764

765+
print(f" weight_decay = {base_wd}, lr = {base_lr}")
766+
617767
step = 0
618768
grokked_step = -1
619769
final_metrics = None
@@ -639,23 +789,31 @@ def train_task(self, task) -> TaskResult:
639789
# FIX 2 & 3: Track energy for Cv computation and self-derived Re_crit
640790
self.controller.thermal.update_energy(loss.item())
641791

792+
# Sprint C+1: Update loss-SGC monitor for real-time grokking detection
793+
trans_rate = loss_monitor.update(loss.item(), step)
794+
642795
# Update weight decay based on temperature
643796
wd = self.controller.get_weight_decay(base_wd)
644797
for param_group in optimizer.param_groups:
645798
param_group['weight_decay'] = wd
646799

647-
# FIX 3: Modulate learning rate by Fermi quench factor
648-
# This replaces hard grokking quench with smooth phase transition
649-
# sigma_quench ∈ [0, 1]: higher = more crystallization
650-
sigma_quench = self.controller.thermal.get_fermi_quench_factor()
800+
# Sprint C+1 FIX: Gate Fermi quench on loss-SGC detection
801+
# Don't crystallize BEFORE grokking is detected!
802+
# Prior Sprint C quenched immediately (sigma=0 from step 1)
803+
if not loss_monitor.grokking_detected:
804+
# Before grokking: full exploration, no crystallization
805+
sigma_quench = 0.0
806+
else:
807+
# After loss-SGC detects grokking: engage Fermi quench
808+
sigma_quench = self.controller.thermal.get_fermi_quench_factor()
651809

652810
# Reduce learning rate smoothly as system crystallizes
653811
# lr_effective = lr_base * (1 - 0.9 * sigma_quench)
654812
# At sigma=0 (exploring): full learning rate
655813
# At sigma=1 (crystallized): 10% of learning rate
656814
lr_scale = 1.0 - 0.9 * sigma_quench
657815
for param_group in optimizer.param_groups:
658-
param_group['lr'] = 1e-3 * lr_scale
816+
param_group['lr'] = base_lr * lr_scale
659817

660818
# Optimizer step
661819
optimizer.step()
@@ -664,13 +822,13 @@ def train_task(self, task) -> TaskResult:
664822
# Logging
665823
if step % self.log_interval == 0:
666824
train_acc, test_acc = self.compute_accuracy(task_name)
667-
sigma_q = self.controller.thermal.get_fermi_quench_factor()
668825
thermal = self.controller.thermal
669826

670-
# Basic metrics
827+
# Basic metrics + Sprint C+1 trans_rate
671828
log_line = (f"Step {step:5d} | Train: {train_acc:.3f} | Test: {test_acc:.3f} | "
672829
f"eps: {metrics.epsilon:.4f} | R: {metrics.ridge_ratio:.2f} | "
673-
f"T: {thermal.temperature:.2f} | sigma_q: {sigma_q:.3f}")
830+
f"T: {thermal.temperature:.2f} | sigma_q: {sigma_quench:.3f} | "
831+
f"trans: {trans_rate:.3f}")
674832

675833
# Verbose: add Cv and Re_SGC tracking
676834
if getattr(self, 'verbose', False):
@@ -794,12 +952,13 @@ def save_results(self, output_path: str):
794952
def main():
795953
import argparse
796954

797-
parser = argparse.ArgumentParser(description='Sprint C: EGI Tower Validation')
955+
parser = argparse.ArgumentParser(description='Sprint C+1: EGI Tower Validation')
798956
parser.add_argument('--device', type=str,
799957
default='cuda' if torch.cuda.is_available() else 'cpu')
800-
parser.add_argument('--max_steps', type=int, default=20000,
958+
parser.add_argument('--max_steps', type=int, default=50000,
801959
help='Maximum steps per task')
802-
parser.add_argument('--hidden_dim', type=int, default=256)
960+
parser.add_argument('--hidden_dim', type=int, default=128,
961+
help='Hidden dimension (Sprint C+1: smaller=forces structure)')
803962
parser.add_argument('--output', type=str, default='sprint_c_tower_result.json')
804963
parser.add_argument('--quick', action='store_true',
805964
help='Quick test with reduced parameters')
@@ -809,6 +968,13 @@ def main():
809968
help='Run 500-step smoke test to verify Cv peak detection')
810969
parser.add_argument('--early_stop_on_grok', action='store_true',
811970
help='Stop task immediately after grokking (for smoke test)')
971+
# Sprint C+1 additions
972+
parser.add_argument('--weight_decay', type=float, default=1.0,
973+
help='Weight decay (Sprint C+1: 1.0 per Nanda et al.)')
974+
parser.add_argument('--learning_rate', type=float, default=1e-3,
975+
help='Base learning rate')
976+
parser.add_argument('--prime', type=int, default=11,
977+
help='Prime for modular addition task (Sprint C+1: 11 groks in ~5k steps)')
812978

813979
args = parser.parse_args()
814980

@@ -826,24 +992,34 @@ def main():
826992
args.max_steps = min(args.max_steps, 2000)
827993
prime = 17 # Smaller prime for quick/smoke tests
828994
else:
829-
prime = 97
995+
prime = args.prime # Sprint C+1: Use CLI-specified prime (default 11)
830996

831997
# Create tasks - ORDERED EASIEST TO HARDEST for strongest preservation test
832998
# Parity (Z_2) groks fastest, Permutation (S_n) medium, Modular Add (Z_p) slowest
999+
# Sprint C+1: Use S_3 instead of S_5 for faster grokking
8331000
tasks = [
8341001
ParityTask(n_bits=8), # Task 1: Simplest invariant, fastest grokking
835-
PermutationTask(n_elements=5), # Task 2: Medium complexity
836-
ModularAdditionTask(prime=prime), # Task 3: Hardest (Fourier invariant)
1002+
PermutationTask(n_elements=3), # Task 2: Reduced complexity for faster grokking
1003+
ModularAdditionTask(prime=prime), # Task 3: p=11 groks in ~5k steps
8371004
]
8381005

839-
# Run Sprint C
1006+
print(f"\n[Sprint C+1] Configuration:")
1007+
print(f" hidden_dim = {args.hidden_dim}")
1008+
print(f" weight_decay = {args.weight_decay}")
1009+
print(f" learning_rate = {args.learning_rate}")
1010+
print(f" prime = {prime}")
1011+
print(f" max_steps = {args.max_steps}")
1012+
1013+
# Run Sprint C+1
8401014
conductor = SprintCConductor(
8411015
tasks=tasks,
8421016
hidden_dim=args.hidden_dim,
8431017
device=args.device,
8441018
max_steps_per_task=args.max_steps,
8451019
log_interval=50 if args.verbose else 100,
846-
preservation_check_interval=500
1020+
preservation_check_interval=500,
1021+
weight_decay=args.weight_decay,
1022+
learning_rate=args.learning_rate
8471023
)
8481024

8491025
# Pass verbosity setting

0 commit comments

Comments
 (0)