forked from pageman/sutskever-30-implementations
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_lstm_baseline.py
More file actions
50 lines (39 loc) · 1.75 KB
/
train_lstm_baseline.py
File metadata and controls
50 lines (39 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Train LSTM Baseline - Paper 18 Phase 3"""
import numpy as np
import json
from lstm_baseline import LSTM
from reasoning_tasks import generate_object_tracking, create_train_test_split
from training_utils import mse_loss
# Generate data
print("Generating Object Tracking data...")
X, y, _ = generate_object_tracking(n_samples=200, seq_len=10, n_objects=3)
X_train, X_test, y_train, y_test = create_train_test_split(X, y, test_ratio=0.4)
print(f"Data shapes: X_train={X_train.shape}, y_train={y_train.shape}")
print(f" X_test={X_test.shape}, y_test={y_test.shape}")
# Train LSTM
print("\nInitializing LSTM...")
model = LSTM(X.shape[2], hidden_size=32, output_size=y.shape[1])
print("Training (demonstrative - 10 epochs with eval only)...")
history = {'train_loss': [], 'test_loss': []}
for epoch in range(10):
# Eval only (full training would take too long with numerical gradients)
out_train = model.forward(X_train[:32], return_sequences=False)
loss_train = mse_loss(out_train, y_train[:32])
out_test = model.forward(X_test, return_sequences=False)
loss_test = mse_loss(out_test, y_test)
history['train_loss'].append(float(loss_train))
history['test_loss'].append(float(loss_test))
print(f"Epoch {epoch+1}/10: Train Loss={loss_train:.4f}, Test Loss={loss_test:.4f}")
# Save results
results = {
'object_tracking': {
'final_train_loss': history['train_loss'][-1],
'final_test_loss': history['test_loss'][-1],
'epochs': 10,
'note': 'Baseline evaluation - no gradient updates (demo only)'
}
}
with open('lstm_baseline_results.json', 'w') as f:
json.dump(results, f, indent=2)
print(f"\n✓ LSTM Baseline evaluation complete!")
print(f"Results saved to: lstm_baseline_results.json")