From f708df786b979df33e512d9147d81f37d2317924 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Tue, 7 Apr 2026 11:12:26 +0200 Subject: [PATCH 1/4] Skip kd-tree remapping in map_forecast_to_truth when grids are already aligned Avoids O(N log N) kd-tree build and query (~1M points at 1km resolution) when forecast and truth share the same grid, reducing baseline spatial verification from ~20 minutes to near-instantaneous. --- src/verification/spatial.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/verification/spatial.py b/src/verification/spatial.py index a5186d5f..cdae3546 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -121,7 +121,17 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: xr.Dataset Mapped forecast dataset. """ - # TODO: return fcst unchanged when forecast and truth are already aligned + fcst_lat = fcst["lat"].values + fcst_lon = fcst["lon"].values + truth_lat = truth["lat"].values + truth_lon = truth["lon"].values + if ( + fcst_lat.shape == truth_lat.shape + and fcst_lon.shape == truth_lon.shape + and np.max(np.abs(fcst_lat - truth_lat)) < 1e-6 + and np.max(np.abs(fcst_lon - truth_lon)) < 1e-6 + ): + return fcst truth_is_grid = "y" in truth.dims and "x" in truth.dims From f14106158694513755594b11253d9554b8779264 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Tue, 7 Apr 2026 11:14:44 +0200 Subject: [PATCH 2/4] Add test for map_forecast_to_truth fast path when grids are aligned --- tests/unit/test_spatial_mapping.py | 31 ++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index 73d56954..0060b3ce 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -91,6 +91,37 @@ def test_map_forecast_to_truth_maps_forecast_to_truth_locations(): ) +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_aligned(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]) + lon = np.array([[7.0, 8.0], [7.0, 8.0]]) + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.array([[[1.0, 2.0], [3.0, 4.0]]]))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + + assert result is fcst + + def test_map_forecast_to_truth_restores_grid_when_truth_is_gridded(): fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") From 4cc8d4dbe8ba694fcefdde44817d8e7d7b51f2e6 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 1 Jun 2026 09:10:25 +0200 Subject: [PATCH 3/4] catch edge case --- src/verification/spatial.py | 10 +++- tests/unit/test_spatial_mapping.py | 73 ++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/src/verification/spatial.py b/src/verification/spatial.py index cdae3546..cb06fe70 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -131,7 +131,15 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: and np.max(np.abs(fcst_lat - truth_lat)) < 1e-6 and np.max(np.abs(fcst_lon - truth_lon)) < 1e-6 ): - return fcst + if np.array_equal(fcst_lat, truth_lat) and np.array_equal(fcst_lon, truth_lon): + return fcst + coords = { + "lat": (fcst["lat"].dims, truth["lat"].data), + "lon": (fcst["lon"].dims, truth["lon"].data), + } + if "values" in fcst.dims and "values" in truth.dims: + coords["values"] = truth["values"].data + return fcst.assign_coords(coords) truth_is_grid = "y" in truth.dims and "x" in truth.dims diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index 0060b3ce..ceeba0e3 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -122,6 +122,79 @@ def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_aligned(): assert result is fcst +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_within_tolerance(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]) + lon = np.array([[7.0, 8.0], [7.0, 8.0]]) + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.array([[[1.0, 2.0], [3.0, 4.0]]]))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat + 5e-8), + "lon": (("y", "x"), lon - 5e-8), + }, + ) + # Nudge coordinates by less than the 1e-6 tolerance — should still be treated as aligned. + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) + + assert result is not fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) + + +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_within_tolerance_icon(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]).flatten() + lon = np.array([[7.0, 8.0], [7.0, 8.0]]).flatten() + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.array([[1.0, 2.0, 3.0, 4.0]]))}, + coords={ + "time": fcst_time, + "values": [0, 1, 2, 3], + "lat": (("values"), lat + 5e-8), + "lon": (("values"), lon - 5e-8), + }, + ) + # Nudge coordinates by less than the 1e-6 tolerance — should still be treated as aligned. + truth = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.zeros((1, 4)))}, + coords={ + "time": fcst_time, + "values": [3, 1, 2, 0], + "lat": (("values"), lat), + "lon": (("values"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) + + assert result is not fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result["values"].values, truth["values"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) + + def test_map_forecast_to_truth_restores_grid_when_truth_is_gridded(): fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") From aa70e539e326ebf0f10e298abc5230cc45b81a99 Mon Sep 17 00:00:00 2001 From: Jonas Bhend Date: Mon, 1 Jun 2026 09:23:48 +0200 Subject: [PATCH 4/4] expand original test --- tests/unit/test_spatial_mapping.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index ceeba0e3..fa277665 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -118,8 +118,13 @@ def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_aligned(): ) result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) assert result is fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_within_tolerance():