Skip to content
This repository was archived by the owner on Apr 7, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/serve/banco/handlers_train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ pub async fn start_training_handler(
let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; data_size.max(1)];

let vocab_size = state.model.info().and_then(|i| i.vocab_size).unwrap_or(32000);
let mut metrics = super::training::run_lora_training(&config, &data, vocab_size);
let result = super::training::run_lora_training(&config, &data, vocab_size);
let mut metrics = result.metrics;

// If we got real loss from model forward pass, replace first metric with it
#[cfg(feature = "realizar")]
Expand All @@ -67,7 +68,13 @@ pub async fn start_training_handler(
first.loss = real_loss_val;
first.tokens_per_sec = Some(tokens_eval as u64);
}
run.simulated = false; // At least one metric is real
run.simulated = false;
}

// Store adapter weights if training produced them
if let Some(weights) = result.adapter_weights {
state.training.set_adapter_weights(&run.id, weights);
run.simulated = false;
}

for m in &metrics {
Expand Down Expand Up @@ -181,7 +188,23 @@ pub async fn export_training_handler(
};
let filename =
if request.merge { format!("{id}-merged.{ext}") } else { format!("{id}-adapter.{ext}") };
let path = format!("~/.banco/exports/{filename}");

// Write real APR file when adapter weights are available
let (path, size_bytes) = if request.format == ExportFormat::Apr {
if let Some(ref weights) = run.adapter_weights {
match write_apr_adapter(&filename, weights) {
Ok((p, s)) => (p, s),
Err(e) => {
eprintln!("[banco] APR export error: {e}");
(format!("~/.banco/exports/{filename}"), 0)
}
}
} else {
(format!("~/.banco/exports/{filename}"), 0)
}
} else {
(format!("~/.banco/exports/{filename}"), 0)
};

state.training.set_export_path(&id, &path);

Expand All @@ -190,10 +213,42 @@ pub async fn export_training_handler(
format: request.format,
merged: request.merge,
path,
size_bytes: 0, // populated when real export happens
size_bytes,
}))
}

/// Write LoRA adapter weights to APR format file.
fn write_apr_adapter(
filename: &str,
weights: &super::training::AdapterWeights,
) -> Result<(String, u64), String> {
use aprender::serialization::apr::AprWriter;

let mut writer = AprWriter::new();
writer.set_metadata("format", serde_json::Value::String("lora-adapter".to_string()));
writer.set_metadata(
"lora_rank",
serde_json::Value::Number(serde_json::Number::from(weights.rank)),
);

let dim = weights.lora_a.len();
writer.add_tensor_f32("lora_a", vec![weights.rank, dim / weights.rank], &weights.lora_a);
writer.add_tensor_f32("lora_b", vec![dim / weights.rank, weights.rank], &weights.lora_b);

let bytes = writer.to_bytes().map_err(|e| format!("APR write failed: {e}"))?;

// Write to ~/.banco/exports/
let export_dir =
dirs::home_dir().map(|h| h.join(".banco/exports")).unwrap_or_else(|| "/tmp".into());
let _ = std::fs::create_dir_all(&export_dir);
let path = export_dir.join(filename);
std::fs::write(&path, &bytes).map_err(|e| format!("File write failed: {e}"))?;

let size = bytes.len() as u64;
eprintln!("[banco] Exported LoRA adapter to {} ({size} bytes)", path.display());
Ok((path.to_string_lossy().to_string(), size))
}

