Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 122 additions & 9 deletions src/data_input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,107 @@ def load_baseline_from_zarr(
return baseline


def load_baseline_from_grib(
root: Path, reftime: datetime, steps: list[int], params: list[str]
) -> xr.Dataset:
"""Load baseline forecast data directly from an ICON GRIB archive directory.

`root` should be the FCST<year> directory of the operational archive, e.g.
``/store_new/mch/msopr/osm/ICON-CH1-EPS/FCST25``. The function reads the
surface GRIB files for the control member (run_id ``000``) and returns a
Dataset with the same structure as :func:`load_baseline_from_zarr`.
"""
import earthkit.data as ekd
from meteodatalab.icon_grid import load_grid_from_balfrin
from uuid import UUID

# Find the reftime subdirectory (format yymmddHH_<version>).
# Multiple versions may exist; take the last one (lexicographically highest).
reftime_dirs = sorted(root.glob(f"{reftime:%y%m%d%H}_*"))
if not reftime_dirs:
raise ValueError(
f"No archive subdirectory found for {reftime:%y%m%d%H} in {root}"
)
reftime_dir = reftime_dirs[-1]
LOG.info("Reading ICON archive from %s", reftime_dir)

# Derive the GRIB filename prefix from the model name in the path.
if "ICON-CH1-EPS" in root.parts:
gribname = "i1eff"
elif "ICON-CH2-EPS" in root.parts:
gribname = "i2eff"
else:
raise ValueError(
f"Cannot determine model from path (expected ICON-CH1-EPS or "
f"ICON-CH2-EPS): {root}"
)

# Accumulate GRIB fields across all requested lead times for the control
# member. Surface files have no level-type suffix in the filename.
run_id = "000"
out = ekd.SimpleFieldList()
last_field = None
for lt in steps:
ld, lh = lt // 24, lt % 24
filepath = reftime_dir / "grib" / f"{gribname}{ld:02}{lh:02}0000_{run_id}"
for field in ekd.from_source("file", filepath):
if field.metadata("shortName") in params:
out.append(field)
last_field = field

if last_field is None:
raise ValueError(f"No fields matching params {params} found in {reftime_dir}")

# Convert to xarray using the native GRIB profile.
ds = out.to_xarray(profile="grib")

# Add lat/lon coordinates from the ICON horizontal grid file.
uuid_of_hgrid = UUID(last_field.metadata("uuidOfHGrid"))
hcoords = load_grid_from_balfrin()(uuid_of_hgrid)
ds = ds.assign_coords(
lat=hcoords["lat"].rename({"cell": "values"}),
lon=hcoords["lon"].rename({"cell": "values"}),
)

# earthkit profile="grib" uses "step" for the lead-time dimension.
if "step" in ds.dims:
ds = ds.rename({"step": "lead_time"})
lead_times = np.array(steps, dtype="timedelta64[h]")
ds = ds.sel(lead_time=lead_times)

if "TOT_PREC" in ds.data_vars:
## ICON GRIB archives store TOT_PREC cumulative from the start of the
## forecast in kg m-2 (no unit conversion required). Disaggregate to
## per-step accumulations using the same logic as load_baseline_from_zarr.
##
## Step 0 may be absent from some GRIB files; fill it with 0 so that
## .diff() does not produce NaN for the first accumulation period.
ds = ds.assign(
TOT_PREC=xr.where(ds.lead_time == np.timedelta64(0, "h"), 0.0, ds.TOT_PREC)
)
diff = ds.TOT_PREC.diff("lead_time")
min_diff = float(diff.min().compute())
if min_diff < -0.1:
raise ValueError(
f"TOT_PREC in the GRIB archive appears to already be "
f"period-accumulated (min(.diff()) = {min_diff:.3e} mm). "
f"Check the archive at {reftime_dir}."
)
ds = ds.assign(TOT_PREC=diff.clip(min=0.0).reindex(lead_time=lead_times))

# Add ref_time and time coordinates; drop earthkit GRIB artefacts that do
# not match evalml conventions.
ds = ds.assign_coords(
ref_time=np.datetime64(reftime, "ns"),
time=("lead_time", np.datetime64(reftime, "ns") + lead_times),
)
for coord in ("valid_time", "forecast_reference_time"):
if coord in ds.coords:
ds = ds.drop_vars(coord)

return ds


def load_obs_data_from_peakweather(
root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h"
) -> xr.Dataset:
Expand Down Expand Up @@ -316,23 +417,35 @@ def load_truth_data(
def load_forecast_data(
root, reftime: datetime, steps: list[int], params: list[str]
) -> xr.Dataset:
"""Load forecast data from GRIB files or a baseline Zarr dataset."""
"""Load forecast data from GRIB files, a Zarr dataset, or an ICON archive.

if any(root.glob("*.grib")):
LOG.info("Loading forecasts from GRIB files...")
fcst = load_fct_data_from_grib(
Routing (in order):
1. Path ends in ``.zarr`` → :func:`load_baseline_from_zarr`
2. ``*.grib`` files present in *root* → :func:`load_fct_data_from_grib`
(ML inference output)
3. Otherwise → :func:`load_baseline_from_grib` (ICON operational archive)
"""
root = Path(root)
if root.suffix == ".zarr":
LOG.info("Loading baseline forecasts from zarr dataset...")
return load_baseline_from_zarr(
root=root,
reftime=reftime,
steps=steps,
params=params,
)
else:
LOG.info("Loading baseline forecasts from zarr dataset...")
fcst = load_baseline_from_zarr(
if any(root.glob("*.grib")):
LOG.info("Loading forecasts from GRIB files...")
return load_fct_data_from_grib(
root=root,
reftime=reftime,
steps=steps,
params=params,
)

return fcst
LOG.info("Loading baseline forecasts from ICON GRIB archive...")
return load_baseline_from_grib(
root=root,
reftime=reftime,
steps=steps,
params=params,
)
18 changes: 12 additions & 6 deletions workflow/rules/verification.smk
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# VERIFICATION WORKFLOW #
# ----------------------------------------------------- #
from datetime import datetime
from pathlib import Path

import pandas as pd

Expand All @@ -10,16 +11,21 @@ include: "common.smk"


# TODO: make sure the boundaries aren't used
def _get_baseline_forecast_path(wc):
"""Return the forecast path for a baseline: zarr store if it exists,
otherwise the FCST<year> directory from the operational GRIB archive."""
root = BASELINE_CONFIGS[wc.baseline_id].get("root")
year = wc.init_time[2:4]
zarr_path = f"{root}/FCST{year}.zarr"
return zarr_path if Path(zarr_path).exists() else f"{root}/FCST{year}"


rule verification_metrics_baseline:
input:
"src/verification/__init__.py",
"src/data_input/__init__.py",
script="workflow/scripts/verification_metrics.py",
baseline_zarr=lambda wc: expand(
"{root}/FCST{year}.zarr",
root=BASELINE_CONFIGS[wc.baseline_id].get("root"),
year=wc.init_time[2:4],
),
forecast=_get_baseline_forecast_path,
truth=config["truth"]["root"],
params:
baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"),
Expand All @@ -38,7 +44,7 @@ rule verification_metrics_baseline:
shell:
"""
uv run {input.script} \
--forecast {input.baseline_zarr} \
--forecast {input.forecast} \
--truth {input.truth} \
--reftime {wildcards.init_time} \
--steps "{params.baseline_steps}" \
Expand Down
Loading