Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions src/data_input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def _collect_icon_archive_files(
]


def _discover_icon_member_ids(
root: Path, reftime: datetime, steps: list[int]
) -> list[str]:
"""Return sorted list of numeric member IDs present in the ICON archive for `reftime`."""
first_file = _collect_icon_archive_files(root, reftime, [steps[0]])[0]
prefix = first_file.name.rsplit("_", 1)[0]
return sorted(
p.name.rsplit("_", 1)[1] for p in first_file.parent.glob(f"{prefix}_???")
)


def load_from_grib_file(file: str | list[str], sel_kwargs):
fieldlist = ekd.from_source("file", file, lazily=True).to_fieldlist()
return fieldlist_to_xarray(fieldlist.sel(**sel_kwargs))
Expand Down Expand Up @@ -584,16 +595,54 @@ def load_icon_baseline_from_grib(
reftime: datetime,
steps: list[int],
params: list[str],
member: str = "000",
) -> xr.Dataset:
"""Load an ICON-CH1-EPS or ICON-CH2-EPS baseline from the operational GRIB archive."""
return load_forecast_data_from_grib(
files=_collect_icon_archive_files(root, reftime, steps),
params=params,
)
"""Load an ICON-CH1-EPS or ICON-CH2-EPS baseline from the operational GRIB archive.

`member` selects which data to load:
- ``"mean"``: compute the average over all available ensemble members
- ``"median"``: load the pre-computed median member file from the archive
- ``"control"`` or ``"000"``: load the control member
- any 3-digit string (e.g. ``"001"``…): load that specific member
"""
if member == "control":
member = "000"
if member == "mean":
member_ids = _discover_icon_member_ids(root, reftime, steps)
LOG.info(
"Computing ensemble mean over %d members: %s", len(member_ids), member_ids
)
acc = None
n_loaded = 0
for mid in member_ids:
try:
ds = load_forecast_data_from_grib(
files=_collect_icon_archive_files(
root, reftime, steps, member_id=mid
),
params=params,
)
if "number" in ds.dims:
ds = ds.isel(number=0, drop=True)
acc = ds if acc is None else acc + ds
n_loaded += 1
except Exception as exc:
LOG.warning("Skipping member %s: %s", mid, exc)
if acc is None:
raise ValueError(
f"No ensemble members could be loaded for {reftime} from {root}"
)
LOG.info("Ensemble mean computed over %d members.", n_loaded)
return acc / n_loaded
else:
return load_forecast_data_from_grib(
files=_collect_icon_archive_files(root, reftime, steps, member_id=member),
params=params,
)


def load_forecast_data(
root, reftime: datetime, steps: list[int], params: list[str]
root, reftime: datetime, steps: list[int], params: list[str], member: str = "000"
) -> xr.Dataset:
"""Load forecast data from GRIB files or an ICON archive.

Expand All @@ -615,4 +664,4 @@ def load_forecast_data(
LOG.info("Loading INCA baseline from NetCDF files...")
return load_INCA_baseline_from_netcdf(root, reftime, steps, params)
LOG.info("Loading baseline forecasts from ICON GRIB archive...")
return load_icon_baseline_from_grib(root, reftime, steps, params)
return load_icon_baseline_from_grib(root, reftime, steps, params, member=member)
4 changes: 4 additions & 0 deletions src/evalml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class BaselineConfig(BaseModel):
description="Forecast steps to be used from baseline, e.g. '10/120/1'.",
pattern=r"^\d*/\d*/\d*$",
)
member: str = Field(
"000",
description="Ensemble member to use: '000'/'control' for control, 'median' for the pre-computed median, 'mean' to average all members, or any 3-digit member ID.",
)


class TruthConfig(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions workflow/rules/verification.smk
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rule verification_metrics_baseline:
params:
baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"),
baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"],
member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"),
truth_label=config["truth"]["label"],
regions=REGIONS,
experiment_params=",".join(EXPERIMENT_PARAMS),
Expand All @@ -44,6 +45,7 @@ rule verification_metrics_baseline:
--regions "{params.regions}" \
--params "{params.experiment_params}" \
--threshold_dict "{params.threshold_dict}" \
--member "{params.member}" \
--output {output} >{log} 2>&1
"""

Expand Down
10 changes: 9 additions & 1 deletion workflow/scripts/verification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def main(args: ScriptConfig):
# get baseline forecast data
now = datetime.now()

fcst = load_forecast_data(args.forecast, args.reftime, args.steps, args.params)
fcst = load_forecast_data(
args.forecast, args.reftime, args.steps, args.params, member=args.member
)

LOG.info(
"Loaded forecast data in %s seconds: \n%s",
Expand Down Expand Up @@ -147,6 +149,12 @@ def main(args: ScriptConfig):
help="Dictionary of thresholds for each parameter in the format '{param: [threshold1, threshold2, ...]}' (default: None).",
default=None,
)
parser.add_argument(
"--member",
type=str,
default="000",
help="Ensemble member to load: '000' for control, 'median' for the pre-computed median, 'mean' to average all members, or any 3-digit member ID.",
)
parser.add_argument(
"--output",
type=Path,
Expand Down
6 changes: 6 additions & 0 deletions workflow/tools/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@
"pattern": "^\\d*/\\d*/\\d*$",
"title": "Steps",
"type": "string"
},
"member": {
"default": "000",
"description": "Ensemble member to use: '000'/'control' for control, 'median' for the pre-computed median, 'mean' to average all members, or any 3-digit member ID.",
"title": "Member",
"type": "string"
}
},
"required": [
Expand Down