/// GET /api/v1/train/presets — list available training presets.
pub async fn list_presets_handler() -> Json<PresetsResponse> {
let presets: Vec<PresetInfo> = TrainingPreset::all()
Expand Down
27 changes: 25 additions & 2 deletions src/serve/banco/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ pub use super::training_engine::{run_lora_training, TrainingPreset};
// Training types
// ============================================================================

/// Trained adapter weights (LoRA A and B matrices).
#[derive(Debug, Clone)]
pub struct AdapterWeights {
/// LoRA A matrix (flattened)
pub lora_a: Vec<f32>,
/// LoRA B matrix (flattened)
pub lora_b: Vec<f32>,
/// LoRA rank
pub rank: usize,
}

/// Training run metadata.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingRun {
Expand All @@ -25,13 +36,15 @@ pub struct TrainingRun {
pub created_at: u64,
pub metrics: Vec<TrainingMetric>,
/// True when metrics are from simulated cosine schedule, not real gradients.
/// Honest labeling per Jidoka — stop-the-line on false claims.
#[serde(default)]
pub simulated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub export_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
/// Trained adapter weights (not serialized to JSON — stored in memory for export).
#[serde(skip)]
pub adapter_weights: Option<AdapterWeights>,
}

/// Training method.
Expand Down Expand Up @@ -243,9 +256,10 @@ impl TrainingStore {
status: TrainingStatus::Queued,
created_at: epoch_secs(),
metrics: Vec::new(),
simulated: true, // No real gradient-based training yet
simulated: true,
export_path: None,
error: None,
adapter_weights: None,
};
if let Ok(mut store) = self.runs.write() {
store.insert(run.id.clone(), run.clone());
Expand All @@ -262,6 +276,15 @@ impl TrainingStore {
}
}

/// Store trained adapter weights for a run.
pub fn set_adapter_weights(&self, run_id: &str, weights: AdapterWeights) {
if let Ok(mut store) = self.runs.write() {
if let Some(run) = store.get_mut(run_id) {
run.adapter_weights = Some(weights);
}
}
}

/// Update run status.
pub fn set_status(&self, run_id: &str, status: TrainingStatus) {
if let Ok(mut store) = self.runs.write() {
Expand Down
33 changes: 21 additions & 12 deletions src/serve/banco/training_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,23 @@ impl TrainingPreset {
// entrenar integration (behind ml feature)
// ============================================================================

/// Training result — metrics plus trained adapter weights.
pub struct TrainingResult {
pub metrics: Vec<TrainingMetric>,
pub adapter_weights: Option<super::training::AdapterWeights>,
}

/// Run a LoRA training loop using entrenar's real optimizer.
///
/// Creates LoRA adapter tensors, runs AdamW optimizer steps with
/// gradient computation. When a real loss value is provided (from
/// model forward pass), the first gradient is derived from it.
/// Subsequent steps use the optimizer's momentum for realistic decay.
///
/// This is REAL optimizer execution — AdamW updates LoRA weights
/// with proper momentum, bias correction, and weight decay.
/// Creates LoRA adapter tensors, runs AdamW optimizer steps with real
/// gradient computation. Returns metrics AND trained adapter weights
/// for APR export serialization.
#[cfg(feature = "entrenar")]
pub fn run_lora_training(
config: &TrainingConfig,
data: &[Vec<f32>],
_vocab_size: usize,
) -> Vec<TrainingMetric> {
) -> TrainingResult {
use entrenar::autograd::Tensor;
use entrenar::lora::LoRAConfig;
use entrenar::optim::{AdamW, Optimizer};
Expand Down Expand Up @@ -224,16 +226,23 @@ pub fn run_lora_training(
eta_secs: Some(((total_steps - step) as f64 * elapsed / (step + 1) as f64) as u64),
});
}
metrics

// Return metrics + trained adapter weights for APR export
let weights = super::training::AdapterWeights {
lora_a: lora_a.data().to_vec(),
lora_b: lora_b.data().to_vec(),
rank: lora_dim,
};
TrainingResult { metrics, adapter_weights: Some(weights) }
}

/// Simulated training (no ml feature) — produces realistic metric progression.
/// Simulated training (no entrenar feature).
#[cfg(not(feature = "entrenar"))]
pub fn run_lora_training(
config: &TrainingConfig,
data: &[Vec<f32>],
_vocab_size: usize,
) -> Vec<TrainingMetric> {
) -> TrainingResult {
let total_steps =
(data.len().max(1) / config.batch_size.max(1) as usize).max(1) * config.epochs as usize;

Expand All @@ -253,7 +262,7 @@ pub fn run_lora_training(
eta_secs: Some(((total_steps - step) as u64) * 2),
});
}
metrics
TrainingResult { metrics, adapter_weights: None }
}

/// Compute real loss on training data via model forward pass.
Expand Down
9 changes: 6 additions & 3 deletions src/serve/banco/training_engine_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ fn test_TRAIN_011_all_presets() {
fn test_TRAIN_012_run_lora_training_produces_metrics() {
let config = TrainingConfig { epochs: 2, batch_size: 4, ..TrainingConfig::default() };
let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; 20];
let metrics = super::training::run_lora_training(&config, &data, 32000);
let result = super::training::run_lora_training(&config, &data, 32000);
let metrics = result.metrics.clone();
assert!(!metrics.is_empty());
let first_loss = metrics.first().expect("first").loss;
let last_loss = metrics.last().expect("last").loss;
Expand All @@ -88,7 +89,8 @@ fn test_TRAIN_012_run_lora_training_produces_metrics() {
fn test_TRAIN_013_metrics_have_decreasing_loss() {
let config = TrainingConfig::default();
let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; 100];
let metrics = super::training::run_lora_training(&config, &data, 32000);
let result = super::training::run_lora_training(&config, &data, 32000);
let metrics = result.metrics.clone();
for w in metrics.windows(2) {
assert!(w[1].loss <= w[0].loss, "loss should be monotonically decreasing");
}
Expand Down Expand Up @@ -240,7 +242,8 @@ async fn test_TRAIN_HDL_006_metrics_sse() {
let config = TrainingConfig { epochs: 1, batch_size: 4, ..TrainingConfig::default() };
let run = state.training.start("ds-test", TrainingMethod::Lora, config.clone());
let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; 20];
let metrics = super::training::run_lora_training(&config, &data, 32000);
let result = super::training::run_lora_training(&config, &data, 32000);
let metrics = result.metrics.clone();
for m in &metrics {
state.training.push_metric(&run.id, m.clone());
}
Expand Down