diff --git a/src/verification/spatial.py b/src/verification/spatial.py index a5186d5f..cb06fe70 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -121,7 +121,25 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: xr.Dataset Mapped forecast dataset. """ - # TODO: return fcst unchanged when forecast and truth are already aligned + fcst_lat = fcst["lat"].values + fcst_lon = fcst["lon"].values + truth_lat = truth["lat"].values + truth_lon = truth["lon"].values + if ( + fcst_lat.shape == truth_lat.shape + and fcst_lon.shape == truth_lon.shape + and np.max(np.abs(fcst_lat - truth_lat)) < 1e-6 + and np.max(np.abs(fcst_lon - truth_lon)) < 1e-6 + ): + if np.array_equal(fcst_lat, truth_lat) and np.array_equal(fcst_lon, truth_lon): + return fcst + coords = { + "lat": (fcst["lat"].dims, truth["lat"].data), + "lon": (fcst["lon"].dims, truth["lon"].data), + } + if "values" in fcst.dims and "values" in truth.dims: + coords["values"] = truth["values"].data + return fcst.assign_coords(coords) truth_is_grid = "y" in truth.dims and "x" in truth.dims diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index 73d56954..fa277665 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -91,6 +91,115 @@ def test_map_forecast_to_truth_maps_forecast_to_truth_locations(): ) +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_aligned(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]) + lon = np.array([[7.0, 8.0], [7.0, 8.0]]) + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.array([[[1.0, 2.0], [3.0, 4.0]]]))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) + + assert result is fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) + + +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_within_tolerance(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]) + lon = np.array([[7.0, 8.0], [7.0, 8.0]]) + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.array([[[1.0, 2.0], [3.0, 4.0]]]))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat + 5e-8), + "lon": (("y", "x"), lon - 5e-8), + }, + ) + # Nudge coordinates by less than the 1e-6 tolerance — should still be treated as aligned. + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), lat), + "lon": (("y", "x"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) + + assert result is not fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) + + +def test_map_forecast_to_truth_returns_fcst_unchanged_when_grids_are_within_tolerance_icon(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + lat = np.array([[46.0, 46.0], [47.0, 47.0]]).flatten() + lon = np.array([[7.0, 8.0], [7.0, 8.0]]).flatten() + + fcst = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.array([[1.0, 2.0, 3.0, 4.0]]))}, + coords={ + "time": fcst_time, + "values": [0, 1, 2, 3], + "lat": (("values"), lat + 5e-8), + "lon": (("values"), lon - 5e-8), + }, + ) + # Nudge coordinates by less than the 1e-6 tolerance — should still be treated as aligned. + truth = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.zeros((1, 4)))}, + coords={ + "time": fcst_time, + "values": [3, 1, 2, 0], + "lat": (("values"), lat), + "lon": (("values"), lon), + }, + ) + + result = map_forecast_to_truth(fcst, truth) + _, result_aligned = xr.align(truth, result) + + assert result is not fcst + assert result["T_2M"].values is fcst["T_2M"].values + assert np.array_equal(result["lat"].values, truth["lat"].values) + assert np.array_equal(result["lon"].values, truth["lon"].values) + assert np.array_equal(result["values"].values, truth["values"].values) + assert np.array_equal(result_aligned["T_2M"].values, fcst["T_2M"].values) + + def test_map_forecast_to_truth_restores_grid_when_truth_is_gridded(): fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]")