diff --git a/config/multipanel_example.yaml b/config/multipanel_example.yaml new file mode 100644 index 00000000..1b961786 --- /dev/null +++ b/config/multipanel_example.yaml @@ -0,0 +1,117 @@ +# yaml-language-server: $schema=../workflow/tools/config.schema.json +description: | + Evaluate skill of Stage E with/without cutoff edges trained with and without subgrid orography. + +dates: + start: 2025-01-01T06:00 + end: 2025-12-26T00:00 + frequency: 30h + +runs: + + - forecaster: + inference_resources: + slurm_partition: normal-shared + checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/fd63e17043014af59170c7beca516b95 + label: stage_E_realch1 + steps: 0/120/6 + config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + extra_requirements: + - git+https://github.com/ecmwf/anemoi-inference.git@0.10.0 + + - forecaster: + inference_resources: + slurm_partition: normal-shared + checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad + label: stage_E_icon_1km_cutoff_edges_subgrid_horography + steps: 0/120/6 + config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + extra_requirements: + - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + # - forecaster: + # inference_resources: + # slurm_partition: normal-shared + # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/57684b20f64f414b937cce10e5ceeb68 + # label: stage_E_realch1_new + # steps: 0/120/6 + # config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + # extra_requirements: + # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + # - forecaster: + # inference_resources: + # slurm_partition: normal-shared + # checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/2265ae18b04e4470ab89314a85a822ae + # label: stage_E_icon_1km_cutoff_edges_KNN_5_dec + # steps: 0/120/6 + # config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml + # extra_requirements: + # - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db + + +baselines: + - baseline: + baseline_id: ICON-CH1-EPS + label: ICON-CH1-ctrl + root: /scratch/mch/cmerker/ICON-CH1-EPS + steps: 0/33/6 + + - baseline: + baseline_id: ICON-CH2-EPS + label: ICON-CH2-ctrl + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 + + +truth: + label: KENDA-CH1 + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr + +stratification: + regions: + - jura + root: /scratch/mch/bhendj/regions/Prognoseregionen_LV95_20220517 + +dashboard: + stratification: + - season + +locations: + output_root: ./output + +profile: + executor: slurm + global_resources: + gpus: 16 + default_resources: + slurm_partition: "postproc" + cpus_per_task: 1 + mem_mb_per_cpu: 1800 + runtime: "1h" + gpus: 0 + jobs: 50 + batch_rules: + plot_frame: 32 + +multipanel_plots: + bias_overview: + rows: 2 + cols: 2 + figsize: [12, 8] + title: "BIAS vs lead time" + panels: + - {metric: BIAS, param: T_2M, season: all, title: "T_2M — all"} + - {metric: BIAS, param: T_2M, season: JJA, title: "T_2M — JJA"} + - {metric: BIAS, param: PMSL, season: all, title: "PMSL — all"} + - {metric: BIAS, param: PMSL, season: JJA, title: "PMSL — JJA"} + rmse_overview: + rows: 2 + cols: 2 + figsize: [12, 8] + title: "RMSE vs lead time" + panels: + - {metric: RMSE, param: T_2M, init_hour: -999, title: "T_2M — 00 UTC"} + - {metric: RMSE, param: T_2M, init_hour: 12, title: "T_2M — 12 UTC"} + - {metric: RMSE, param: PMSL, init_hour: -999, title: "PMSL — 00 UTC"} + - {metric: RMSE, param: PMSL, init_hour: 12, title: "PMSL — 12 UTC"} diff --git a/resources/report/dashboard/script.js b/resources/report/dashboard/script.js index 170ee050..f2936021 100644 --- a/resources/report/dashboard/script.js +++ b/resources/report/dashboard/script.js @@ -62,6 +62,12 @@ document.getElementById("param-select").addEventListener("change", updateChart); data = JSON.parse(document.getElementById("verif-data").textContent) header = document.getElementById("header-text").textContent.trim() +// Pin the source -> color mapping to the full, alphabetically-sorted source +// list so it stays bijective even when sources are toggled in the UI. Must +// match src/plotting/source_colors.py to keep the dashboard and the static +// matplotlib figures consistent. +const allSources = [...new Set(data.map(d => d.source))].sort(); + // Define base spec var spec = { "data": { "values": data }, @@ -106,6 +112,7 @@ var spec = { "color": { "field": "source", "type": "nominal", + "scale": { "scheme": "tableau10", "domain": allSources }, "legend": { "orient": "top", "title": "Data Source", "offset": 0, "padding": 10 } }, "shape": { diff --git a/src/evalml/config.py b/src/evalml/config.py index f1343d07..b65365a1 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Dict, List, Any, ClassVar, FrozenSet -from pydantic import BaseModel, Field, RootModel, field_validator +from pydantic import BaseModel, Field, RootModel, field_validator, model_validator PROJECT_ROOT = Path(__file__).parents[2] @@ -227,6 +227,66 @@ class Stratification(BaseModel): ) +class MultipanelPanelSpec(BaseModel): + """One panel inside a multi-panel metric-vs-lead-time figure.""" + + metric: str = Field(..., description="Metric name (e.g. 'rmse').") + param: str = Field(..., description="Parameter name (e.g. 'T_2M').") + region: str = Field( + "all", + description="Region to subset to. 'all' uses the unstratified aggregate.", + ) + season: str = Field( + "all", + description="Season to subset to. 'all' uses the unstratified aggregate.", + ) + init_hour: int = Field( + -999, + description="Init hour to subset to. -999 (sentinel) uses the unstratified aggregate.", + ) + title: str | None = Field( + None, + description="Panel title. Defaults to ' - '.", + ) + ylim: List[float] | None = Field( + None, + description="Optional [ymin, ymax] for this panel's y-axis.", + min_length=2, + max_length=2, + ) + + model_config = {"extra": "forbid"} + + +class MultipanelPlotSpec(BaseModel): + """Layout for a single multi-panel metric-vs-lead-time figure.""" + + rows: int = Field(..., ge=1, description="Number of subplot rows.") + cols: int = Field(..., ge=1, description="Number of subplot columns.") + figsize: List[float] | None = Field( + None, + description="Optional [width, height] in inches. Defaults to (4.5*cols, 3.5*rows).", + min_length=2, + max_length=2, + ) + title: str | None = Field(None, description="Optional figure-level title.") + panels: List[MultipanelPanelSpec] = Field( + ..., + description="Per-panel specs in row-major order. Length must equal rows*cols.", + ) + + model_config = {"extra": "forbid"} + + @model_validator(mode="after") + def _check_panel_count(self) -> "MultipanelPlotSpec": + expected = self.rows * self.cols + if len(self.panels) != expected: + raise ValueError( + f"panels has length {len(self.panels)}, expected rows*cols = {expected}" + ) + return self + + class Dashboard(BaseModel): """Settings for the dashboard""" @@ -351,6 +411,14 @@ def validate_threshold_operators( dashboard: Dashboard locations: Locations profile: Profile + multipanel_plots: Dict[str, MultipanelPlotSpec] = Field( + default_factory=dict, + description=( + "Optional named multi-panel metric-vs-lead-time figures. " + "Each entry produces one PNG under results//multipanel/.png " + "when the verification_metrics_multipanel_plot_all target is built." + ), + ) model_config = { "extra": "forbid", # fail on misspelled keys diff --git a/src/plotting/metric_lead_time_panel.py b/src/plotting/metric_lead_time_panel.py new file mode 100644 index 00000000..8749e6b1 --- /dev/null +++ b/src/plotting/metric_lead_time_panel.py @@ -0,0 +1,64 @@ +"""Per-axes plotting helper for verification metrics vs. lead time.""" + +import pandas as pd +from matplotlib.axes import Axes + +from verification import decode_metric + +from .units import metric_units + + +def _default_ylabel(metric: str, param: str | None) -> str: + label = decode_metric(metric) + units = metric_units(metric, param) if param is not None else "" + return f"{label} [{units}]" if units else label + + +def plot_panel( + ax: Axes, + sub_df: pd.DataFrame, + *, + metric: str, + param: str | None = None, + title: str | None = None, + panel_label: str | None = None, + xlabel: str | None = "Lead Time [h]", + ylabel: str | None = None, + show_legend: bool = True, + color_map: dict[str, str] | None = None, +) -> None: + """Plot one metric-vs-lead-time panel onto `ax`. + + `sub_df` must already be filtered to a single (metric, param, region, season, + init_hour) combo and contain at least the columns: source, lead_time, value. + One line per source is drawn. + + If `ylabel` is None and `param` is provided, the y-axis label is built as + " []" via plotting.units.metric_units. + + `panel_label` (e.g. "a)") is rendered left-aligned at the same height as + the centred title. + + If `color_map` is given, each source's line is drawn in + ``color_map[source]``; sources missing from the map fall back to + matplotlib's default color cycle. Use ``plotting.source_colors.source_color_map`` + to build a map that matches the dashboard. + """ + if ylabel is None: + ylabel = _default_ylabel(metric, param) + for source, df in sub_df.groupby("source"): + df.plot( + x="lead_time", + y="value", + kind="line", + marker="o", + title=title, + xlabel=xlabel or "", + ylabel=ylabel or "", + label=source, + color=(color_map or {}).get(source), + ax=ax, + legend=show_legend, + ) + if panel_label: + ax.set_title(panel_label, loc="left", fontweight="bold") diff --git a/src/plotting/source_colors.py b/src/plotting/source_colors.py new file mode 100644 index 00000000..8f6a48c5 --- /dev/null +++ b/src/plotting/source_colors.py @@ -0,0 +1,38 @@ +"""Stable source -> color mapping shared with the dashboard. + +The dashboard uses Vega-Lite's ``tableau10`` categorical scheme and pins its +``color.scale.domain`` to the alphabetically-sorted full source list so the +mapping stays bijective regardless of dashboard filters. The matplotlib plots +use the same palette and ordering so a given source has the same color in +every figure produced from a verification run. + +Both the dashboard and the matplotlib side wrap around when there are more +than ``len(TABLEAU10)`` sources, at which point two sources will share a +color. Switch palettes (e.g. to ``tableau20`` or a deterministic HSV ramp) +if that becomes a problem. +""" + +# Vega-Lite "tableau10" scheme: +# https://vega.github.io/vega/docs/schemes/#tableau10 +TABLEAU10: list[str] = [ + "#4c78a8", + "#f58518", + "#e45756", + "#72b7b2", + "#54a24b", + "#eeca3b", + "#b279a2", + "#ff9da6", + "#9d755d", + "#bab0ac", +] + + +def source_color_map(sources) -> dict[str, str]: + """Return ``{source: color}`` over unique sources, ordered alphabetically. + + Wraps around for more than ``len(TABLEAU10)`` sources, matching Vega-Lite's + behaviour for a categorical scale whose domain exceeds the scheme. + """ + ordered = sorted(set(sources)) + return {s: TABLEAU10[i % len(TABLEAU10)] for i, s in enumerate(ordered)} diff --git a/src/plotting/units.py b/src/plotting/units.py new file mode 100644 index 00000000..22d219d4 --- /dev/null +++ b/src/plotting/units.py @@ -0,0 +1,26 @@ +"""Canonical units for verification parameters and metrics. + +Storage units in the verification netCDFs (BIAS, RMSE, MAE, STDE, ... all +inherit these). Update the dict if a parameter's internal representation +changes. +""" + +PARAM_UNITS: dict[str, str] = { + "T_2M": "K", + "TD_2M": "K", + "PMSL": "Pa", + "PS": "Pa", + "TOT_PREC": "mm", + "U_10M": "m/s", + "V_10M": "m/s", + "SP_10M": "m/s", +} + +UNITLESS_METRICS: set[str] = {"CORR", "R2"} + + +def metric_units(metric: str, param: str) -> str: + """Return the canonical units of (metric, param), or '' if unitless/unknown.""" + if metric.upper() in UNITLESS_METRICS: + return "" + return PARAM_UNITS.get(param, "") diff --git a/src/verification/loading.py b/src/verification/loading.py new file mode 100644 index 00000000..8ef07780 --- /dev/null +++ b/src/verification/loading.py @@ -0,0 +1,77 @@ +"""Helpers for loading aggregated verification netCDFs into long-form DataFrames.""" + +from pathlib import Path + +import pandas as pd +import xarray as xr + + +def _ensure_unique_lead_time(ds: xr.Dataset) -> xr.Dataset: + """Drop duplicate lead_time entries within a Dataset (keep first occurrence).""" + try: + idx = ds.get_index("lead_time") + except Exception: + idx = pd.Index(ds["lead_time"].values) + if getattr(idx, "has_duplicates", False): + keep = ~idx.duplicated(keep="first") + ds = ds.isel(lead_time=keep) + return ds + + +def _select_best_sources(dfs: list[xr.Dataset]) -> list[xr.Dataset]: + """For sources present in multiple datasets, keep the one with the most lead_times.""" + src_sets = [set(d.source.values.tolist()) for d in dfs] + all_sources = set().union(*src_sets) + + best: dict[str, int] = {} + for s in all_sources: + candidates = [] + for i, d in enumerate(dfs): + if s in d.source.values: + di = d.sel(source=s) + try: + n = pd.Index(di["lead_time"].values).unique().size + except Exception: + n = len(pd.unique(di["lead_time"].values)) + candidates.append((i, n)) + if candidates: + best_idx, _ = max(candidates, key=lambda t: t[1]) + best[s] = best_idx + + out = [] + for i, d in enumerate(dfs): + drop_src = [s for s, b in best.items() if b != i and s in d.source.values] + if drop_src: + d = d.drop_sel(source=drop_src) + out.append(d) + return out + + +def load_long_df(verif_files: list[Path]) -> pd.DataFrame: + """Open verification netCDFs and return a long-form DataFrame. + + Columns: source, lead_time (hours, float), region, season, init_hour, + param, metric, value. + """ + dfs = [xr.open_dataset(f) for f in verif_files] + dfs = [_ensure_unique_lead_time(d) for d in dfs] + dfs = _select_best_sources(dfs) + ds = xr.concat(dfs, dim="source", join="outer") + + nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] + df = ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() + df[["param", "metric"]] = df["stack"].str.split(".", n=1, expand=True) + df.drop(columns=["stack"], inplace=True) + df["lead_time"] = df["lead_time"].dt.total_seconds() / 3600 + return df + + +def subset_df(df: pd.DataFrame, **kwargs) -> pd.DataFrame: + """Return rows of `df` matching every column=value (or column in [values]) constraint.""" + mask = pd.Series([True] * len(df)) + for key, value in kwargs.items(): + if isinstance(value, (list, tuple, set)): + mask &= df[key].isin(value) + else: + mask &= df[key] == value + return df[mask] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 40281a6e..f508094b 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -2,7 +2,7 @@ import pytest -from evalml.config import ConfigModel +from evalml.config import ConfigModel, MultipanelPanelSpec, MultipanelPlotSpec def test_example_forecasters_config(example_forecasters_config): @@ -91,3 +91,68 @@ def test_workflow_derives_baseline_id_from_root_stem(example_interpolators_confi "root": "/store_new/mch/msopr/ml/COSMO-E_hourly", "steps": "0/120/1", } + + +def _spec(rows, cols, panel_count=None): + n = panel_count if panel_count is not None else rows * cols + return { + "rows": rows, + "cols": cols, + "panels": [{"metric": "BIAS", "param": "T_2M"} for _ in range(n)], + } + + +def test_multipanel_panel_defaults(): + panel = MultipanelPanelSpec.model_validate({"metric": "BIAS", "param": "T_2M"}) + assert panel.region == "all" + assert panel.season == "all" + assert panel.init_hour == -999 + assert panel.title is None + assert panel.ylim is None + + +def test_multipanel_panel_forbids_extras(): + with pytest.raises(ValueError, match="Extra"): + MultipanelPanelSpec.model_validate( + {"metric": "BIAS", "param": "T_2M", "unknown": True} + ) + + +def test_multipanel_plot_accepts_matching_panel_count(): + spec = MultipanelPlotSpec.model_validate(_spec(2, 3)) + assert spec.rows == 2 + assert spec.cols == 3 + assert len(spec.panels) == 6 + + +def test_multipanel_plot_rejects_mismatched_panel_count(): + with pytest.raises(ValueError, match=r"rows\*cols"): + MultipanelPlotSpec.model_validate(_spec(2, 2, panel_count=3)) + + +def test_multipanel_plot_forbids_extras(): + bad = _spec(1, 1) + bad["unexpected"] = True + with pytest.raises(ValueError, match="Extra"): + MultipanelPlotSpec.model_validate(bad) + + +def test_multipanel_plot_rejects_zero_dim(): + with pytest.raises(ValueError): + MultipanelPlotSpec.model_validate(_spec(0, 1, panel_count=0)) + + +def test_configmodel_multipanel_plots_default(example_forecasters_config): + """`multipanel_plots` is optional and defaults to an empty dict.""" + cfg = ConfigModel.model_validate(example_forecasters_config) + assert cfg.multipanel_plots == {} + + +def test_configmodel_multipanel_plots_roundtrip(example_forecasters_config): + example_forecasters_config["multipanel_plots"] = { + "bias_overview": _spec(1, 2), + } + cfg = ConfigModel.model_validate(example_forecasters_config) + assert "bias_overview" in cfg.multipanel_plots + assert cfg.multipanel_plots["bias_overview"].rows == 1 + assert cfg.multipanel_plots["bias_overview"].cols == 2 diff --git a/workflow/Snakefile b/workflow/Snakefile index a2586cb4..34a2344b 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -133,6 +133,11 @@ rule experiment_all: rules.verification_metrics_plot.output, experiment=EXPERIMENT_NAME, ), + expand( + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + experiment=[EXPERIMENT_NAME], + plot_name=list((config.get("multipanel_plots") or {}).keys()), + ), rule showcase_all: diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 58b215e5..335e4cbc 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -1,6 +1,8 @@ # ----------------------------------------------------- # # VERIFICATION WORKFLOW # # ----------------------------------------------------- # +import json +import shlex from datetime import datetime import pandas as pd @@ -164,3 +166,48 @@ rule verification_metrics_plot: """ uv run {input.script} {input.verif} --output_dir {output} > {log} 2>&1 """ + + +def _multipanel_plots_cfg() -> dict: + return config.get("multipanel_plots") or {} + + +rule verification_metrics_multipanel_plot: + input: + "src/verification/__init__.py", + script="workflow/scripts/verification_plot_metrics_multipanel.py", + verif=list(EXPERIMENT_PARTICIPANTS.values()), + output: + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + params: + spec_json=lambda wc: shlex.quote( + json.dumps(_multipanel_plots_cfg()[wc.plot_name]) + ), + log: + OUT_ROOT + / "logs/verification_metrics_multipanel_plot/{experiment}-{plot_name}.log", + resources: + cpus_per_task=4, + mem_mb=20_000, + runtime="20m", + shell: + """ + uv run {input.script} {input.verif} \ + --spec_json {params.spec_json} \ + --output {output} > {log} 2>&1 + """ + + +rule verification_metrics_multipanel_plot_all: + """Build every multipanel layout declared in `multipanel_plots` in the config. + + Invoke by rule name (no wildcards). No-op when the config has no + `multipanel_plots` section. + """ + localrule: True + input: + lambda wc: expand( + OUT_ROOT / "results/{experiment}/multipanel/{plot_name}.png", + experiment=[EXPERIMENT_NAME], + plot_name=list(_multipanel_plots_cfg().keys()), + ), diff --git a/workflow/scripts/report_experiment_dashboard.py b/workflow/scripts/report_experiment_dashboard.py index 866d18a5..bf333537 100644 --- a/workflow/scripts/report_experiment_dashboard.py +++ b/workflow/scripts/report_experiment_dashboard.py @@ -1,14 +1,12 @@ import argparse import logging -import sys as _sys from pathlib import Path import jinja2 import xarray as xr -_sys.path.append(str(Path(__file__).parent)) -from verification_plot_metrics import _ensure_unique_lead_time, _select_best_sources from verification import decode_metric +from verification.loading import _ensure_unique_lead_time, _select_best_sources LOG = logging.getLogger(__name__) logging.basicConfig( diff --git a/workflow/scripts/verification_plot_metrics.py b/workflow/scripts/verification_plot_metrics.py index 92149d82..77e255bb 100644 --- a/workflow/scripts/verification_plot_metrics.py +++ b/workflow/scripts/verification_plot_metrics.py @@ -5,9 +5,10 @@ from pathlib import Path import matplotlib.pyplot as plt -import pandas as pd -import xarray as xr -from verification import decode_metric + +from plotting.metric_lead_time_panel import plot_panel +from plotting.source_colors import source_color_map +from verification.loading import load_long_df, subset_df LOG = logging.getLogger(__name__) logging.basicConfig( @@ -15,83 +16,11 @@ ) -def _ensure_unique_lead_time(ds: xr.Dataset) -> xr.Dataset: - """Drop duplicate lead_time entries within a Dataset (keep first occurrence).""" - try: - idx = ds.get_index("lead_time") - except Exception: - idx = pd.Index(ds["lead_time"].values) - if getattr(idx, "has_duplicates", False): - keep = ~idx.duplicated(keep="first") - ds = ds.isel(lead_time=keep) - return ds - - -def _select_best_sources(dfs: list[xr.Dataset]) -> list[xr.Dataset]: - """ - If the same 'source' exists in multiple datasets, keep it only from the dataset - that has the largest number of unique lead_time entries. Drop it from others. - """ - # Compute unique sources per dataset - src_sets = [set(d.source.values.tolist()) for d in dfs] - all_sources = set().union(*src_sets) - - # Decide best provider (dataset index) for each source - best = {} - for s in all_sources: - candidates = [] - for i, d in enumerate(dfs): - if s in d.source.values: - di = d.sel(source=s) - try: - n = pd.Index(di["lead_time"].values).unique().size - except Exception: - n = len(pd.unique(di["lead_time"].values)) - candidates.append((i, n)) - if candidates: - best_idx, _ = max(candidates, key=lambda t: t[1]) - best[s] = best_idx - - # Drop non-best occurrences - out = [] - for i, d in enumerate(dfs): - drop_src = [s for s, b in best.items() if b != i and s in d.source.values] - if drop_src: - d = d.drop_sel(source=drop_src) - out.append(d) - return out - - -def subset_df(df, **kwargs): - mask = pd.Series([True] * len(df)) - for key, value in kwargs.items(): - if isinstance(value, (list, tuple, set)): - mask &= df[key].isin(value) - else: - mask &= df[key] == value - return df[mask] - - def main(args: Namespace) -> None: """Main function to verify results from KENDA-1 data.""" - # remove duplicated but not identical values from analyses (rounding errors) - dfs = [xr.open_dataset(f) for f in args.verif_files] - # 1) Ensure each dataset has unique lead_time values - dfs = [_ensure_unique_lead_time(d) for d in dfs] - # 2) For sources present in multiple datasets, keep the one with most lead_times - dfs = _select_best_sources(dfs) - # 3) Concatenate by source; outer join to keep the union of lead_times - ds = xr.concat(dfs, dim="source", join="outer") - - # extract only non-spatial variables to pd.DataFrame - nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] - all_df = ( - ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() - ) - all_df[["param", "metric"]] = all_df["stack"].str.split(".", n=1, expand=True) - all_df.drop(columns=["stack"], inplace=True) - all_df["lead_time"] = all_df["lead_time"].dt.total_seconds() / 3600 + all_df = load_long_df(args.verif_files) + color_map = source_color_map(all_df["source"].unique()) metrics = all_df["metric"].unique() params = all_df["param"].unique() @@ -108,17 +37,14 @@ def main(args: Namespace) -> None: f"Processing region: {region}, metric: {metric}, param: {param}, season: {season}, init_hour: {init_hour}" ) - def _subset_df(df): - return subset_df( - df, - region=region, - metric=metric, - param=param, - season=season, - init_hour=init_hour, - ) - - sub_df = _subset_df(all_df).dropna() + sub_df = subset_df( + all_df, + region=region, + metric=metric, + param=param, + season=season, + init_hour=init_hour, + ).dropna() if sub_df.empty: continue @@ -126,19 +52,15 @@ def _subset_df(df): title = f"{metric} - {param} - {region}" title += f"- {season} - {init_hour}" if args.stratify else "" - for source, df in sub_df.groupby("source"): - df.plot( - x="lead_time", - y="value", - kind="line", - marker="o", - title=title, - xlabel="Lead Time [h]", - ylabel=decode_metric(metric), - label=source, - color="black" if "analysis" in source else None, - ax=ax, - ) + plot_panel( + ax, + sub_df, + metric=metric, + param=param, + title=title, + color_map=color_map, + ) + args.output_dir.mkdir(parents=True, exist_ok=True) fn = f"{metric}_{param}" fn += f"_{season}_{init_hour}.png" if args.stratify else ".png" @@ -153,8 +75,6 @@ def _subset_df(df): type=Path, nargs="+", help="Paths to verification files.", - # "--verif_files", type=Path, nargs="+", help="Paths to verification files.", - # default = list(Path("output/data").glob("*/*/verif_aggregated.nc")), required=False ) parser.add_argument( "--stratify", diff --git a/workflow/scripts/verification_plot_metrics_multipanel.py b/workflow/scripts/verification_plot_metrics_multipanel.py new file mode 100644 index 00000000..ba5c495e --- /dev/null +++ b/workflow/scripts/verification_plot_metrics_multipanel.py @@ -0,0 +1,170 @@ +"""Build a multi-panel metric-vs-lead-time figure from aggregated verification files. + +The panel layout (rows, cols, per-panel selectors) is supplied as a JSON spec +either inline (``--spec_json ''``) or as a path to a JSON file +(``--spec_path /path/to/spec.json``). The spec schema mirrors +``MultipanelPlotSpec`` in ``src/evalml/config.py``. +""" + +import json +import logging +import string +from argparse import ArgumentParser +from argparse import Namespace +from pathlib import Path + +import matplotlib.pyplot as plt + +from plotting.metric_lead_time_panel import plot_panel +from plotting.source_colors import source_color_map +from plotting.units import metric_units +from verification import decode_metric +from verification.loading import load_long_df, subset_df + + +def _panel_label(idx: int) -> str: + """Return 'a)', 'b)', ..., 'z)', 'aa)', ... for the given 0-based index.""" + letters = string.ascii_lowercase + if idx < len(letters): + return f"{letters[idx]})" + a, b = divmod(idx, len(letters)) + return f"{letters[a - 1]}{letters[b]})" + + +LOG = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + + +def _load_spec(args: Namespace) -> dict: + if args.spec_json: + return json.loads(args.spec_json) + return json.loads(args.spec_path.read_text()) + + +def main(args: Namespace) -> None: + spec = _load_spec(args) + rows = int(spec["rows"]) + cols = int(spec["cols"]) + panels = spec["panels"] + if len(panels) != rows * cols: + raise ValueError( + f"panels has length {len(panels)}, expected rows*cols = {rows * cols}" + ) + + all_df = load_long_df(args.verif_files) + color_map = source_color_map(all_df["source"].unique()) + + figsize = tuple(spec.get("figsize") or (4.5 * cols, 3.5 * rows)) + fig, axes = plt.subplots(rows, cols, sharex=True, figsize=figsize, squeeze=False) + + legend_entries: dict[str, object] = {} + for idx, panel in enumerate(panels): + r, c = divmod(idx, cols) + ax = axes[r][c] + metric = panel["metric"] + param = panel["param"] + sub = subset_df( + all_df, + metric=metric, + param=param, + region=panel.get("region", "all"), + season=panel.get("season", "all"), + init_hour=panel.get("init_hour", -999), + ).dropna() + if sub.empty: + LOG.warning( + "No data for panel %d (metric=%s, param=%s, region=%s, season=%s, init_hour=%s)", + idx, + metric, + param, + panel.get("region", "all"), + panel.get("season", "all"), + panel.get("init_hour", -999), + ) + + is_bottom = r == rows - 1 + is_left = c == 0 + title = panel.get("title", f"{metric} - {param}") + units = metric_units(metric, param) + ylabel = ( + (f"{decode_metric(metric)} [{units}]" if units else decode_metric(metric)) + if is_left + else None + ) + plot_panel( + ax, + sub, + metric=metric, + param=param, + title=title, + panel_label=_panel_label(idx), + xlabel="Lead Time [h]" if is_bottom else None, + ylabel=ylabel, + show_legend=False, + color_map=color_map, + ) + if panel.get("ylim"): + ax.set_ylim(panel["ylim"]) + + handles, labels = ax.get_legend_handles_labels() + for handle, label in zip(handles, labels): + legend_entries.setdefault(label, handle) + + if spec.get("title"): + fig.suptitle(spec["title"]) + + if legend_entries: + fig.legend( + list(legend_entries.values()), + list(legend_entries.keys()), + loc="lower center", + ncol=min(len(legend_entries), 4), + bbox_to_anchor=(0.5, 0.0), + ) + + top = 0.92 if spec.get("title") else 0.96 + fig.subplots_adjust( + left=0.09, + right=0.97, + top=top, + bottom=0.13, + hspace=0.45, + wspace=0.28, + ) + + args.output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(args.output, dpi=150) + plt.close(fig) + + +if __name__ == "__main__": + parser = ArgumentParser(description=__doc__) + parser.add_argument( + "verif_files", + type=Path, + nargs="+", + help="Paths to aggregated verification netCDFs.", + ) + spec_group = parser.add_mutually_exclusive_group(required=True) + spec_group.add_argument( + "--spec_json", + type=str, + default=None, + help="Inline JSON string describing the panel layout.", + ) + spec_group.add_argument( + "--spec_path", + type=Path, + default=None, + help="Path to a JSON file describing the panel layout.", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output PNG path.", + ) + args = parser.parse_args() + main(args) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index d6153e87..a41ab8a7 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -445,6 +445,141 @@ "title": "Locations", "type": "object" }, + "MultipanelPanelSpec": { + "additionalProperties": false, + "description": "One panel inside a multi-panel metric-vs-lead-time figure.", + "properties": { + "metric": { + "description": "Metric name (e.g. 'rmse').", + "title": "Metric", + "type": "string" + }, + "param": { + "description": "Parameter name (e.g. 'T_2M').", + "title": "Param", + "type": "string" + }, + "region": { + "default": "all", + "description": "Region to subset to. 'all' uses the unstratified aggregate.", + "title": "Region", + "type": "string" + }, + "season": { + "default": "all", + "description": "Season to subset to. 'all' uses the unstratified aggregate.", + "title": "Season", + "type": "string" + }, + "init_hour": { + "default": -999, + "description": "Init hour to subset to. -999 (sentinel) uses the unstratified aggregate.", + "title": "Init Hour", + "type": "integer" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Panel title. Defaults to ' - '.", + "title": "Title" + }, + "ylim": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional [ymin, ymax] for this panel's y-axis.", + "title": "Ylim" + } + }, + "required": [ + "metric", + "param" + ], + "title": "MultipanelPanelSpec", + "type": "object" + }, + "MultipanelPlotSpec": { + "additionalProperties": false, + "description": "Layout for a single multi-panel metric-vs-lead-time figure.", + "properties": { + "rows": { + "description": "Number of subplot rows.", + "minimum": 1, + "title": "Rows", + "type": "integer" + }, + "cols": { + "description": "Number of subplot columns.", + "minimum": 1, + "title": "Cols", + "type": "integer" + }, + "figsize": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional [width, height] in inches. Defaults to (4.5*cols, 3.5*rows).", + "title": "Figsize" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Optional figure-level title.", + "title": "Title" + }, + "panels": { + "description": "Per-panel specs in row-major order. Length must equal rows*cols.", + "items": { + "$ref": "#/$defs/MultipanelPanelSpec" + }, + "title": "Panels", + "type": "array" + } + }, + "required": [ + "rows", + "cols", + "panels" + ], + "title": "MultipanelPlotSpec", + "type": "object" + }, "Profile": { "description": "Workflow execution profile.", "properties": { @@ -624,6 +759,14 @@ }, "profile": { "$ref": "#/$defs/Profile" + }, + "multipanel_plots": { + "additionalProperties": { + "$ref": "#/$defs/MultipanelPlotSpec" + }, + "description": "Optional named multi-panel metric-vs-lead-time figures. Each entry produces one PNG under results//multipanel/.png when the verification_metrics_multipanel_plot_all target is built.", + "title": "Multipanel Plots", + "type": "object" } }, "required": [