diff --git a/config/forecasters-co1e.yaml b/config/forecasters-co1e.yaml index 777492ca..03fcfcfe 100644 --- a/config/forecasters-co1e.yaml +++ b/config/forecasters-co1e.yaml @@ -72,3 +72,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/forecasters-co2-disentangled.yaml b/config/forecasters-co2-disentangled.yaml index 77bdba4b..91489c00 100644 --- a/config/forecasters-co2-disentangled.yaml +++ b/config/forecasters-co2-disentangled.yaml @@ -91,3 +91,36 @@ profile: runtime: "1h" gpus: 0 jobs: 50 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/forecasters-co2.yaml b/config/forecasters-co2.yaml index a51c1203..ccf67396 100644 --- a/config/forecasters-co2.yaml +++ b/config/forecasters-co2.yaml @@ -68,3 +68,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/forecasters-ich1-oper-fixed.yaml b/config/forecasters-ich1-oper-fixed.yaml index 9c5cb970..09261399 100644 --- a/config/forecasters-ich1-oper-fixed.yaml +++ b/config/forecasters-ich1-oper-fixed.yaml @@ -84,3 +84,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index cac91861..1f53a534 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -79,3 +79,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/forecasters-ich1.yaml b/config/forecasters-ich1.yaml index 3f8ee7db..09fbe0af 100644 --- a/config/forecasters-ich1.yaml +++ b/config/forecasters-ich1.yaml @@ -91,3 +91,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/interpolators-co2.yaml b/config/interpolators-co2.yaml index 87ea6fa3..64c3f30f 100644 --- a/config/interpolators-co2.yaml +++ b/config/interpolators-co2.yaml @@ -99,3 +99,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/config/interpolators-ich1.yaml b/config/interpolators-ich1.yaml index a56296eb..833ed22d 100644 --- a/config/interpolators-ich1.yaml +++ b/config/interpolators-ich1.yaml @@ -88,3 +88,36 @@ profile: jobs: 50 batch_rules: plot_forecast_frame: 32 + +score_maps: + params: + - T_2M + # - TD_2M + # - U_10M + # - V_10M + # - SP_10M + # - PS + # - PMSL + # - TOT_PREC + leadtimes: [6, 24] + # Or, to compute every available leadtime from runs+baselines: + # leadtimes: "all" + scores: + - BIAS + # - RMSE + # - MAE + regions: + - switzerland + # - centraleurope + seasons: + - all + # - DJF + # - MAM + # - JJA + # - SON + init_hours: + - all + # - "00" + # - "06" + # - "12" + # - "18" diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 3c990b70..71e49bf6 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -111,9 +111,28 @@ def load_fct_data_from_grib( root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from GRIB files for a specific valid time.""" + # TODO: this function carries a large per-call setup cost that is + # independent of data volume (likely eccodes/FileDataSource init or GRIB + # index build, not decoding). It dominates runtime in any rule that calls + # it inside a per-init-time loop (e.g. verification_score_maps) and also + # adds noticeable overhead to verif_metrics and the plot rules. files = sorted(root.glob(f"{reftime:%Y%m%d%H%M}*.grib")) fds = data_source.FileDataSource(datafiles=files) - ds = grib_decoder.load(fds, {"param": params, "step": steps}) + # For TOT_PREC (cumulative-from-start) we need step 0 to disaggregate to a + # 0->step period accumulation even when the caller asks for a single step. + # anemoi-inference may omit step 0 from the GRIB; tolerate that and + # synthesize lead_time=0, TOT_PREC=0 below (cumulative-from-start has + # nothing accumulated at the IC by definition). + needs_step_zero = "TOT_PREC" in params and 0 not in steps + fetch_steps = [0, *steps] if needs_step_zero else list(steps) + ds = grib_decoder.load(fds, {"param": params, "step": fetch_steps}) + # grib_decoder.load may silently drop steps that aren't on disk + # (anemoi-inference often omits step 0 even with cumulative-from-start + # accumulation). Detect that here so the TOT_PREC block can synthesize + # lead_time=0, TOT_PREC=0 below. + zero_lt = np.timedelta64(0, "h") + loaded_lead_times = next(iter(ds.values())).lead_time.values + step_zero_synthetic = needs_step_zero and zero_lt not in loaded_lead_times for var, da in ds.items(): if "z" in da.dims and da.sizes["z"] == 1: ds[var] = da.squeeze("z", drop=True) @@ -121,11 +140,17 @@ def load_fct_data_from_grib( ds[var] = da.rename({"z": da.attrs["vcoord_type"]}) ds = xr.merge([ds[p].rename(p) for p in ds], compat="no_conflicts") lead_times = np.array(steps, dtype="timedelta64[h]") - # Restrict to the requested lead times so that the TOT_PREC disaggregation - # below operates on the correct step interval even if the GRIB contains - # extra (e.g. hourly) steps beyond those requested — e.g. when consuming - # output from an interpolator emulator or a baseline with sub-step output. - ds = ds.sel(lead_time=lead_times) + fetch_lead_times = np.array(fetch_steps, dtype="timedelta64[h]") + # Restrict to the lead times we'll work with (fetch_lead_times = requested + # steps + step 0 if needed). This drops any extra (e.g. hourly) steps the + # GRIB may contain beyond what we asked for — e.g. when consuming output + # from an interpolator emulator or a baseline with sub-step output. + if step_zero_synthetic: + # Step 0 is missing from the GRIB; reindex inserts NaN at lead 0, + # which the xr.where below replaces with 0. + ds = ds.sel(lead_time=lead_times).reindex(lead_time=fetch_lead_times) + else: + ds = ds.sel(lead_time=fetch_lead_times) if "TOT_PREC" in ds.data_vars: ## Disaggregate TOT_PREC from cumulative-from-start (expected when the ## accumulate_from_start_of_forecast post-processor is enabled in @@ -166,6 +191,8 @@ def load_fct_data_from_grib( ## small float-noise negatives to zero (anything below -0.1 mm has ## already been caught by the check above). ds = ds.assign(TOT_PREC=diff.clip(min=0.0).reindex(lead_time=lead_times)) + # Drop the auxiliary step 0 from any non-TOT_PREC variables. + ds = ds.sel(lead_time=lead_times) # make sure time coordinate is available, and valid_time is not if "valid_time" in ds.coords: ds = ds.rename({"valid_time": "time"}) @@ -192,10 +219,18 @@ def load_baseline_from_zarr( {"forecast_reference_time": "ref_time", "step": "lead_time"} ).sortby("lead_time") lead_times = np.array(steps, dtype="timedelta64[h]") + # For TOT_PREC (cumulative-from-start) we need step 0 in the slice so that + # .diff() yields a 0->step period accumulation even when the caller + # requested a single step. The extra step is dropped at the final reindex. + zero_lt = np.timedelta64(0, "h") + if "TOT_PREC" in params and zero_lt not in lead_times: + fetch_lead_times = np.concatenate([[zero_lt], lead_times]) + else: + fetch_lead_times = lead_times # Restrict to the requested lead times up-front so that the TOT_PREC # disaggregation below operates on the correct step interval, and so that # all other variables avoid loading unused hourly steps from the zarr. - baseline = baseline[params].sel(ref_time=reftime, lead_time=lead_times) + baseline = baseline[params].sel(ref_time=reftime, lead_time=fetch_lead_times) if "TOT_PREC" in baseline.data_vars: if baseline.TOT_PREC.units == "m": baseline = baseline.assign(TOT_PREC=lambda x: x.TOT_PREC * 1000) @@ -222,6 +257,8 @@ def load_baseline_from_zarr( baseline = baseline.assign( TOT_PREC=diff.clip(min=0.0).reindex(lead_time=lead_times) ) + # Drop the auxiliary step 0 from any non-TOT_PREC variables. + baseline = baseline.sel(lead_time=lead_times) baseline = baseline.assign_coords(time=baseline.ref_time + baseline.lead_time) if "latitude" in baseline.coords and "longitude" in baseline: baseline = baseline.rename({"latitude": "lat", "longitude": "lon"}) diff --git a/src/evalml/cli.py b/src/evalml/cli.py index 51a9ed45..fcc295ff 100644 --- a/src/evalml/cli.py +++ b/src/evalml/cli.py @@ -129,6 +129,7 @@ def execute_workflow( dag: bool = False, rulegraph: bool = False, extra_smk_args: tuple[str, ...] = (), + extra_targets: list[str] = [], ): if dag or rulegraph: generate_graph( @@ -146,7 +147,7 @@ def execute_workflow( if report and not dry_run: command += ["--report-after-run", "--report", str(report)] - command.append(target) + command += [target] + extra_targets command += list(extra_smk_args) if not verbose: command += ["--quiet", "rules"] # reduce verobosity of snakemake output @@ -163,9 +164,24 @@ def cli(): @click.argument( "configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path) ) +@click.option( + "--maps", + is_flag=True, + default=False, + help="Also produce score maps (computationally intensive).", +) @workflow_options def experiment( - configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args + configfile, + maps, + cores, + verbose, + dry_run, + unlock, + report, + dag, + rulegraph, + extra_smk_args, ): execute_workflow( configfile, @@ -178,6 +194,7 @@ def experiment( dag, rulegraph, extra_smk_args, + extra_targets=["score_maps_all"] if maps else [], ) diff --git a/src/evalml/config.py b/src/evalml/config.py index f1343d07..73925af8 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Any, ClassVar, FrozenSet +from typing import Dict, List, Any, ClassVar, FrozenSet, Literal from pydantic import BaseModel, Field, RootModel, field_validator @@ -208,6 +208,45 @@ class BaselineItem(BaseModel): baseline: BaselineConfig +class ScoreMapsConfig(BaseModel): + """Parameters controlling which score map plots are produced.""" + + params: List[str] = Field( + default=["T_2M", "TD_2M", "U_10M", "V_10M", "SP_10M", "PS", "PMSL", "TOT_PREC"], + description=( + "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, " + "PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M)." + ), + ) + leadtimes: List[int] | Literal["all"] = Field( + default=list(range(6, 121, 6)), + description=( + "List of lead times (hours) to plot, or the literal string 'all' " + "to expand to the union of step lists from all configured runs " + "and baselines." + ), + ) + scores: List[str] = Field( + default=["BIAS", "RMSE", "MAE"], + description="List of verification scores to plot.", + ) + regions: List[str] = Field( + default=["switzerland", "centraleurope"], + description="List of regions to plot.", + ) + seasons: List[str] = Field( + default=["all", "DJF", "MAM", "JJA", "SON"], + description="List of seasons to plot.", + ) + init_hours: List[str] = Field( + default=["all"], + description=( + "List of initialization hours to plot. Use 'all' for the unstratified " + "view, or zero-padded hour strings like '00', '06', '12', '18'." + ), + ) + + class Locations(BaseModel): """Locations of data and services used in the workflow.""" @@ -351,6 +390,10 @@ def validate_threshold_operators( dashboard: Dashboard locations: Locations profile: Profile + score_maps: ScoreMapsConfig = Field( + default_factory=ScoreMapsConfig, + description="Parameters for score map plots (used with --maps flag).", + ) model_config = { "extra": "forbid", # fail on misspelled keys diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py index ce5e4e63..08c200ca 100644 --- a/src/plotting/__init__.py +++ b/src/plotting/__init__.py @@ -30,11 +30,11 @@ "projection": _PROJECTIONS["orthographic"], }, "centraleurope": { - "extent": [-2.6, 19.5, 40.2, 52.3], + "extent": [-1.5, 18, 41.5, 51], "projection": _PROJECTIONS["orthographic"], }, "switzerland": { - "extent": [0, 17.5, 40.5, 53.0], + "extent": [5.5, 11.0, 45.5, 48.0], "projection": _PROJECTIONS["orthographic"], }, } diff --git a/src/plotting/colormap_defaults.py b/src/plotting/colormap_defaults.py index 6794b82a..568646da 100644 --- a/src/plotting/colormap_defaults.py +++ b/src/plotting/colormap_defaults.py @@ -4,6 +4,7 @@ from matplotlib import pyplot as plt import warnings from .colormap_loader import load_ncl_colormap +import numpy as np def _fallback(): @@ -110,6 +111,135 @@ def _fallback(): 120.0, ], }, + # Sequential Reds for RMSE and MAE: error is non-negative, larger ⇒ darker. + # Levels start at 0 so saturation maps directly to error magnitude; + # discrete levels make absolute values readable from the colour bar. + # RMSE: + "U_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "V_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "SP_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "TD_2M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "T_2M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "PMSL.RMSE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "PS.RMSE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "TOT_PREC.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 1, 1.5, 2, 3, 4], + } + | {"units": "mm"}, + # MAE: + "U_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "V_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "SP_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "TD_2M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "T_2M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "PMSL.MAE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "PS.MAE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "TOT_PREC.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 1, 1.5, 2, 3, 4], + } + | {"units": "mm"}, + # the levels for precip are a bit on the bright side, but still worth keeping consistent with RMSE. + # Bias: + # diverging colour scheme for the Bias to reflect the nature of the data (can be positive or negative, symmetric). + # Red-Blue colour scheme for all variables except precipitation, where a Brown-Green scheme is more suggestive. + "U_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "V_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "SP_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "TD_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "T_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "PMSL.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "PS.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "TOT_PREC.BIAS.map": { + "cmap": plt.get_cmap("BrBG", 9), + "levels": [-1, -0.5, -0.25, -0.1, 0.1, 0.25, 0.5, 1], + } + | {"units": "mm"}, } CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS) diff --git a/workflow/Snakefile b/workflow/Snakefile index a2586cb4..d8001d2a 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -135,6 +135,33 @@ rule experiment_all: ), +rule score_maps_all: + """Target rule for score maps (opt-in via evalml experiment --maps).""" + input: + expand( + rules.plot_score_maps.output, + run_id=collect_all_candidates(), + leadtime=resolve_leadtimes(config["score_maps"]["leadtimes"]), + score=config["score_maps"]["scores"], + param=config["score_maps"]["params"], + region=config["score_maps"]["regions"], + season=config["score_maps"]["seasons"], + init_hour=config["score_maps"]["init_hours"], + experiment=EXPERIMENT_NAME, + ), + expand( + rules.plot_score_maps_baseline.output, + baseline_id=list(BASELINE_CONFIGS), + leadtime=resolve_leadtimes(config["score_maps"]["leadtimes"]), + score=config["score_maps"]["scores"], + param=config["score_maps"]["params"], + region=config["score_maps"]["regions"], + season=config["score_maps"]["seasons"], + init_hour=config["score_maps"]["init_hours"], + experiment=EXPERIMENT_NAME, + ), + + rule showcase_all: """Target rule for showcase workflow.""" input: diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index e858c80f..b44bde3d 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -299,3 +299,20 @@ RUN_CONFIGS = collect_all_runs() ENV_CONFIGS = collect_all_envs() BASELINE_CONFIGS = collect_all_baselines() EXPERIMENT_PARTICIPANTS = collect_experiment_participants() + + +def resolve_leadtimes(spec): + """Resolve a lead-time specification from config. + + Accepts: + - a list of ints — returned verbatim. + - the literal string "all" — expanded to the union of step lists + from all configured runs and baselines. + """ + if spec != "all": + return spec + all_steps = set() + for cfg in (*RUN_CONFIGS.values(), *BASELINE_CONFIGS.values()): + start, end, step = map(int, cfg["steps"].split("/")) + all_steps.update(range(start, end + 1, step)) + return sorted(all_steps) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 7eb90a59..8f48eed2 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -142,3 +142,49 @@ rule make_forecast_animation: """ convert -delay {params.delay} -loop 0 {input} {output} """ + + +rule plot_score_maps: + # localrule: True + input: + script="workflow/scripts/plot_score_maps.mo.py", + verif_file=OUT_ROOT / "data/runs/{run_id}/score_maps/{param}_{leadtime}.nc", + output: + OUT_ROOT + / "results/{experiment}/score_maps/runs/{run_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + wildcard_constraints: + leadtime=r"\d+", # only digits + init_hour=r"all|\d{1,2}", + log: + OUT_ROOT + / "logs/plot_score_maps/{experiment}/{run_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", + resources: + slurm_partition="postproc", + cpus_per_task=1, + runtime="10m", + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run python {input.script} \ + --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + --season {wildcards.season} --init_hour {wildcards.init_hour} > {log} 2>&1 + # interactive editing (needs to set localrule: True and use only one core) + # marimo edit {input.script} -- \ + # --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + # --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + # --season {wildcards.season} --init_hour {wildcards.init_hour} + """ + + +use rule plot_score_maps as plot_score_maps_baseline with: + input: + script="workflow/scripts/plot_score_maps.mo.py", + verif_file=OUT_ROOT + / f"data/baselines/{{baseline_id}}/{config['truth']['label']}/score_maps/{{param}}_{{leadtime}}.nc", + output: + OUT_ROOT + / "results/{experiment}/score_maps/baselines/{baseline_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + log: + OUT_ROOT + / "logs/plot_score_maps/{experiment}/{baseline_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 58b215e5..195ad2e5 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -164,3 +164,79 @@ rule verification_metrics_plot: """ uv run {input.script} {input.verif} --output_dir {output} > {log} 2>&1 """ + + +rule verification_score_maps: + input: + "src/verification/__init__.py", + "src/data_input/__init__.py", + script="workflow/scripts/verification_score_maps.py", + inference_okfiles=lambda wc: expand( + rules.inference_execute.output.okfile, + init_time=_restrict_reftimes_to_hours(REFTIMES), + allow_missing=True, + ), + truth=config["truth"]["root"], + output: + OUT_ROOT / "data/runs/{run_id}/score_maps/{param}_{leadtime}.nc", + # wildcard_constraints: + # run_id="^" # to avoid ambiguitiy with run_baseline_verif + # TODO: implement logic to use experiment name instead of run_id as wildcard + params: + fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), + fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth_label=config["truth"]["label"], + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + log: + OUT_ROOT / "logs/verification_score_maps/{run_id}-{param}-{leadtime}.log", + resources: + cpus_per_task=24, + mem_mb=50_000, + runtime="60m", + shell: + """ + uv run {input.script} \ + --run_root output/data/runs/{wildcards.run_id} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --param {wildcards.param} \ + --output {output} > {log} 2>&1 + """ + + +rule verification_score_maps_baseline: + input: + script="workflow/scripts/verification_score_maps.py", + # Declared as inputs purely for dependency tracking (re-run if the baseline + # archive changes). The script discovers the zarrs itself by globbing + # `--baseline_root`, so this list is intentionally not passed on the CLI. + baseline_zarrs=lambda wc: expand( + "{root}/FCST{year}.zarr", + root=BASELINE_CONFIGS[wc.baseline_id].get("root"), + year=sorted({t.strftime("%y") for t in REFTIMES}), + ), + truth=config["truth"]["root"], + output: + OUT_ROOT + / f"data/baselines/{{baseline_id}}/{config['truth']['label']}/score_maps/{{param}}_{{leadtime}}.nc", + params: + baseline_root=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("root"), + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + log: + OUT_ROOT + / f"logs/verification_score_maps_baseline/{{baseline_id}}-{config['truth']['label']}-{{param}}-{{leadtime}}.log", + resources: + cpus_per_task=24, + mem_mb=50_000, + runtime="60m", + shell: + """ + uv run {input.script} \ + --baseline_root {params.baseline_root} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --param {wildcards.param} \ + --output {output} > {log} 2>&1 + """ diff --git a/workflow/scripts/plot_score_maps.mo.py b/workflow/scripts/plot_score_maps.mo.py new file mode 100644 index 00000000..c484ad3e --- /dev/null +++ b/workflow/scripts/plot_score_maps.mo.py @@ -0,0 +1,222 @@ +import marimo + +__generated_with = "0.19.4" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import logging + from argparse import ArgumentParser + from pathlib import Path + + import earthkit.plots as ekp + import numpy as np + import xarray as xr + + from plotting import DOMAINS + from plotting import StatePlotter + from plotting.colormap_defaults import CMAP_DEFAULTS + + return ( + ArgumentParser, + CMAP_DEFAULTS, + DOMAINS, + Path, + StatePlotter, + ekp, + logging, + np, + xr, + ) + + +@app.cell +def _(logging): + LOG = logging.getLogger(__name__) + LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + logging.basicConfig(level=logging.INFO, format=LOG_FMT) + return (LOG,) + + +@app.cell +def _(ArgumentParser, Path, np): + parser = ArgumentParser() + + parser.add_argument( + "--input", + type=str, + default=None, + help="Directory to .nc data containing the error fields", + ) + parser.add_argument("--outfn", type=str, help="output filename") + parser.add_argument("--leadtime", type=str, help="leadtime") + parser.add_argument("--param", type=str, help="parameter") + parser.add_argument("--region", type=str, help="name of region") + parser.add_argument( + "--score", + type=str, + help="Evaluation Score. So far Bias, RMSE, MAE or STDE are implemented.", + ) + parser.add_argument("--season", type=str, default="all", help="season filter") + parser.add_argument( + "--init_hour", type=str, default="all", help="initialization hour filter" + ) + + args = parser.parse_args() + verif_file = Path(args.input) + outfn = Path(args.outfn) + lead_time = args.leadtime + param = args.param + region = args.region + season = args.season + init_hour = args.init_hour + score = args.score + + if isinstance(init_hour, str): + if init_hour == "all": + init_hour = -999 + else: + try: + init_hour = int(init_hour) + except ValueError as exc: + raise ValueError("init_hour must be 'all' or an integer hour") from exc + + lead_time = np.timedelta64(lead_time, "h") + return ( + init_hour, + lead_time, + outfn, + param, + region, + score, + season, + verif_file, + ) + + +@app.cell +def _(LOG, init_hour, param, score, season, verif_file, xr): + ds = xr.open_dataset(verif_file) + LOG.info("Opened dataset: %s", ds) + var = f"{param}.{score}" + LOG.info( + "Selecting variable '%s' for season '%s', init_hour=%s", var, season, init_hour + ) + ds = ds[var].sel(season=season, init_hour=init_hour) + LOG.info( + "Selected DataArray: dims=%s, shape=%s, dtype=%s", ds.dims, ds.shape, ds.dtype + ) + LOG.info( + "Value range: min=%.4g, max=%.4g, n_nan=%d", + float(ds.min()), + float(ds.max()), + int(ds.isnull().sum()), + ) + return (ds,) + + +@app.cell +def _(CMAP_DEFAULTS, ekp): + def get_style(param, score, units_override=None): + """Get style and colormap settings for the plot. + Needed because cmap/norm does not work in Style(colors=cmap), + still needs to be passed as arguments to tripcolor()/tricontourf(). + """ + score_key = f"{param}.{score}.map" + cfg = ( + CMAP_DEFAULTS[score_key] + if score_key in CMAP_DEFAULTS + else CMAP_DEFAULTS.get(param, {}) + ) + units = units_override if units_override is not None else cfg.get("units", "") + return { + "style": ekp.styles.Style( + levels=cfg.get("bounds", cfg.get("levels", None)), + extend="both", + units=units, + colors=cfg.get("colors", None), + ), + "norm": cfg.get("norm", None), + "cmap": cfg.get("cmap", None), + "levels": cfg.get("levels", None), + "vmin": cfg.get("vmin", None), + "vmax": cfg.get("vmax", None), + "colors": cfg.get("colors", None), + } + + return (get_style,) + + +@app.cell +def _( + DOMAINS, + LOG, + StatePlotter, + ds, + get_style, + init_hour, + lead_time, + np, + outfn, + param, + region, + score, + season, +): + # plot individual fields + + plotter = StatePlotter( + ds["lon"].values.ravel(), + ds["lat"].values.ravel(), + outfn.parent, + ) + fig = plotter.init_geoaxes( + nrows=1, + ncols=1, + projection=DOMAINS[region]["projection"], + bbox=DOMAINS[region]["extent"], + name=region, + size=(6, 6), + ) + subplot = fig.add_map(row=0, column=0) + + plot_vals = ds.values.ravel() + + style_kwargs = get_style(param, score) + LOG.info("style_kwargs: %s", style_kwargs) + + if np.all(np.isnan(plot_vals)): + LOG.warning( + "All values are NaN for %s %s season=%s — plotting empty map.", + param, + score, + season, + ) + import matplotlib.patches as mpatches + + subplot.ax.set_facecolor("#cccccc") + subplot.standard_layers() + grey_patch = mpatches.Patch(color="#cccccc", label="No data") + subplot.ax.legend(handles=[grey_patch], loc="lower left", fontsize=8) + else: + plotter.plot_field(subplot, plot_vals, **style_kwargs) + + # black coast lines and country borders for better visibility + # grey is hardly visible, especially when the shading colours are intense. + subplot.coastlines(edgecolor="black", linewidth=1.0, zorder=5) + subplot.borders(edgecolor="black", linewidth=0.5, zorder=5) + + init_hour_lbl = "all" if init_hour == -999 else f"{init_hour:02d}" + fig.title( + f"{score} of {param}, Season: {season}, " + f"Init hour: {init_hour_lbl}, Lead Time: {lead_time}" + ) + + fig.save(outfn, bbox_inches="tight", dpi=200) + LOG.info(f"saved: {outfn}") + return + + +if __name__ == "__main__": + app.run() diff --git a/workflow/scripts/report_experiment_dashboard.py b/workflow/scripts/report_experiment_dashboard.py index 866d18a5..f2372d80 100644 --- a/workflow/scripts/report_experiment_dashboard.py +++ b/workflow/scripts/report_experiment_dashboard.py @@ -38,9 +38,7 @@ def main(args): ds = xr.concat(dfs, dim="source", join="outer") LOG.info("Loaded verification netcdf: \n%s", ds) - # extract only non-spatial variables to pd.DataFrame - nonspatial_vars = [d for d in ds.data_vars if "spatial" not in d] - df = ds[nonspatial_vars].to_array("stack").to_dataframe(name="value").reset_index() + df = ds.to_array("stack").to_dataframe(name="value").reset_index() df[["param", "metric"]] = df["stack"].str.split(".", n=1, expand=True) df["metric"] = df.metric.apply(decode_metric) df.drop(columns=["stack"], inplace=True) diff --git a/workflow/scripts/verification_score_maps.py b/workflow/scripts/verification_score_maps.py new file mode 100644 index 00000000..5f3afdaa --- /dev/null +++ b/workflow/scripts/verification_score_maps.py @@ -0,0 +1,572 @@ +"""Compute spatial maps of temporally-aggregated forecast errors. + +For a fixed lead time and variable, iterates over all initialisation times found +under a run directory, loads the corresponding GRIB forecast field and the +matching truth slice from a reference zarr, maps the forecast onto the truth +grid, and accumulates running error statistics without ever holding the full +time series in memory. The final BIAS / RMSE / MAE / STDE maps are written to a +NetCDF file. + +Usage +----- + uv run workflow/scripts/verification_score_maps.py \\ + output/data/runs/ \\ + --truth /path/to/truth.zarr \\ + --step 24 \\ + --param T_2M +""" + +import logging +from argparse import ArgumentParser, Namespace +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import xarray as xr + +from data_input import load_fct_data_from_grib, load_baseline_from_zarr +from verification.spatial import map_forecast_to_truth + +LOG = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +DATETIME_FMT = "%Y%m%d%H%M" + +SEASONS = ["DJF", "MAM", "JJA", "SON", "all"] +# Init hour buckets. -999 is the "all" sentinel (matches verification_aggregation.py). +INIT_HOURS = [0, 6, 12, 18, -999] + + +def _season_of(dt: datetime) -> str: + """Return the meteorological season string for a given datetime.""" + month = dt.month + if month in (12, 1, 2): + return "DJF" + if month in (3, 4, 5): + return "MAM" + if month in (6, 7, 8): + return "JJA" + return "SON" + + +# Maps from standard parameter names to zarr variable names. +# COSMO-2e zarrs use short CF names; COSMO-1e zarrs keep the COSMO names. +_PARAMS_MAP_CO2 = { + "T_2M": "2t", + "TD_2M": "2d", + "U_10M": "10u", + "V_10M": "10v", + "PS": "sp", + "PMSL": "msl", + "TOT_PREC": "tp", +} +_PARAMS_MAP_CO1 = {k: k.replace("TOT_PREC", "TOT_PREC_6H") for k in _PARAMS_MAP_CO2} + +# Derived variables and the components they require. +_DERIVED = { + "SP_10M": ("U_10M", "V_10M"), +} + + +def _params_map(truth_root: Path) -> dict[str, str]: + return _PARAMS_MAP_CO2 if "co2" in truth_root.name else _PARAMS_MAP_CO1 + + +def _compute_derived(ds: xr.Dataset, param: str) -> xr.DataArray: + """Compute a derived variable from its components already present in *ds*.""" + if param == "SP_10M": + return (ds["U_10M"] ** 2 + ds["V_10M"] ** 2) ** 0.5 + raise ValueError(f"No recipe for derived variable '{param}'") + + +# --------------------------------------------------------------------------- +# Truth loading +# --------------------------------------------------------------------------- +# TODO: consolidate with src/data_input/__init__.py once the ongoing +# data-loading refactor lands. _open_zarr_component below duplicates +# ~80% of load_analysis_data_from_zarr but returns a lazy DataArray +# rather than a time-sliced Dataset, which is what our streaming +# aggregation needs. The right end-state is a shared lazy-open primitive +# in data_input that both consumers use; not introduced here because +# data_input is being reworked separately and we don't want to conflict. + + +def _open_zarr_component(root: Path, param: str) -> xr.DataArray: + """Open a single native zarr variable lazily as a DataArray.""" + zarr_param = _params_map(root)[param] + + ds = xr.open_zarr(root, consolidated=False) + ds = ds.set_index(time="dates") + + # Extract lat/lon before selecting on variable (they live on cell only). + spatial_dim = "cell" + lat = ds["latitudes"] if "latitudes" in ds else None + lon = ds["longitudes"] if "longitudes" in ds else None + + ds = ds.assign_coords(variable=ds.attrs["variables"]) + ds = ds.sel(variable=zarr_param).squeeze("ensemble", drop=True) + + # Recover 2-D spatial shape when stored as a flat cell dimension. + if len(ds.attrs["field_shape"]) == 2: + ny, nx = ds.attrs["field_shape"] + y_idx, x_idx = np.unravel_index(np.arange(ny * nx), (ny, nx)) + ds = ds.assign_coords(y=(spatial_dim, y_idx), x=(spatial_dim, x_idx)) + ds = ds.set_index(**{spatial_dim: ("y", "x")}).unstack(spatial_dim) + spatial_dim = None # now (y, x) + + da = ds["data"].rename(param).drop_vars("variable", errors="ignore") + + # Attach lat/lon as coordinates on the spatial dimension(s). + if lat is not None and lon is not None: + if spatial_dim is not None: + # flat 1-D case: cell/values dim + da = da.assign_coords( + lat=(spatial_dim, lat.values), lon=(spatial_dim, lon.values) + ) + else: + # 2-D case: lat/lon still on original flat index — attach via unstack + da = da.assign_coords( + lat=(["y", "x"], lat.values.reshape(ny, nx)), + lon=(["y", "x"], lon.values.reshape(ny, nx)), + ) + + return da + + +def open_truth_zarr(root: Path, param: str) -> xr.DataArray: + """Open the truth zarr lazily and return a DataArray for *param*. + + For derived variables (e.g. SP_10M) the required components are loaded and + the derivation is applied on the fly. The returned DataArray has dimensions + ``(time, y, x)`` or ``(time, values)`` and always exposes ``lat``/``lon``. + """ + if param in _DERIVED: + components = { + c: _open_zarr_component(root, c).drop_vars("variable", errors="ignore") + for c in _DERIVED[param] + } + ds = xr.Dataset(components) + return _compute_derived(ds, param) + return _open_zarr_component(root, param) + + +# --------------------------------------------------------------------------- +# Init-time discovery +# --------------------------------------------------------------------------- + + +def iter_init_dirs(run_root: Path) -> list[tuple[datetime, Path]]: + """Return ``(reftime, grib_dir)`` pairs for every complete init time. + + Expects subdirectories named ``YYYYMMDDHHMI`` directly under *run_root*. + GRIB files may live either directly in the init-time directory or inside a + ``grib/`` subdirectory. + """ + result = [] + for d in sorted(run_root.iterdir()): + if not d.is_dir(): + continue + try: + reftime = datetime.strptime(d.name, DATETIME_FMT) + except ValueError: + continue + grib_dir = d / "grib" if (d / "grib").is_dir() else d + if not any(grib_dir.glob("*.grib")): + LOG.debug("No GRIB files in %s, skipping", grib_dir) + continue + result.append((reftime, grib_dir)) + return result + + +def iter_baseline_init_times(baseline_root: Path, step: int) -> list[datetime]: + """Return all init times from a baseline's zarr(s) that have the requested step available. + + The per-year zarrs are discovered by globbing ``FCST*.zarr`` under ``baseline_root`` + rather than taking an explicit list: the layout is fixed and the discovered init + times are filtered down to the configured reftimes by the caller anyway, so any + extra years picked up from the archive are harmless. + """ + step_td = np.timedelta64(step, "h") + reftimes = [] + for zarr_path in sorted(baseline_root.glob("FCST*.zarr")): + if not zarr_path.exists(): + LOG.warning("Baseline zarr not found: %s", zarr_path) + continue + ds = xr.open_zarr(zarr_path, consolidated=True, decode_timedelta=True) + if step_td not in ds["step"].values: + LOG.warning("Step %dh not in %s, skipping", step, zarr_path) + continue + for rt in ds["forecast_reference_time"].values: + ts = (rt - np.datetime64("1970-01-01T00:00:00")) / np.timedelta64(1, "s") + reftimes.append(datetime.utcfromtimestamp(float(ts))) + return sorted(reftimes) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(args: Namespace) -> None: + LOG.info("=" * 60) + LOG.info("Spatial verification param=%s step=%dh", args.param, args.step) + LOG.info("Run root : %s", args.run_root) + LOG.info("Truth : %s", args.truth) + LOG.info("Output : %s", args.output) + LOG.info("=" * 60) + + # Open the truth zarr once; individual time slices are loaded on demand. + truth_da = open_truth_zarr(args.truth, args.param) + # Normalise to datetime64[ns] so membership checks work regardless of zarr precision. + truth_da = truth_da.assign_coords( + time=truth_da.time.values.astype("datetime64[ns]") + ) + # Rename flat spatial dim to 'values' if the zarr uses 'cell'. + if "cell" in truth_da.dims: + truth_da = truth_da.rename({"cell": "values"}) + truth_times = set( + truth_da.time.values + ) # keep as datetime64, tolist() yields ints for ns precision + LOG.info("Truth opened lazily: %s", truth_da) + + if args.baseline_root: + init_items = [ + (rt, None) + for rt in iter_baseline_init_times(args.baseline_root, args.step) + ] + LOG.info("Found %d baseline init times", len(init_items)) + else: + init_items = iter_init_dirs(args.run_root) + LOG.info("Found %d init time directories", len(init_items)) + + # Restrict to the experiment's configured init times if provided. + # Without this, baseline zarrs (which contain a continuous archive) would + # cause the script to process every init time in the file rather than + # only those in the user's hindcast period. + if args.reftimes: + wanted = {datetime.strptime(s, DATETIME_FMT) for s in args.reftimes} + init_items = [(rt, d) for rt, d in init_items if rt in wanted] + LOG.info("Filtered to %d init times matching --reftimes", len(init_items)) + + step_td = timedelta(hours=args.step) + + # Running accumulators keyed by (season, init_hour) – initialised on the + # first successfully processed sample so that we can infer the spatial + # shape from the data. Each entry is a numpy array over the spatial + # dimension(s). + bucket_keys = [(s, h) for s in SEASONS for h in INIT_HOURS] + accum_n: dict[tuple[str, int], np.ndarray | None] = {k: None for k in bucket_keys} + accum_sum_e: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_se: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_ae: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + ref_truth_slice: xr.DataArray | None = None # kept for output coordinates + + n_ok = 0 + n_skip = 0 + + for reftime, grib_dir in init_items: + valid_time = np.datetime64(reftime + step_td).astype("datetime64[ns]") + + if valid_time not in truth_times: + LOG.debug("Valid time %s not in truth, skipping %s", valid_time, reftime) + n_skip += 1 + continue + + LOG.info( + "Processing reftime=%s valid=%s", + reftime.strftime(DATETIME_FMT), + valid_time, + ) + + first_iter = n_ok == 0 + + # --- load forecast --- + fct_params = ( + list(_DERIVED[args.param]) if args.param in _DERIVED else [args.param] + ) + + try: + if args.baseline_root: + zarr_path = args.baseline_root / f"FCST{reftime.strftime('%y')}.zarr" + fcst = load_baseline_from_zarr( + root=zarr_path, + reftime=reftime, + steps=[args.step], + params=fct_params, + ) + else: + # The loaders handle cumulative-from-start disaggregation + # internally (including fetching step 0 when needed for + # TOT_PREC), so a single-step request is sufficient here. + fcst = load_fct_data_from_grib( + root=grib_dir, + reftime=reftime, + steps=[args.step], + params=fct_params, + ) + except Exception as exc: + LOG.warning("Could not load forecast for %s: %s", reftime, exc) + n_skip += 1 + continue + + # Drop lead_time dimension (select only the requested step). + if "lead_time" in fcst.dims: + fcst = fcst.sel(lead_time=np.timedelta64(args.step, "h")) + + # Compute derived variable if needed. + if args.param in _DERIVED: + fcst = fcst.assign({args.param: _compute_derived(fcst, args.param)}) + + if first_iter: + LOG.info("fcst (after step selection): %s", fcst) + fcst_raw = fcst[args.param].values if args.param in fcst else None + if fcst_raw is not None: + n_nan_fcst = int(np.isnan(fcst_raw).sum()) + LOG.info( + "fcst[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_raw.shape, + float(np.nanmin(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + float(np.nanmax(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + n_nan_fcst, + ) + + # --- load truth slice --- + truth_slice = truth_da.sel(time=valid_time).compute() + # For derived variables truth_da is already the derived DataArray, + # so wrap it in a Dataset for map_forecast_to_truth compatibility. + truth_ds = ( + truth_slice.to_dataset(name=args.param) + if isinstance(truth_slice, xr.DataArray) + else truth_slice + ) + + if first_iter: + truth_raw = truth_slice.values + n_nan_truth = int(np.isnan(truth_raw).sum()) + LOG.info( + "truth_slice[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + truth_raw.shape, + float(np.nanmin(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + float(np.nanmax(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + n_nan_truth, + ) + + # --- map forecast onto truth grid --- + try: + fcst_mapped = map_forecast_to_truth(fcst, truth_ds) + except Exception as exc: + LOG.warning("Spatial mapping failed for %s: %s", reftime, exc) + n_skip += 1 + continue + + fcst_param = fcst_mapped[args.param] + # Squeeze any ensemble/eps dimension (deterministic run stored with size-1 eps). + for dim in ["eps", "ensemble", "number"]: + if dim in fcst_param.dims and fcst_param.sizes[dim] == 1: + fcst_param = fcst_param.squeeze(dim, drop=True) + fcst_vals = fcst_param.values + truth_vals = truth_slice.values + error = fcst_vals - truth_vals # shape: spatial dims of truth + + if first_iter: + n_nan_mapped = int(np.isnan(fcst_vals).sum()) + LOG.info( + "fcst_mapped[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_vals.shape, + float(np.nanmin(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + float(np.nanmax(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + n_nan_mapped, + ) + n_nan_err = int(np.isnan(error).sum()) + LOG.info( + "error: shape=%s, min=%.4g, max=%.4g, n_nan=%d / %d", + error.shape, + float(np.nanmin(error)) if n_nan_err < error.size else float("nan"), + float(np.nanmax(error)) if n_nan_err < error.size else float("nan"), + n_nan_err, + error.size, + ) + + n_nan_error = int(np.isnan(error).sum()) + if n_nan_error == error.size: + LOG.warning( + "reftime=%s: error is all-NaN (%d points) — nothing accumulated.", + reftime.strftime(DATETIME_FMT), + error.size, + ) + + # --- initialise accumulators on first valid sample --- + if accum_n[("all", -999)] is None: + for k in bucket_keys: + accum_n[k] = np.zeros(error.shape, dtype=np.int64) + accum_sum_e[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_se[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_ae[k] = np.zeros(error.shape, dtype=np.float64) + ref_truth_slice = truth_slice + + # --- accumulate into matching (season, init_hour) buckets, plus the + # "all" rows/cols on each axis (NaN-safe) --- + season = _season_of(reftime) + ih = reftime.hour + valid = ~np.isnan(error) + for s in (season, "all"): + for h in (ih, -999): + accum_n[(s, h)][valid] += 1 + accum_sum_e[(s, h)][valid] += error[valid] + accum_sum_se[(s, h)][valid] += error[valid] ** 2 + accum_sum_ae[(s, h)][valid] += np.abs(error[valid]) + n_ok += 1 + + LOG.info("Finished: %d init times processed, %d skipped", n_ok, n_skip) + + if n_ok == 0: + LOG.error("No data could be processed – no output written.") + return + + # --- compute aggregate maps per (season, init_hour), then stack --- + spatial_coords = { + c: ref_truth_slice[c] + for c in ref_truth_slice.coords + if set(ref_truth_slice[c].dims).issubset(set(ref_truth_slice.dims)) + and c != "time" + } + spatial_dims = list(ref_truth_slice.dims) + out_dims = ["season", "init_hour"] + spatial_dims + out_coords = {"season": SEASONS, "init_hour": INIT_HOURS, **spatial_coords} + + def _strat_da(compute_fn) -> xr.DataArray: + """Stack per-(season, init_hour) arrays into a (season, init_hour, *spatial) DataArray.""" + out_shape = (len(SEASONS), len(INIT_HOURS)) + ref_truth_slice.shape + arr = np.empty(out_shape, dtype=np.float32) + for i, s in enumerate(SEASONS): + for j, h in enumerate(INIT_HOURS): + n = accum_n[(s, h)] + with np.errstate(invalid="ignore", divide="ignore"): + arr[i, j] = compute_fn(n, s, h).astype(np.float32) + return xr.DataArray(arr, dims=out_dims, coords=out_coords) + + out = xr.Dataset( + { + f"{args.param}.BIAS": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_e[(s, h)] / n, np.nan) + ), + f"{args.param}.RMSE": _strat_da( + lambda n, s, h: np.where( + n > 0, np.sqrt(accum_sum_se[(s, h)] / n), np.nan + ) + ), + f"{args.param}.MAE": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_ae[(s, h)] / n, np.nan) + ), + f"{args.param}.STDE": _strat_da( + lambda n, s, h: np.where( + n > 0, + np.sqrt( + np.maximum( + accum_sum_se[(s, h)] / n + - (accum_sum_e[(s, h)] / n) ** 2, + 0.0, + ) + ), + np.nan, + ) + ), + f"{args.param}.N": _strat_da(lambda n, s, h: np.where(n > 0, n, np.nan)), + }, + attrs={ + "param": args.param, + "step_h": args.step, + "source": str(args.baseline_root if args.baseline_root else args.run_root), + "n_processed": n_ok, + "n_skipped": n_skip, + }, + ) + + LOG.info("Output dataset:\n%s", out) + args.output.parent.mkdir(parents=True, exist_ok=True) + out.to_netcdf(args.output) + LOG.info("Saved to %s", args.output) + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "Compute spatial maps of temporally-aggregated forecast errors. " + "Supports both model runs (GRIB) and baselines (zarr). " + "Exactly one of --run_root or --baseline_root must be provided." + ) + ) + parser.add_argument( + "--run_root", + type=Path, + default=None, + help="Root directory of a model run (e.g. output/data/runs/).", + ) + parser.add_argument( + "--baseline_root", + type=Path, + default=None, + help="Root directory of a baseline (e.g. /path/to/ICON-CH1-EPS), containing FCST.zarr files.", + ) + parser.add_argument( + "--truth", + type=Path, + required=True, + help="Path to the reference zarr dataset.", + ) + parser.add_argument( + "--step", + type=int, + required=True, + help="Forecast lead time in hours (e.g. 24).", + ) + parser.add_argument( + "--param", + type=str, + required=True, + help="Variable to verify (e.g. T_2M, TD_2M, U_10M).", + ) + parser.add_argument( + "--reftimes", + nargs="+", + default=None, + help=( + "Optional list of init times (YYYYMMDDHHMM) to restrict processing to. " + "Required for baselines whose zarr contains a continuous archive." + ), + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output NetCDF file.", + ) + args = parser.parse_args() + + if bool(args.run_root) == bool(args.baseline_root): + parser.error("Exactly one of --run_root or --baseline_root must be provided.") + + main(args) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index d6153e87..40e2936f 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -483,6 +483,120 @@ "title": "Profile", "type": "object" }, + "ScoreMapsConfig": { + "description": "Parameters controlling which score map plots are produced.", + "properties": { + "params": { + "default": [ + "T_2M", + "TD_2M", + "U_10M", + "V_10M", + "SP_10M", + "PS", + "PMSL", + "TOT_PREC" + ], + "description": "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M).", + "items": { + "type": "string" + }, + "title": "Params", + "type": "array" + }, + "leadtimes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "const": "all", + "type": "string" + } + ], + "default": [ + 6, + 12, + 18, + 24, + 30, + 36, + 42, + 48, + 54, + 60, + 66, + 72, + 78, + 84, + 90, + 96, + 102, + 108, + 114, + 120 + ], + "description": "List of lead times (hours) to plot, or the literal string 'all' to expand to the union of step lists from all configured runs and baselines.", + "title": "Leadtimes" + }, + "scores": { + "default": [ + "BIAS", + "RMSE", + "MAE" + ], + "description": "List of verification scores to plot.", + "items": { + "type": "string" + }, + "title": "Scores", + "type": "array" + }, + "regions": { + "default": [ + "switzerland", + "centraleurope" + ], + "description": "List of regions to plot.", + "items": { + "type": "string" + }, + "title": "Regions", + "type": "array" + }, + "seasons": { + "default": [ + "all", + "DJF", + "MAM", + "JJA", + "SON" + ], + "description": "List of seasons to plot.", + "items": { + "type": "string" + }, + "title": "Seasons", + "type": "array" + }, + "init_hours": { + "default": [ + "all" + ], + "description": "List of initialization hours to plot. Use 'all' for the unstratified view, or zero-padded hour strings like '00', '06', '12', '18'.", + "items": { + "type": "string" + }, + "title": "Init Hours", + "type": "array" + } + }, + "title": "ScoreMapsConfig", + "type": "object" + }, "Stratification": { "description": "Stratification settings for the analysis.", "properties": { @@ -624,6 +738,10 @@ }, "profile": { "$ref": "#/$defs/Profile" + }, + "score_maps": { + "$ref": "#/$defs/ScoreMapsConfig", + "description": "Parameters for score map plots (used with --maps flag)." } }, "required": [