diff --git a/CHANGELOG.md b/CHANGELOG.md index ae4b9886..4a6c7bb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - [issue/405] (https://github.com/podaac/l2ss-py/issues/405) Chunk logic now handles data without dimensions in root node. - Optimize performance of ScanTime variable computation +- [issue/435] (https://github.com/podaac/l2ss-py/issues/435) Extra unneeded dimension scales included in output ### Security - Updated dependency libraries diff --git a/podaac/subsetter/datatree_subset.py b/podaac/subsetter/datatree_subset.py index ba190dad..eaf2fb6a 100755 --- a/podaac/subsetter/datatree_subset.py +++ b/podaac/subsetter/datatree_subset.py @@ -295,7 +295,6 @@ def process_node( and (indexers.keys() - dataset[variable_name].dims) and set(indexers.keys()).intersection(dataset[variable_name].dims) ): - missing_dim = sorted(indexers.keys() - dataset[variable_name].dims)[0] var_indexers = { dim_name: dim_value @@ -800,91 +799,6 @@ def tree_get_spatial_bounds( return np.array([[min(min_lons), max(max_lons)], [min(min_lats), max(max_lats)]]) -def get_vars_with_paths(tree: DataTree) -> list[str]: - """ - Get all variables and coordinates with their full paths from a DataTree - - Parameters - ---------- - tree : DataTree - The input DataTree - - Returns - ------- - List[str] - List of variable paths in format '/group/var' or '/var' for root level, - including coordinate variables at root level - - Examples - -------- - >>> ds = xr.Dataset({'var1': [1], 'var2': [2], 'time': ('time', [0])}) - >>> tree = DataTree(data=ds) - >>> tree['group1'] = DataTree(data=ds.copy()) - >>> paths = get_vars_with_paths(tree) - >>> print(paths) - ['/time', '/var1', '/var2', '/group1/var1', '/group1/var2'] - """ - paths = [] - - def collect_vars(node: DataTree, current_path: str = "") -> None: - # Add data variables from current node - for var_name in node.ds.data_vars: - paths.append(f"{current_path}/{var_name}") - - # Recursively process child nodes - for child_name in node.children: - new_path = f"{current_path}/{child_name}" if current_path else f"/{child_name}" - collect_vars(node[child_name], new_path) - - collect_vars(tree) - return sorted(paths) # Sort for consistent ordering - - -def drop_vars_by_path(tree: DataTree, var_paths: str | list[str]) -> DataTree: - """ - Drop variables from a DataTree using paths in the format '/group/var' or '/var' for root level - - Parameters - ---------- - tree : DataTree - The input DataTree - var_paths : str or List[str] - Paths to variables to drop in format '/group/var' or '/var' for root level - Examples: - - '/var1' # root level variable - - '/group1/var1' # variable in group1 - - '/group1/subgroup/var1' # variable in nested group - - Returns - ------- - DataTree - Modified DataTree with variables dropped - """ - if isinstance(var_paths, str): - var_paths = [var_paths] - - for path in var_paths: - # Split the path into group path and variable name - parts = path.strip("/").split("/") - - if len(parts) == 1: - # Root level variable - var_name = parts[0] - # Modify the dataset in-place using xarray's drop_vars - tree.ds = tree.ds.drop_vars([var_name], errors="ignore") - else: - # Group variable - group_path = "/".join(parts[:-1]) - var_name = parts[-1] - try: - node = tree[group_path] - node.ds = node.ds.drop_vars([var_name], errors="ignore") - except KeyError: - pass - - return tree - - def prepare_basic_encoding(datasets: DataTree, time_encoding) -> dict: """ Prepare basic encoding dictionary for DataTree organized by groups. diff --git a/podaac/subsetter/subset.py b/podaac/subsetter/subset.py index 26c16092..085e81ba 100755 --- a/podaac/subsetter/subset.py +++ b/podaac/subsetter/subset.py @@ -184,7 +184,6 @@ def subset_with_bbox( iterator = zip(lat_var_names, lon_var_names, time_var_names) for lat_var_name, lon_var_name, time_var_name in iterator: - lat_path = file_utils.get_path(lat_var_name) lon_path = file_utils.get_path(lon_var_name) @@ -376,13 +375,11 @@ def subset( if args["decode_times"]: with xr.open_datatree(file_to_subset, decode_times=False) as dataset: - lat_var_names, lon_var_names, time_var_names = coordinate_utils.get_coordinate_variable_names( dataset=dataset, lat_var_names=lat_var_names, lon_var_names=lon_var_names, time_var_names=time_var_names ) for time in time_var_names: - time_var = dataset[time] var_name = os.path.basename(time) group_path = os.path.dirname(time) @@ -408,7 +405,6 @@ def subset( args["decode_times"] = False with xr.open_datatree(file_to_subset, **args) as dataset: - if hdf_type: dataset = hdf_utils.rename_phony_dims(dataset) @@ -439,21 +435,21 @@ def subset( if hdf_type and (min_time or max_time): dataset, _ = tree_time_converting.convert_to_datetime(dataset, time_var_names, hdf_type) - all_vars = variables_utils.get_all_variable_names_from_dtree(dataset) + all_vars = variables_utils.get_vars_with_paths(dataset) if variables: - # Drop variables that aren't explicitly requested, except lat_var_name and - # lon_var_name which are needed for subsetting - normalized_variables = [f"/{s.replace('__', '/').lstrip('/')}".upper() for s in variables] + # add in root "/" to variable path if not present so that + # matching with `all_data_variables` is works correctly + normalized_variables = variables_utils.normalize_candidate_paths_against_dtree(variables, all_vars) keep_variables = normalized_variables + lon_var_names + lat_var_names + time_var_names - keep_variables = variables_utils.normalize_candidate_paths_against_dtree(keep_variables, all_vars) - all_data_variables = datatree_subset.get_vars_with_paths(dataset) - drop_variables = [ - var for var in all_data_variables if var not in keep_variables and var.upper() not in keep_variables - ] + keep_coords = coordinate_utils.collect_coordinate_variables(dataset, keep_variables) + + keep_set = set(keep_variables) | keep_coords + + drop_variables = all_vars - keep_set - dataset = datatree_subset.drop_vars_by_path(dataset, drop_variables) + variables_utils.drop_vars_by_path(dataset, drop_variables) lon_var_names = variables_utils.normalize_candidate_paths_against_dtree(lon_var_names, all_vars) lat_var_names = variables_utils.normalize_candidate_paths_against_dtree(lat_var_names, all_vars) diff --git a/podaac/subsetter/utils/coordinate_utils.py b/podaac/subsetter/utils/coordinate_utils.py index 405ce340..e6fe96e6 100644 --- a/podaac/subsetter/utils/coordinate_utils.py +++ b/podaac/subsetter/utils/coordinate_utils.py @@ -234,6 +234,89 @@ def get_coordinate_variable_names( return lat_var_names, lon_var_names, time_var_names +def find_coordinate_origin_node( + tree: xr.DataTree, + node_path: str, + coord_name: str, +) -> str | None: + """ + Find the path of the DataTree node where a coordinate is actually defined. + + Walks up the ancestry chain from the given node, checking each node's + own dataset (``node.ds``) for the coordinate. Returns the path of the + first ancestor that owns it, or ``None`` if the coordinate does not + exist anywhere in the ancestry. + + Parameters + ---------- + tree : DataTree + The root DataTree. + node_path : str + The path of the node to start searching from (e.g. "/group/subgroup"). + coord_name : str + The name of the coordinate to locate. + + Returns + ------- + str or None + The path string of the node that defines the coordinate, or ``None``. + + Examples + -------- + >>> origin = find_coordinate_origin_node(dt, "/group/subgroup", "time") + >>> # Returns "/" if time is defined at the root + """ + node: xr.DataTree | None = tree[node_path] + + # iterate from current node to root via closest parents (inclusive + # of current node) + for n in (node, *node.parents): + # have to use to_dataset so that we can specificy *not* to + # include the inherited coords + if coord_name in n.to_dataset(inherit=False).coords: + return n.path + + return None + + +def collect_coordinate_variables(tree: xr.DataTree, variables: list[str]) -> set[str]: + """ + Collect and construct the full set of paths to coordinate + variables (if any) which each variable depends on. + + Parameters + ---------- + tree : DataTree + The root DataTree. + variables : list[str] + The name of the coordinate to locate. + + Returns + ------- + set[str] + A set containing the paths to the coordinate variables + """ + keep_coords: set[str] = set() + for var in variables: + try: + var_node = tree[var] + except KeyError: + continue + + node_path = var.rsplit("/", 1)[0] # get the prefix path + for leaf in var_node.coords: + # want to find where the dimension variable + # actually lives, continuing if none present + owning_node = find_coordinate_origin_node(tree, node_path, leaf) + if not owning_node: + continue + # strip root "/", otherwise we end up with something like "//corner" + if owning_node == "/": + owning_node = "" + keep_coords.add(f"{owning_node}/{leaf}") + return keep_coords + + def _compute_utc_name(dataset: xr.Dataset) -> str | None: """ Get the name of the utc variable if it is there to determine origine time diff --git a/podaac/subsetter/utils/variables_utils.py b/podaac/subsetter/utils/variables_utils.py index 25d3008b..e30f9176 100644 --- a/podaac/subsetter/utils/variables_utils.py +++ b/podaac/subsetter/utils/variables_utils.py @@ -9,35 +9,63 @@ import xarray as xr -def get_all_variable_names_from_dtree(dtree: xr.DataTree) -> list[str]: +def get_vars_with_paths(tree: xr.DataTree) -> set[str]: """ - Recursively extract all variable names (with full paths) from an xarray DataTree. + Get all variables and coordinates with their full paths from a DataTree Parameters ---------- - dtree : xr.DataTree - The root of the DataTree. + tree : DataTree + The input DataTree Returns ------- - List[str] - A list of variable full paths (e.g. '/group1/var'). + set[str] + Unordered set of variable and coordinate paths in format + '/group/var' or '/var' for root level. + + Examples + -------- + >>> ds = xr.Dataset({'var1': [1], 'var2': [2], 'time': ('time', [0])}) + >>> tree = DataTree(data=ds) + >>> tree['group1'] = DataTree(data=ds.copy()) + >>> paths = get_vars_with_paths(tree) + >>> print(paths) + {'/time', '/var1', '/var2', '/group1/var1', '/group1/var2'} + """ + paths: set[str] = set() + for node in tree.subtree: + prefix = node.path.rstrip("/") + "/" + for name in set(node.data_vars) | set(node.to_dataset(inherit=False).coords): + paths.add(f"{prefix}{name}") + return paths + + +def drop_vars_by_path(tree: xr.DataTree, var_paths: str | list[str] | set[str]) -> None: + """ + Drop variables *in place* from a DataTree using paths in the + format '/group/var' or '/var' for root level. + + Parameters + ---------- + tree : DataTree + The input DataTree + var_paths : str or list[str] or set[str] + Paths to variables to drop in format '/group/var' or '/var' for root level + Examples: + - '/var1' # root level variable + - '/group1/var1' # variable in group1 + - '/group1/subgroup/var1' # variable in nested group + """ - var_names = [] - - def recurse(node: xr.DataTree): - group_path = node.path - for var_name in node.data_vars: - if group_path in ("", "/"): - full_path = f"/{var_name}" - else: - full_path = f"{group_path}/{var_name}" - var_names.append(full_path) - for child in node.children.values(): - recurse(child) - - recurse(dtree) - return var_names + # guard for single string being passed + drop: set[str] = {var_paths} if isinstance(var_paths, str) else set(var_paths) + + for node in tree.subtree: + prefix = node.path.rstrip("/") + "/" + to_drop = [name for name in node.variables if f"{prefix}{name}" in drop] + if to_drop: + node.dataset = node.dataset.drop_vars(to_drop, errors="ignore") def _normalize_for_matching(path: str) -> str: diff --git a/tests/test_subset/test_specified_variables.py b/tests/test_subset/test_specified_variables.py index 448609cf..522602e2 100644 --- a/tests/test_subset/test_specified_variables.py +++ b/tests/test_subset/test_specified_variables.py @@ -1,98 +1,173 @@ -import os -import shutil -from os.path import join from pathlib import Path +from typing import NamedTuple import numpy as np import pytest import xarray as xr -from conftest import data_files - from podaac.subsetter import subset -from podaac.subsetter.utils.coordinate_utils import get_coordinate_variable_names -from podaac.subsetter.utils.variables_utils import get_all_variable_names_from_dtree - - -def get_non_variable_names_from_dtree(dtree: xr.DataTree): - """ - Recursively extract all non-variable names (with full paths) from an xarray DataTree. - This includes coordinates, dimensions, and other variables that are not data_vars. - - Parameters - ---------- - dtree : xr.DataTree - The root of the DataTree. - Returns - ------- - List[str] - A list of non-variable full paths (e.g. '/group1/coord'). - """ - non_var_names = [] - - def recurse(node: xr.DataTree): - group_path = node.path - - # Get all variables that are NOT data_vars - if node.ds is not None: # Check if node has a dataset - for var_name in node.ds.variables: - if var_name not in node.ds.data_vars: - if group_path in ("", "/"): - full_path = f"/{var_name}" - else: - full_path = f"{group_path}/{var_name}" - non_var_names.append(full_path) - - for child in node.children.values(): - recurse(child) - - recurse(dtree) - return non_var_names - - - -@pytest.mark.parametrize("test_file", data_files()) -def test_specified_variables(test_file, data_dir, subset_output_dir, request): +from podaac.subsetter.utils.variables_utils import get_vars_with_paths + + +class VariableTestCase(NamedTuple): + input: str + want_var: set[str] + want_coord: set[str] + + +_test_table: list[VariableTestCase] = [ + VariableTestCase( + input="MODIS_T-JPL-L2P-v2014.0.nc", + want_var={"/sst_dtime"}, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="MODIS_A-JPL-L2P-v2014.0.nc", + want_var={"/sea_surface_temperature"}, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="cyg04.ddmi.s20210228-000000-e20210228-235959.l1.power-brcs-cdr.a10.d10.nc", + want_var={ + "/sc_pos_z", + "/sc_vel_y", + "/sp_vel_z", + }, + want_coord={"/sp_lat", "/sp_lon", "/ddm_timestamp_utc", "/ddm", "/sample"}, + ), + VariableTestCase( + input="SWOT_L2_LR_SSH_Expert_368_012_20121111T235910_20121112T005015_DG10_01.nc", + want_var={ + "/mean_sea_surface_dtu", + "/latitude_avg_ssh", + "/geoid", + "/x_factor", + "/mean_sea_surface_cnescls_uncert", + "/simulated_error_orbital", + "/internal_tide_hret", + }, + want_coord={"/latitude", "/longitude", "/time"}, + ), + VariableTestCase( + input="20200101000001-JPL-L2P_GHRSST-SSTskin-MODIS_T-N-v02.0-fv01.0.nc", + want_var={ + "/quality_level", + "/sst_dtime", + "/sea_surface_temperature_4um", + "/quality_level_4um", + "/l2p_flags", + "/sses_standard_deviation_4um", + }, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="JA1_GPN_2PeP001_002_20020115_060706_20020115_070316.nc", + want_var={ + "/alt_state_flag_oper", + "/qual_inst_corr_1hz_swh_c", + "/sea_state_bias_ku", + "/range_used_20hz_ku", + }, + want_coord={"/lat", "/lon", "/time", "/meas_ind"}, + ), + VariableTestCase( + input="AMSR2-L2B_v08_r38622-v02.0-fv01.0.nc", + want_var={ + "/quality_level", + "/sses_standard_deviation", + "/diurnal_amplitude", + "/wind_speed", + "/rain_rate", + "/l2p_flags", + "/dt_analysis", + }, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="Merged_TOPEX_Jason_OSTM_Jason-3_Cycle_002.V4_2.nc", + want_var={"/Surface_Type", "/reference_orbit", "/Distance_to_coast", "/index"}, + want_coord={"/latitude", "/longitude", "/time"}, + ), + VariableTestCase( + input="ascat_20150702_084200_metopa_45145_eps_o_250_2300_ovw.l2.nc", + want_var={"/wvc_index", "/wind_speed", "/ice_age", "/ice_prob", "/wind_dir"}, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="ascat_20150702_102400_metopa_45146_eps_o_250_2300_ovw.l2.nc", + want_var={"/wvc_index", "/wind_speed", "/ice_age", "/ice_prob", "/wind_dir"}, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="20180101005944-REMSS-L2P_GHRSST-SSTsubskin-AMSR2-L2B_rt_r29918-v02.0-fv01.0.nc", + want_var={ + "/quality_level", + "/sses_standard_deviation", + "/diurnal_amplitude", + "/wind_speed", + "/rain_rate", + "/l2p_flags", + "/dt_analysis", + }, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="TEMPO_HCHO_L2_V01_20240110T170237Z_S005G08.nc", + want_var={ + "/support_data/amf_cloud_fraction", + "/geolocation/longitude_bounds", + "/support_data/amf_cloud_pressure", + "/geolocation/viewing_azimuth_angle", + }, + want_coord={"/mirror_step", "/xtrack", "/geolocation/latitude", "/geolocation/longitude", "/geolocation/time"}, + ), + VariableTestCase( + input="VIIRS_NPP-NAVO-L2P-v3.0.nc", + want_var={ + "/quality_level", + "/brightness_temperature_12um", + "/sea_surface_temperature", + "/sses_bias", + "/adi_dtime_from_sst", + }, + want_coord={"/lat", "/lon", "/time"}, + ), + VariableTestCase( + input="20190927000500-JPL-L2P_GHRSST-SSTskin-MODIS_A-D-v02.0-fv01.0.nc", + want_var={"/quality_level", "/wind_speed", "/sea_surface_temperature", "/sses_bias"}, + want_coord={"/lat", "/lon", "/time"}, + ), +] + + +@pytest.mark.parametrize("case", _test_table, ids=lambda c: c.input) +def test_specified_variables(case, data_dir: str, tmp_path: Path): """ Test that the variables which are specified when calling the subset - operation are present in the resulting subsetted data file, - and that the variables which are specified are not present. + operation are present in the resulting subsetted data file plus + their required dimension scale/coordinate variables """ - nc_copy_for_expected_results = os.path.join(subset_output_dir, Path(test_file).stem + "_dup.nc") - shutil.copyfile(os.path.join(data_dir, test_file), nc_copy_for_expected_results) - - bbox = np.array(((-180, 180), (-90, 90))) - output_file = f"{request.node.name}_{test_file}" - - in_ds_tree = xr.open_datatree(nc_copy_for_expected_results, decode_times=False, decode_coords=False) - - # Coordinate variables are always included in the result - lat_var_names, lon_var_names, time_var_names = get_coordinate_variable_names(in_ds_tree) - - coordinate_variables = lat_var_names + lon_var_names + time_var_names - all_variables = get_all_variable_names_from_dtree(in_ds_tree) - non_coordinate_vars = [ - var for var in all_variables if var not in coordinate_variables - ] - - included_variables = non_coordinate_vars[::2] + coordinate_variables - non_vars = get_non_variable_names_from_dtree(in_ds_tree) + output_path = tmp_path / case.input subset.subset( - file_to_subset=join(data_dir, test_file), - bbox=bbox, - output_file=join(subset_output_dir, output_file), - variables=included_variables + file_to_subset=Path(data_dir) / case.input, + bbox=np.array(((-180, 180), (-90, 90))), + output_file=output_path, + variables=list(case.want_var), # only specify wanted data variables ) - out_ds_tree = xr.open_datatree(join(subset_output_dir, output_file), decode_times=False, decode_coords=False) - out_lat_var_names, out_lon_var_names, out_time_var_names = get_coordinate_variable_names(out_ds_tree) - out_coordinate_variables = out_lat_var_names + out_lon_var_names + out_time_var_names + with xr.open_datatree(output_path, decode_times=False, decode_coords=False) as out_tree: + # all vars is the super set containing data + coord vars + all_vars = get_vars_with_paths(out_tree) - subsetted_vars = get_all_variable_names_from_dtree(out_ds_tree) - subsetted_non_vars = get_non_variable_names_from_dtree(out_ds_tree) + # wanted variable should be a subset of all vars + assert case.want_var <= all_vars - assert set(subsetted_vars + subsetted_non_vars) == set(included_variables + non_vars) + # wanted coordinate vars should be a subset of all vars as well + assert case.want_coord <= all_vars - in_ds_tree.close() - out_ds_tree.close() + # and the symmetric difference of the variable super set and + # the union of data and coordinate vars should be an empty set + # indicating that nothing is present that is not expected to + # be present. E.g. extra dimension scale vars + assert (case.want_var | case.want_coord) ^ all_vars == set() diff --git a/tests/test_subset_illegal_name_recovery.py b/tests/test_subset_illegal_name_recovery.py index 969c4c98..ae7df9f9 100644 --- a/tests/test_subset_illegal_name_recovery.py +++ b/tests/test_subset_illegal_name_recovery.py @@ -45,7 +45,7 @@ def _patch_subset_dependencies(monkeypatch, subsetted_dataset, spatial_bounds): subset.coordinate_utils, "get_coordinate_variable_names", lambda **_kwargs: (["/lat"], ["/lon"], []) ) monkeypatch.setattr(subset.file_utils, "chunk_datatree", lambda dt: dt) - monkeypatch.setattr(subset.variables_utils, "get_all_variable_names_from_dtree", lambda *_args, **_kwargs: []) + monkeypatch.setattr(subset.variables_utils, "get_vars_with_paths", lambda *_args, **_kwargs: []) monkeypatch.setattr( subset.variables_utils, "normalize_candidate_paths_against_dtree", lambda paths, *_args, **_kwargs: paths )