Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cd511b9
read data with earthkit
frazane Dec 11, 2025
ebc5dec
optionally disable eccodes local definitions
frazane Dec 11, 2025
acc794f
add minimal configs
frazane Dec 11, 2025
695cfcb
Merge branch 'main' into feat/support-global-models
dnerini Mar 12, 2026
6f23152
Update config
dnerini Mar 12, 2026
e2553a1
Merge branch 'main' into feat/support-global-models
dnerini Mar 13, 2026
7445145
Merge branch 'main' into feat/support-global-models
dnerini Mar 20, 2026
091f75c
Change example config to stage B checkpoint
dnerini Mar 20, 2026
7197a75
Merge branch 'main' into feat/support-global-models
dnerini Apr 8, 2026
80f062c
Merge branch 'main' into feat/support-global-models
dnerini Apr 20, 2026
e1e0a7b
Merge branch 'main' into feat/support-global-models
dnerini Apr 20, 2026
a5ee466
Stage B models work
dnerini Apr 22, 2026
f3e9439
Fix config
dnerini Apr 22, 2026
df329cc
Support for tp
dnerini Apr 23, 2026
a474436
lint
dnerini Apr 23, 2026
36fd358
Merge branch 'main' into feat/support-global-models
dnerini May 5, 2026
561130c
Merge branch 'main' into feat/support-global-models
dnerini May 8, 2026
9c3f7dc
Linting
dnerini May 8, 2026
d0bc642
chore: temporary fix of dependencies. Before merging we need to deal …
MicheleCattaneo May 19, 2026
9eac3f4
fix: removed hardcoded verification domain equal to the ICON domain, …
MicheleCattaneo May 19, 2026
ad76047
Add lat lon information to ICON baselines (#150)
jonasbhend May 8, 2026
1caa848
Blacklist of forecast initializations (#151)
jonasbhend May 18, 2026
ee3a5df
fix docstrings
frazane May 19, 2026
f25d0ac
update dependencies
frazane May 19, 2026
79c34f1
fix config
frazane May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions config/forecasters-ea-global.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# yaml-language-server: $schema=../workflow/tools/config.schema.json
description: |
Evaluate skill of a Stage B (global) checkpoint.

dates:
start: 2024-01-01T06:00
end: 2024-02-01T00:00
frequency: 30h

runs:
- forecaster:
checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/721576a388114c8788405cc6936ca13c
label: Stage B
steps: 0/120/6
config: resources/inference/configs/global-forecaster.yaml
extra_requirements:
- git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db
disable_local_eccodes_definitions: true
inference_resources:
slurm_partition: preemptible
- forecaster:
checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/878f19c35f6644d7bd91bd9374e47b91
label: Stage B w/ subgrid
steps: 0/120/6
config: resources/inference/configs/global-forecaster.yaml
extra_requirements:
- git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db
disable_local_eccodes_definitions: true
inference_resources:
slurm_partition: preemptible

truth:
label: ERA5
root: /store_new/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-n320-1979-2024-6h-v1-for-single-v2.zarr

stratification:
regions: []
root: /scratch/mch/bhendj/regions/Prognoseregionen_LV95_20220517

dashboard:
stratification:
# - init_hour
# - region
- season

locations:
output_root: output/

profile:
executor: slurm
global_resources:
gpus: 16
default_resources:
slurm_partition: "postproc"
cpus_per_task: 1
mem_mb_per_cpu: 1800
runtime: "1h"
gpus: 0
jobs: 50
62 changes: 62 additions & 0 deletions config/swissAI_o96_global.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# yaml-language-server: $schema=../workflow/tools/config.schema.json
description: |
Evaluation of the multistep O96 global models for project "Advancing Temporal Resolution (a0170)"

dates:
start: 2024-01-01T06:00
end: 2024-02-01T00:00
frequency: 30h
# - 2024-02-01T00:00

runs:
- forecaster:
checkpoint: /scratch/mch/miccatta/SwissAI_checkpoints/checkpoint_resO96_fr1_st1_in7_out6/a689d740f37642c38fd01a39cdbde96f/inference-last.ckpt
label: resO96_fr1_st1_in7_out6
steps: 0/120/1
config: resources/inference/configs/o96-global-multistep-forecaster.yaml
extra_requirements:
# - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db # TODO change
- git+https://github.com/ecmwf/anemoi-inference.git@2c249caa75af191ff3fd3b2a6a6297459d534a55
disable_local_eccodes_definitions: true
# inference_resources:
# slurm_partition: preemptible
- forecaster:
checkpoint: /scratch/mch/miccatta/SwissAI_checkpoints/checkpoint_resO96_fr1_st1_in2_out1/877cb9559f6c484d97b3733b7481c39a/inference-last.ckpt
label: resO96_fr1_st1_in2_out1
steps: 0/120/1
config: resources/inference/configs/o96-global-multistep-forecaster.yaml
extra_requirements:
# - git+https://github.com/ecmwf/anemoi-inference.git@b9aaee5df86614cad9d8d08b76876a4be4e980db # TODO change
- git+https://github.com/ecmwf/anemoi-inference.git@2c249caa75af191ff3fd3b2a6a6297459d534a55
disable_local_eccodes_definitions: true


truth:
label: ERA5-o96
# root: /store_new/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-n320-1979-2024-1h-v2-with-era51.zarr
root: /store_new/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr/

stratification:
type: global
regions: []

dashboard:
stratification:
# - init_hour
# - region
- season

locations:
output_root: output/

profile:
executor: slurm
global_resources:
gpus: 16
default_resources:
slurm_partition: "postproc"
cpus_per_task: 1
mem_mb_per_cpu: 1800
runtime: "1h"
gpus: 0
jobs: 50
14 changes: 5 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"snakemake<9.10", # https://github.com/snakemake/snakemake/issues/3289
"snakemake-executor-plugin-slurm",
"click",
"meteodata-lab>=0.4.0",
"anemoi-datasets>=0.5.31",
"mlflow>=3.1.1",
"pydantic>=2.11.7",
Expand All @@ -25,16 +24,11 @@ dependencies = [
"marimo>=0.23.3",
"geopandas>=0.14.0",
"peakweather",
"earthkit-data==1.0.0rc11",
"pyzmq>=27.1.0",
"scores>=2.0.0",
]

[project.optional-dependencies]
kerchunk = [
"kerchunk",
"zarr<3.0.0",
"fastparquet",
"ujson"
"eccodes<2.47",
"eccodes-cosmo-resources-python-internal",
]

[project.scripts]
Expand Down Expand Up @@ -64,3 +58,5 @@ packages = [

[tool.uv.sources]
peakweather = { git = "https://github.com/MeteoSwiss/PeakWeather.git" }
earthkit-data = { git = "https://github.com/ecmwf/earthkit-data" }
eccodes-cosmo-resources-python-internal = { git = "https://github.com/MeteoSwiss/eccodes-cosmo-resources-python-internal" }
14 changes: 14 additions & 0 deletions resources/inference/configs/global-forecaster.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
input:
test:
use_original_paths: true

allow_nans: true

output:
grib:
path: grib/{date}{time:04}_{step:03}.grib
negative_step_mode: skip
# templates:
# samples: resources/templates_index_global-ea.yaml

write_initial_state: true
25 changes: 25 additions & 0 deletions resources/inference/configs/o96-global-multistep-forecaster.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
input:
test:
use_original_paths: true

allow_nans: true

output:
grib:
path: grib/{date}{time:04}_{step:03}.grib
negative_step_mode: skip
# templates:
# samples: resources/templates_index_global-ea.yaml

write_initial_state: true

patch_metadata:
config:
dataloader:
test:
datasets:
data:
dataset_config:
dataset: /store_new/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr/
start: null
end: null
105 changes: 88 additions & 17 deletions src/data_input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
import yaml
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path

eccodes_definition_path = Path(sys.prefix) / "share/eccodes-cosmo-resources/definitions"
os.environ["ECCODES_DEFINITION_PATH"] = str(eccodes_definition_path)

from meteodatalab import data_source, grib_decoder # noqa: E402
from functools import lru_cache

import numpy as np # noqa: E402
import xarray as xr # noqa: E402
import earthkit.data as ekd # noqa: E402

LOG = logging.getLogger(__name__)

# IFS shortNames that COSMO eccodes definitions don't remap to COSMO names.
# These need explicit aliasing so callers can use COSMO names consistently.
_IFS_TO_COSMO = {
"tp": "TOT_PREC",
"msl": "PMSL",
"10u": "U_10M",
"10v": "V_10M",
"2t": "T_2M",
"2d": "TD_2M",
"sp": "PS",
"lsm": "FR_LAND",
"z": "FIS",
}
_COSMO_TO_IFS = {v: k for k, v in _IFS_TO_COSMO.items()}


@lru_cache(maxsize=1)
def earthkit_xarray_engine_profile() -> dict:
fn = Path(__file__).parent / "profile.yaml"
with open(fn) as f:
profile = yaml.safe_load(f)
return profile


def _select_valid_times(ds, times: np.datetime64):
# (handle special case where some valid times are not in the dataset, e.g. at the end)
Expand Down Expand Up @@ -65,7 +84,12 @@ def load_analysis_data_from_zarr(
PARAMS_MAP_COSMO1 = {
v: v.replace("TOT_PREC", tot_prec_string) for v in PARAMS_MAP_COSMO2.keys()
}
PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in root.name else PARAMS_MAP_COSMO1
USE_IFS_NAMES = {"-co2-", "-ea-"}
PARAMS_MAP = (
PARAMS_MAP_COSMO2
if any(tag in root.name for tag in USE_IFS_NAMES)
else PARAMS_MAP_COSMO1
)

ds = xr.open_zarr(root, consolidated=False)

Expand All @@ -79,7 +103,7 @@ def load_analysis_data_from_zarr(
ds = ds.sel(variable=[PARAMS_MAP[p] for p in params]).squeeze("ensemble", drop=True)

# recover original 2D shape
if len(ds.attrs["field_shape"]) == 2:
if "field_shape" in ds.attrs and len(ds.attrs["field_shape"]) == 2:
ny, nx = ds.attrs["field_shape"]
y_idx, x_idx = np.unravel_index(np.arange(ny * nx), shape=(ny, nx))
ds = ds.assign_coords({"y": ("cell", y_idx), "x": ("cell", x_idx)})
Expand All @@ -89,7 +113,10 @@ def load_analysis_data_from_zarr(
# set lat lon as coords (optional)
if "latitudes" in ds and "longitudes" in ds:
ds = ds.rename({"latitudes": "lat", "longitudes": "lon"})
ds = ds.set_coords(["lat", "lon"])
if "latitude" in ds and "longitude" in ds:
ds = ds.rename({"latitude": "lat", "longitude": "lon"})
if "lat" in ds and "lon" in ds:
ds = ds.set_coords(["lat", "lon"])
ds = (
ds["data"]
.to_dataset("variable")
Expand All @@ -112,9 +139,43 @@ def load_fct_data_from_grib(
) -> xr.Dataset:
"""Load forecast data from GRIB files for a specific valid time."""
files = sorted(root.glob(f"{reftime:%Y%m%d%H%M}*.grib"))
fds = data_source.FileDataSource(datafiles=files)
ds = grib_decoder.load(fds, {"param": params, "step": steps})
for var, da in ds.items():

profile = earthkit_xarray_engine_profile()
LOG.debug(f"loading GRIB for params {params} and steps {steps} from {root}")
# Extend param selection to include IFS aliases (e.g. "tp" for "TOT_PREC") so
# that both COSMO-named and IFS-named GRIB files (global models) are handled.
params_sel = list(
{p for p in params} | {_COSMO_TO_IFS[p] for p in params if p in _COSMO_TO_IFS}
)
# Precipitation params don't have a step=0 field (accumulation is zero at
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this was previously handled by meteodata-lab?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this didn't occur previously because we haven't read from the global files at all. All our lam cutouts use COSMO names.

# analysis time and is often not written), so loading them together with
# other variables causes an inconsistent-step error in earthkit-data.
# Load them separately with step>0, then merge.
_PREC_PARAMS = {"tp", "TOT_PREC"}
prec_params = [p for p in params_sel if p in _PREC_PARAMS]
other_params = [p for p in params_sel if p not in _PREC_PARAMS]
fieldlist = ekd.from_source("file", files)
datasets = []
if other_params:
datasets.append(
fieldlist.sel(param=other_params, step=steps).to_xarray(profile=profile)
)
if prec_params:
prec_steps = [s for s in steps if s > 0]
datasets.append(
fieldlist.sel(param=prec_params, step=prec_steps).to_xarray(profile=profile)
)
ds: xr.Dataset = (
xr.merge(datasets, join="outer") if len(datasets) > 1 else datasets[0]
)
# Rename any IFS names back to COSMO names
ifs_rename = {
ifs: cosmo for ifs, cosmo in _IFS_TO_COSMO.items() if ifs in ds.data_vars
}
if ifs_rename:
ds = ds.rename(ifs_rename)

for var, da in ds.data_vars.items():
if "z" in da.dims and da.sizes["z"] == 1:
ds[var] = da.squeeze("z", drop=True)
elif "z" in da.dims and da.sizes["z"] > 1:
Expand Down Expand Up @@ -146,6 +207,7 @@ def load_fct_data_from_grib(
ds.TOT_PREC,
)
)

## Sanity-check that the incoming data is actually cumulative: if
## .diff() produces significantly negative values, TOT_PREC is already
## period-accumulated and a second disaggregation would produce
Expand All @@ -166,6 +228,15 @@ def load_fct_data_from_grib(
## small float-noise negatives to zero (anything below -0.1 mm has
## already been caught by the check above).
ds = ds.assign(TOT_PREC=diff.clip(min=0.0).reindex(lead_time=lead_times))

# set lat lon as coords (optional)
if "latitudes" in ds and "longitudes" in ds:
ds = ds.rename({"latitudes": "lat", "longitudes": "lon"})
if "latitude" in ds and "longitude" in ds:
ds = ds.rename({"latitude": "lat", "longitude": "lon"})
if "lat" in ds and "lon" in ds:
ds = ds.set_coords(["lat", "lon"])

# make sure time coordinate is available, and valid_time is not
if "valid_time" in ds.coords:
ds = ds.rename({"valid_time": "time"})
Expand Down Expand Up @@ -288,7 +359,7 @@ def load_truth_data(
) -> xr.Dataset:
"""Load truth data from analysis Zarr or PeakWeather observations."""
if root.suffix == ".zarr":
LOG.info("Loading ground truth from an analysis zarr dataset...")
LOG.info(f"Loading ground truth from {root}")
truth = load_analysis_data_from_zarr(
root=root,
reftime=reftime,
Expand All @@ -301,7 +372,7 @@ def load_truth_data(
else {"values": -1}
)
elif "peakweather" in str(root):
LOG.info("Loading ground truth from PeakWeather observations...")
LOG.info(f"Loading ground truth from {root}")
truth = load_obs_data_from_peakweather(
root=root,
reftime=reftime,
Expand All @@ -319,15 +390,15 @@ def load_forecast_data(
"""Load forecast data from GRIB files or a baseline Zarr dataset."""

if any(root.glob("*.grib")):
LOG.info("Loading forecasts from GRIB files...")
LOG.info(f"Loading forecasts from GRIB files from {root}")
fcst = load_fct_data_from_grib(
root=root,
reftime=reftime,
steps=steps,
params=params,
)
else:
LOG.info("Loading baseline forecasts from zarr dataset...")
LOG.info(f"Loading baseline forecasts from zarr dataset from {root}")
fcst = load_baseline_from_zarr(
root=root,
reftime=reftime,
Expand Down
Loading
Loading