Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions src/data_input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def _collect_icon_archive_files(
]


def _discover_icon_member_ids(
root: Path, reftime: datetime, steps: list[int]
) -> list[str]:
"""Return sorted list of numeric member IDs present in the ICON archive for `reftime`."""
first_file = _collect_icon_archive_files(root, reftime, [steps[0]])[0]
prefix = first_file.name.rsplit("_", 1)[0]
return sorted(
p.name.rsplit("_", 1)[1] for p in first_file.parent.glob(f"{prefix}_???")
)


def load_from_grib_file(file: str | list[str], sel_kwargs):
fieldlist = ekd.from_source("file", file, lazily=True).to_fieldlist()
return fieldlist_to_xarray(fieldlist.sel(**sel_kwargs))
Expand Down Expand Up @@ -584,16 +595,43 @@ def load_icon_baseline_from_grib(
reftime: datetime,
steps: list[int],
params: list[str],
ensmean: bool = False,
) -> xr.Dataset:
"""Load an ICON-CH1-EPS or ICON-CH2-EPS baseline from the operational GRIB archive."""
return load_forecast_data_from_grib(
files=_collect_icon_archive_files(root, reftime, steps),
params=params,
)
if not ensmean:
return load_forecast_data_from_grib(
files=_collect_icon_archive_files(root, reftime, steps),
params=params,
)

member_ids = _discover_icon_member_ids(root, reftime, steps)
LOG.info("Computing ensemble mean over %d members: %s", len(member_ids), member_ids)
acc = None
n_loaded = 0
for mid in member_ids:
try:
ds = load_forecast_data_from_grib(
files=_collect_icon_archive_files(root, reftime, steps, member_id=mid),
params=params,
)
if "number" in ds.dims:
ds = ds.isel(number=0, drop=True)
acc = ds if acc is None else acc + ds
n_loaded += 1
except Exception as exc:
LOG.warning("Skipping member %s: %s", mid, exc)

if acc is None:
raise ValueError(
f"No ensemble members could be loaded for {reftime} from {root}"
)

LOG.info("Ensemble mean computed over %d members.", n_loaded)
return acc / n_loaded


def load_forecast_data(
root, reftime: datetime, steps: list[int], params: list[str]
root, reftime: datetime, steps: list[int], params: list[str], ensmean: bool = False
) -> xr.Dataset:
"""Load forecast data from GRIB files or an ICON archive.

Expand All @@ -615,4 +653,4 @@ def load_forecast_data(
LOG.info("Loading INCA baseline from NetCDF files...")
return load_INCA_baseline_from_netcdf(root, reftime, steps, params)
LOG.info("Loading baseline forecasts from ICON GRIB archive...")
return load_icon_baseline_from_grib(root, reftime, steps, params)
return load_icon_baseline_from_grib(root, reftime, steps, params, ensmean=ensmean)
4 changes: 4 additions & 0 deletions src/evalml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class BaselineConfig(BaseModel):
description="Forecast steps to be used from baseline, e.g. '10/120/1'.",
pattern=r"^\d*/\d*/\d*$",
)
ensmean: bool = Field(
False,
description="If true, load all ensemble members and compute their mean before verification.",
)


class TruthConfig(BaseModel):
Expand Down
5 changes: 5 additions & 0 deletions workflow/rules/verification.smk
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ rule verification_metrics_baseline:
OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc",
log:
OUT_ROOT / "logs/verification_metrics_baseline/{baseline_id}-{init_time}.log",
localrule: True
resources:
cpus_per_task=24,
mem_mb=50_000,
runtime="60m",
params:
baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"),
baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"],
ensmean=lambda wc: (
"--ensmean" if BASELINE_CONFIGS[wc.baseline_id].get("ensmean") else ""
),
truth_label=config["truth"]["label"],
regions=REGIONS,
experiment_params=",".join(EXPERIMENT_PARAMS),
Expand All @@ -44,6 +48,7 @@ rule verification_metrics_baseline:
--regions "{params.regions}" \
--params "{params.experiment_params}" \
--threshold_dict "{params.threshold_dict}" \
{params.ensmean} \
--output {output} >{log} 2>&1
"""

Expand Down
10 changes: 9 additions & 1 deletion workflow/scripts/verification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def main(args: ScriptConfig):
# get baseline forecast data
now = datetime.now()

fcst = load_forecast_data(args.forecast, args.reftime, args.steps, args.params)
fcst = load_forecast_data(
args.forecast, args.reftime, args.steps, args.params, ensmean=args.ensmean
)

LOG.info(
"Loaded forecast data in %s seconds: \n%s",
Expand Down Expand Up @@ -147,6 +149,12 @@ def main(args: ScriptConfig):
help="Dictionary of thresholds for each parameter in the format '{param: [threshold1, threshold2, ...]}' (default: None).",
default=None,
)
parser.add_argument(
"--ensmean",
action="store_true",
default=False,
help="Compute ensemble mean across all members before verification.",
)
parser.add_argument(
"--output",
type=Path,
Expand Down
6 changes: 6 additions & 0 deletions workflow/tools/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@
"pattern": "^\\d*/\\d*/\\d*$",
"title": "Steps",
"type": "string"
},
"ensmean": {
"default": false,
"description": "If true, load all ensemble members and compute their mean before verification.",
"title": "Ensmean",
"type": "boolean"
}
},
"required": [
Expand Down