Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/factorzen/core/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Small terminal progress helpers for CLI pipelines."""

from __future__ import annotations

import sys
from atexit import register


class OverallProgress:
"""Render a coarse overall progress bar for stage-based CLI workflows."""

def __init__(self, total: int, *, label: str = "Overall") -> None:
self.total = max(total, 1)
self.label = label
self.current = 0
self._enabled = sys.stderr.isatty()
self._closed = False

def start(self) -> "OverallProgress":
if self._enabled:
register(self.close)
self._render("starting")
return self

def advance(self, step: str) -> None:
self.current = min(self.current + 1, self.total)
if self._enabled:
self._render(step)

def close(self) -> None:
if self._enabled and not self._closed:
sys.stderr.write("\n")
sys.stderr.flush()
self._closed = True

def _render(self, step: str) -> None:
width = 28
filled = round(width * self.current / self.total)
bar = "#" * filled + "-" * (width - filled)
percent = round(100 * self.current / self.total)
sys.stderr.write(
f"\r{self.label} [{bar}] {self.current}/{self.total} {percent:3d}% {step}"
)
sys.stderr.flush()
19 changes: 19 additions & 0 deletions src/factorzen/pipelines/daily_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
run_experiment,
)
from factorzen.core.logger import get_logger, setup_logging
from factorzen.core.progress import OverallProgress
from factorzen.core.storage import load_parquet
from factorzen.core.timing import StageTimer
from factorzen.core.universe import get_universe
Expand Down Expand Up @@ -564,6 +565,7 @@ def _run(
timer: StageTimer | None = None,
) -> dict[str, str]:
timer = timer or StageTimer()
progress = OverallProgress(16, label="Daily run").start()
# ── 0b. 设置全局随机种子(可选)──
if args.seed is not None:
from factorzen.core.seed import set_global_seed
Expand All @@ -583,6 +585,7 @@ def _run(
factor_output_dir = daily_factor_output_dir(factor.name)
result_output_dir = daily_result_output_dir(factor.name)
report_output_dir = daily_report_output_dir(factor.name)
progress.advance("init")

# ── 2. 准备数据 ──
trade_dates = get_trade_dates(args.start, args.end)
Expand All @@ -608,6 +611,7 @@ def _run(
except Exception as e:
logger.error(f"数据保障失败: {e}")
raise RuntimeError(f"ensure_data_for_daily_run failed: {e}") from e
progress.advance("data")

# ── 3. 股票池 ──
universe = get_universe(args.end, args.universe)
Expand All @@ -624,6 +628,7 @@ def _run(
)
universe.write_parquet(str(universe_snapshot_path))
logger.info(f"Universe 快照已保存: {universe_snapshot_path} ({len(ts_codes)} 只)")
progress.advance("universe")

# ── 4. 计算因子 ──
ctx = FactorDataContext(
Expand All @@ -647,6 +652,7 @@ def _run(
raise RuntimeError("empty factor result")
if validation.get("coverage", 0) < 0.5:
logger.warning("因子覆盖率不足 50%,结果可能不可靠")
progress.advance("factor")

# ── 5. 预处理 ──
daily_basic_for_neutralize = None
Expand All @@ -667,6 +673,7 @@ def _run(
daily_basic=daily_basic_for_neutralize,
)
logger.info("预处理完成 (去极值 → 填充 → 标准化)")
progress.advance("preprocess")

# ── 6. 计算前向收益 ──
daily = ctx.daily.collect()
Expand Down Expand Up @@ -696,12 +703,14 @@ def _run(
if quality_report["warnings"]:
logger.warning(f"数据质量警告: {quality_report['warnings']}")
logger.info(f"数据质量报告已保存: {quality_path}")
progress.advance("returns-quality")

# ── 7. IC 分析 ──
with timer.stage("IC 分析"):
ic_result = compute_rank_ic(clean_df, ret_df, frequency=args.frequency)
ic_result.factor_name = factor.name
logger.info(f"\n{ic_result.summary()}")
progress.advance("ic")

# 可选:Pearson IC / Both IC
pearson_ic_result = None
Expand Down Expand Up @@ -773,6 +782,7 @@ def _run(
logger.info(f"Neutralized IC Mean: {neutralized_ic_result.ic_mean:.4f}")
except Exception as e:
logger.warning(f"中性化 IC 计算失败(跳过): {e}")
progress.advance("optional-ic")

# ── 8. 策略回测 ──
with timer.stage("策略回测"):
Expand All @@ -783,12 +793,14 @@ def _run(
factor_name=factor.name,
frequency=args.frequency,
)
progress.advance("backtest")

# ── 9. 换手率 ──
with timer.stage("换手率"):
to_result = compute_turnover(clean_df, frequency=args.frequency)
to_result.factor_name = factor.name
logger.info(f"\n{to_result.summary()}")
progress.advance("turnover")

factor_output_dir.mkdir(parents=True, exist_ok=True)
result_output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -800,6 +812,7 @@ def _run(
ic_path = result_output_dir / f"{factor.name}_{args.start}_{args.end}_ic.parquet"
ic_result.ic_series.write_parquet(str(ic_path))
logger.info(f"IC 序列已保存: {ic_path}")
progress.advance("save-core")

# ── 10. Walk-forward / OOS 摘要 ──
if effective_config.walk_forward.enabled:
Expand All @@ -821,6 +834,7 @@ def _run(
walk_forward_summary = {"status": "disabled", "n_folds": 0}
walk_forward_result = None
logger.info("Walk-forward 已关闭,跳过")
progress.advance("walk-forward")

# ── 11. 落盘 ──
daily_basic_for_breakdowns = daily_basic_for_neutralize
Expand Down Expand Up @@ -860,6 +874,7 @@ def _run(
logger.info(f"事件研究完成: {event_study_result.n_events} 个事件")
except Exception as e:
logger.warning(f"事件研究计算失败(跳过): {e}")
progress.advance("advanced")

# ── 12. Benchmark 对比(可选)──
benchmark_result = None
Expand All @@ -881,6 +896,7 @@ def _run(
logger.info(f"Benchmark: {benchmark_result.summary()}")
except Exception as e:
logger.warning(f"Benchmark 计算失败(跳过): {e}")
progress.advance("benchmark")

# ── 13. HTML 报告(当 --benchmark 提供时生成,或始终生成)──
date_range = f"{args.start[:4]}-{args.start[4:6]}-{args.start[6:]} ~ {args.end[:4]}-{args.end[4:6]}-{args.end[6:]}"
Expand All @@ -902,6 +918,7 @@ def _run(
quality_report=quality_report,
backtest_direction=None,
)
progress.advance("llm")
with timer.stage("报告生成"):
html = generate_tear_sheet(
factor_name=factor.name,
Expand All @@ -928,6 +945,7 @@ def _run(
report_path = report_output_dir / f"{factor.name}_{args.start}_{args.end}.html"
report_path.write_text(html, encoding="utf-8")
logger.info(f"报告已生成: {report_path}")
progress.advance("report")

outputs = {
"factor": str(factor_path),
Expand All @@ -941,6 +959,7 @@ def _run(
outputs["llm_explanation"] = str(llm_explanation_path)
if getattr(args, "metrics_out", None):
_write_run_metrics(args.metrics_out, ic_result, bt_result)
progress.close()
return outputs


Expand Down
9 changes: 9 additions & 0 deletions src/factorzen/pipelines/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from factorzen.core.loader import fetch_daily
from factorzen.core.logger import get_logger, setup_logging
from factorzen.core.progress import OverallProgress
from factorzen.core.storage import load_parquet
from factorzen.core.timing import StageTimer
from factorzen.core.universe import get_universe
Expand Down Expand Up @@ -358,6 +359,7 @@ def _run(
timer: StageTimer | None = None,
) -> dict[str, str]:
timer = timer or StageTimer()
progress = OverallProgress(5, label="Report run").start()
logger.info(f"──── 因子报告生成: {args.factor} | {args.start} ~ {args.end} ────")

# ── 1. 获取因子类 ──
Expand All @@ -367,6 +369,7 @@ def _run(
logger.error(str(e))
raise RuntimeError(f"unknown factor: {args.factor}") from e
factor = factor_cls()
progress.advance("init")
logger.info(f"因子: {factor.name} | {factor.description}")

walk_forward_summary: dict | None = None
Expand Down Expand Up @@ -431,6 +434,7 @@ def _run(
strategy_results = {bt_result.strategy_name: bt_result}
logger.warning("日线数据为空,跳过高级评价")
advanced_results = None
progress.advance("results")
else:
if args.reuse:
logger.info("--reuse: 未找到缓存,退回完整计算")
Expand Down Expand Up @@ -577,6 +581,7 @@ def _run(
walk_forward_summary=walk_forward_summary,
backtest_direction=backtest_direction,
)
progress.advance("results")

# ── (Optional) Benchmark 对比 ──
benchmark_result = None
Expand All @@ -592,6 +597,7 @@ def _run(
logger.warning(f"Benchmark 计算失败(跳过): {e}")

# ── 11. 生成 HTML 报告 ──
progress.advance("benchmark")
date_range = f"{args.start[:4]}-{args.start[4:6]}-{args.start[6:]} ~ {args.end[:4]}-{args.end[4:6]}-{args.end[6:]}"
quality_report_for_llm: dict[str, Any] | None = None
quality_report_path = _quality_path(args.factor, args.start, args.end)
Expand All @@ -618,6 +624,7 @@ def _run(
quality_report=quality_report_for_llm,
backtest_direction=backtest_direction,
)
progress.advance("llm")
with timer.stage("报告生成"):
html = generate_tear_sheet(
factor_name=factor.name,
Expand Down Expand Up @@ -647,6 +654,7 @@ def _run(
report_dir.mkdir(parents=True, exist_ok=True)
report_path = report_dir / f"{factor.name}_{args.start}_{args.end}.html"
report_path.write_text(html, encoding="utf-8")
progress.advance("report")
logger.info(f"报告已生成: {report_path}")

outputs = {
Expand All @@ -658,6 +666,7 @@ def _run(
outputs["quality_report"] = str(quality_report_path)
if llm_explanation_path is not None:
outputs["llm_explanation"] = str(llm_explanation_path)
progress.close()
return outputs


Expand Down
Loading