Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions xcdat/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Literal

import cf_xarray as cfxr # noqa: F401
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -73,7 +74,7 @@ def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | list[


def get_dim_coords(
obj: xr.Dataset | xr.DataArray, axis: CFAxisKey
obj: xr.Dataset | xr.DataArray, axis: CFAxisKey, multidim: bool = False
) -> xr.Dataset | xr.DataArray:
"""Gets the dimension coordinates for an axis.

Expand Down Expand Up @@ -117,10 +118,14 @@ def get_dim_coords(
----------
.. [1] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates
"""
# Get the object's index keys, with each being a dimension.
# NOTE: xarray does not include multidimensional coordinates as index keys.
# Example: ["lat", "lon", "time"]
index_keys = obj.indexes.keys()
if multidim:
# multidimensional coordinates cannot be indexes, use all coords
index_keys = list([y for x in obj.cf.coordinates.values() for y in x])
else:
# Get the object's index keys, with each being a dimension.
# NOTE: xarray does not include multidimensional coordinates as index keys.
# Example: ["lat", "lon", "time"]
index_keys = list(obj.indexes.keys())
Comment on lines +121 to +135

@tomvothecoder tomvothecoder Apr 28, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this is more robust than only using obj.cf.coordinates.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined cf.coordinates and cf.axes, should have complete coverage.


# Attempt to map the axis it all of its coordinate variable(s) using the
# axis and coordinate names in the object attributes (if they are set).
Expand Down
43 changes: 28 additions & 15 deletions xcdat/regridder/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from xcdat.axis import CFAxisKey, get_coords_by_name, get_dim_coords
from xcdat.bounds import create_bounds
from xcdat.regridder import regrid2, xesmf, xgcm
from xcdat.regridder.base import BaseRegridder
from xcdat.regridder.grid import _validate_grid_has_single_axis_dim

HorizontalRegridTools = Literal["xesmf", "regrid2"]
HORIZONTAL_REGRID_TOOLS = {
HORIZONTAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = {
"regrid2": regrid2.Regrid2Regridder,
"xesmf": xesmf.XESMFRegridder,
}

VerticalRegridTools = Literal["xgcm"]
VERTICAL_REGRID_TOOLS = {"xgcm": xgcm.XGCMRegridder}
VERTICAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = {"xgcm": xgcm.XGCMRegridder}


@xr.register_dataset_accessor(name="regridder")
Expand Down Expand Up @@ -166,7 +167,9 @@ def horizontal(
f"Tool {e!s} does not exist, valid choices {list(HORIZONTAL_REGRID_TOOLS)}"
) from e

input_grid = _get_input_grid(self._ds, data_var, ["X", "Y"])
input_grid = _get_input_grid(
self._ds, data_var, ["X", "Y"], multidim=regrid_tool.can_handle_multidim()
)
Comment on lines +170 to +172

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added missing test.

regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.horizontal(data_var, self._ds)

Expand Down Expand Up @@ -236,20 +239,17 @@ def vertical(
f"Tool {e!s} does not exist, valid choices "
f"{list(VERTICAL_REGRID_TOOLS)}"
) from e
input_grid = _get_input_grid(
self._ds,
data_var,
[
"Z",
],
)

input_grid = _get_input_grid(self._ds, data_var, ["Z"])
regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.vertical(data_var, self._ds)

return output_ds


def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset:
def _obj_to_grid_ds(
obj: xr.Dataset | xr.DataArray, multidim: bool = False
) -> xr.Dataset:
"""
Convert an xarray object to a new Dataset containing axis coordinates and
bounds.
Expand Down Expand Up @@ -304,14 +304,20 @@ def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset:
attrs=obj.attrs,
)

# Multidimensional coordinates bounds generation is not supported
if multidim:
return output_ds
Comment on lines +307 to +309

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Copilot is referencing the logic before this conditional in _obj_to_grid_ds():

axis_names: list[CFAxisKey] = ["X", "Y", "Z"]
axis_coords: dict[str, xr.DataArray] = {}
axis_bounds: dict[str, xr.DataArray] = {}
axis_has_bounds: dict[CFAxisKey, bool] = {}
with xr.set_options(keep_attrs=True):
for axis in axis_names:
coord, bounds = _get_axis_coord_and_bounds(obj, axis)
if coord is not None:
axis_coords[str(coord.name)] = coord
if bounds is not None:
axis_bounds[str(bounds.name)] = bounds
axis_has_bounds[axis] = True
else:
axis_has_bounds[axis] = False
# Create a new dataset with coordinates and bounds
output_ds = xr.Dataset(
coords=axis_coords,
data_vars=axis_bounds,
attrs=obj.attrs,
)


# Add bounds only for axes that do not already have them. This
# prevents multiple sets of bounds being added for the same axis.
# For example, curvilinear grids can have multiple coordinates for the
# same axis (e.g., (nlat, lat) for X and (nlon, lon) for Y). We only
# need lat_bnds and lon_bnds for the X and Y axes, respectively, and not
# nlat_bnds and nlon_bnds.

for axis, has_bounds in axis_has_bounds.items():
if not has_bounds:
# FIXME: Line 313 --Can't add bounds for multidimensional coordinates

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this #FIXME still relevant or should it be a #TODO? Also exact line number should be removed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer relevant, will remove.

output_ds = output_ds.bounds.add_bounds(axis=axis)

return output_ds
Expand Down Expand Up @@ -347,7 +353,12 @@ def _get_axis_coord_and_bounds(
return coord_var, bounds_var


def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKey]):
def _get_input_grid(
ds: xr.Dataset,
data_var: str,
dup_check_dims: list[CFAxisKey],
multidim: bool = False,
):
"""
Extract the grid from ``ds``.

Expand All @@ -374,10 +385,12 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe
all_coords = set(ds.coords.keys())

for dimension in dup_check_dims:
coords = get_dim_coords(ds, dimension)
coords = get_dim_coords(ds, dimension, multidim=multidim)

if isinstance(coords, xr.Dataset):
coord = set([get_dim_coords(ds[data_var], dimension).name])
coord = set(
[get_dim_coords(ds[data_var], dimension, multidim=multidim).name]
)

dimension_coords = set(ds.cf[[dimension]].coords.keys())

Expand All @@ -387,7 +400,7 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe
input_grid = ds.drop_dims(to_drop)

# drops extra dimensions on input grid
grid = input_grid.regridder.grid
grid = _obj_to_grid_ds(input_grid, multidim=multidim)

# preserve mask on grid
if "mask" in ds:
Expand Down
6 changes: 6 additions & 0 deletions xcdat/regridder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def _drop_axis(ds: xr.Dataset, axis: list[CFAxisKey]) -> xr.Dataset:
class BaseRegridder(abc.ABC):
"""BaseRegridder."""

supports_multidim: bool

@classmethod
def can_handle_multidim(cls) -> bool:
return cls.supports_multidim

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, defaulting to False.

def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any):
self._input_grid = input_grid
self._output_grid = output_grid
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class Regrid2Regridder(BaseRegridder):
supports_multidim = False

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@


class XESMFRegridder(BaseRegridder):
supports_multidim = True

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
2 changes: 2 additions & 0 deletions xcdat/regridder/xgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class XGCMRegridder(BaseRegridder):
supports_multidim = False

def __init__(
self,
input_grid: xr.Dataset,
Expand Down
Loading