diff --git a/aok/core/config.py b/aok/core/config.py new file mode 100644 index 0000000..6fa054c --- /dev/null +++ b/aok/core/config.py @@ -0,0 +1,464 @@ +# # config.py + +# path = 'C:/Workstation/ICESat2_HLS/Kd_ComparisionPaper/Dataset/ATL03_ICESat2/' + +# #US, Orgon +# # ATL03_h5_file = "processed_ATL03_20181206092124_10500106_005_01.h5"# +# #South America +# # ATL03_h5_file = "processed_ATL03_20220530041141_10391513_005_02.h5"# + + +# #Alaska Cook Inlet +# # ATL03_h5_file = "processed_ATL03_20210801125753_05941205_005_01.h5"# +# ATL03_20200805175053_06320803_006_01_subsetted +# ATL03_20210828231447_10131203_006_01_subsetted + +# # Hawaii line without afterpulses +# # ATL03_h5_file = "processed_ATL03_20220122044818_04721407_006_01.h5"# + + +# # ChesapeakeBay +# ATL03_h5_file = "processed_ATL03_20230825074121_10102002_006_02.h5"# + + +# # India +# # ATL03_h5_file = "processed_ATL03_20200331204156_00810707_006_01.h5"# + + +# #China Bohai +# # ATL03_h5_file = "processed_ATL03_20191128115256_09560502_006_01.h5"# +# # ATL03_h5_file = "processed_ATL03_20190530203316_09560302_006_02.h5"# + +# ATL03_h5_file = "processed_ATL03_20190829161305_09560402_006_02.h5"# + + +# ATL03_h5_file_path = path + ATL03_h5_file + +# Current_Path = 'C:/Workstation/ICESat2_HLS/Kd_ComparisionPaper/Dataset/' + +# shoreline_data_path = Current_Path + 'Shorelines/GeoPkgGlobalShoreline.gpkg' + +# #load the Global bathy dataset +# GEBCO_paths = [ +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n0.0_s-90.0_w0.0_e90.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n0.0_s-90.0_w-90.0_e0.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n0.0_s-90.0_w90.0_e180.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n0.0_s-90.0_w-180.0_e-90.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n90.0_s0.0_w0.0_e90.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n90.0_s0.0_w-90.0_e0.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n90.0_s0.0_w90.0_e180.0.tif", +# Current_Path + "Bathy/gebco_2024_geotiff/gebco_2024_n90.0_s0.0_w-180.0_e-90.0.tif", +# ] + +# horizontal_res = 500 + +# vertical_res = 0.25 + +# # Use calculated sea height to determine photons at 0.5m below peak +# subsurface_thresh = 0.5 + +# # subsurface distance below the sea surface beginning to account +# Kd_max_depth = -1 + +# OutputPath='C:/Workstation/ICESat2_HLS/Kd_ComparisionPaper/Results/' + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser( + description="Configure ICESat-2 HLS analysis script." + ) + + # General paths + parser.add_argument( + "--workspace_path", + type=str, + default="C:/Workstation/ICESat2_HLS/Kd_ComparisionPaper/", + help="Root path for all data processing.", + ) + + # parser.add_argument("--atl03_path", type=str, default='Dataset/ATL03_ICESat2/New/', + # help="Base path for ICESat-2 ATL03 data.") + + # # Bohai Sea (v006) + # parser.add_argument("--atl03_path", type=str, default='Dataset/ATL03_ICESat2/', + # help="Base path for ICESat-2 ATL03 data.") + + # Wax Delta (v007) — required for IR/AP filter sensitivity test + parser.add_argument( + "--atl03_path", + type=str, + default="Dataset/ATL03_ICESat2/Wax_Delta/", + help="Base path for ICESat-2 ATL03 data.", + ) + + # #ChesapeakeBay + # parser.add_argument("--atl03_file", type=str, default="processed_ATL03_20230825074121_10102002_006_02.h5", + # help="Name of the ATL03 H5 file to process.") + + # parser.add_argument("--atl03_file", type=str, default="processed_ATL03_20221001113813_01641706_006_01.h5", + # help="Name of the ATL03 H5 file to process.") + + # # India: processed_ATL03_20200331204156_00810707_006_01.h5 + # parser.add_argument("--atl03_file", type=str, default="processed_ATL03_20200331204156_00810707_006_01.h5", + # help="Name of the ATL03 H5 file to process.") + + # #ATL03_20210628230109_00811207_006_01_subsetted + + # parser.add_argument("--atl03_file", type=str, default="processed_ATL03_20210628230109_00811207_006_01.h5", + # help="Name of the ATL03 H5 file to process.") + + # # WaxDelta + # parser.add_argument("--atl03_file", type=str, default="ATL03_20231103172707_06982106_006_01_subsetted.h5", + # help="Name of the ATL03 H5 file to process.") + + # # Bohai (v006) — baseline sensitivity test site + # parser.add_argument("--atl03_file", type=str, default="ATL03_20190829161305_09560402_006_02_subsetted.h5", + # help="Name of the ATL03 H5 file to process.") + + # Wax Delta (v007) — required for IR/AP filter; use with atl03_path = Dataset/ATL03_ICESat2/Wax_Delta/ + parser.add_argument( + "--atl03_file", + type=str, + default="ATL03_20231103172707_06982106_007_01_subsetted.h5", + help="Name of the ATL03 H5 file to process.", + ) + + # # Cook Inlet + # parser.add_argument("--atl03_file", type=str, default="ATL03_20220507112145_06931503_006_01_subsetted.h5", + # help="Name of the ATL03 H5 file to process.") + + # #Core Sound + # parser.add_argument("--atl03_file", type=str, default="105430185_ATL03_20220423191113_04841506_006_02_subsetted.h5", + # help="Name of the ATL03 H5 file to process.") + + # # Pamlico Sound + # parser.add_argument("--atl03_file", type=str, default="ATL03_20240415083629_04232306_006_01_subsetted.h5", + # help="Name of the ATL03 H5 file to process.") + + # + # #Cook Inlet: processed_ATL03_20210801125753_05941205_005_01 + # parser.add_argument("--atl03_file", type=str, default="processed_ATL03_20210801125753_05941205_005_01.h5", + # help="Name of the ATL03 H5 file to process.") + + parser.add_argument( + "--other_data_path", + type=str, + default="Dataset/", + help="Path to the current working directory.", + ) + + parser.add_argument( + "--output_path", + type=str, + default="Results/", + help="Directory for saving results.", + ) + + # Shoreline data + parser.add_argument( + "--shoreline_data", + type=str, + # default='Shorelines/GeoPkgGlobalShoreline.gpkg', + default="Shorelines/ne_10m_land/ne_10m_land.shp", + help="Path to the global shoreline dataset.", + ) + + # Bathymetry datasets + parser.add_argument( + "--gebco_path", + nargs="+", + type=str, + default="Bathy/gebco_2024_geotiff/", + help="Path to GEBCO bathymetry datasets.", + ) + + # Resolution settings + parser.add_argument( + "--horizontal_res", + type=int, + default=500, + help="Horizontal resolution for the analysis (in meters).", + ) + + parser.add_argument( + "--vertical_res", + type=float, + default=0.25, + help="Vertical resolution for the analysis.", + ) + + # Analysis parameters + parser.add_argument( + "--subsurface_thresh", + type=float, + default=1.0, + help="Threshold for photons below the sea surface.", + ) + + parser.add_argument( + "--ignore_subsurface_height_thres", + type=float, + default=-6, + help="Maximum depth thres for Kd calculations (in meters).", + ) # 5m + + # Empty string '' auto-selects the first strong beam. Pass comma-separated IDs to override. + # Valid IDs: gt1l, gt1r, gt2l, gt2r, gt3l, gt3r — only strong beams for the granule are kept. + # To process all strong beams, use --target_beams "gt1r,gt2r,gt3r". + parser.add_argument( + "--target_beams", + type=str, + default="", + help="Comma-separated beam IDs to process. Empty means auto-select first strong beam.", + ) + + parser.add_argument( + "--disable_solar_background_filter", + action="store_true", + help="Disable solar background filtering (enabled by default; " + "self-gates on solar elevation so has no effect on nighttime passes).", + ) + parser.add_argument( + "--solar_elevation_day_threshold", + type=float, + default=6.0, + help="Solar elevation threshold (deg) to define daytime for optional filtering.", + ) + parser.add_argument( + "--solar_bg_median_window_deg", + type=float, + default=10.0, + help="Rolling median window width (degrees of solar elevation) for " + "the solar-elevation-aware background filter.", + ) + parser.add_argument( + "--solar_bg_noise_multiplier", + type=float, + default=1.5, + help="Ratio of actual/expected background rate above which a photon " + "is flagged as noisy by the solar background filter.", + ) + parser.add_argument( + "--solar_background_min_signal_conf", + type=int, + default=2, + help="Minimum photon signal confidence kept in high daytime background.", + ) + + parser.add_argument( + "--enable_ir_ap_filter", + action="store_true", + help="Optional: remove photons using IR/afterpulse proxy flags.", + ) + parser.add_argument( + "--ir_ap_quality_max", + type=int, + default=0, + help="Maximum allowed quality_ph value when IR/AP filter is enabled.", + ) + parser.add_argument( + "--ir_ap_min_signal_conf", + type=int, + default=0, + help="Minimum allowed photon_conf value when IR/AP filter is enabled.", + ) + + parser.add_argument( + "--enable_gebco_filter", + action="store_true", + help="Optional: remove photons below GEBCO seafloor and shallow bins.", + ) + parser.add_argument( + "--enable_sea_surface_flattening", + action="store_true", + help="Optional: flatten small-bin sea-surface variation using larger along-track windows.", + ) + parser.add_argument( + "--sea_surface_flattening_window_m", + type=int, + default=5000, + help="Window size in meters for sea-surface flattening mean level. " + "Must be larger than horizontal_res to have any effect.", + ) + + parser.add_argument( + "--enable_convex_hull_filter", + action="store_true", + help="Optional: apply convex hull area filtering before Kd fitting.", + ) + parser.add_argument( + "--convex_hull_area_threshold", + type=float, + default=100, + help="Minimum convex hull area to keep a horizontal bin when enabled.", + ) + + parser.add_argument( + "--enable_histogram_quality_filter", + action="store_true", + help="Optional: discard along-track bins with weak 6-7 m signal.", + ) + parser.add_argument( + "--histogram_quality_min_ratio", + type=float, + default=0.05, + help="Minimum ratio of photons in depth band to keep a bin.", + ) + parser.add_argument( + "--histogram_quality_depth_min", + type=float, + default=6.0, + help="Minimum depth (m below surface) of quality-check band.", + ) + parser.add_argument( + "--histogram_quality_depth_max", + type=float, + default=7.0, + help="Maximum depth (m below surface) of quality-check band.", + ) + parser.add_argument( + "--histogram_quality_ref_depth_min", + type=float, + default=0.0, + help="Min depth (m) of near-surface reference band for decay ratio.", + ) + parser.add_argument( + "--histogram_quality_ref_depth_max", + type=float, + default=1.0, + help="Max depth (m) of near-surface reference band for decay ratio.", + ) + + parser.add_argument( + "--enable_surface_sigma_filter", + action="store_true", + help="Optional: discard bins where Gaussian surface sigma is too large.", + ) + parser.add_argument( + "--surface_sigma_max", + type=float, + default=0.5, + help="Maximum allowed Gaussian sigma (m) for sea-surface peak.", + ) + + parser.add_argument( + "--enable_refraction_correction", + action="store_true", + help="Optional: apply refraction correction to subsurface photons.", + ) + parser.add_argument( + "--refraction_water_temp_c", + type=float, + default=20.0, + help="Water temperature (deg C) used for refraction correction.", + ) + parser.add_argument( + "--refraction_wavelength_nm", + type=float, + default=532.0, + help="Laser wavelength (nm) used for refraction correction.", + ) + + parser.add_argument( + "--enable_post_refraction_refit", + action="store_true", + help="Optional: rebuild histograms and re-fit surface after refraction correction.", + ) + + parser.add_argument( + "--enable_atl24_filter", + action="store_true", + help="Optional: use ATL24 bathymetry matching before GEBCO fallback.", + ) + parser.add_argument( + "--atl24_file", + type=str, + default="", + help="Path to ATL24 point dataset (CSV/Parquet/GPKG/SHP) for bathymetry matching.", + ) + parser.add_argument( + "--atl24_max_match_distance_deg", + type=float, + default=0.01, + help="Maximum lon/lat degree distance for ATL24 nearest-point matching.", + ) + + parser.add_argument( + "--decay_zone_threshold", + type=float, + default=0.0, + help="Fraction of peak photon count below which depth bins are excluded " + "in Step 16 decay zone detection. 0.0 = original behaviour (only " + "zero-count bins removed). E.g. 0.01 removes bins with < 1%% of peak.", + ) + + parser.add_argument( + "--kd_fit_method", + type=str, + default="log_linear", + choices=["log_linear", "bg_subtract", "breakpoint", "nonlinear", "hybrid"], + help="Beer's Law fitting strategy for Kd calculation. " + "log_linear = original log-space linear regression; " + "bg_subtract = subtract noise floor then log-linear; " + "breakpoint = segmented regression with breakpoint detection; " + "nonlinear = fit C(z) = A*exp(-Kd*z) + N directly.", + ) + + parser.add_argument( + "--enable_wave_adaptive_fit", + action="store_true", + help="Optional: trim shallow depth bins by wave amplitude before Beer's Law fit. " + "Uses per-bin surface sigma to skip the top k*sigma metres in wavy bins.", + ) + parser.add_argument( + "--wave_exclusion_multiplier", + type=float, + default=4.0, + help="Number of surface sigmas to skip from the actual water surface when " + "wave-adaptive fit is enabled. The adaptive surface detection already " + "removes 3*sigma; this adds (k-3)*sigma additional margin for bubble " + "injection zone. Default 4.0 adds 1*sigma beyond the 3-sigma cutoff.", + ) + parser.add_argument( + "--wave_sigma_calm_threshold", + type=float, + default=0.1, + help="Surface sigma (m) below which a bin is considered calm and no " + "wave trimming is applied. Default 0.1 m.", + ) + + parser.add_argument( + "--enable_paired_beam_combine", + action="store_true", + help="Optional: combine paired beams (gt1, gt2, gt3) before Kd fitting.", + ) + + parser.add_argument( + "--no_plot", + action="store_true", + help="Suppress all plot generation (useful for batch runs).", + ) + + # switch on consider the coastal water or inland water + # if Inland water, we should use one mask and for coastal water, it should be another mask + + # consider bathymetry or not manual setting + + # night vs daytime + + # solar elevation determine remove or not for solar background + + # 1-2 hours + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + # Access parameters like this: + atl03_file_path = args.workspace_path + args.atl03_file + + print("Processing file:", atl03_file_path) + print("Output path:", args.output_path) diff --git a/aok/core/kd_utils/Kd_analysis.py b/aok/core/kd_utils/Kd_analysis.py new file mode 100644 index 0000000..7d1504b --- /dev/null +++ b/aok/core/kd_utils/Kd_analysis.py @@ -0,0 +1,1003 @@ +# utils/Kd_analysis.py +# updated to perform a linear fit in log-space, just like MATLAB's polyfitn(zdepth, y, 1) for a first-order polynomial. + +import logging + +import numpy as np +import pandas as pd +from scipy.optimize import curve_fit +from sklearn.linear_model import LinearRegression + +# def log_model(z, kd, e0): +# return np.log(e0) - kd * z + +## This is wrong because the input is already a bined data, +## it is not necessary to do a histogram again +# def CalculateKdFromFilteredSubsurfacePhoton(df, vertical_res=0.8): +# if df.empty or 'lat_bins' not in df.columns: +# return pd.DataFrame({'lat_bins': [np.nan], 'kd': [np.nan], 'e0': [np.nan], 'latitude': [np.nan], 'longitude': [np.nan]}) + +# # Get the latitude bin value +# lat_bin_value = df['lat_bins'].iloc[0] if not df['lat_bins'].empty else np.nan +# latitude = df['latitude'].mean() if 'latitude' in df.columns else df['lat'].mean() if 'lat' in df.columns else np.nan +# longitude = df['longitude'].mean() if 'longitude' in df.columns else df['lon'].mean() if 'lon' in df.columns else np.nan + +# # Calculate photon height range +# photon_height_min = df['photon_height'].min() +# photon_height_max = df['photon_height'].max() + +# # Check for sufficient data range +# if np.isnan(photon_height_min) or np.isnan(photon_height_max) or photon_height_min == photon_height_max: +# return pd.DataFrame({'lat_bins': [lat_bin_value], 'kd': [np.nan], 'e0': [np.nan]}) + +# height_bins_range = abs(photon_height_max - photon_height_min) +# height_bins_number = round(height_bins_range / vertical_res) + +# # Ensure there are enough bins +# if height_bins_number < 5: +# return pd.DataFrame({'lat_bins': [lat_bin_value], 'kd': [np.nan], 'e0': [np.nan]}) + +# # Create histogram of photon heights +# bin_edges = np.linspace(photon_height_min, photon_height_max, num=height_bins_number) +# counts, _ = np.histogram(df['photon_height'], bins=bin_edges) +# bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + +# # Store histogram data +# hist_df = pd.DataFrame({'zdepth': bin_centers, 'photon_counts': counts}) + +# # x value for model +# # Reverse zdepth to align with the MATLAB approach +# hist_df['zdepth'] = hist_df['zdepth'].max() - hist_df['zdepth'] + +# # Log-transform photon counts, replacing zeros with NaN +# # y value for model +# hist_df['log_photon_counts'] = np.log(hist_df['photon_counts'].replace(0, np.nan)) +# hist_df.loc[np.isinf(hist_df['log_photon_counts']), 'log_photon_counts'] = np.nan + +# # Filter out rows with NaNs in either column for regression +# valid_data = hist_df.dropna(subset=['zdepth', 'log_photon_counts']) + +# # Check for enough valid data points +# # Skip the regression if there are fewer than 5 datapoints +# if valid_data['log_photon_counts'].notna().sum() > 3: +# # Drop NaNs for regression +# zdepth_valid = valid_data['zdepth'].values.reshape(-1, 1) +# log_counts_valid = valid_data['log_photon_counts'].values + +# # Perform linear regression +# model = LinearRegression() +# model.fit(zdepth_valid, log_counts_valid) + +# # Extract kd as the negative of the slope and e0 from the intercept +# kd = -model.coef_[0] + +# print('zdepth_valid:',zdepth_valid) +# print('photon_counts:',hist_df['photon_counts']) +# print('kd:',kd) + +# e0 = np.exp(model.intercept_) + +# # Set kd to NaN if negative +# if kd < 0: +# kd = np.nan +# else: +# kd, e0 = np.nan, np.nan + +# return pd.DataFrame({ +# 'lat_bins': [lat_bin_value], +# 'kd': [kd], +# 'e0': [e0], +# 'latitude': [latitude], +# 'longitude': [longitude] +# }) + + +# # one solution is to adjust the vertical_res to 0.25 +# def CalculateKdFromFilteredSubsurfacePhoton(df, vertical_res=0.25): +# if df.empty or 'lat_bins' not in df.columns: +# return pd.DataFrame({'lat_bins': [np.nan], 'kd': [np.nan], 'e0': [np.nan], 'latitude': [np.nan], 'longitude': [np.nan]}) + +# # Get the latitude bin value +# lat_bin_value = df['lat_bins'].iloc[0] if not df['lat_bins'].empty else np.nan +# latitude = df['latitude'].mean() if 'latitude' in df.columns else df['lat'].mean() if 'lat' in df.columns else np.nan +# longitude = df['longitude'].mean() if 'longitude' in df.columns else df['lon'].mean() if 'lon' in df.columns else np.nan + +# # Calculate photon height range +# photon_height_min = df['photon_height'].min() +# photon_height_max = df['photon_height'].max() + +# # Check for sufficient data range +# if np.isnan(photon_height_min) or np.isnan(photon_height_max) or photon_height_min == photon_height_max: +# return pd.DataFrame({'lat_bins': [lat_bin_value], 'kd': [np.nan], 'e0': [np.nan]}) + +# height_bins_range = abs(photon_height_max - photon_height_min) +# height_bins_number = round(height_bins_range / vertical_res) + +# # Ensure there are enough bins +# if height_bins_number < 5: +# return pd.DataFrame({'lat_bins': [lat_bin_value], 'kd': [np.nan], 'e0': [np.nan]}) + +# # Create histogram of photon heights +# bin_edges = np.linspace(photon_height_min, photon_height_max, num=height_bins_number) +# counts, _ = np.histogram(df['photon_height'], bins=bin_edges) +# bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + +# # Store histogram data +# hist_df = pd.DataFrame({'zdepth': bin_centers, 'photon_counts': counts}) + +# # x value for model +# # Reverse zdepth to align with the MATLAB approach +# hist_df['zdepth'] = hist_df['zdepth'].max() - hist_df['zdepth'] + +# # Log-transform photon counts, replacing zeros with NaN +# # y value for model +# hist_df['log_photon_counts'] = np.log(hist_df['photon_counts'].replace(0, np.nan)) +# hist_df.loc[np.isinf(hist_df['log_photon_counts']), 'log_photon_counts'] = np.nan + +# # Filter out rows with NaNs in either column for regression +# valid_data = hist_df.dropna(subset=['zdepth', 'log_photon_counts']) + +# # Check for enough valid data points +# # Skip the regression if there are fewer than 5 datapoints +# if valid_data['log_photon_counts'].notna().sum() > 3: +# # Drop NaNs for regression +# zdepth_valid = valid_data['zdepth'].values.reshape(-1, 1) +# log_counts_valid = valid_data['log_photon_counts'].values + +# # Perform linear regression +# model = LinearRegression() +# model.fit(zdepth_valid, log_counts_valid) + +# # Extract kd as the negative of the slope and e0 from the intercept +# kd = -model.coef_[0] + +# print('zdepth_valid:',zdepth_valid) +# print('photon_counts:',hist_df['photon_counts']) +# print('kd:',kd) + +# e0 = np.exp(model.intercept_) + +# # Set kd to NaN if negative +# if kd < 0: +# kd = np.nan +# else: +# kd, e0 = np.nan, np.nan + +# return pd.DataFrame({ +# 'lat_bins': [lat_bin_value], +# 'kd': [kd], +# 'e0': [e0], +# 'latitude': [latitude], +# 'longitude': [longitude] +# }) + + +def find_exponential_decay_zone(hist_df, decay_threshold=0.0): + """ + Identify the depth bins that lie within the exponential decay zone. + + Flowchart step: 16 — Determine depth where signal decays to 1% of incoming + signal or where photon decay is no longer exponential. Bins with zero photon + counts (log undefined) are excluded; the remaining contiguous non-zero bins + define the exponential decay zone. + + Parameters + ---------- + hist_df : pd.DataFrame + Columns: 'zdepth' (depth from surface, ascending), 'photon_counts'. + decay_threshold : float, optional + Fraction of peak photon count below which bins are excluded. + 0.0 (default) preserves original behaviour (only zero-count bins removed). + E.g. 0.01 removes bins with < 1% of the peak signal. + + Returns + ------- + pd.DataFrame + Filtered rows with a 'log_photon_counts' column added; only bins inside + the exponential decay zone are kept (photon_counts > 0 and log finite). + """ + df = hist_df.copy() + + # Apply decay threshold: exclude bins below threshold fraction of peak signal + if decay_threshold > 0.0: + peak_count = df["photon_counts"].max() + if peak_count > 0: + df.loc[ + df["photon_counts"] < decay_threshold * peak_count, "photon_counts" + ] = 0 + + df["log_photon_counts"] = np.log(df["photon_counts"].replace(0, np.nan)) + df.loc[np.isinf(df["log_photon_counts"]), "log_photon_counts"] = np.nan + return df.dropna(subset=["zdepth", "log_photon_counts"]) + + +def fit_beers_law(valid_data): + """ + Fit Beer's Law exponential decay curve to the identified decay zone. + + Flowchart step: 17 — Fit exponential decay curve (Beer's Law) to data + within the zone of exponential decay. Implemented as linear regression on + log(photon_counts) vs depth; kd = -slope, e0 = exp(intercept). + + Parameters + ---------- + valid_data : pd.DataFrame + Output of find_exponential_decay_zone; columns 'zdepth' and + 'log_photon_counts'. + + Returns + ------- + tuple[float, float] + (kd, e0). kd is set to np.nan if negative or if fewer than 4 points. + """ + if valid_data["log_photon_counts"].notna().sum() <= 3: + return np.nan, np.nan + + zdepth_valid = valid_data["zdepth"].values.reshape(-1, 1) + log_counts_valid = valid_data["log_photon_counts"].values + + model = LinearRegression() + model.fit(zdepth_valid, log_counts_valid) + + kd = -model.coef_[0] + e0 = np.exp(model.intercept_) + + if kd < 0: + kd = np.nan + return kd, e0 + + +# --------------------------------------------------------------------------- +# Strategy A: Background subtraction + log-linear fit +# --------------------------------------------------------------------------- +def fit_beers_law_bg_subtract(hist_df, bg_fraction=0.2): + """ + Estimate the noise floor from the deepest bins, subtract it, then fit + Beer's Law in log-space on the noise-subtracted counts. + + Parameters + ---------- + hist_df : pd.DataFrame + Columns: 'zdepth' (ascending from surface), 'photon_counts'. + bg_fraction : float + Fraction of the deepest bins used to estimate the background level. + + Returns + ------- + tuple[float, float, float] + (kd, e0, noise_floor). + """ + df = hist_df.copy().sort_values("zdepth") + n_bins = len(df) + if n_bins < 5: + return np.nan, np.nan, np.nan + + # Estimate noise floor from the deepest bg_fraction of bins + n_bg = max(int(n_bins * bg_fraction), 2) + noise_floor = float(df.tail(n_bg)["photon_counts"].median()) + + # Subtract noise and keep only positive residuals + df["signal"] = df["photon_counts"] - noise_floor + df = df[df["signal"] > 0].copy() + if len(df) < 4: + return np.nan, np.nan, noise_floor + + df["log_signal"] = np.log(df["signal"]) + + zdepth = df["zdepth"].values.reshape(-1, 1) + log_signal = df["log_signal"].values + + model = LinearRegression() + model.fit(zdepth, log_signal) + + kd = -model.coef_[0] + e0 = np.exp(model.intercept_) + + if kd < 0: + kd = np.nan + return kd, e0, noise_floor + + +# --------------------------------------------------------------------------- +# Strategy B: Breakpoint / segmented regression +# --------------------------------------------------------------------------- +def fit_beers_law_breakpoint(hist_df): + """ + Find the optimal breakpoint between exponential decay and noise floor, + then fit Beer's Law only to the decay segment. + + The breakpoint is chosen by testing every candidate position and selecting + the one that minimises total residual sum of squares of a two-segment + model: linear slope above the breakpoint, flat constant below. + + Parameters + ---------- + hist_df : pd.DataFrame + Columns: 'zdepth' (ascending from surface), 'photon_counts'. + + Returns + ------- + tuple[float, float, float, float] + (kd, e0, breakpoint_depth, noise_floor). + """ + df = hist_df.copy().sort_values("zdepth") + df = df[df["photon_counts"] > 0].copy() + if len(df) < 5: + return np.nan, np.nan, np.nan, np.nan + + df["log_counts"] = np.log(df["photon_counts"]) + depths = df["zdepth"].values + log_counts = df["log_counts"].values + n = len(depths) + + best_rss = np.inf + best_bp_idx = None + + # Try each candidate breakpoint (need >= 4 points in decay segment, + # >= 1 in noise segment) + for bp_idx in range(4, n - 1): + # Decay segment: linear fit on bins 0..bp_idx-1 + z_decay = depths[:bp_idx].reshape(-1, 1) + lc_decay = log_counts[:bp_idx] + model = LinearRegression() + model.fit(z_decay, lc_decay) + rss_decay = float(np.sum((model.predict(z_decay) - lc_decay) ** 2)) + + # Noise segment: flat at the mean of bins bp_idx..end + lc_noise = log_counts[bp_idx:] + noise_mean = lc_noise.mean() + rss_noise = float(np.sum((lc_noise - noise_mean) ** 2)) + + total_rss = rss_decay + rss_noise + if total_rss < best_rss: + best_rss = total_rss + best_bp_idx = bp_idx + + if best_bp_idx is None: + return np.nan, np.nan, np.nan, np.nan + + # Final fit on the decay segment + z_decay = depths[:best_bp_idx].reshape(-1, 1) + lc_decay = log_counts[:best_bp_idx] + model = LinearRegression() + model.fit(z_decay, lc_decay) + + kd = -model.coef_[0] + e0 = np.exp(model.intercept_) + breakpoint_depth = float(depths[best_bp_idx]) + noise_floor = float(np.exp(log_counts[best_bp_idx:].mean())) + + if kd < 0: + kd = np.nan + return kd, e0, breakpoint_depth, noise_floor + + +# --------------------------------------------------------------------------- +# Strategy C: Nonlinear fit C(z) = A * exp(-Kd * z) + N +# --------------------------------------------------------------------------- +def _beer_plus_noise(z, A, kd, N): + """Model: signal = A * exp(-kd * z) + N.""" + return A * np.exp(-kd * z) + N + + +def fit_beers_law_nonlinear(hist_df): + """ + Fit the physically correct model C(z) = A·exp(-Kd·z) + N directly + to raw photon counts using nonlinear least-squares (no log transform). + + Parameters + ---------- + hist_df : pd.DataFrame + Columns: 'zdepth' (ascending from surface), 'photon_counts'. + + Returns + ------- + tuple[float, float, float] + (kd, e0, noise_floor). + """ + df = hist_df.copy().sort_values("zdepth") + depths = df["zdepth"].values + counts = df["photon_counts"].values.astype(float) + + if len(depths) < 5: + return np.nan, np.nan, np.nan + + # Initial guesses + A0 = float(counts.max()) + N0 = float(np.median(counts[len(counts) * 3 // 4 :])) # deepest 25% + # Quick log-linear Kd estimate for initial guess (ignore noise) + pos = counts > 0 + if pos.sum() >= 2: + log_c = np.log(counts[pos]) + z_pos = depths[pos] + kd0 = max(float(-(log_c[-1] - log_c[0]) / (z_pos[-1] - z_pos[0] + 1e-9)), 0.01) + else: + kd0 = 0.1 + + try: + popt, _ = curve_fit( + _beer_plus_noise, + depths, + counts, + p0=[A0, kd0, N0], + bounds=([0, 0, 0], [np.inf, np.inf, np.inf]), + maxfev=5000, + ) + A_fit, kd_fit, N_fit = popt + if kd_fit <= 0: + kd_fit = np.nan + return float(kd_fit), float(A_fit), float(N_fit) + except (RuntimeError, ValueError): + return np.nan, np.nan, np.nan + + +# --------------------------------------------------------------------------- +# Hybrid: Stabilised breakpoint zone detection + log-linear fit on decay only +# --------------------------------------------------------------------------- +def fit_beers_law_hybrid( + hist_df, + min_breakpoint_depth=2.0, + expected_noise_floor=None, + noise_floor_tolerance=3.0, +): + """ + Two-stage approach matching the flowchart intent: + Step 16 — Find where exponential decay transitions to noise floor + (stabilised breakpoint with BIC selection). + Step 17 — Fit Beer's Law log-linear ONLY on the decay segment. + + Stabilisation constraints: + * Minimum breakpoint depth prevents shallow breakpoints that yield + unstable regression from too few decay bins. + * Decay segment slope must be negative (Kd > 0). + * BIC model selection penalises over-fitting, preventing the breakpoint + from drifting too deep (where the decay segment gets long and noisy). + * When expected_noise_floor is provided (from beam-level median), + candidate breakpoints whose noise segment deviates too far from the + expected value are penalised. + * Falls back to full log-linear if no valid breakpoint improves BIC + over a single-line model. + + Parameters + ---------- + hist_df : pd.DataFrame + Columns: 'zdepth' (ascending from surface), 'photon_counts'. + min_breakpoint_depth : float + Minimum depth (m) below surface for a valid breakpoint. Prevents + breakpoints landing in the first few bins where regression is + unstable. Default 2.0 m. + expected_noise_floor : float or None + If provided, the beam-level median noise floor (photon counts). + Candidate breakpoints whose noise segment mean deviates by more + than ``noise_floor_tolerance`` times from the expected value + receive a BIC penalty. + noise_floor_tolerance : float + Factor controlling how far the candidate noise segment mean may + deviate from ``expected_noise_floor`` before a penalty is applied. + Default 3.0 (allow 3x variation). + + Returns + ------- + tuple[float, float, float, float] + (kd, e0, breakpoint_depth, noise_floor). + breakpoint_depth and noise_floor are np.nan if fallback to full fit. + """ + df = hist_df.copy().sort_values("zdepth") + df = df[df["photon_counts"] > 0].copy() + n = len(df) + if n < 6: + return np.nan, np.nan, np.nan, np.nan + + df["log_counts"] = np.log(df["photon_counts"]) + depths = df["zdepth"].values + log_counts = df["log_counts"].values + + # ------------------------------------------------------------------ + # Reference: BIC for a single-line model (no breakpoint) + # ------------------------------------------------------------------ + model_full = LinearRegression() + model_full.fit(depths.reshape(-1, 1), log_counts) + rss_full = float( + np.sum((model_full.predict(depths.reshape(-1, 1)) - log_counts) ** 2) + ) + k_full = 2 # slope + intercept + bic_full = n * np.log(rss_full / n + 1e-10) + k_full * np.log(n) + + # ------------------------------------------------------------------ + # Search: best breakpoint using BIC + # ------------------------------------------------------------------ + best_bic = bic_full # must beat the single-line model + best_bp_idx = None + best_model = None + + for bp_idx in range(4, n - 1): + # -- Minimum breakpoint depth constraint -- + if depths[bp_idx] < min_breakpoint_depth: + continue + + # -- Decay segment (surface to breakpoint) -- + z_decay = depths[:bp_idx].reshape(-1, 1) + lc_decay = log_counts[:bp_idx] + model = LinearRegression() + model.fit(z_decay, lc_decay) + + # Physical constraint: slope must be negative + if model.coef_[0] >= 0: + continue + + pred_decay = model.predict(z_decay) + rss_decay = float(np.sum((pred_decay - lc_decay) ** 2)) + + # -- Noise segment (breakpoint to bottom) -- + lc_noise = log_counts[bp_idx:] + noise_mean = float(lc_noise.mean()) + rss_noise = float(np.sum((lc_noise - noise_mean) ** 2)) + + # -- BIC for piecewise model (3 params: slope, intercept, noise level) + rss_total = rss_decay + rss_noise + k_bp = 3 + bic_bp = n * np.log(rss_total / n + 1e-10) + k_bp * np.log(n) + + # -- Noise floor consistency penalty -- + # When we have a beam-level expected noise floor, penalise + # candidates whose noise segment deviates substantially. + if expected_noise_floor is not None and expected_noise_floor > 0: + candidate_nf = float(np.exp(noise_mean)) + ratio = candidate_nf / expected_noise_floor + if ratio > noise_floor_tolerance or ratio < 1.0 / noise_floor_tolerance: + # Add a penalty proportional to log-deviation + bic_bp += n * abs(np.log(ratio)) + + if bic_bp < best_bic: + best_bic = bic_bp + best_bp_idx = bp_idx + best_model = model + + # ------------------------------------------------------------------ + # Result + # ------------------------------------------------------------------ + if best_bp_idx is None: + # No breakpoint beats the single-line model — fall back + kd = -model_full.coef_[0] + e0 = np.exp(model_full.intercept_) + if kd < 0: + kd = np.nan + return kd, e0, np.nan, np.nan + + kd = -best_model.coef_[0] + e0 = np.exp(best_model.intercept_) + breakpoint_depth = float(depths[best_bp_idx]) + noise_floor = float(np.exp(log_counts[best_bp_idx:].mean())) + + if kd < 0: + kd = np.nan + return kd, e0, breakpoint_depth, noise_floor + + +# --------------------------------------------------------------------------- +# Dispatcher: select fitting strategy by name +# --------------------------------------------------------------------------- +KD_FIT_METHODS = ("log_linear", "bg_subtract", "breakpoint", "nonlinear", "hybrid") + + +def _fit_kd_with_method( + hist_df, method="log_linear", decay_threshold=0.0, expected_noise_floor=None +): + """ + Run the requested fitting strategy on a single histogram. + + Returns + ------- + tuple[float, float, float] + (kd, e0, noise_floor). noise_floor is np.nan for log_linear. + """ + if method == "log_linear": + valid = find_exponential_decay_zone(hist_df, decay_threshold=decay_threshold) + kd, e0 = fit_beers_law(valid) + return kd, e0, np.nan + + if method == "bg_subtract": + return fit_beers_law_bg_subtract(hist_df) + + if method == "breakpoint": + kd, e0, bp, nf = fit_beers_law_breakpoint(hist_df) + return kd, e0, nf + + if method == "nonlinear": + return fit_beers_law_nonlinear(hist_df) + + if method == "hybrid": + kd, e0, bp, nf = fit_beers_law_hybrid( + hist_df, expected_noise_floor=expected_noise_floor + ) + return kd, e0, nf + + raise ValueError( + f"Unknown kd_fit_method: {method!r}. " f"Choose from {KD_FIT_METHODS}" + ) + + +# another solution is to calculate kd without hist +def CalculateKdFromFilteredSubsurfacePhoton( + df, + vertical_res=0.8, + decay_zone_threshold=0.0, + kd_fit_method="log_linear", + expected_noise_floor=None, + surface_sigma=None, + wave_exclusion_multiplier=0.0, + wave_sigma_calm_threshold=0.1, +): + """ + Calculate Kd for a single along-track bin by fitting Beer's Law in log-space. + + Orchestrates steps 16 and 17 by calling find_exponential_decay_zone and + fit_beers_law in sequence. + + Flowchart steps: + 16 — find_exponential_decay_zone: determine depth range of exponential decay. + 17 — fit_beers_law: fit Beer's Law curve; kd = -slope, e0 = exp(intercept). + """ + # Early exit if DataFrame is empty or missing required column + if df.empty or "lat_bins" not in df.columns: + return pd.DataFrame( + { + "lat_bins": [np.nan], + "kd": [np.nan], + "e0": [np.nan], + "noise_floor": [np.nan], + "surface_sigma": [np.nan], + "latitude": [np.nan], + "longitude": [np.nan], + } + ) + + # Retrieve latitude and longitude + lat_bin_value = df["lat_bins"].iloc[0] if not df["lat_bins"].empty else np.nan + latitude = ( + df["latitude"].mean() + if "latitude" in df.columns + else df["lat"].mean() + if "lat" in df.columns + else np.nan + ) + longitude = ( + df["longitude"].mean() + if "longitude" in df.columns + else df["lon"].mean() + if "lon" in df.columns + else np.nan + ) + + # Use value_counts to get photon counts in each height bin + height_counts = df["height_bins"].value_counts().sort_index() + bin_centers = height_counts.index.astype(float) + + # Create a DataFrame for the height bins and counts + hist_df = pd.DataFrame( + {"zdepth": bin_centers, "photon_counts": height_counts.values} + ) + + # Reverse zdepth for model alignment + hist_df["zdepth"] = hist_df["zdepth"].max() - hist_df["zdepth"] + + # Wave-adaptive fit: skip shallow bins contaminated by wave smearing/bubbles. + # zdepth=0 in the histogram corresponds to the adaptive surface detection + # threshold: max(3*sigma, 1.0 m) below the Gaussian surface peak + # (set by get_sea_surface_height_adaptive). The wave-adaptive trim adds + # an additional margin of (k - 3)*sigma beyond the 3-sigma cutoff to + # cover bubble injection beneath wave troughs and compensate for sigma + # underestimation from the truncated ±1 m Gaussian fit window. + if surface_sigma is not None and wave_exclusion_multiplier > 0: + if not np.isnan(surface_sigma) and surface_sigma > wave_sigma_calm_threshold: + adaptive_offset = max(3.0 * surface_sigma, 1.0) + wave_skip = max( + 0.0, wave_exclusion_multiplier * surface_sigma - adaptive_offset + ) + trimmed = hist_df[hist_df["zdepth"] >= wave_skip] + if len(trimmed) >= 4: + logging.info( + "Wave-adaptive trim: sigma=%.3f, adaptive_offset=%.2f m, " + "skip=%.2f m, %d->%d bins", + surface_sigma, + adaptive_offset, + wave_skip, + len(hist_df), + len(trimmed), + ) + hist_df = trimmed + # else: keep original hist_df (too few bins after trim) + + kd, e0, noise_floor = _fit_kd_with_method( + hist_df, + method=kd_fit_method, + decay_threshold=decay_zone_threshold, + expected_noise_floor=expected_noise_floor, + ) + + return pd.DataFrame( + { + "lat_bins": [lat_bin_value], + "kd": [kd], + "e0": [e0], + "noise_floor": [noise_floor], + "surface_sigma": [surface_sigma if surface_sigma is not None else np.nan], + "latitude": [latitude], + "longitude": [longitude], + } + ) + + +# Original kd calculation function remains unchanged +def calculate_kd( + filtered_seafloor_subsurface_photon_dataset, + decay_zone_threshold=0.0, + kd_fit_method="log_linear", + wave_exclusion_multiplier=0.0, + wave_sigma_calm_threshold=0.1, +): + """ + Loop over along-track bins and call CalculateKdFromFilteredSubsurfacePhoton + for each bin. + + For the 'hybrid' method, a two-pass approach is used: + Pass 1 — Fit each bin independently to estimate per-bin noise floors. + Pass 2 — Compute the beam-level median noise floor, then re-fit each + bin with the expected noise floor as a stabilisation constraint. + + After fitting, IQR-based outlier filtering removes extreme Kd values + (for hybrid method only). + + Flowchart steps: + 16 — Determine depth of exponential decay zone (per bin, inside + CalculateKdFromFilteredSubsurfacePhoton). + 17 — Fit Beer's Law exponential decay curve (per bin). + 18 — Save K_dph attenuation coefficient value and statistical terms + (kd, e0 columns in the returned DataFrame). + """ + logging.info( + "Calculating Kd from filtered subsurface photon dataset (method=%s)", + kd_fit_method, + ) + + empty_df = pd.DataFrame( + { + "lat_bins": [], + "kd": [], + "e0": [], + "noise_floor": [], + "surface_sigma": [], + "latitude": [], + "longitude": [], + } + ) + + groups = list( + filtered_seafloor_subsurface_photon_dataset.groupby("lat_bins", observed=False) + ) + if not groups: + return empty_df + + if kd_fit_method == "hybrid": + # ------------------------------------------------------------------ + # Pass 1: estimate per-bin noise floors (no expected_noise_floor) + # ------------------------------------------------------------------ + logging.info("Hybrid pass 1: estimating per-bin noise floors") + pass1_results = [] + for _, group in groups: + sigma = ( + float(group["surface_sigma"].iloc[0]) + if ("surface_sigma" in group.columns and len(group) > 0) + else None + ) + pass1_results.append( + CalculateKdFromFilteredSubsurfacePhoton( + group, + decay_zone_threshold=decay_zone_threshold, + kd_fit_method=kd_fit_method, + surface_sigma=sigma, + wave_exclusion_multiplier=wave_exclusion_multiplier, + wave_sigma_calm_threshold=wave_sigma_calm_threshold, + ) + ) + if not pass1_results: + return empty_df + pass1_df = pd.concat(pass1_results, ignore_index=True) + + # ------------------------------------------------------------------ + # Compute sliding-window noise floor for pass 2 + # ------------------------------------------------------------------ + # 10 km window = ±10 bins at 500 m resolution (20 bins total). + # Each bin's expected noise floor = median of valid noise floors + # within ±half_window bins. Falls back to beam-level median when + # the local window has < 3 valid estimates. + HALF_WINDOW = 10 # ±10 bins = ±5 km at 500 m horizontal_res + MIN_LOCAL = 3 # minimum valid noise floors to use local median + + nf_array = pass1_df["noise_floor"].values.copy() + n_bins = len(nf_array) + + # Beam-level fallback + valid_nf_all = pass1_df["noise_floor"].dropna() + valid_nf_all = valid_nf_all[valid_nf_all > 0] + beam_median_nf = ( + float(valid_nf_all.median()) if len(valid_nf_all) >= MIN_LOCAL else None + ) + + if beam_median_nf is not None: + # Build per-bin expected noise floor via sliding window + expected_nf_per_bin = np.full(n_bins, np.nan) + for i in range(n_bins): + lo = max(0, i - HALF_WINDOW) + hi = min(n_bins, i + HALF_WINDOW + 1) + window_nf = nf_array[lo:hi] + valid = window_nf[~np.isnan(window_nf) & (window_nf > 0)] + if len(valid) >= MIN_LOCAL: + expected_nf_per_bin[i] = float(np.median(valid)) + else: + expected_nf_per_bin[i] = beam_median_nf # fallback + + logging.info( + "Hybrid pass 2: sliding window noise floor " + "(half_window=%d bins, beam_median=%.3f, " + "local range=%.3f-%.3f)", + HALF_WINDOW, + beam_median_nf, + float(np.nanmin(expected_nf_per_bin)), + float(np.nanmax(expected_nf_per_bin)), + ) + else: + expected_nf_per_bin = None + logging.info( + "Hybrid: insufficient noise floor estimates (%d), " "skipping pass 2", + len(valid_nf_all), + ) + + # ------------------------------------------------------------------ + # Pass 2: re-fit with per-bin expected noise floor constraint + # ------------------------------------------------------------------ + if expected_nf_per_bin is not None: + results = [] + for idx, (_, group) in enumerate(groups): + sigma = ( + float(group["surface_sigma"].iloc[0]) + if ("surface_sigma" in group.columns and len(group) > 0) + else None + ) + results.append( + CalculateKdFromFilteredSubsurfacePhoton( + group, + decay_zone_threshold=decay_zone_threshold, + kd_fit_method=kd_fit_method, + expected_noise_floor=float(expected_nf_per_bin[idx]), + surface_sigma=sigma, + wave_exclusion_multiplier=wave_exclusion_multiplier, + wave_sigma_calm_threshold=wave_sigma_calm_threshold, + ) + ) + SubsurfacePhotonDFAddedKd = pd.concat(results, ignore_index=True) + else: + SubsurfacePhotonDFAddedKd = pass1_df + + # ------------------------------------------------------------------ + # IQR-based outlier filtering (hybrid only) + # Uses 3×IQR to accommodate real spatial variation (e.g. turbidity + # gradients near river mouths) while removing fitting artefacts. + # ------------------------------------------------------------------ + kd_vals = SubsurfacePhotonDFAddedKd["kd"].dropna() + if len(kd_vals) >= 5: + q1 = float(kd_vals.quantile(0.25)) + q3 = float(kd_vals.quantile(0.75)) + iqr = q3 - q1 + lower = q1 - 3.0 * iqr + upper = q3 + 3.0 * iqr + # Also apply physical cap: Kd > 5.0 m⁻¹ is unrealistic + upper = min(upper, 5.0) + outlier_mask = (SubsurfacePhotonDFAddedKd["kd"] < lower) | ( + SubsurfacePhotonDFAddedKd["kd"] > upper + ) + n_outliers = int(outlier_mask.sum()) + if n_outliers > 0: + logging.info( + "Hybrid IQR filter: removed %d outliers (bounds: [%.4f, %.4f])", + n_outliers, + lower, + upper, + ) + SubsurfacePhotonDFAddedKd.loc[outlier_mask, "kd"] = np.nan + + else: + # ------------------------------------------------------------------ + # Non-hybrid methods: single pass, no IQR filtering + # ------------------------------------------------------------------ + results = [] + for _, group in groups: + sigma = ( + float(group["surface_sigma"].iloc[0]) + if ("surface_sigma" in group.columns and len(group) > 0) + else None + ) + results.append( + CalculateKdFromFilteredSubsurfacePhoton( + group, + decay_zone_threshold=decay_zone_threshold, + kd_fit_method=kd_fit_method, + surface_sigma=sigma, + wave_exclusion_multiplier=wave_exclusion_multiplier, + wave_sigma_calm_threshold=wave_sigma_calm_threshold, + ) + ) + if not results: + return empty_df + SubsurfacePhotonDFAddedKd = pd.concat(results, ignore_index=True) + + return SubsurfacePhotonDFAddedKd + + +# Updated function to apply kd calculation beam-by-beam +def process_kd_calculation( + Final_filtered_subsurface_photon_dataset, + decay_zone_threshold=0.0, + kd_fit_method="log_linear", + wave_exclusion_multiplier=0.0, + wave_sigma_calm_threshold=0.1, +): + """ + Beam-by-beam wrapper that calls calculate_kd for every beam and concatenates + the results into a single Kd output DataFrame. + + Flowchart steps: + 16 — Determine depth of exponential decay zone (delegated to + CalculateKdFromFilteredSubsurfacePhoton). + 17 — Fit Beer's Law exponential decay curve (delegated). + 18 — Save K_dph attenuation coefficient value (slope of the Beer's Law + curve) and statistical terms (kd, e0) — the returned DataFrame is + written to CSV by run_pipeline. + """ + # Initialize list to store results for each beam + kd_beam_datasets = [] + + # Group by 'beam_id' to process each beam independently + for beam_id, beam_data in Final_filtered_subsurface_photon_dataset.groupby( + "beam_id" + ): + logging.info(f"Calculating Kd for beam: {beam_id}") + + # Apply the calculate_kd function to the current beam's dataset + SubsurfacePhotonDFAddedKd = calculate_kd( + beam_data, + decay_zone_threshold=decay_zone_threshold, + kd_fit_method=kd_fit_method, + wave_exclusion_multiplier=wave_exclusion_multiplier, + wave_sigma_calm_threshold=wave_sigma_calm_threshold, + ) + + # Add a column to track the beam_id in the results + SubsurfacePhotonDFAddedKd["beam_id"] = beam_id + + # Append the result to the list + kd_beam_datasets.append(SubsurfacePhotonDFAddedKd) + + if not kd_beam_datasets: + logging.warning( + "process_kd_calculation: no beam data remained after filtering. " + "All bins may have been discarded by an upstream filter (e.g. histogram quality). " + "Returning empty DataFrame." + ) + return pd.DataFrame( + columns=[ + "lat_bins", + "kd", + "e0", + "noise_floor", + "surface_sigma", + "latitude", + "longitude", + "beam_id", + ] + ) + + # Combine results from all beams into a single DataFrame + combined_kd_dataset = pd.concat(kd_beam_datasets, ignore_index=True) + + return combined_kd_dataset diff --git a/aok/core/kd_utils/README.md b/aok/core/kd_utils/README.md new file mode 100644 index 0000000..8cefa84 --- /dev/null +++ b/aok/core/kd_utils/README.md @@ -0,0 +1,197 @@ +# icesat-2_kdph_py + + +This repository contains Python code developed in collaboration to replicate the functionality of the original MATLAB scripts (https://github.com/emilyeidam/icesat-2_kdph +) by Dr. Emily Eidam (emily.eidam@oregonstate.edu). The code facilitates processing ICESat-2 data to calculate the diffuse attenuation coefficient (Kd) based on space-based lidar photon profiles, following methods described in the original MATLAB code. Please cite appropriately both Python and the MATLAB version under the GNU GPLv3 license if you use this code in your research. + +## Repository Structure + +Current structure: + +```text +icesat-2_kdph_py/ ++-- main.py ++-- config.py ++-- requirements.txt ++-- README.md ++-- kd_utils/ + +-- __init__.py + +-- data_processing.py + +-- sea_photons_analysis.py + +-- bathy_processing.py + +-- Kd_analysis.py + +-- interpolation.py + +-- visualization.py + +-- sliderule_adapter.py + +-- SeaSurfaceFlattening.py + +-- SolarBckgrd/ +``` + +Key modules: + +- `main.py`: End-to-end pipeline entrypoint for local ATL03 workflows. +- `config.py`: All CLI switches, paths, and thresholds. +- `kd_utils/data_processing.py`: ATL03 ingestion, land/sea masking, optional pre-filters. +- `kd_utils/sea_photons_analysis.py`: Binning and sea-surface detection. +- `kd_utils/bathy_processing.py`: Subsurface filtering and optional gap steps (GEBCO/ATL24/refraction/etc.). +- `kd_utils/Kd_analysis.py`: Kd fitting. +- `kd_utils/sliderule_adapter.py`: SlideRule ATL03 fetch + conversion to framework schema for notebook/script imports. + +### Getting Started + +### Prerequisites + +- Python 3.x +- Libraries: `numpy`, `pandas`, `scipy`, `h5py`, `matplotlib` +- Ensure you have ICESat-2 ATL03 `.h5` files available in the appropriate directory, as specified in `config.py`, and also refers to the original MATLAB code. + +### Installation + +1. Clone the repository: + + ```bash + git clone https://github.com/ChaoEcohydroRS/icesat-2_kdph_py.git + cd IS2_Kd_Proj + ``` + + +2. Install the required packages: + ``` + pip install -r requirements.txt + ``` + +3. Ensure .h5 files are placed in the appropriate default folder structure. + +### Usage +1. Edit the config.py file to set up file paths and any parameters. +2. Run the entire workflow through main.py: + + ```bash + python main.py + ``` + +### SlideRule Ingestion (Notebook Import Path) + +You can reuse the same framework without local ATL03 `.h5` files by importing the SlideRule adapter: + +```python +from kd_utils.sliderule_adapter import fetch_and_prepare_sliderule_dataset +from kd_utils.sea_photons_analysis import process_sea_photon_binning +from kd_utils.bathy_processing import process_subsurface_photon_filtering +from kd_utils.Kd_analysis import process_kd_calculation + +region = [ + {"lon": -84.029, "lat": 29.732}, + {"lon": -84.029, "lat": 29.954}, + {"lon": -83.969, "lat": 29.954}, + {"lon": -83.969, "lat": 29.732}, + {"lon": -84.029, "lat": 29.732}, +] + +sea_photon_dataset = fetch_and_prepare_sliderule_dataset( + region=region, + t0="2025-05-22T00:00:00Z", + t1="2025-05-24T00:00:00Z", + rgt=1033, + strong_beams_only=True, +) + +binned = process_sea_photon_binning(sea_photon_dataset, horizontal_res=500, vertical_res=0.25) +sea_h, sea_lbl, subsurface = process_subsurface_photon_filtering( + binned, GEBCO_paths=[], subsurface_thresh=1.0, Ignore_Subsurface_Height_Thres=-6 +) +kd_df = process_kd_calculation(subsurface) +``` + +This path reuses your existing processing framework; only ingestion switches from local HDF5 to SlideRule API. + +### Optional Gap Steps (Sensitivity On/Off) + +All extra processing steps are optional and disabled by default. + +- IR/AP proxy photon filtering: + ```bash + python main.py --enable_ir_ap_filter + ``` +- Solar background filtering: + ```bash + python main.py --enable_solar_background_filter + ``` +- GEBCO seafloor filtering: + ```bash + python main.py --enable_gebco_filter + ``` +- ATL24-first with GEBCO fallback: + ```bash + python main.py --enable_atl24_filter --atl24_file path/to/atl24_points.csv --enable_gebco_filter + ``` +- Sea surface flattening: + ```bash + python main.py --enable_sea_surface_flattening + ``` +- Convex hull reasonableness filter: + ```bash + python main.py --enable_convex_hull_filter + ``` +- Paired-beam combine before Kd fit: + ```bash + python main.py --enable_paired_beam_combine + ``` +- Histogram quality filter (<5% signal at 6-7 m depth band): + ```bash + python main.py --enable_histogram_quality_filter + ``` +- Surface Gaussian sigma filter: + ```bash + python main.py --enable_surface_sigma_filter + ``` +- Refraction correction: + ```bash + python main.py --enable_refraction_correction + ``` +- Post-refraction rebuild + re-fit: + ```bash + python main.py --enable_refraction_correction --enable_post_refraction_refit + ``` + +### Gap Implementation Progress + +The flowchart gaps are being implemented one-by-one with optional switches. + +- [x] Gap 1: Remove photons flagged as IR/AP (proxy implementation via `quality_ph` and `photon_conf`) + - Switch: `--enable_ir_ap_filter` + - Tuning: `--ir_ap_quality_max`, `--ir_ap_min_signal_conf` +- [x] Gap 2: Histogram quality review (<5% signal at 6-7 m) + - Switch: `--enable_histogram_quality_filter` + - Tuning: + - `--histogram_quality_min_ratio` (default `0.05`) + - `--histogram_quality_depth_min` (default `6.0`) + - `--histogram_quality_depth_max` (default `7.0`) +- [x] Gap 3: Discard bins with high surface Gaussian std dev + - Switch: `--enable_surface_sigma_filter` + - Tuning: `--surface_sigma_max` (default `0.5`) +- [x] Gap 4: Refraction correction integration + - Switch: `--enable_refraction_correction` + - Tuning: + - `--refraction_water_temp_c` (default `20.0`) + - `--refraction_wavelength_nm` (default `532.0`) +- [x] Gap 5: Rebuild histogram and re-fit after refraction + - Switch: `--enable_post_refraction_refit` (requires `--enable_refraction_correction`) +- [x] Gap 6: ATL24 query + fallback orchestration + - Switch: `--enable_atl24_filter` + - Note: current implementation uses local ATL24 point files provided via `--atl24_file`. + - Inputs: + - `--atl24_file` (CSV/Parquet/GPKG/SHP with lon/lat + bathymetry field) + - `--atl24_max_match_distance_deg` (default `0.01`) + - Fallback: if ATL24 has no usable matches and `--enable_gebco_filter` is on, GEBCO filtering is used. +- [x] Gap 7: Combine paired beams + - Switch: `--enable_paired_beam_combine` + - Method: maps left/right beams to `gt1_pair`, `gt2_pair`, `gt3_pair` and re-bins by `relative_AT_dist` using `horizontal_res`. + +### Citation +If you use this code, please cite: + +Eidam, E.F., K. Bisson, C. Wang, C. Walker, and A. Gibbons (2024). ICESat-2 and ocean particulates: A roadmap for calculating Kd from space-based lidar photon profiles. Remote Sensing of Environment. Vol. 311. https://doi.org/10.1016/j.rse.2024.114222 + +### License +This project is licensed under the GNU GPLv3 license. diff --git a/aok/core/kd_utils/__init__.py b/aok/core/kd_utils/__init__.py new file mode 100644 index 0000000..d10baab --- /dev/null +++ b/aok/core/kd_utils/__init__.py @@ -0,0 +1,18 @@ +# utils/__init__.py + +from .bathy_processing import * +from .data_processing import ( + Extract_sea_photons, + create_photon_dataframe, + extract_file_params, + load_data, +) +from .interpolation import ( + apply_interpolation, + geoid_correction, + interpolate_labels, + refraction_correction, +) +from .Kd_analysis import CalculateKdFromFilteredSubsurfacePhoton +from .sea_photons_analysis import * +from .visualization import plot_kd_photons, plot_photon_height diff --git a/aok/core/kd_utils/bathy_processing.py b/aok/core/kd_utils/bathy_processing.py new file mode 100644 index 0000000..f29418b --- /dev/null +++ b/aok/core/kd_utils/bathy_processing.py @@ -0,0 +1,1101 @@ +# utils/bathy_processing.py +import logging +import os + +import numpy as np +import pandas as pd +import rasterio +from rtree import index +from scipy.spatial import cKDTree +from scipy.stats import norm +from shapely.geometry import box + +from kd_utils.sea_photons_analysis import ( + get_sea_surface_height_adaptive, + horizontal_vertical_bin_dataset, +) + + +def apply_optional_histogram_quality_filter( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + min_ratio=0.05, + depth_min=6.0, + depth_max=7.0, + reference_depth_min=0.0, + reference_depth_max=1.0, +): + """ + Optional histogram quality check: keep along-track bins only if enough + signal remains at the target depth band relative to near-surface signal. + + Flowchart step: 6 — Review histograms for quality: if <5% of signal remains + at an uncorrected water depth of 6-7 m below the surface, quality-flag this + along-track bin as low confidence and exclude it from further analyses + (Possible discard). + + The decay ratio is computed as: + photon_count_in_depth_band / photon_count_in_reference_band + + This measures how much signal has decayed at depth compared to near the + surface, consistent with Beer's Law attenuation. + + Parameters + ---------- + binned_dataset_sea_surface : pd.DataFrame + Full binned photon dataset with 'lat_bins' and 'photon_height'. + subsurface_photon_dataset : pd.DataFrame + Subsurface photon dataset to filter. + sea_surface_height : list + Sea surface height per along-track bin. + min_ratio : float + Minimum decay ratio to keep a bin (default 0.05 = 5%). + depth_min : float + Minimum depth of the quality-check band (default 6.0 m). + depth_max : float + Maximum depth of the quality-check band (default 7.0 m). + reference_depth_min : float + Minimum depth of the near-surface reference band (default 0.0 m). + reference_depth_max : float + Maximum depth of the near-surface reference band (default 1.0 m). + """ + if subsurface_photon_dataset.empty: + return subsurface_photon_dataset + + grouped = binned_dataset_sea_surface.groupby(["lat_bins"], observed=False) + lat_bin_keys = list(grouped.groups.keys()) + if len(lat_bin_keys) != len(sea_surface_height): + return subsurface_photon_dataset + + quality_df = ( + pd.DataFrame( + {"lat_bins": lat_bin_keys, "sea_surface_height": sea_surface_height} + ) + .dropna(subset=["sea_surface_height"]) + .copy() + ) + if quality_df.empty: + return subsurface_photon_dataset.iloc[0:0].copy() + + ratios = [] + for _, row in quality_df.iterrows(): + lat_bin = row["lat_bins"] + surface = row["sea_surface_height"] + bin_data = binned_dataset_sea_surface[ + binned_dataset_sea_surface["lat_bins"] == lat_bin + ] + if bin_data.empty: + ratios.append(0.0) + continue + + depth = surface - bin_data["photon_height"] + underwater_mask = depth > 0 + + # Reference band: near-surface photon count (incoming signal) + ref_mask = ( + underwater_mask + & (depth >= reference_depth_min) + & (depth <= reference_depth_max) + ) + ref_count = float(ref_mask.sum()) + if ref_count == 0: + ratios.append(0.0) + continue + + # Target band: photon count at the check depth + band_mask = underwater_mask & (depth >= depth_min) & (depth <= depth_max) + band_count = float(band_mask.sum()) + + # Normalize by band width so bands of different sizes are comparable + ref_width = max(reference_depth_max - reference_depth_min, 0.01) + band_width = max(depth_max - depth_min, 0.01) + decay_ratio = (band_count / band_width) / (ref_count / ref_width) + + ratios.append(decay_ratio) + + quality_df["hist_quality_ratio"] = ratios + valid_bins = set( + quality_df.loc[ + quality_df["hist_quality_ratio"] >= min_ratio, "lat_bins" + ].tolist() + ) + return subsurface_photon_dataset[ + subsurface_photon_dataset["lat_bins"].isin(valid_bins) + ].copy() + + +def compute_surface_sigma(binned_dataset_sea_surface, sea_surface_height): + """Compute per-bin Gaussian surface sigma from photons near the surface peak. + + Uses a two-pass fit: initial ±1 m window, then expands to ±min(2.5σ, 3 m) + if the initial sigma > 0.4 m (indicating the ±1 m window truncates the + distribution). Returns a dict mapping lat_bin -> sigma (float or NaN). + Bins with fewer than 10 surface photons get NaN. + """ + grouped = binned_dataset_sea_surface.groupby(["lat_bins"], observed=False) + lat_bin_keys = list(grouped.groups.keys()) + if len(lat_bin_keys) != len(sea_surface_height): + return {} + + sigma_map = {} + for lat_bin, surface in zip(lat_bin_keys, sea_surface_height, strict=False): + if pd.isna(surface): + continue + bin_data = binned_dataset_sea_surface[ + binned_dataset_sea_surface["lat_bins"] == lat_bin + ] + if bin_data.empty: + continue + peak_data = bin_data[ + (bin_data["photon_height"] > surface - 1.0) + & (bin_data["photon_height"] < surface + 1.0) + ] + if len(peak_data) > 10: + mu, sigma = norm.fit(peak_data["photon_height"]) + if sigma > 0.4: + half_win2 = min(2.5 * sigma, 3.0) + peak_data2 = bin_data[ + (bin_data["photon_height"] > mu - half_win2) + & (bin_data["photon_height"] < mu + half_win2) + ] + if len(peak_data2) > 10: + _, sigma = norm.fit(peak_data2["photon_height"]) + else: + sigma = np.nan + sigma_map[lat_bin] = sigma + + return sigma_map + + +def apply_optional_surface_sigma_filter( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + sigma_max=0.5, +): + """ + Optional Gaussian surface sigma filter: discard along-track bins if the + fitted surface Gaussian sigma exceeds sigma_max. + + Flowchart step: 8 — If standard deviation of Gaussian peak is >X m, + quality-flag this along-track bin as low confidence and discard it + (addresses surface waves) (Possible discard). + """ + if subsurface_photon_dataset.empty: + return subsurface_photon_dataset + + sigma_map = compute_surface_sigma(binned_dataset_sea_surface, sea_surface_height) + if not sigma_map: + return subsurface_photon_dataset + + sigma_df = pd.DataFrame( + list(sigma_map.items()), columns=["lat_bins", "surface_sigma"] + ) + valid_bins = set( + sigma_df.loc[sigma_df["surface_sigma"] <= sigma_max, "lat_bins"].tolist() + ) + return subsurface_photon_dataset[ + subsurface_photon_dataset["lat_bins"].isin(valid_bins) + ].copy() + + +def apply_optional_refraction_correction( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + water_temp_c=20.0, + wavelength_nm=532.0, +): + """ + Optional refraction correction for subsurface photons. + Updates photon positions/heights using Snell-based geometry. + + Flowchart step: 9 — Compute water depths (distance below surface) and + correct depths for refraction. + """ + required_cols = { + "lat_bins", + "photon_height", + "ref_elevation", + "ref_azimuth", + "lon", + "lat", + } + if subsurface_photon_dataset.empty or ( + not required_cols.issubset(set(subsurface_photon_dataset.columns)) + ): + return subsurface_photon_dataset + + grouped = binned_dataset_sea_surface.groupby(["lat_bins"], observed=False) + lat_bin_keys = list(grouped.groups.keys()) + if len(lat_bin_keys) != len(sea_surface_height): + return subsurface_photon_dataset + + surface_df = pd.DataFrame( + {"lat_bins": lat_bin_keys, "sea_surface_height": sea_surface_height} + ) + corrected = subsurface_photon_dataset.merge( + surface_df, on="lat_bins", how="left" + ).copy() + if corrected["sea_surface_height"].isna().all(): + return subsurface_photon_dataset + + corrected["photon_height_pre_refraction"] = corrected["photon_height"] + corrected["lon_pre_refraction"] = corrected["lon"] + corrected["lat_pre_refraction"] = corrected["lat"] + + # Refraction index parameterization from legacy module. + a = -0.000001501562500 + b = 0.000000107084865 + c = -0.000042759374989 + d = -0.000160475520686 + e = 1.398067112092424 + n1 = 1.00029 + n2 = ( + (a * water_temp_c**2) + + (b * wavelength_nm**2) + + (c * water_temp_c) + + (d * wavelength_nm) + + e + ) + + ref_elev = pd.to_numeric(corrected["ref_elevation"], errors="coerce").to_numpy( + dtype=float + ) + ref_az = pd.to_numeric(corrected["ref_azimuth"], errors="coerce").to_numpy( + dtype=float + ) + z = pd.to_numeric(corrected["photon_height"], errors="coerce").to_numpy(dtype=float) + ws = pd.to_numeric(corrected["sea_surface_height"], errors="coerce").to_numpy( + dtype=float + ) + x = pd.to_numeric(corrected["lon"], errors="coerce").to_numpy(dtype=float) + y = pd.to_numeric(corrected["lat"], errors="coerce").to_numpy(dtype=float) + + # Convert to radians if values look like degrees. + if np.nanmax(np.abs(ref_elev)) > (2 * np.pi): + ref_elev = np.deg2rad(ref_elev) + if np.nanmax(np.abs(ref_az)) > (2 * np.pi): + ref_az = np.deg2rad(ref_az) + + valid_mask = ( + np.isfinite(ref_elev) + & np.isfinite(ref_az) + & np.isfinite(z) + & np.isfinite(ws) + & (z <= ws) + ) + if not np.any(valid_mask): + return subsurface_photon_dataset + + theta1 = (np.pi / 2.0) - ref_elev[valid_mask] + theta2_arg = (n1 * np.sin(theta1)) / n2 + theta2_arg = np.clip(theta2_arg, -1.0, 1.0) + theta2 = np.arcsin(theta2_arg) + + D = ws[valid_mask] - z[valid_mask] + cos_theta1 = np.cos(theta1) + cos_theta1 = np.where(np.abs(cos_theta1) < 1e-6, np.nan, cos_theta1) + S = D / cos_theta1 + R = (S * n1) / n2 + Gamma = (np.pi / 2.0) - theta1 + phi = theta1 - theta2 + P_sq = np.maximum(R**2 + S**2 - 2 * R * S * np.cos(phi), 0.0) + P = np.sqrt(P_sq) + alpha_arg = np.divide(R * np.sin(phi), P, out=np.zeros_like(P), where=(P != 0)) + alpha_arg = np.clip(alpha_arg, -1.0, 1.0) + alpha = np.arcsin(alpha_arg) + Beta = Gamma - alpha + + DY = P * np.cos(Beta) + DZ = P * np.sin(Beta) + DE = DY * np.sin(ref_az[valid_mask]) + DN = DY * np.cos(ref_az[valid_mask]) + + x_corr = x[valid_mask] + DE + y_corr = y[valid_mask] + DN + z_corr = z[valid_mask] + DZ + + corrected.loc[valid_mask, "lon"] = x_corr + corrected.loc[valid_mask, "lat"] = y_corr + corrected.loc[valid_mask, "photon_height"] = z_corr + return corrected + + +def apply_sea_surface_flattening( + subsurface_photon_dataset, + sea_surface_height, + lat_bin_keys, + horizontal_res, + flattening_window_m=500, +): + """ + Flatten small-bin sea-surface variations to the mean sea level of larger along-track windows. + This follows the standalone SeaSurfaceFlattening module logic, but is optional. + """ + if subsurface_photon_dataset.empty: + return subsurface_photon_dataset + + if len(sea_surface_height) != len(lat_bin_keys): + return subsurface_photon_dataset + + surface_df = ( + pd.DataFrame( + {"lat_bins": lat_bin_keys, "sea_surface_height_local": sea_surface_height} + ) + .dropna(subset=["sea_surface_height_local"]) + .copy() + ) + if surface_df.empty: + return subsurface_photon_dataset + + surface_df["lat_bins_num"] = pd.to_numeric(surface_df["lat_bins"], errors="coerce") + surface_df = surface_df.dropna(subset=["lat_bins_num"]) + if surface_df.empty: + return subsurface_photon_dataset + + big_bin_size = max(1, int(round(flattening_window_m / max(horizontal_res, 1)))) + if big_bin_size <= 1: + logging.warning( + "Sea surface flattening window (%d m) <= horizontal_res (%d m) → " + "big_bin_size=1, flattening has no effect. " + "Use --sea_surface_flattening_window_m > --horizontal_res.", + flattening_window_m, + horizontal_res, + ) + surface_df["big_bin"] = ( + surface_df["lat_bins_num"].astype(int) // big_bin_size + ).astype(int) + mean_surface = ( + surface_df.groupby("big_bin", observed=False)["sea_surface_height_local"] + .mean() + .rename("sea_surface_height_bigbin_mean") + .reset_index() + ) + surface_df = surface_df.merge(mean_surface, on="big_bin", how="left") + surface_df["sea_surface_flattening_offset"] = ( + surface_df["sea_surface_height_bigbin_mean"] + - surface_df["sea_surface_height_local"] + ) + + flattened = subsurface_photon_dataset.copy() + flattened["lat_bins_num"] = pd.to_numeric(flattened["lat_bins"], errors="coerce") + flattened["big_bin"] = ( + flattened["lat_bins_num"].fillna(-1).astype(int) // big_bin_size + ).astype(int) + flattened = flattened.merge( + surface_df[ + [ + "lat_bins", + "sea_surface_height_local", + "sea_surface_height_bigbin_mean", + "sea_surface_flattening_offset", + ] + ], + on="lat_bins", + how="left", + ) + flattened["photon_height_original"] = flattened["photon_height"] + flattened["photon_height"] = flattened["photon_height"] - flattened[ + "sea_surface_flattening_offset" + ].fillna(0.0) + + # Re-bin height_bins from flattened photon_height so that downstream + # Kd calculation (which reads height_bins) uses the corrected depths. + if "height_bins" in flattened.columns: + vertical_res_actual = horizontal_res # not used; infer from existing bins + # Infer the vertical resolution from the original height_bins spacing + orig_bins = ( + pd.to_numeric(flattened["height_bins"], errors="coerce").dropna().unique() + ) + if len(orig_bins) >= 2: + sorted_bins = np.sort(orig_bins) + diffs = np.diff(sorted_bins) + v_res = float(np.median(diffs[diffs > 0])) if np.any(diffs > 0) else 0.25 + else: + v_res = 0.25 + h_min = flattened["photon_height"].min() + h_max = flattened["photon_height"].max() + n_hbins = max(1, int(round((h_max - h_min) / v_res))) + bin_edges = np.linspace(h_min, h_max, n_hbins + 1) + bin_labels = np.round((bin_edges[:-1] + bin_edges[1:]) / 2, decimals=1) + flattened["height_bins"] = pd.cut( + flattened["photon_height"], + bins=bin_edges, + labels=bin_labels, + include_lowest=True, + ) + + return flattened.drop(columns=["lat_bins_num", "big_bin"], errors="ignore") + + +# 1. Function to create an R-tree spatial index for raster bounds +def create_spatial_index(gebco_paths): + """ + Create an R-tree spatial index for raster bounds to quickly find relevant rasters. + + Parameters: + gebco_paths (list): List of paths to GEBCO raster files. + + Returns: + raster_data_dict (dict): Dictionary containing loaded rasters and their respective data. + spatial_index (rtree.index.Index): R-tree spatial index for raster bounds. + """ + idx = index.Index() + raster_data_dict = {} + for i, path in enumerate(gebco_paths): + with rasterio.Env(GTIFF_SRS_SOURCE="EPSG"): + gebco_raster = rasterio.open(path) + raster_data = gebco_raster.read(1) + raster_data_dict[path] = (gebco_raster, raster_data) + bounds = gebco_raster.bounds + idx.insert( + i, (bounds.left, bounds.bottom, bounds.right, bounds.top), obj=path + ) + return raster_data_dict, idx + + +# 2. Function to determine which rasters are needed for a given set of coordinates using spatial index in a batch manner +def get_relevant_rasters_using_index(lons, lats, raster_data_dict, spatial_index): + """ + Determine which raster datasets are relevant for the given coordinates using spatial index. + This function uses a more efficient approach by performing a bounding box query for batches of points. + + Parameters: + lons (numpy.ndarray): Array of longitudes. + lats (numpy.ndarray): Array of latitudes. + raster_data_dict (dict): Dictionary containing raster datasets and their data. + spatial_index (rtree.index.Index): R-tree spatial index for raster bounds. + + Returns: + relevant_rasters (list): List of relevant raster datasets and their respective data. + """ + # Create a bounding box that covers all points + min_lon, max_lon = lons.min(), lons.max() + min_lat, max_lat = lats.min(), lats.max() + bounding_box = box(min_lon, min_lat, max_lon, max_lat) + + # Get all rasters that intersect with the bounding box + matches = list(spatial_index.intersection((bounding_box.bounds), objects=True)) + relevant_paths = {match.object for match in matches} + relevant_rasters = [ + (raster_data_dict[path][0], raster_data_dict[path][1]) + for path in relevant_paths + ] + return relevant_rasters + + +# 3. Function to get seafloor elevation from the relevant rasters in a vectorized manner +def get_seafloor_elevation(lons, lats, relevant_rasters): + """ + Get seafloor elevations for a batch of points based on longitude and latitude from relevant rasters. + + Parameters: + lons (numpy.ndarray): Array of longitudes. + lats (numpy.ndarray): Array of latitudes. + relevant_rasters (list): List of relevant raster datasets and their respective data. + + Returns: + seafloor_elevations (numpy.ndarray): Array of seafloor elevations. + """ + points = np.vstack((lons, lats)).T + seafloor_elevations = np.full(len(lons), np.nan) + + for gebco_raster, raster_data in relevant_rasters: + # Use rasterio.sample to get values for multiple points in a batch + values = list(gebco_raster.sample(points)) + for idx, value in enumerate(values): + if np.isnan(seafloor_elevations[idx]) and value is not None: + seafloor_elevations[idx] = value[0] + + return seafloor_elevations + + +def query_gebco_seafloor_elevation(sea_photon_dataset, gebco_paths): + """ + Query GEBCO raster tiles and attach a 'seafloor_elevation' column to the + photon dataset. Does not remove any rows. + + Flowchart step: 12 — Retrieve GEBCO data (or other regional raster + bathymetry data) for the region when ATL24 is unavailable or returns no + matching points. + + Parameters + ---------- + sea_photon_dataset : pd.DataFrame + Must contain 'longitude' and 'latitude' columns. + gebco_paths : list + Paths to GEBCO GeoTIFF raster files. + + Returns + ------- + pd.DataFrame + Input dataset with 'seafloor_elevation' column added (negative = below + sea level). + """ + dataset = sea_photon_dataset.copy() + raster_data_dict, spatial_index = create_spatial_index(gebco_paths) + lons = dataset["longitude"].values + lats = dataset["latitude"].values + relevant_rasters = get_relevant_rasters_using_index( + lons, lats, raster_data_dict, spatial_index + ) + dataset["seafloor_elevation"] = get_seafloor_elevation(lons, lats, relevant_rasters) + return dataset + + +def remove_photons_below_seafloor(sea_photon_dataset): + """ + Remove photons whose height is at or below the seabed elevation. + + Flowchart step: 13 — Remove data below seabed (seabed defined by ATL24 or + GEBCO). + + Parameters + ---------- + sea_photon_dataset : pd.DataFrame + Must contain 'photon_height' and 'seafloor_elevation' columns. + + Returns + ------- + pd.DataFrame + Only rows where photon_height > seafloor_elevation. + """ + return sea_photon_dataset[ + sea_photon_dataset["photon_height"] > sea_photon_dataset["seafloor_elevation"] + ].copy() + + +def discard_shallow_seabed_bins(sea_photon_dataset, min_depth_m): + """ + Discard all along-track bins where the seabed is shallower than the minimum + depth threshold. Kd is not calculated for these bins; they are labelled as + low confidence by exclusion. + + Flowchart step: 14 — Discard all photon data in horizontal bins where the + seabed is <5 m deep; Kd is not calculated for these bins — label discarded + bins as low confidence. + + Parameters + ---------- + sea_photon_dataset : pd.DataFrame + Must contain a 'depth' column (= -seafloor_elevation, positive downward). + min_depth_m : float + Minimum seabed depth to retain (e.g., 5.0 m). Bins shallower than this + are discarded. + + Returns + ------- + pd.DataFrame + Only rows where depth >= min_depth_m. + """ + return sea_photon_dataset[sea_photon_dataset["depth"] >= min_depth_m].copy() + + +# 4. The main process function that ties everything together +def process_seafloor_data( + sea_photon_dataset, gebco_paths, Ignore_Subsurface_Height_Thres +): + """ + Query GEBCO bathymetry then apply seabed and shallow-water filters. + + Delegates to three single-step helpers: + query_gebco_seafloor_elevation → step 12 + remove_photons_below_seafloor → step 13 + discard_shallow_seabed_bins → step 14 + + Parameters: + gebco_paths (list): List of path to GEBCO raster files. + sea_photon_dataset (pandas.DataFrame): Dataset containing longitude, + latitude, and photon height. + + Returns: + filtered_sea_photon_dataset (pandas.DataFrame): Filtered dataset + containing points above the seafloor in water deeper than threshold. + """ + dataset = query_gebco_seafloor_elevation(sea_photon_dataset, gebco_paths) # step 12 + dataset["depth"] = -dataset["seafloor_elevation"] + dataset = remove_photons_below_seafloor(dataset) # step 13 + return discard_shallow_seabed_bins( + dataset, abs(Ignore_Subsurface_Height_Thres) + ) # step 14 + + +def load_atl24_points(atl24_file_path): + """ + Load ATL24 point dataset and normalize to columns: longitude, latitude, + seafloor_elevation_atl24. + + Flowchart step: 12 — Query ATL24 data for bathymetry match; if no match, + retrieve GEBCO data (or other regional raster bathymetry data) for region. + """ + if (not atl24_file_path) or (not os.path.exists(atl24_file_path)): + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + + lower_path = atl24_file_path.lower() + if lower_path.endswith((".h5", ".hdf5")): + try: + import h5py + + beams = ["gt1l", "gt1r", "gt2l", "gt2r", "gt3l", "gt3r"] + frames = [] + with h5py.File(atl24_file_path, "r") as f: + for beam in beams: + if beam not in f: + continue + grp = f[beam] + if not all( + k in grp for k in ("lon_ph", "lat_ph", "ortho_h", "class_ph") + ): + continue + class_ph = grp["class_ph"][:] + bathy_mask = class_ph == 40 # ATL24 bathymetry class + # Apply low_confidence_flag filter if available + if "low_confidence_flag" in grp: + bathy_mask = bathy_mask & (grp["low_confidence_flag"][:] == 0) + if bathy_mask.sum() == 0: + continue + frames.append( + pd.DataFrame( + { + "longitude": grp["lon_ph"][:][bathy_mask].astype(float), + "latitude": grp["lat_ph"][:][bathy_mask].astype(float), + "seafloor_elevation_atl24": grp["ortho_h"][:][ + bathy_mask + ].astype(float), + } + ) + ) + if not frames: + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + return pd.concat(frames, ignore_index=True).dropna() + except Exception: + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + elif lower_path.endswith((".csv", ".txt")): + atl24_df = pd.read_csv(atl24_file_path) + elif lower_path.endswith(".parquet"): + atl24_df = pd.read_parquet(atl24_file_path) + elif lower_path.endswith((".gpkg", ".shp", ".geojson")): + try: + import geopandas as gpd + + atl24_gdf = gpd.read_file(atl24_file_path) + atl24_df = pd.DataFrame(atl24_gdf) + if ( + "longitude" not in atl24_df.columns + or "latitude" not in atl24_df.columns + ) and ("geometry" in atl24_gdf.columns): + atl24_df["longitude"] = atl24_gdf.geometry.x + atl24_df["latitude"] = atl24_gdf.geometry.y + except Exception: + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + else: + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + + lon_candidates = ["longitude", "lon", "x", "LONGITUDE", "LON"] + lat_candidates = ["latitude", "lat", "y", "LATITUDE", "LAT"] + elev_candidates = [ + "seafloor_elevation", + "bottom_elevation", + "bathymetry", + "elevation", + "z", + "depth", + "water_depth", + ] + + lon_col = next((c for c in lon_candidates if c in atl24_df.columns), None) + lat_col = next((c for c in lat_candidates if c in atl24_df.columns), None) + elev_col = next((c for c in elev_candidates if c in atl24_df.columns), None) + if lon_col is None or lat_col is None or elev_col is None: + return pd.DataFrame( + columns=["longitude", "latitude", "seafloor_elevation_atl24"] + ) + + out_df = ( + pd.DataFrame( + { + "longitude": pd.to_numeric(atl24_df[lon_col], errors="coerce"), + "latitude": pd.to_numeric(atl24_df[lat_col], errors="coerce"), + "atl24_raw_bathy": pd.to_numeric(atl24_df[elev_col], errors="coerce"), + } + ) + .dropna(subset=["longitude", "latitude", "atl24_raw_bathy"]) + .copy() + ) + + if "depth" in elev_col.lower() and ("elev" not in elev_col.lower()): + out_df["seafloor_elevation_atl24"] = -out_df["atl24_raw_bathy"].abs() + else: + out_df["seafloor_elevation_atl24"] = out_df["atl24_raw_bathy"] + + return out_df[["longitude", "latitude", "seafloor_elevation_atl24"]] + + +def process_seafloor_data_atl24( + sea_photon_dataset, + atl24_file_path, + Ignore_Subsurface_Height_Thres, + max_match_distance_deg=0.01, +): + """ + Match ATL24 bathymetry points to photons, then filter below seafloor and + shallow bins. Returns filtered dataset and number of matched photons. + + Flowchart steps: + 12 — Query ATL24 data for bathymetry match; if no match, retrieve GEBCO data. + 13 — Remove data below seabed (seabed defined by ATL24 or GEBCO). + 14 — Discard all photon data in horizontal bins where the seabed is <5 m + deep; Kd is not calculated for these bins — label discarded bins as + low confidence. + """ + atl24_points = load_atl24_points(atl24_file_path) + if atl24_points.empty: + return sea_photon_dataset, 0 + + photon_coords = sea_photon_dataset[["longitude", "latitude"]].to_numpy(dtype=float) + atl24_coords = atl24_points[["longitude", "latitude"]].to_numpy(dtype=float) + if len(photon_coords) == 0 or len(atl24_coords) == 0: + return sea_photon_dataset, 0 + + tree = cKDTree(atl24_coords) + distances, indices = tree.query( + photon_coords, k=1, distance_upper_bound=max_match_distance_deg + ) + valid_match = np.isfinite(distances) & (indices < len(atl24_points)) + matched_count = int(valid_match.sum()) + if matched_count == 0: + return sea_photon_dataset, 0 + + matched_dataset = sea_photon_dataset.copy() + matched_dataset["seafloor_elevation"] = np.nan + matched_dataset.loc[valid_match, "seafloor_elevation"] = atl24_points[ + "seafloor_elevation_atl24" + ].to_numpy()[indices[valid_match]] + matched_dataset = matched_dataset.dropna(subset=["seafloor_elevation"]).copy() + matched_dataset["depth"] = -matched_dataset["seafloor_elevation"] + + above_floor = remove_photons_below_seafloor(matched_dataset) # step 13 + filtered_dataset = discard_shallow_seabed_bins( + above_floor, abs(Ignore_Subsurface_Height_Thres) + ) # step 14 + return filtered_dataset, matched_count + + +def rebuild_and_refit_surface_after_refraction( + binned_dataset_sea_surface, + sea_surface_height, + water_temp_c, + wavelength_nm, + horizontal_res, + vertical_res, +): + """ + Apply refraction correction to the full binned dataset, rebuild histograms + on the corrected depths, and re-fit the Gaussian surface peak. + + Flowchart steps: + 10 — Re-build histograms based on corrected depths + (apply_optional_refraction_correction → horizontal_vertical_bin_dataset). + 11 — Re-fit Gaussian curve to surface to identify surface peak; compute + standard deviation and remove histogram data within three standard + deviations (get_sea_surface_height_adaptive on the rebinned data). + + Parameters + ---------- + binned_dataset_sea_surface : pd.DataFrame + Full binned photon dataset before refraction correction. + sea_surface_height : list + Per-bin surface heights from step 7. + water_temp_c : float + Water temperature for the refraction index calculation. + wavelength_nm : float + Laser wavelength (nm) for the refraction index calculation. + horizontal_res : float + Along-track bin size (m) for histogram rebuild. + vertical_res : float + Vertical bin size (m) for histogram rebuild. + + Returns + ------- + tuple + (sea_surface_height, sea_surface_height_abnormal_label, + solo_sea_surface_label, subsurface_photon_dataset) — same 4-tuple as + get_sea_surface_height_adaptive. + """ + corrected_full_dataset = apply_optional_refraction_correction( + binned_dataset_sea_surface, + binned_dataset_sea_surface.copy(), + sea_surface_height, + water_temp_c=water_temp_c, + wavelength_nm=wavelength_nm, + ) + rebinned_corrected = horizontal_vertical_bin_dataset( # step 10 + corrected_full_dataset, horizontal_res, vertical_res + ) + return get_sea_surface_height_adaptive(rebinned_corrected) # step 11 + + +# 5. The function to get the subsurface photon dataset +def get_subsurface_photon( + binned_dataset_sea_surface, + GEBCO_paths, + subsurface_thresh, + Ignore_Subsurface_Height_Thres, + use_atl24_filter=False, + atl24_file_path="", + atl24_max_match_distance_deg=0.01, + use_gebco_filter=False, + apply_histogram_quality_filter=False, + histogram_quality_min_ratio=0.05, + histogram_quality_depth_min=6.0, + histogram_quality_depth_max=7.0, + histogram_quality_ref_depth_min=0.0, + histogram_quality_ref_depth_max=1.0, + apply_surface_sigma_filter=False, + surface_sigma_max=0.5, + apply_refraction_correction=False, + refraction_water_temp_c=20.0, + refraction_wavelength_nm=532.0, + apply_post_refraction_refit=False, + apply_flattening=False, + flattening_window_m=500, + horizontal_res=500, + vertical_res=0.25, +): + """ + Orchestrate per-beam subsurface photon extraction and all optional quality filters. + + Flowchart steps executed in order: + 7 — Fit Gaussian curve to identify surface elevation; compute std dev of + Gaussian peak (via get_sea_surface_height_adaptive). + 6 — Review histograms for quality; quality-flag low-confidence bins and + exclude (optional, apply_histogram_quality_filter). + 8 — If Gaussian std dev > X m, quality-flag and discard bin (optional, + apply_surface_sigma_filter). + 9 — Compute water depths and correct for refraction (optional, + apply_refraction_correction). + 10 — Re-build histograms based on corrected depths (optional, + apply_post_refraction_refit → rebuild_and_refit_surface_after_refraction). + 11 — Re-fit Gaussian curve to identify surface peak; remove histogram data + within three standard deviations (optional, + rebuild_and_refit_surface_after_refraction). + 12 — Query ATL24 / retrieve GEBCO bathymetry data (optional, + use_atl24_filter / use_gebco_filter). + 13 — Remove data below seabed (optional). + 14 — Discard bins where seabed < 5 m deep; label as low confidence + (applied unconditionally via Ignore_Subsurface_Height_Thres). + """ + # adaptive threshold + ( + sea_surface_height, + sea_surface_label, + solo_sea_surface_label, + subsurface_photon_dataset, + ) = get_sea_surface_height_adaptive(binned_dataset_sea_surface) + + # Always compute per-bin surface sigma for quality flagging and wave-adaptive fit + sigma_map = compute_surface_sigma(binned_dataset_sea_surface, sea_surface_height) + + if apply_histogram_quality_filter: + subsurface_photon_dataset = apply_optional_histogram_quality_filter( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + min_ratio=histogram_quality_min_ratio, + depth_min=histogram_quality_depth_min, + depth_max=histogram_quality_depth_max, + reference_depth_min=histogram_quality_ref_depth_min, + reference_depth_max=histogram_quality_ref_depth_max, + ) + + if apply_surface_sigma_filter: + subsurface_photon_dataset = apply_optional_surface_sigma_filter( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + sigma_max=surface_sigma_max, + ) + + if apply_refraction_correction: + subsurface_photon_dataset = apply_optional_refraction_correction( + binned_dataset_sea_surface, + subsurface_photon_dataset, + sea_surface_height, + water_temp_c=refraction_water_temp_c, + wavelength_nm=refraction_wavelength_nm, + ) + if apply_post_refraction_refit: + ( + sea_surface_height, + sea_surface_label, + solo_sea_surface_label, + subsurface_photon_dataset, + ) = rebuild_and_refit_surface_after_refraction( # steps 10 & 11 + binned_dataset_sea_surface, + sea_surface_height, + water_temp_c=refraction_water_temp_c, + wavelength_nm=refraction_wavelength_nm, + horizontal_res=horizontal_res, + vertical_res=vertical_res, + ) + + # Apply minimum depth threshold unconditionally — keeps photons at height + # >= Ignore_Subsurface_Height_Thres regardless of whether GEBCO/ATL24 is on. + # Without this, deep noise photons (down to -70 m) contaminate the Beer's Law fit + # and produce near-zero Kd values, especially at shallow turbid sites. + subsurface_photon_dataset = subsurface_photon_dataset[ + subsurface_photon_dataset["photon_height"] >= Ignore_Subsurface_Height_Thres + ].copy() + + if use_atl24_filter: + atl24_filtered, atl24_match_count = process_seafloor_data_atl24( + subsurface_photon_dataset, + atl24_file_path, + abs(Ignore_Subsurface_Height_Thres), + max_match_distance_deg=atl24_max_match_distance_deg, + ) + if atl24_match_count > 0: + filtered_seafloor_subsurface_photon_dataset = atl24_filtered + elif use_gebco_filter: + filtered_seafloor_subsurface_photon_dataset = process_seafloor_data( + subsurface_photon_dataset, + GEBCO_paths, + abs(Ignore_Subsurface_Height_Thres), + ) + else: + filtered_seafloor_subsurface_photon_dataset = subsurface_photon_dataset + elif use_gebco_filter: + filtered_seafloor_subsurface_photon_dataset = process_seafloor_data( + subsurface_photon_dataset, GEBCO_paths, abs(Ignore_Subsurface_Height_Thres) + ) + else: + filtered_seafloor_subsurface_photon_dataset = subsurface_photon_dataset + + if apply_flattening: + lat_bin_keys = list( + binned_dataset_sea_surface.groupby( + ["lat_bins"], observed=False + ).groups.keys() + ) + filtered_seafloor_subsurface_photon_dataset = apply_sea_surface_flattening( + filtered_seafloor_subsurface_photon_dataset, + sea_surface_height, + lat_bin_keys, + horizontal_res=horizontal_res, + flattening_window_m=flattening_window_m, + ) + # Attach per-bin surface_sigma as a quality flag column + filtered_seafloor_subsurface_photon_dataset = ( + filtered_seafloor_subsurface_photon_dataset.copy() + ) + if sigma_map: + filtered_seafloor_subsurface_photon_dataset["surface_sigma"] = ( + filtered_seafloor_subsurface_photon_dataset["lat_bins"].map(sigma_map) + ) + else: + filtered_seafloor_subsurface_photon_dataset["surface_sigma"] = np.nan + + return ( + sea_surface_height, + sea_surface_label, + filtered_seafloor_subsurface_photon_dataset, + ) + + +# Main processing function to apply the subsurface photon filtering beam-by-beam +def process_subsurface_photon_filtering( + binned_dataset_sea_surface, + GEBCO_paths, + subsurface_thresh, + Ignore_Subsurface_Height_Thres, + use_atl24_filter=False, + atl24_file_path="", + atl24_max_match_distance_deg=0.01, + use_gebco_filter=False, + apply_histogram_quality_filter=False, + histogram_quality_min_ratio=0.05, + histogram_quality_depth_min=6.0, + histogram_quality_depth_max=7.0, + histogram_quality_ref_depth_min=0.0, + histogram_quality_ref_depth_max=1.0, + apply_surface_sigma_filter=False, + surface_sigma_max=0.5, + apply_refraction_correction=False, + refraction_water_temp_c=20.0, + refraction_wavelength_nm=532.0, + apply_post_refraction_refit=False, + apply_flattening=False, + flattening_window_m=500, + horizontal_res=500, + vertical_res=0.25, +): + """ + Beam-by-beam wrapper that calls get_subsurface_photon for every beam and + concatenates the results. + + Flowchart steps 6–14 are all executed inside this function (delegated to + get_subsurface_photon per beam). See get_subsurface_photon for the + step-by-step breakdown. + """ + # Initialize lists to store results for each beam + sea_surface_heights = [] + sea_surface_labels = [] + filtered_beam_datasets = [] + + # Group the binned dataset by 'beam_id' and process each group separately + for beam_id, beam_data in binned_dataset_sea_surface.groupby("beam_id"): + print(f"Processing subsurface filtering for beam: {beam_id}") + + # Apply get_subsurface_photon to the current beam's dataset + ( + sea_surface_height, + sea_surface_label, + filtered_seafloor_subsurface_photon_dataset, + ) = get_subsurface_photon( + beam_data, + GEBCO_paths, + subsurface_thresh, + Ignore_Subsurface_Height_Thres, + use_atl24_filter=use_atl24_filter, + atl24_file_path=atl24_file_path, + atl24_max_match_distance_deg=atl24_max_match_distance_deg, + use_gebco_filter=use_gebco_filter, + apply_histogram_quality_filter=apply_histogram_quality_filter, + histogram_quality_min_ratio=histogram_quality_min_ratio, + histogram_quality_depth_min=histogram_quality_depth_min, + histogram_quality_depth_max=histogram_quality_depth_max, + histogram_quality_ref_depth_min=histogram_quality_ref_depth_min, + histogram_quality_ref_depth_max=histogram_quality_ref_depth_max, + apply_surface_sigma_filter=apply_surface_sigma_filter, + surface_sigma_max=surface_sigma_max, + apply_refraction_correction=apply_refraction_correction, + refraction_water_temp_c=refraction_water_temp_c, + refraction_wavelength_nm=refraction_wavelength_nm, + apply_post_refraction_refit=apply_post_refraction_refit, + apply_flattening=apply_flattening, + flattening_window_m=flattening_window_m, + horizontal_res=horizontal_res, + vertical_res=vertical_res, + ) + + # Append each result to the lists + sea_surface_heights.append(sea_surface_height) + sea_surface_labels.append(sea_surface_label) + filtered_beam_datasets.append(filtered_seafloor_subsurface_photon_dataset) + + # Combine all filtered beam datasets into a single DataFrame + combined_filtered_dataset = pd.concat(filtered_beam_datasets, ignore_index=True) + + return sea_surface_heights, sea_surface_labels, combined_filtered_dataset diff --git a/aok/core/kd_utils/data_processing.py b/aok/core/kd_utils/data_processing.py new file mode 100644 index 0000000..7f74001 --- /dev/null +++ b/aok/core/kd_utils/data_processing.py @@ -0,0 +1,887 @@ +# utils/data_processing.py + + +import io +import logging +import math +import os +import re + +import geopandas as gpd +import h5py + +# from pyproj import Transformer +import numpy as np +import pandas as pd +from pyproj import Proj, Transformer +from scipy.spatial import ConvexHull + +# import fiona +from .interpolation import * + +logger = logging.getLogger(__name__) + + +# this function from icesat2_toolkit +# PURPOSE: read ICESat-2 ATL03 HDF5 data files +def read_granule(FILENAME, ATTRIBUTES=False, **kwargs): + """ + Reads ICESat-2 ATL03 Global Geolocated Photons data files. + + Flowchart step: 1 — Acquire ATL03 data (download or use cloud services). + + Parameters + ---------- + FILENAME: str + full path to ATL03 file + ATTRIBUTES: bool, default False + read file, group and variable attributes + + Returns + ------- + IS2_atl03_mds: dict + ATL03 variables + IS2_atl03_attrs: dict + ATL03 attributes + IS2_atl03_beams: list + valid ICESat-2 beams within ATL03 file + """ + # Open the HDF5 file for reading + if isinstance(FILENAME, io.IOBase): + fileID = h5py.File(FILENAME, "r") + else: + fileID = h5py.File(os.path.expanduser(FILENAME), "r") + + # Output HDF5 file information + logging.info(fileID.filename) + logging.info(list(fileID.keys())) + + # allocate python dictionaries for ICESat-2 ATL03 variables and attributes + IS2_atl03_mds = {} + IS2_atl03_attrs = {} + + # read each input beam within the file + IS2_atl03_beams = [] + for gtx in [k for k in fileID.keys() if bool(re.match(r"gt\d[lr]", k))]: + # check if subsetted beam contains data + # check in both the geolocation and heights groups + try: + fileID[gtx]["geolocation"]["segment_id"] + fileID[gtx]["heights"]["delta_time"] + except KeyError: + pass + else: + IS2_atl03_beams.append(gtx) + + # for each included beam + for gtx in IS2_atl03_beams: + # ------------------------------------------- + # 1. make sure the beam-level dict exists + IS2_atl03_attrs.setdefault(gtx, {}) + # 2. always save the two “must-have” attributes + for key in ("atlas_beam_type", "atlas_spot_number"): + IS2_atl03_attrs[gtx][key] = fileID[gtx].attrs[key] + + # get each HDF5 variable + IS2_atl03_mds[gtx] = {} + IS2_atl03_mds[gtx]["heights"] = {} + IS2_atl03_mds[gtx]["geolocation"] = {} + IS2_atl03_mds[gtx]["bckgrd_atlas"] = {} + IS2_atl03_mds[gtx]["geophys_corr"] = {} + # ICESat-2 Measurement Group + for key, val in fileID[gtx]["heights"].items(): + IS2_atl03_mds[gtx]["heights"][key] = val[:] + # ICESat-2 Geolocation Group + for key, val in fileID[gtx]["geolocation"].items(): + IS2_atl03_mds[gtx]["geolocation"][key] = val[:] + # ICESat-2 Background Photon Rate Group + for key, val in fileID[gtx]["bckgrd_atlas"].items(): + IS2_atl03_mds[gtx]["bckgrd_atlas"][key] = val[:] + # ICESat-2 Geophysical Corrections Group: Values for tides (ocean, + # solid earth, pole, load, and equilibrium), inverted barometer (IB) + # effects, and range corrections for tropospheric delays + for key, val in fileID[gtx]["geophys_corr"].items(): + IS2_atl03_mds[gtx]["geophys_corr"][key] = val[:] + + # Getting attributes of included variables + if ATTRIBUTES: + # Getting attributes of IS2_atl03_mds beam variables + IS2_atl03_attrs[gtx] = {} + IS2_atl03_attrs[gtx]["heights"] = {} + IS2_atl03_attrs[gtx]["geolocation"] = {} + IS2_atl03_attrs[gtx]["bckgrd_atlas"] = {} + IS2_atl03_attrs[gtx]["geophys_corr"] = {} + + # Global Group Attributes + for att_name, att_val in fileID[gtx].attrs.items(): + IS2_atl03_attrs[gtx][att_name] = att_val + # ICESat-2 Measurement Group + for key, val in fileID[gtx]["heights"].items(): + IS2_atl03_attrs[gtx]["heights"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[gtx]["heights"][key][att_name] = att_val + # ICESat-2 Geolocation Group + for key, val in fileID[gtx]["geolocation"].items(): + IS2_atl03_attrs[gtx]["geolocation"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[gtx]["geolocation"][key][att_name] = att_val + # ICESat-2 Background Photon Rate Group + for key, val in fileID[gtx]["bckgrd_atlas"].items(): + IS2_atl03_attrs[gtx]["bckgrd_atlas"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[gtx]["bckgrd_atlas"][key][att_name] = att_val + # ICESat-2 Geophysical Corrections Group + for key, val in fileID[gtx]["geophys_corr"].items(): + IS2_atl03_attrs[gtx]["geophys_corr"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[gtx]["geophys_corr"][key][att_name] = att_val + + # ICESat-2 spacecraft orientation at time + IS2_atl03_mds["orbit_info"] = {} + IS2_atl03_attrs["orbit_info"] = {} + for key, val in fileID["orbit_info"].items(): + IS2_atl03_mds["orbit_info"][key] = val[:] + # Getting attributes of group and included variables + if ATTRIBUTES: + # Global Group Attributes + for att_name, att_val in fileID["orbit_info"].attrs.items(): + IS2_atl03_attrs["orbit_info"][att_name] = att_val + # Variable Attributes + IS2_atl03_attrs["orbit_info"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs["orbit_info"][key][att_name] = att_val + + # information ancillary to the data product + # number of GPS seconds between the GPS epoch (1980-01-06T00:00:00Z UTC) + # and ATLAS Standard Data Product (SDP) epoch (2018-01-01T00:00:00Z UTC) + # Add this value to delta time parameters to compute full gps_seconds + # could alternatively use the Julian day of the ATLAS SDP epoch: 2458119.5 + # and add leap seconds since 2018-01-01T00:00:00Z UTC (ATLAS SDP epoch) + IS2_atl03_mds["ancillary_data"] = {} + IS2_atl03_attrs["ancillary_data"] = {} + ancillary_keys = [ + "atlas_sdp_gps_epoch", + "data_end_utc", + "data_start_utc", + "end_cycle", + "end_geoseg", + "end_gpssow", + "end_gpsweek", + "end_orbit", + "end_region", + "end_rgt", + "granule_end_utc", + "granule_start_utc", + "release", + "start_cycle", + "start_geoseg", + "start_gpssow", + "start_gpsweek", + "start_orbit", + "start_region", + "start_rgt", + "version", + ] + for key in ancillary_keys: + # get each HDF5 variable + IS2_atl03_mds["ancillary_data"][key] = fileID["ancillary_data"][key][:] + # Getting attributes of group and included variables + if ATTRIBUTES: + # Variable Attributes + IS2_atl03_attrs["ancillary_data"][key] = {} + for att_name, att_val in fileID["ancillary_data"][key].attrs.items(): + IS2_atl03_attrs["ancillary_data"][key][att_name] = att_val + + # transmit-echo-path (tep) parameters + IS2_atl03_mds["ancillary_data"]["tep"] = {} + IS2_atl03_attrs["ancillary_data"]["tep"] = {} + for key, val in fileID["ancillary_data"]["tep"].items(): + # get each HDF5 variable + IS2_atl03_mds["ancillary_data"]["tep"][key] = val[:] + # Getting attributes of group and included variables + if ATTRIBUTES: + # Variable Attributes + IS2_atl03_attrs["ancillary_data"]["tep"][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs["ancillary_data"]["tep"][key][att_name] = att_val + + # channel dead time and first photon bias derived from ATLAS calibration + cal1, cal2 = ("ancillary_data", "calibrations") + for var in ["dead_time", "first_photon_bias"]: + IS2_atl03_mds[cal1][var] = {} + IS2_atl03_attrs[cal1][var] = {} + for key, val in fileID[cal1][cal2][var].items(): + # get each HDF5 variable + if isinstance(val, h5py.Dataset): + IS2_atl03_mds[cal1][var][key] = val[:] + elif isinstance(val, h5py.Group): + IS2_atl03_mds[cal1][var][key] = {} + for k, v in val.items(): + IS2_atl03_mds[cal1][var][key][k] = v[:] + # Getting attributes of group and included variables + if ATTRIBUTES: + # Variable Attributes + IS2_atl03_attrs[cal1][var][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[cal1][var][key][att_name] = att_val + if isinstance(val, h5py.Group): + for k, v in val.items(): + IS2_atl03_attrs[cal1][var][key][k] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[cal1][var][key][k][att_name] = att_val + + # get ATLAS impulse response variables for the transmitter echo path (TEP) + tep1, tep2 = ("atlas_impulse_response", "tep_histogram") + IS2_atl03_mds[tep1] = {} + IS2_atl03_attrs[tep1] = {} + for pce in ["pce1_spot1", "pce2_spot3"]: + IS2_atl03_mds[tep1][pce] = {tep2: {}} + IS2_atl03_attrs[tep1][pce] = {tep2: {}} + # for each TEP variable + for key, val in fileID[tep1][pce][tep2].items(): + IS2_atl03_mds[tep1][pce][tep2][key] = val[:] + # Getting attributes of included variables + if ATTRIBUTES: + # Global Group Attributes + for att_name, att_val in fileID[tep1][pce][tep2].attrs.items(): + IS2_atl03_attrs[tep1][pce][tep2][att_name] = att_val + # Variable Attributes + IS2_atl03_attrs[tep1][pce][tep2][key] = {} + for att_name, att_val in val.attrs.items(): + IS2_atl03_attrs[tep1][pce][tep2][key][att_name] = att_val + + # Global File Attributes + if ATTRIBUTES: + for att_name, att_val in fileID.attrs.items(): + IS2_atl03_attrs[att_name] = att_val + + # Closing the HDF5 file + fileID.close() + # Return the datasets and variables + return (IS2_atl03_mds, IS2_atl03_attrs, IS2_atl03_beams) + + +# convert_wgs_to_utm function, see https://stackoverflow.com/a/40140326/4556479 +def convert_wgs_to_utm(lon: float, lat: float): + """Based on lat and lng, return best utm epsg-code""" + utm_band = str((math.floor((lon + 180) / 6) % 60) + 1) + if len(utm_band) == 1: + utm_band = "0" + utm_band + if lat >= 0: + epsg_code = "epsg:326" + utm_band + return epsg_code + epsg_code = "epsg:327" + utm_band + return epsg_code + + +def orthometric_correction(lat, lon, Z, epsg): + # Define the Proj string + # To transform from WGS84 ellipsoidal height + # to EGM2008 orthometric height using PyProj + # proj_string = '+proj=latlong +ellps=WGS84 +datum=WGS84 +vunits=m +no_defs +geoidgrids=egm2008-1.gtx' + # # Define the Proj string for WGS84 ellipsoidal height + # wgs84_proj_string = '+proj=latlong +ellps=WGS84 +datum=WGS84 +no_defs' + + # # Define the Proj string for EGM2008 orthometric height: egm08_25,egm2008-1 + # egm2008_proj_string = \ + # '+proj=latlong +ellps=WGS84 +datum=WGS84 +no_defs ' \ + # '+geoidgrids=C:/Workstation/ICESat2_HLS/Code/Geoids/egm08_25.gtx' + + # transform ellipsoid (WGS84) height to orthometric height + # transformer = Transformer.from_crs(wgs84_proj_string, egm2008_proj_string, always_xy=True) + transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True) + X_egm08, Y_egm08, Z_egm08 = transformer.transform(lon, lat, Z) + + # transform WGS84 proj to local UTM + myProj = Proj(epsg) + X_utm, Y_utm = myProj(lon, lat) + + return Y_utm, X_utm, Z_egm08 + + +def load_data(file_path, read_attributes=False): + """ + Wrapper around read_granule that lets you decide + whether to pull the full attribute tree. + + Flowchart step: 1 — Acquire ATL03 data (download or use cloud services). + + Parameters + ---------- + file_path : str + Path to the ATL03 HDF5 granule. + read_attributes : bool, optional + If True, read_granule returns the full IS2_atl03_attrs tree. + Defaults to False. + """ + + return read_granule(file_path, ATTRIBUTES=read_attributes) + + +# requires that the input gdf has ranged index values i +# will need to change if index is changed to time or something +# this currently checks point in polygon for EVERY point +# would be significantly sped up if evaluated at 10m or something similar +# maybe later, fine for now +def isolate_sea_land_photons(shoreline_data_path, ICESat2_GDF): + # try loading the shoreline data + try: + ICESat2_GDF.insert(0, "lat", ICESat2_GDF.geometry.y, False) + ICESat2_GDF.insert(0, "lon", ICESat2_GDF.geometry.x, False) + + # allocation of to be used arrays + zero_int_array = np.int64(np.zeros_like(ICESat2_GDF.geometry.x)) + + # Land flag initialized as -1 + # If shorelines downloaded already, will be set to 0 or 1 + ICESat2_GDF.insert(0, "is_land", zero_int_array - 1, False) + + # set the projection + ICESat2_GDF.set_crs("EPSG:4326", inplace=True) + + # load shoreline dataset to include only the features that intersect the bounding box + # bbox can be GeoDataFrame or GeoSeries | shapely Geometry, default None + # Filter features by given bounding box, GeoSeries, GeoDataFrame or a shapely geometry. + # engine str, 'fiona' or 'pyogrio' + # sometime it gives error if using fiona + # land_polygon_gdf = gpd.read_file(shoreline_data_path, bbox=ICESat2_GDF, engine='fiona') + land_polygon_gdf = gpd.read_file( + shoreline_data_path, bbox=ICESat2_GDF, engine="pyogrio" + ) + + # continue with getting a new array of 0-or-1 labels for each photon + land_point_labels = np.zeros_like(ICESat2_GDF.is_land.values) + + # update labels for points in the land polygons + pts_in_land = gpd.sjoin(ICESat2_GDF, land_polygon_gdf, predicate="within") + + # get land or not bool value + land_loc = ICESat2_GDF.index.isin(pts_in_land.index) + + # assigned them to new numpy array + land_point_labels[land_loc] = 1 + land_point_labels[~land_loc] = 0 + + return land_point_labels + + except Exception as e: + print(e) + + print("Error loading shoreline data, returning -1s for is_land flag") + + # if the shoreline data is not available + # return the original label array + + return -np.ones_like(ICESat2_GDF.is_land.values) + + +def create_photon_dataframe( + lat_ph, + lon_ph, + ref_elev, + ref_azimuth, + geoid, + h_ph, + quality_ph, + is_land_label_interp1d, + signal_conf_photon, + x_atc, + relative_AT_dist, + solar_elevation=None, + background_rate=None, +): + # Apply geoid correction to the photon heights to convert them from ellipsoidal to orthometric heights + h_ph_geoid_cor = h_ph[:] - geoid[:] + + # Determine the EPSG code for the UTM zone based on the first photon's longitude and latitude + epsg_code = convert_wgs_to_utm(lon_ph[0], lat_ph[0]) + + # Perform orthometric correction to obtain UTM coordinates and corrected heights + lat_utm, lon_utm, h_ph_cor = orthometric_correction(lat_ph, lon_ph, h_ph, epsg_code) + + # Put the data into the dataframe + sea_photon_dataset = pd.DataFrame( + { + "latitude": lat_ph, + "longitude": lon_ph, + "lat": lat_utm, + "lon": lon_utm, + "photon_height": h_ph_geoid_cor, + "quality_ph": quality_ph, + "is_land_label": is_land_label_interp1d, + "photon_conf": signal_conf_photon, + "ref_elevation": ref_elev, + "ref_azimuth": ref_azimuth, + "relative_AT_dist": relative_AT_dist, + }, + columns=[ + "latitude", + "longitude", + "lat", + "lon", + "photon_height", + "quality_ph", + "is_land_label", + "photon_conf", + "ref_elevation", + "ref_azimuth", + "relative_AT_dist", + ], + ) + + if solar_elevation is not None: + sea_photon_dataset["solar_elevation"] = solar_elevation + if background_rate is not None: + sea_photon_dataset["background_rate"] = background_rate + + return sea_photon_dataset + + +def interpolate_by_time(source_time, source_values, target_time): + """Linearly interpolate source_values from source_time to target_time.""" + source_time = np.asarray(source_time) + source_values = np.asarray(source_values) + target_time = np.asarray(target_time) + + valid_mask = np.isfinite(source_time) & np.isfinite(source_values) + if valid_mask.sum() < 2: + return np.full_like(target_time, np.nan, dtype=float) + + sorted_idx = np.argsort(source_time[valid_mask]) + sorted_time = source_time[valid_mask][sorted_idx] + sorted_values = source_values[valid_mask][sorted_idx] + + unique_time, unique_idx = np.unique(sorted_time, return_index=True) + unique_values = sorted_values[unique_idx] + if unique_time.size < 2: + return np.full_like(target_time, unique_values[0], dtype=float) + + return np.interp( + target_time, + unique_time, + unique_values, + left=unique_values[0], + right=unique_values[-1], + ) + + +def apply_optional_solar_background_filter( + sea_photon_dataset, + enabled=False, + day_threshold=6.0, + median_window_deg=10.0, + noise_multiplier=1.5, + min_signal_conf=2, +): + """ + Optionally filter photons under strong daytime solar background conditions. + + Uses a degree-based rolling median of background_rate over solar_elevation + to compute the expected background at each elevation. Photons whose actual + background_rate exceeds expected * noise_multiplier are flagged noisy. + Noisy photons are removed unless their signal confidence >= min_signal_conf. + + Flowchart step: 5 — Compute solar background from atmospheric signal (using + selected x and z bin length scales) and subtract from subsurface photon + histogram data. + """ + if not enabled: + return sea_photon_dataset + + required_cols = {"solar_elevation", "background_rate", "photon_conf"} + if not required_cols.issubset(set(sea_photon_dataset.columns)): + logger.warning( + "Solar background filter skipped because required columns are missing." + ) + return sea_photon_dataset + + filtered_dataset = sea_photon_dataset.copy() + daytime_mask = filtered_dataset["solar_elevation"] >= day_threshold + + # Select valid daytime photons for rolling median computation + daytime = filtered_dataset.loc[daytime_mask].copy() + valid = np.isfinite(daytime["background_rate"]) & np.isfinite( + daytime["solar_elevation"] + ) + daytime = daytime.loc[valid] + + if len(daytime) < 10: + logger.info( + "Solar background filter enabled, but fewer than 10 valid daytime " + "photons were found. Skipping filter." + ) + return filtered_dataset + + # Clamp solar elevations to physical range [DAY_THRESHOLD, 90] to avoid + # HDF5 fill values (e.g. 3.4e+38) blowing up the bin array. + daytime = daytime.copy() + daytime["solar_elevation"] = daytime["solar_elevation"].clip( + lower=day_threshold, upper=90.0 + ) + + # Vectorized degree-based rolling median: + # Bin elevations into fine steps, compute one median per bin using + # np.searchsorted on the sorted array, then map back to photons. + daytime = daytime.sort_values("solar_elevation") + elevations = daytime["solar_elevation"].values + bg_rates = daytime["background_rate"].values + half_window = median_window_deg / 2.0 + + bin_step = 0.1 # degrees — fine enough for smooth curve + bin_centers = np.arange(elevations[0], elevations[-1] + bin_step, bin_step) + bin_medians = np.empty(len(bin_centers)) + for i in range(len(bin_centers)): + lo = np.searchsorted(elevations, bin_centers[i] - half_window, side="left") + hi = np.searchsorted(elevations, bin_centers[i] + half_window, side="right") + if lo < hi: + bin_medians[i] = np.median(bg_rates[lo:hi]) + else: + bin_medians[i] = np.nan + + # Map bin medians back to each photon via nearest-bin lookup (vectorized) + bin_indices = np.searchsorted(bin_centers, elevations, side="right") - 1 + bin_indices = np.clip(bin_indices, 0, len(bin_centers) - 1) + expected_bg = bin_medians[bin_indices] + + # Map expected background onto the full DataFrame. Non-daytime and + # non-finite daytime rows stay NaN, so they are never flagged noisy. + filtered_dataset["expected_bg"] = np.nan + filtered_dataset.loc[daytime.index, "expected_bg"] = expected_bg + + # Flag noisy daytime photons + noisy_daytime_mask = daytime_mask & ( + filtered_dataset["background_rate"] + >= filtered_dataset["expected_bg"] * noise_multiplier + ) + keep_mask = (~noisy_daytime_mask) | ( + filtered_dataset["photon_conf"] >= min_signal_conf + ) + + # Clean up temporary column + filtered_dataset.drop(columns=["expected_bg"], inplace=True, errors="ignore") + + before_count = len(filtered_dataset) + filtered_dataset = filtered_dataset.loc[keep_mask].copy() + after_count = len(filtered_dataset) + logger.info( + "Solar background filter (multiplier=%.1f) removed %s photons (from %s to %s).", + noise_multiplier, + before_count - after_count, + before_count, + after_count, + ) + return filtered_dataset + + +def apply_optional_ir_ap_filter( + sea_photon_dataset, enabled=False, quality_max=0, min_signal_conf=0 +): + """ + Optionally remove photons flagged as afterpulse using the quality_ph bitmask. + + Flowchart step: 2 — Remove photons flagged as afterpulse or impulse response. + + ATL03 v007 quality_ph bit encoding: + bit 0 (value 1) = possible afterpulse <- targeted by this filter + bit 1 (value 2) = possible impulse response effect + bit 2 (value 4) = possible TEP + + Only photons with bit 0 set (afterpulse) are removed. Photons flagged + for impulse response or TEP only are kept, avoiding over-filtering on + turbid high-background sites where many photons carry non-zero flags. + """ + if not enabled: + return sea_photon_dataset + + filtered_dataset = sea_photon_dataset.copy() + keep_mask = np.ones(len(filtered_dataset), dtype=bool) + + if "quality_ph" in filtered_dataset.columns: + quality_values = ( + pd.to_numeric(filtered_dataset["quality_ph"], errors="coerce") + .fillna(0) + .astype(int) + ) + afterpulse_mask = (quality_values & 1) == 1 # bit 0 set = afterpulse + keep_mask &= ~afterpulse_mask + logger.info( + "IR/AP filter: %d photons flagged as afterpulse (quality_ph bit 0) will be removed.", + int(afterpulse_mask.sum()), + ) + else: + logger.warning("IR/AP filter enabled, but quality_ph is missing.") + + if "photon_conf" in filtered_dataset.columns: + conf_values = pd.to_numeric(filtered_dataset["photon_conf"], errors="coerce") + keep_mask &= conf_values >= min_signal_conf + else: + logger.warning("IR/AP filter enabled, but photon_conf is missing.") + + before_count = len(filtered_dataset) + filtered_dataset = filtered_dataset.loc[keep_mask].copy() + after_count = len(filtered_dataset) + logger.info( + "IR/AP proxy filter removed %s photons (from %s to %s).", + before_count - after_count, + before_count, + after_count, + ) + return filtered_dataset + + +def Extract_sea_photons(IS2_atl03_mds, target_strong_beams, shoreline_data_path): + Segment_ID = {} + Segment_Index_begin = {} + Segment_PE_count = {} + Equator_Segment_Distance = {} + Segment_Length = {} + Segment_Is_Land = {} + Segment_Lon = {} + Segment_Lat = {} + Segment_Elev = {} + Segment_Time = {} + Segment_ref_elev = {} + Segment_ref_azimuth = {} + background_rate = {} + background_counts = {} + + # Initialize a list to store data for each beam + beam_datasets = [] + + # Loop over each strong beam in target_strong_beams + for gtx in target_strong_beams: + print("Processing strong beam ID:", gtx) + + # Access the data for the current beam + IS2_val = IS2_atl03_mds[gtx] + + # Initialize dictionaries to store segment data for the beam + Segment_ID[gtx] = IS2_val["geolocation"]["segment_id"] + n_seg = len(Segment_ID[gtx]) + (n_pe,) = IS2_val["heights"]["delta_time"].shape + Segment_Index_begin[gtx] = IS2_val["geolocation"]["ph_index_beg"] - 1 + Segment_PE_count[gtx] = IS2_val["geolocation"]["segment_ph_cnt"] + Equator_Segment_Distance[gtx] = IS2_val["geolocation"]["segment_dist_x"] + Segment_Length[gtx] = IS2_val["geolocation"]["segment_length"] + delta_time = IS2_val["geolocation"]["delta_time"] + segment_lat = IS2_val["geolocation"]["reference_photon_lat"][:].copy() + segment_lon = IS2_val["geolocation"]["reference_photon_lon"][:].copy() + ref_elev = IS2_val["geolocation"]["ref_elev"][:].copy() + ref_azimuth = IS2_val["geolocation"]["ref_azimuth"][:].copy() + geoid = IS2_val["geophys_corr"]["geoid"][:].copy() + h_ph = IS2_val["heights"]["h_ph"][:].copy() + photon_delta_time = IS2_val["heights"]["delta_time"][:].copy() + lat_ph = IS2_val["heights"]["lat_ph"][:].copy() + lon_ph = IS2_val["heights"]["lon_ph"][:].copy() + signal_conf_photon = IS2_val["heights"]["signal_conf_ph"][..., 0].copy() + x_atc = IS2_val["heights"]["dist_ph_along"][:].copy() + y_atc = IS2_val["heights"]["dist_ph_across"][:].copy() + quality_ph = IS2_val["heights"]["quality_ph"] + + # Optional variables for solar background sensitivity testing + photon_solar_elevation = np.full_like(photon_delta_time, np.nan, dtype=float) + photon_background_rate = np.full_like(photon_delta_time, np.nan, dtype=float) + if "solar_elevation" in IS2_val["geolocation"]: + segment_solar_elevation = IS2_val["geolocation"]["solar_elevation"][ + : + ].copy() + photon_solar_elevation = interpolate_by_time( + delta_time, segment_solar_elevation, photon_delta_time + ) + if ( + "bckgrd_atlas" in IS2_val + and "bckgrd_rate" in IS2_val["bckgrd_atlas"] + and "delta_time" in IS2_val["bckgrd_atlas"] + ): + bckgrd_rate = IS2_val["bckgrd_atlas"]["bckgrd_rate"][:].copy() + bckgrd_time = IS2_val["bckgrd_atlas"]["delta_time"][:].copy() + photon_background_rate = interpolate_by_time( + bckgrd_time, bckgrd_rate, photon_delta_time + ) + + # Adjust x_atc based on segment distances + for seg_index in range(n_seg): + idx = Segment_Index_begin[gtx][seg_index] + cnt = Segment_PE_count[gtx][seg_index] + x_atc[idx : idx + cnt] += Equator_Segment_Distance[gtx][seg_index] + + # Calculate relative distances + relative_AT_dist = (x_atc - x_atc[0]) / 1000 + relative_seg_dist = ( + Equator_Segment_Distance[gtx] - Equator_Segment_Distance[gtx][0] + ) / 1000 + + # Create a GeoDataFrame to hold segment data for shoreline check + Segment_Is_Land["geometry"] = gpd.points_from_xy(segment_lon, segment_lat) + ICESat2_GDF = gpd.GeoDataFrame(Segment_Is_Land, crs="EPSG:4326") + + # Determine if it is land by the land/sea mask + Segment_Is_Land_Labels = isolate_sea_land_photons( + shoreline_data_path, ICESat2_GDF + ) + ICESat2_GDF.loc[:, "is_land"] = Segment_Is_Land_Labels + + # Apply interpolations for required data + is_land_label_interp1d = apply_interpolation( + interpolate_labels(segment_lat, Segment_Is_Land_Labels), lat_ph + ) + ph_ref_elev = apply_interpolation( + interpolate_labels(segment_lat, ref_elev), lat_ph + ) + ph_ref_azimuth = apply_interpolation( + interpolate_labels(segment_lat, ref_azimuth), lat_ph + ) + ph_geoid = apply_interpolation(interpolate_labels(segment_lat, geoid), lat_ph) + + # Create photon DataFrame for the current beam + sea_photon_dataset = create_photon_dataframe( + lat_ph=lat_ph, + lon_ph=lon_ph, + ref_elev=ph_ref_elev, + ref_azimuth=ph_ref_azimuth, + geoid=ph_geoid, + h_ph=h_ph, + quality_ph=quality_ph, + is_land_label_interp1d=is_land_label_interp1d, + signal_conf_photon=signal_conf_photon, + x_atc=x_atc, # Note: x_atc is not used in the function. Remove if unnecessary. + relative_AT_dist=relative_AT_dist, + solar_elevation=photon_solar_elevation, + background_rate=photon_background_rate, + ) + + # Filter out land photons + sea_photon_dataset = sea_photon_dataset[ + sea_photon_dataset["is_land_label"] != 1 + ] + + # Add a new column to indicate the beam ID + sea_photon_dataset["beam_id"] = gtx + + # Append the processed dataset for the current beam to the list + beam_datasets.append(sea_photon_dataset) + + # Concatenate all beam data into a single DataFrame + all_beams_dataset = pd.concat(beam_datasets, ignore_index=True) + + return all_beams_dataset + + +def filter_photon_dataset_by_hull_area(photon_dataset, hull_area_threshold=3000): + """ + Filters photon dataset based on ConvexHull area threshold and returns the filtered dataset, + convex hull areas, and convex hull points. + """ + lat_bins_grouped = photon_dataset.groupby("lat_bins", observed=False) + filtered_dataset = photon_dataset.copy() + + convex_hulls = {} + convex_hull_areas = {} + + for lat_bin, bin_data in lat_bins_grouped: + if len(bin_data) >= 3: + points = bin_data[["lat", "photon_height"]].to_numpy() + hull = ConvexHull(points) + area = hull.volume + convex_hull_areas[lat_bin] = area + + # Store the ConvexHull points if area meets the threshold + if area >= hull_area_threshold: + convex_hulls[lat_bin] = points[hull.vertices] + else: + # Remove bins with hull area below the threshold + filtered_dataset = filtered_dataset[ + filtered_dataset["lat_bins"] != lat_bin + ] + + return filtered_dataset, convex_hull_areas, convex_hulls + + +########################## +##Discard Functions Below +########################## + + +# Extracts key information from filenames, +# which could include processed status, product ID, timestamps, identifiers, or metadata. +def extract_file_params(file_path): + """ + Example: input "processed_ATL06_20231115120000_20231115_001_12_data.h5" + Output: ('processed_', 'ATL06', '2023', '11', '15', '12', '00', '00', + '2023', '11', '15', '001', '12', '_data') + """ + # Defines a regex pattern to match strings in the file path + rx = re.compile( + r"(processed_)?(ATL\d{2})_(\d{4})(\d{2})(\d{2})(\d{2})" + r"(\d{2})(\d{2})_(\d{4})(\d{2})(\d{2})_(\d{3})_(\d{2})(.*?).h5$" + ) + # Searches for all matches of the regex pattern in the given + params = rx.findall(file_path).pop() + return params + + +def create_mask(binned_data, sea_surface_height, seafloor_height, threshold_height): + """ + create a mask for filtering sea surface, sea floor, and threshold_height + """ + mask = ( + (binned_data["height"] <= sea_surface_height) + & (binned_data["height"] >= seafloor_height) + & (binned_data["height"] >= threshold_height) + ) + return mask + + +def generate_polygons_from_binned_data( + lat, height, mask, lat_interval, height_interval +): + """ + horizontal_vertical_bin_dataset + Generate polygons for each bin after masking within a rectangle formed by height and latitude intervals + """ + # Create latitude bins based on the specified interval + lat_bins = np.arange(lat.min(), lat.max() + lat_interval, lat_interval) + + # Create height bins within the range [-12, 2] based on the specified interval + height_bins = np.arange(-12, 2 + height_interval, height_interval) + + # Identify which bin each masked latitude and height value belongs to + lat_bin_indices = np.digitize(lat[mask], lat_bins) + + # Determine which height bin each masked point belongs to + height_bin_indices = np.digitize(height[mask], height_bins) + + # Combine latitude and height bin indices for each point + combined_bins = list(zip(lat_bin_indices, height_bin_indices, strict=False)) + + # Find unique bin combinations and count the number of points in each bin + unique_bins, counts = np.unique(combined_bins, axis=0, return_counts=True) + + # Filter bins to include only those within valid index ranges + valid_bins = [ + (bin[0], bin[1]) + for bin in unique_bins + if 1 <= bin[0] < len(lat_bins) and 1 <= bin[1] < len(height_bins) + ] + + polygons = [] + for bin_lat, bin_height in valid_bins: + if ( + bin_lat > 0 + and bin_lat < len(lat_bins) + and bin_height > 0 + and bin_height < len(height_bins) + ): + bin_points = np.array( + [ + (lat[mask][i], height[mask][i]) + for i in range(sum(mask)) + if lat_bin_indices[i] == bin_lat + and height_bin_indices[i] == bin_height + ] + ) + if len(bin_points) > 2: + hull = ConvexHull(bin_points) + polygons.append(bin_points[hull.vertices]) + + return lat_bins, height_bins, valid_bins, polygons diff --git a/aok/core/kd_utils/interpolation.py b/aok/core/kd_utils/interpolation.py new file mode 100644 index 0000000..a205f2d --- /dev/null +++ b/aok/core/kd_utils/interpolation.py @@ -0,0 +1,25 @@ +# utils/interpolation.py + +import scipy.interpolate + + +def interpolate_labels(segment_lat, labels): + model = scipy.interpolate.interp1d(segment_lat, labels, fill_value="extrapolate") + return model + + +def apply_interpolation(model, lat_ph): + return model(lat_ph) + + +def geoid_correction(lat_ph, segment_lat, geoid): + model = interpolate_labels(segment_lat, geoid) + return apply_interpolation(model, lat_ph) + + +def refraction_correction(lat_ph, segment_lat, ref_elev, ref_azimuth): + elev_model = interpolate_labels(segment_lat, ref_elev) + azimuth_model = interpolate_labels(segment_lat, ref_azimuth) + return apply_interpolation(elev_model, lat_ph), apply_interpolation( + azimuth_model, lat_ph + ) diff --git a/aok/core/kd_utils/sea_photons_analysis.py b/aok/core/kd_utils/sea_photons_analysis.py new file mode 100644 index 0000000..f986142 --- /dev/null +++ b/aok/core/kd_utils/sea_photons_analysis.py @@ -0,0 +1,570 @@ +from datetime import datetime + +import netCDF4 +import numpy as np +import pandas as pd +from scipy.signal import find_peaks +from scipy.stats import norm + + +# Function to apply binning beam-by-beam by calling the function of horizontal_vertical_bin_dataset +def process_sea_photon_binning(sea_photon_dataset, horizontal_res, vertical_res): + """ + Apply horizontal and vertical binning to each beam. + + Flowchart steps: + 3 — Choose along-track (x) and vertical (z) bin sizes + (e.g., 500 m in x, 0.25 m in z). + 4 — Build vertical histograms of photon counts using selected x, z bin sizes. + """ + # Initialize list to store results from each beam + binned_beam_datasets = [] + + # Group the dataset by 'beam_id' and process each group separately + for beam_id, beam_data in sea_photon_dataset.groupby("beam_id"): + print(f"Processing binning for beam: {beam_id}") + + # Apply binning to the current beam dataset + binned_beam_data = horizontal_vertical_bin_dataset( + beam_data, horizontal_res, vertical_res + ) + + # Append the binned data for the current beam to the list + binned_beam_datasets.append(binned_beam_data) + + # Combine all binned beam datasets into a single DataFrame + binned_dataset_sea_surface = pd.concat(binned_beam_datasets, ignore_index=True) + + return binned_dataset_sea_surface + + +# Bin data along vertical and horizontal scales +def horizontal_vertical_bin_dataset(dataset, lat_res, vertical_res): + """ + Bin data along vertical and horizontal scales for later segmentation. + + Flowchart steps: + 3 — Choose along-track (x) and vertical (z) bin sizes + (e.g., 500 m in x, 0.25 m in z). + 4 — Build vertical histograms of photon counts using selected x, z bin sizes. + 10 — Re-build histograms based on corrected depths (called again after + refraction correction when apply_post_refraction_refit is enabled). + """ + + # Filter values within the range (-50, 10), because photons elevation outside this range will be real noise + valid_range = (-70, 5) + valid_mask = (dataset["photon_height"] > valid_range[0]) & ( + dataset["photon_height"] < valid_range[1] + ) + + # Apply the valid_mask to filter unwanted values + # and create a copy to avoid SettingWithCopyWarning + filtered_dataset = dataset[valid_mask].copy() + + # Calculate the number of height bins + height_range = abs( + filtered_dataset["photon_height"].max() + - filtered_dataset["photon_height"].min() + ) + height_bin_number = max( + 1, round(height_range / vertical_res) + ) # Ensure at least one bin + + # Calculate the number of latitude bins + lat_range = abs(filtered_dataset["lat"].max() - filtered_dataset["lat"].min()) + lat_bin_number = max(1, round(lat_range / lat_res)) # Ensure at least one bin + + # Create bins for latitude + lat_bins = pd.cut( + filtered_dataset["lat"], bins=lat_bin_number, labels=np.arange(lat_bin_number) + ) + + # Create bins for height + height_bins = pd.cut( + filtered_dataset["photon_height"], + bins=height_bin_number, + labels=np.round( + np.linspace( + filtered_dataset["photon_height"].min(), + filtered_dataset["photon_height"].max(), + num=height_bin_number, + ), + decimals=1, + ), + ) + + # Add bins to dataframe using .loc to avoid SettingWithCopyWarning + filtered_dataset.loc[:, "lat_bins"] = lat_bins + filtered_dataset.loc[:, "height_bins"] = height_bins + filtered_dataset = filtered_dataset.reset_index(drop=True) + + return filtered_dataset + + +def get_sea_surface_height_static(binned_data, threshold): + """ + Calculate sea surface height and filter subsurface photons using static threshold. + + Flowchart step: 7 — Fit Gaussian curve to identify surface elevation; compute + standard deviation of Gaussian peak. + + Parameters: + binned_data (pandas.DataFrame): Binned photon data with lat, height_bins, photon_height + + Returns: + final_sea_surface_height (list): Detected sea surface heights + sea_surface_height_abnormal_label (array): Labels for abnormal heights + sea_surface_dominated_label (array): Labels for surface-dominated bins + PhotonDFBelowSurface (pandas.DataFrame): Photons below detected surface + """ + + # set flag for the df save + firstTimeIndex = True + + # Create sea height list + sea_surface_height = [] + mean_lat_bins_seq = [] + sea_surface_subsurface_photons_ratio = [] + + # Group dataset along horizental (latitude bins) + grouped_data = binned_data.groupby(["lat_bins"], group_keys=True, observed=False) + data_groups = dict(list(grouped_data)) + + # Loop through groups to detect sea surface + for k, v in data_groups.items(): + # based on lat_utm + lat_bin_average = v["lat"].mean() + + # Create new dataframe based on occurrence of photons per height bin + new_df = pd.DataFrame(v.groupby(["height_bins"], observed=False).count()) + + # Check if new_df is not empty before finding the bin with the highest photon count + if not new_df.empty: + # Find the vertical bin with the highest photon count + largest_h_bin = new_df["lat"].argmax() + + # Select the index of the bin with the highest count + largest_h_index = new_df.index[largest_h_bin] + + # Calculate the median value of all photon height values within this bin + photons_sea_surface = v.loc[ + v["height_bins"] == largest_h_index, "photon_height" + ] + lat_bin_sea_median = photons_sea_surface.median() + + # Append to sea height list + sea_surface_height.append(lat_bin_sea_median) + mean_lat_bins_seq.append(lat_bin_average) + del new_df + + # Get all photons below sea surface + # to determine segment type of each subsurface water column + # Use calculated sea height to determine photons at 0.5m below peak + photons_sea_surface_up = v.loc[ + (v["photon_height"] > (lat_bin_sea_median - threshold)) + & (v["photon_height"] < (lat_bin_sea_median + 2 * threshold)) + ] + + # Calculate the photon ratio between surface and whole photons + if v["photon_height"].shape[0] > 0: + new_photons_ratio_sea_surface = ( + photons_sea_surface_up.shape[0] / v["photon_height"].shape[0] + ) + else: + new_photons_ratio_sea_surface = np.nan + + sea_surface_subsurface_photons_ratio.append( + 1 - new_photons_ratio_sea_surface + ) + + else: + # Append NaNs if the group is empty + sea_surface_height.append(np.nan) + mean_lat_bins_seq.append(np.nan) + sea_surface_subsurface_photons_ratio.append(np.nan) + + # Filter out sea height bin values outside 2 SD of mean. + mean = np.nanmean(sea_surface_height, axis=0) + sd = np.nanstd(sea_surface_height, axis=0) + + final_sea_surface_height = np.where( + (sea_surface_height > (mean + 2 * sd)) | (sea_surface_height < (mean - 2 * sd)), + np.nan, + sea_surface_height, + ).tolist() + + sea_surface_height_abnormal_label = np.where( + np.isnan(final_sea_surface_height), 0, 1 + ) + + # Determine label based on ratio of sea surface photons and subsurface photons + sea_surface_dominated_label = np.where( + np.array(sea_surface_subsurface_photons_ratio) >= 0.2, 0, 1 + ) + + # Loop through groups again and return photons below 0.5m of sea height + PhotonDFBelowThresholdPeak = pd.DataFrame() + for i, (k, v) in enumerate(data_groups.items()): + # Get all values below this bin + if not np.isnan(final_sea_surface_height[i]): + NewPhotonDFBelowThresholdPeak = v.loc[ + v["photon_height"] < (final_sea_surface_height[i] - threshold) + ] + + if firstTimeIndex: + PhotonDFBelowThresholdPeak = NewPhotonDFBelowThresholdPeak + firstTimeIndex = False + else: + PhotonDFBelowThresholdPeak = pd.concat( + [PhotonDFBelowThresholdPeak, NewPhotonDFBelowThresholdPeak] + ) + + return ( + final_sea_surface_height, + sea_surface_height_abnormal_label, + sea_surface_dominated_label, + PhotonDFBelowThresholdPeak, + ) + + +def get_sea_surface_height_adaptive(binned_data): + """ + Calculate sea surface height and filter subsurface photons with enhanced wave removal. + + Flowchart steps: + 7 — Fit Gaussian curve to identify surface elevation; compute standard + deviation of Gaussian peak. + 11 — Re-fit Gaussian curve to surface to identify surface peak; compute + standard deviation and remove histogram data within three standard + deviations. (This function is called a second time after refraction + correction and histogram rebuild when apply_post_refraction_refit + is enabled.) + + Parameters: + binned_data (pandas.DataFrame): Binned photon data with lat, height_bins, photon_height + + Returns: + final_sea_surface_height (list): Detected sea surface heights + sea_surface_height_abnormal_label (array): Labels for abnormal heights + sea_surface_dominated_label (array): Labels for surface-dominated bins + PhotonDFBelowSurface (pandas.DataFrame): Photons below detected surface with waves removed + """ + + firstTimeIndex = True + sea_surface_height = [] + mean_lat_bins_seq = [] + sea_surface_subsurface_photons_ratio = [] + + grouped_data = binned_data.groupby(["lat_bins"], group_keys=True, observed=False) + data_groups = dict(list(grouped_data)) + + for k, v in data_groups.items(): + lat_bin_average = v["lat"].mean() + + if len(v) < 20: # Increase minimum photon count for robustness + sea_surface_height.append(np.nan) + mean_lat_bins_seq.append(np.nan) + sea_surface_subsurface_photons_ratio.append(np.nan) + continue + + # Finer histogram for better peak resolution + hist, bin_edges = np.histogram( + v["photon_height"], + bins=100, # Finer bins + density=True, + range=(v["photon_height"].min(), v["photon_height"].max()), + ) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + + # Enhanced peak detection + peaks, properties = find_peaks( + hist, + height=np.max(hist) * 0.3, # Stricter peak height + distance=10, # Wider separation + prominence=np.max(hist) * 0.1, + ) # Require prominent peaks + + if len(peaks) == 0: + sea_surface_height.append(np.nan) + mean_lat_bins_seq.append(np.nan) + sea_surface_subsurface_photons_ratio.append(np.nan) + continue + + # Use strongest peak as surface, consider nearby peaks as wave effects + surface_peak_idx = peaks[np.argmax(hist[peaks])] + surface_height = bin_centers[surface_peak_idx] + + # Two-pass Gaussian fit: start with ±1 m, widen if sigma is large + # (the ±1 m window underestimates sigma for SWH > 2 m) + half_win = 1.0 + peak_data = v[ + (v["photon_height"] > surface_height - half_win) + & (v["photon_height"] < surface_height + half_win) + ] + if len(peak_data) > 10: + mu, sigma = norm.fit(peak_data["photon_height"]) + # Refit with wider window if sigma suggests truncation + if sigma > 0.4: + half_win2 = min(2.5 * sigma, 3.0) + peak_data2 = v[ + (v["photon_height"] > mu - half_win2) + & (v["photon_height"] < mu + half_win2) + ] + if len(peak_data2) > 10: + mu, sigma = norm.fit(peak_data2["photon_height"]) + else: + mu, sigma = surface_height, 0.2 # More conservative default + + adaptive_threshold = min(mu - 3.0 * sigma, mu - 1.0) # at least 1 m below + + # Extended surface layer to remove wave effects + surface_photons = v[ + (v["photon_height"] > adaptive_threshold) + & (v["photon_height"] < mu + 3.0 * sigma) + ] # Wider upper bound + ratio = len(surface_photons) / len(v) if len(v) > 0 else np.nan + + sea_surface_height.append(mu) + mean_lat_bins_seq.append(lat_bin_average) + sea_surface_subsurface_photons_ratio.append(1 - ratio) + + # Outlier filtering + mean = np.nanmean(sea_surface_height) + sd = np.nanstd(sea_surface_height) + final_sea_surface_height = np.where( + (sea_surface_height > mean + 2 * sd) | (sea_surface_height < mean - 2 * sd), + np.nan, + sea_surface_height, + ).tolist() + + sea_surface_height_abnormal_label = np.where( + np.isnan(final_sea_surface_height), 0, 1 + ) + sea_surface_dominated_label = np.where( + np.array(sea_surface_subsurface_photons_ratio) >= 0.2, 0, 1 + ) + + # Filter subsurface photons with diagnostics + PhotonDFBelowSurface = pd.DataFrame() + for i, (k, v) in enumerate(data_groups.items()): + if not np.isnan(final_sea_surface_height[i]): + half_win = 1.0 + peak_data = v[ + (v["photon_height"] > final_sea_surface_height[i] - half_win) + & (v["photon_height"] < final_sea_surface_height[i] + half_win) + ] + if len(peak_data) > 10: + mu, sigma = norm.fit(peak_data["photon_height"]) + if sigma > 0.4: + half_win2 = min(2.5 * sigma, 3.0) + peak_data2 = v[ + (v["photon_height"] > mu - half_win2) + & (v["photon_height"] < mu + half_win2) + ] + if len(peak_data2) > 10: + mu, sigma = norm.fit(peak_data2["photon_height"]) + adaptive_threshold = min(mu - 3.0 * sigma, mu - 1.0) + else: + adaptive_threshold = final_sea_surface_height[i] - 1.0 + + NewPhotonDFBelowSurface = v[v["photon_height"] < adaptive_threshold] + + if firstTimeIndex: + PhotonDFBelowSurface = NewPhotonDFBelowSurface + firstTimeIndex = False + else: + PhotonDFBelowSurface = pd.concat( + [PhotonDFBelowSurface, NewPhotonDFBelowSurface] + ) + + return ( + final_sea_surface_height, + sea_surface_height_abnormal_label, + sea_surface_dominated_label, + PhotonDFBelowSurface, + ) + + +def get_water_temp(date_year, date_month, date_day, latitude, longitude): + """ + Pull down surface water temperature along the track from the JPL GHRSST opendap website. + + The GHRSST data are gridded tiles with dimension 17998 x 35999. + To get the specific grid tile of the SST, you must convert from lat, lon coordinates + to the gridded tile ratio of the SST data product using the coordinates of the IS2 data. + """ + # Get date from data filename + # data_path[-33:-25] + date = date_year + date_month + date_day + # date[0:4] + year = date_year + # date[4:6] + month = date_month + # date[6:8] + day = date_day + day_of_year = str(datetime.strptime(date, "%Y%m%d").timetuple().tm_yday) + # Add zero in front of day of year string + zero_day_of_year = day_of_year.zfill(3) + + # Calculate ratio of latitude from mid-point of IS2 track + old_lat = latitude.mean() + old_lat_min = -90 + old_lat_max = 90 + new_lat_min = 0 + new_lat_max = 17998 + + new_lat = round( + ((old_lat - old_lat_min) / (old_lat_max - old_lat_min)) + * (new_lat_max - new_lat_min) + + new_lat_min + ) + + # Calculate ratio of longitude from mid-point of IS2 track + old_lon = longitude.mean() + old_lon_min = -180 + old_lon_max = 180 + new_lon_min = 0 + new_lon_max = 35999 + + new_lon = round( + ((old_lon - old_lon_min) / (old_lon_max - old_lon_min)) + * (new_lon_max - new_lon_min) + + new_lon_min + ) + + # Access the SST data using the JPL OpenDap interface + url = ( + "https://opendap.jpl.nasa.gov/opendap/OceanTemperature/ghrsst/data/GDS2/L4/GLOB/JPL/MUR/v4.1/" + + str(year) + + "/" + + str(zero_day_of_year) + + "/" + + str(date) + + "090000-JPL-L4_GHRSST-SSTfnd-MUR-GLOB-v02.0-fv04.1.nc" + ) + + dataset = netCDF4.Dataset(url) + + # Access the data and convert the temperature from K to C + water_temp = dataset["analysed_sst"][0, new_lat, new_lon] - 273.15 + return water_temp + + +def refraction_correction( + WTemp, + WSmodel, + Wavelength, + Photon_ref_elev, + Ph_ref_azimuth, + PhotonZ, + PhotonX, + PhotonY, + Ph_Conf, +): + """ + WTemp; there is python library that pulls water temp data + WSmodel is the value surface height + Wavelength is fixed + """ + + # Only process photons below water surface model + PhotonX = PhotonX[PhotonZ <= WSmodel] + PhotonY = PhotonY[PhotonZ <= WSmodel] + Photon_ref_elev = Photon_ref_elev[PhotonZ <= WSmodel] + Ph_ref_azimuth = Ph_ref_azimuth[PhotonZ <= WSmodel] + Ph_Conf = Ph_Conf[PhotonZ <= WSmodel] + PhotonZ = PhotonZ[PhotonZ <= WSmodel] + + # water temp for refraction correction + WaterTemp = WTemp + + # Refraction coefficient # + a = -0.000001501562500 + b = 0.000000107084865 + c = -0.000042759374989 + d = -0.000160475520686 + e = 1.398067112092424 + wl = Wavelength + + # refractive index of air + n1 = 1.00029 + + # refractive index of water + n2 = (a * WaterTemp**2) + (b * wl**2) + (c * WaterTemp) + (d * wl) + e + + # assumption is 0.25416 + # This example is refractionCoef = 0.25449 + # 1.00029 is refraction of air constant + correction_coef = 1 - (n1 / n2) + + # read photon ref_elev to get theta1 + theta1 = np.pi / 2 - Photon_ref_elev + + # eq 1. Theta2 + theta2 = np.arcsin((n1 * np.sin(theta1)) / n2) + + # eq 3. S + # Approximate water Surface = 1.5 + # D = raw uncorrected depth + D = WSmodel - PhotonZ + + # For Triangle DTS + S = D / np.cos(theta1) + + # eq 2. R + R = (S * n1) / n2 + Gamma = (np.pi / 2) - theta1 + + # For triangle RPS + # phi is an angle needed + phi = theta1 - theta2 + + # P is the difference between raw and corrected YZ location + P = np.sqrt(R**2 + S**2 - 2 * R * S * np.cos(phi)) + + # alpha is an angle needed + alpha = np.arcsin((R * np.sin(phi)) / P) + + # Beta angle needed for Delta Y an d Delta Z + Beta = Gamma - alpha + + # Delta Y + DY = P * np.cos(Beta) + + # Delta Z + DZ = P * np.sin(Beta) + + # Delta Easting + DE = DY * np.sin(Ph_ref_azimuth) + + # Delta Northing + DN = DY * np.cos(Ph_ref_azimuth) + + outX = PhotonX + DE + outY = PhotonY + DN + outZ = PhotonZ + DZ + + """ + print('For selected Bathy photon:') + print('lat = ', PhotonY[9000]) + print('long = ', PhotonX[9000]) + print('Raw Depth = ', PhotonZ[9000]) + print('D = ', D[9000]) + + print('ref_elev = ', Photon_ref_elev[9000]) + + print('Delta East = ', DE[9000]) + print('Delta North = ', DN[9000]) + print('Delta Z = ', DZ[9000]) + """ + return ( + outX, + outY, + outZ, + Ph_Conf, + PhotonX, + PhotonY, + PhotonZ, + Ph_ref_azimuth, + Photon_ref_elev, + ) # We are most interested in out-x, out-y, out-z diff --git a/aok/core/kd_utils/visualization.py b/aok/core/kd_utils/visualization.py new file mode 100644 index 0000000..900c07c --- /dev/null +++ b/aok/core/kd_utils/visualization.py @@ -0,0 +1,495 @@ +# utils/visualization.py + +import os +import time + +import geopandas as gpd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pyproj import Transformer +from scipy.spatial import ConvexHull + + +def plot_photon_height(sea_photon_dataset, hlims=[-25, 10]): + fig, ax = plt.subplots() + ax.scatter( + sea_photon_dataset["latitude"], + sea_photon_dataset["photon_height"], + s=1, + c="k", + alpha=0.15, + edgecolors="none", + label="ATL03 Photons", + ) + ax.set_xlabel("Relative AT Distance") + ax.set_ylabel("Height (h_ph)") + ax.set_title("Scatter Plot of ATL03 Photons") + ax.set_ylim(hlims) + ax.legend() + plt.show() + + +def plot_filtered_seafloor_photons( + filtered_seafloor_subsurface_dataset, + sea_photon_dataset, + sea_surface_height, + output_path, +): + """ + Plots the filtered seafloor photon data along with sea surface and seafloor elevation data. + + Parameters: + - filtered_seafloor_subsurface_dataset: DataFrame containing the filtered subsurface photon data. + - sea_photon_dataset: DataFrame containing the original sea photon data. + - sea_surface_height: Array of sea surface height values. + - output_path: Path to save the output plot. + """ + + # get sea_surface_x_axis_bins + sea_surface_x_axis_bins = np.linspace( + filtered_seafloor_subsurface_dataset["relative_AT_dist"].min(), + filtered_seafloor_subsurface_dataset["relative_AT_dist"].max(), + len(sea_surface_height), + ) + + hlims = [-25, 10] + fig, ax = plt.subplots() + + # Scatter plot of subsurface photons + ax.scatter( + sea_photon_dataset["relative_AT_dist"], + sea_photon_dataset["photon_height"], + s=1, + c="k", + alpha=0.15, + edgecolors="none", + label="Subsurface ATL03 Photons", + ) + + # Overlay the seafloor elevation data as points + ax.plot( + filtered_seafloor_subsurface_dataset["relative_AT_dist"], + filtered_seafloor_subsurface_dataset["seafloor_elevation"], + linewidth=0.8, + c="b", + alpha=0.4, + label="Seafloor Elevation", + ) + + ax.plot( + sea_surface_x_axis_bins, + [x - 0.5 for x in sea_surface_height], + linewidth=0.8, + color="#DD571C", + alpha=0.4, + label="0.5 m Below Surface Peak", + ) + + # Set labels and title + ax.set_xlabel("Distance From Start of Track (km)") + ax.set_ylabel("Photon Height (m)") + ax.set_title("Scatter Plot of Filtered Sea Photon Dataset") + ax.set_ylim(hlims) + + # Add a legend + ax.legend() + + # Save the plot + plt.legend(loc="upper right") + plt.savefig(output_path, dpi=400, format="jpeg") + + # Show the plot + plt.show() + + +def plot_convex_hulls(photon_dataset, target_beam_ids, convex_hulls, convex_hull_areas): + """ + Plots ConvexHull polygons and annotated areas for each lat_bin group in the dataset. + """ + # filter beam type + # photon_dataset + photon_dataset = photon_dataset[photon_dataset["beam_id"].isin(target_beam_ids)] + + fig, ax = plt.subplots(figsize=(10, 6)) + hlims = [-10, 2] + + for lat_bin, hull_points in convex_hulls.items(): + hull = ConvexHull(hull_points) + ax.fill( + hull_points[hull.vertices, 0], + hull_points[hull.vertices, 1], + alpha=0.3, + label=f"Bin {lat_bin}", + ) + + # Calculate centroid to place the label + centroid = np.mean(hull_points[hull.vertices], axis=0) + ax.text( + centroid[0], + centroid[1], + f"{convex_hull_areas[lat_bin]:.2f}", + horizontalalignment="center", + verticalalignment="center", + fontsize=10, + color="black", + ) + + ax.scatter( + photon_dataset["lat"], + photon_dataset["photon_height"], + s=1, + c="k", + alpha=0.15, + edgecolors="none", + label="Subsurface ATL03 Photons", + ) + ax.set_ylim(hlims) + ax.set_xlabel("Latitude") + ax.set_ylabel("Photon Height") + ax.set_title("ConvexHull of Photons within each Lat Bin") + plt.show() + + +# plot photons coloured by quality_ph flag for diagnostic purposes +def plot_photon_quality_flags( + output_path, timestamp, sea_photon_dataset, target_beam_ids +): + """ + Scatter plot of photon height vs. along-track distance, coloured by quality_ph value. + + quality_ph bit meanings (ATL03 v007): + bit 0 (value 1) = possible afterpulse + bit 1 (value 2) = possible impulse response effect + bit 2 (value 4) = possible TEP + """ + dataset = sea_photon_dataset[ + sea_photon_dataset["beam_id"].isin(target_beam_ids) + ].copy() + if dataset.empty: + return + + quality_colors = { + 0: ("dimgray", "nominal (0)"), + 1: ("red", "afterpulse (1)"), + 2: ("orange", "impulse response (2)"), + 3: ("darkred", "afterpulse + impulse (3)"), + 4: ("royalblue", "TEP (4)"), + 5: ("purple", "afterpulse + TEP (5)"), + 6: ("cyan", "impulse + TEP (6)"), + 7: ("black", "all flags (7)"), + } + + fig, ax = plt.subplots(figsize=(12, 5)) + unique_flags = sorted(dataset["quality_ph"].dropna().astype(int).unique()) + for flag in unique_flags: + subset = dataset[dataset["quality_ph"].astype(int) == flag] + color, label = quality_colors.get(flag, ("magenta", f"unknown ({flag})")) + ax.scatter( + subset["relative_AT_dist"], + subset["photon_height"], + s=0.5, + c=color, + alpha=0.3, + edgecolors="none", + label=label, + ) + + ax.set_xlabel("Relative Along-Track Distance (km)", fontsize=12) + ax.set_ylabel("Photon Height (m)", fontsize=12) + ax.set_title("Photon quality_ph flag distribution", fontsize=13) + ax.set_ylim([-15, 5]) + ax.legend(loc="lower right", markerscale=6, fontsize=9) + fig.tight_layout() + save_path = os.path.join(output_path, f"{timestamp}_quality_ph_flags.jpg") + plt.savefig(save_path, dpi=300, format="jpeg") + plt.show() + print(f"Quality flag plot saved: {save_path}") + + # Print summary counts + print("\nquality_ph flag summary:") + for flag in unique_flags: + count = (dataset["quality_ph"].astype(int) == flag).sum() + _, label = quality_colors.get(flag, ("", f"unknown ({flag})")) + pct = 100.0 * count / len(dataset) + print(f" {label:35s}: {count:>8,d} ({pct:.1f}%)") + + +# plot the kd and photon +def plot_kd_photons( + OutputPath, + timestamp, + target_beam_ids, + subsurface_photon_dataset, + Kd_DF_MergedDistance, +): + # filter beam type + # photon_dataset + subsurface_photon_dataset = subsurface_photon_dataset[ + subsurface_photon_dataset["beam_id"].isin(target_beam_ids) + ] + Kd_DF_MergedDistance = Kd_DF_MergedDistance[ + Kd_DF_MergedDistance["beam_id"].isin(target_beam_ids) + ] + + hlims = [-45, 5] + fig, ax1 = plt.subplots(figsize=(10, 6)) + ax1.scatter( + subsurface_photon_dataset["relative_AT_dist"], + subsurface_photon_dataset["photon_height"], + s=1.5, + c="k", + alpha=0.2, + edgecolors="none", + label="Subsurface ATL03 Photon Height", + ) + ax1.set_xlabel("Relative Along-Track Distance", fontsize=18) + ax1.tick_params(axis="x", labelsize=18) # Added line for x-tick label fontsize + ax1.set_ylabel("Photon Height", color="b", fontsize=18) + ax1.tick_params(axis="y", labelcolor="b", labelsize=18) + if "seafloor_elevation" in subsurface_photon_dataset.columns: + ax1.plot( + subsurface_photon_dataset["relative_AT_dist"], + subsurface_photon_dataset["seafloor_elevation"], + linewidth=0.8, + c="b", + alpha=0.4, + label="Seafloor Elevation", + ) + # ax1.axhline(y=-6, color='blue', linestyle='--', label='y=-6 m') + # ax1.set_xlim([1800, 2100]) # Set x-axis limits + ax1.set_ylim(hlims) + + ax2 = ax1.twinx() + ax2.scatter( + Kd_DF_MergedDistance["relative_AT_dist"], + Kd_DF_MergedDistance["kd"], + label="Kd values", + color="r", + alpha=0.6, + ) + ax2.set_ylabel("Kd Value", color="r", fontsize=18) + ax2.tick_params(axis="y", labelcolor="r", labelsize=18) + # fig.suptitle('Photon Height and Kd Values along Relative Along-Track Distance') + handles1, labels1 = ax1.get_legend_handles_labels() + handles2, labels2 = ax2.get_legend_handles_labels() + fig.legend( + handles1 + handles2, + labels1 + labels2, + loc="upper right", + bbox_to_anchor=(0.85, 0.85), + ) + # fig.legend(handles1 + handles2, labels1 + labels2, loc='lower left', bbox_to_anchor=(0.08, 0.08)) + fig.tight_layout() + plt.savefig( + os.path.join(OutputPath, f"{timestamp}_IS2_subsurface_kd_2.jpg"), + dpi=400, + format="jpeg", + ) + plt.show() + + +def plot_bin_polygon_data( + lat, + height, + seafloor_height, + mask, + lat_bins, + height_bins, + valid_bins, + polygons, + y_min, + y_max, +): + plt.figure(figsize=(12, 6)) + + plt.subplot(1, 2, 1) + plt.title("Original Height Data") + plt.scatter(lat, height, c=height, cmap="viridis", s=10) + plt.plot(np.sort(lat), seafloor_height[np.argsort(lat)], "r-", label="Seafloor") + plt.colorbar(label="Height") + plt.xlabel("Latitude") + plt.ylabel("Height") + plt.ylim(y_min, y_max) + plt.legend() + + plt.subplot(1, 2, 2) + plt.title("Masked Region (ROI)") + plt.scatter(lat[mask], height[mask], c=height[mask], cmap="viridis", s=10) + plt.plot(np.sort(lat), seafloor_height[np.argsort(lat)], "r-", label="Seafloor") + + for polygon in polygons: + plt.gca().add_patch( + plt.Polygon(polygon, edgecolor="red", facecolor="none", linewidth=1) + ) + + for lb in lat_bins: + plt.axvline(x=lb, color="gray", linestyle="--", linewidth=0.5) + for hb in height_bins: + plt.axhline(y=hb, color="gray", linestyle="--", linewidth=0.5) + + plt.colorbar(label="Height") + plt.xlabel("Latitude") + plt.ylabel("Height") + plt.ylim(y_min, y_max) + plt.legend() + + plt.tight_layout() + plt.show() + + +def produce_figures( + binned_data, + bath_height, + sea_height, + solo_sea_surface_label, + y_limit_top, + y_limit_bottom, + percentile, + file, + geo_df, + ref_y, + ref_z, + beam, + epsg_num, +): + """Create figures""" + + # Create bins for latitude + bath_x_axis_bins = ( + np.linspace(binned_data.lat.min(), binned_data.lat.max(), len(bath_height)) + 20 + ) + + sea_surface_x_axis_bins = ( + np.linspace(binned_data.lat.min(), binned_data.lat.max(), len(sea_height)) + 10 + ) + + # Create new dataframes for median values + bath_median_df = pd.DataFrame({"x": bath_x_axis_bins, "y": bath_height}) + + # Create uniform sea surface based on median sea surface values and filter out surface breaching + sea_height1 = [np.nanmedian(sea_height) if i == i else np.nan for i in sea_height] + sea_median_df = pd.DataFrame({"x": sea_surface_x_axis_bins, "y": sea_height1}) + + # Create uniform solo sea surface label + sea_surface_label = solo_sea_surface_label + sea_surface_label_df = pd.DataFrame( + {"x": sea_surface_x_axis_bins, "y": sea_surface_label} + ) + idx_1 = np.where(sea_surface_label_df.y == 1) + idx_0 = np.where(sea_surface_label_df.y == 0) + + # Define figure size + fig = plt.rcParams["figure.figsize"] = (40, 25) + + # Plot raw points + # plt.scatter(x=binned_data.lat, + # y = binned_data.photon_height, marker='o', lw=0, s=1, alpha = 0.8, + # c = 'yellow', label = 'Raw photon height') + plt.scatter(ref_y, ref_z, s=0.5, alpha=0.1, c="black") + plt.scatter( + geo_df.lat, + geo_df.photon_height, + s=0.8, + marker="o", + alpha=0.1, + c="red", + label="Classified Photons", + ) + + # plt.scatter(x=geo_df.lat, + # y = geo_df.photon_height, marker='o', lw=0, s=0.8, + # alpha = 0.8, c = 'black', label = 'Corrected photon bin') + + # Plot median values + plt.scatter( + bath_median_df.x, + bath_median_df.y, + marker="o", + c="r", + alpha=0.8, + s=2, + label="Median bathymetry", + ) + + plt.scatter( + sea_median_df.x, + sea_median_df.y, + marker="o", + c="b", + alpha=1, + s=2, + label="Median sea surface", + ) + + plt.scatter( + sea_surface_label_df.iloc[idx_1].x, + sea_surface_label_df.iloc[idx_1].y, + marker="o", + c="pink", + alpha=1, + s=3, + label="solo_sea_surface", + ) + plt.scatter( + sea_surface_label_df.iloc[idx_0].x, + sea_surface_label_df.iloc[idx_0].y, + marker="o", + c="g", + alpha=1, + s=3, + label="non_solo_sea_surface", + ) + + # Insert titles and subtitles + plt.title("Icesat2 Bathymetry\n" + file) + plt.xlabel("Latitude", fontsize=25) + plt.ylabel("Photon Height (m)", fontsize=25) + plt.xticks(fontsize=16) + plt.yticks(fontsize=16) + + plt.legend(loc="upper left", prop={"size": 20}) + + # Limit the x and y axes using parameters + plt.xlim(left=binned_data.lat.min(), right=binned_data.lat.max()) + plt.ylim(top=y_limit_top, bottom=y_limit_bottom) + + timestr = time.strftime("%Y%m%d_%H%M%S") + file = file.replace(".h5", "") + # Define where to save file + plt.tight_layout() + plt.savefig( + "C:/Workstation/ICESat2_HLS/" + + file + + "_gt" + + str(beam) + + "_" + + str(percentile) + + "_EPSG" + + str(epsg_num) + + "_" + + timestr + + ".pdf" + ) + # plt.show() + # plt.close() + + # convert corrected locations back to wgs84 (useful to contain) + transformer = Transformer.from_crs( + "EPSG:" + str(epsg_num), "EPSG:4326", always_xy=True + ) + print(transformer) + lon_wgs84, lat_wgs84 = transformer.transform(geo_df.lon.values, geo_df.lat.values) + + geo_df["lon_wgs84"] = lon_wgs84 + geo_df["lat_wgs84"] = lat_wgs84 + + geodf = gpd.GeoDataFrame( + geo_df, geometry=gpd.points_from_xy(geo_df.lon_wgs84, geo_df.lat_wgs84) + ) + + geodf.set_crs(epsg=4326, inplace=True) + + # geodf.to_file("C:/Workstation/ICESat2_HLS/" +file + '_gt' + '_' + str(percentile) + '_EPSG' + + # str(epsg_num) + '_' + timestr + ".gpkg", + # driver="GPKG") diff --git a/aok/core/main.py b/aok/core/main.py new file mode 100644 index 0000000..9168ef0 --- /dev/null +++ b/aok/core/main.py @@ -0,0 +1,329 @@ +import glob +import logging +import os +import re +import sys + +from config import get_args +from kd_utils.bathy_processing import process_subsurface_photon_filtering +from kd_utils.data_processing import ( + Extract_sea_photons, + apply_optional_ir_ap_filter, + apply_optional_solar_background_filter, + filter_photon_dataset_by_hull_area, + load_data, +) +from kd_utils.Kd_analysis import process_kd_calculation +from kd_utils.sea_photons_analysis import process_sea_photon_binning +from kd_utils.visualization import ( + plot_convex_hulls, + plot_kd_photons, + plot_photon_quality_flags, +) +import pandas as pd + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +PAIR_ID_MAP = { + "gt1l": "gt1_pair", + "gt1r": "gt1_pair", + "gt2l": "gt2_pair", + "gt2r": "gt2_pair", + "gt3l": "gt3_pair", + "gt3r": "gt3_pair", +} + + +def get_target_beams(all_beams, beam_attrs, target_beams_arg): + strong_beams = [ + gtx + for gtx in all_beams + if beam_attrs[gtx]["atlas_beam_type"].decode("utf-8") == "strong" + ] + if not target_beams_arg: + return strong_beams[:1] # auto-select first strong beam + + requested = [b.strip() for b in target_beams_arg.split(",") if b.strip()] + return [b for b in requested if b in strong_beams] + + +def optionally_combine_paired_beams_for_kd(dataset, horizontal_res, enabled=False): + """ + Optionally combine left/right beams into pair groups before Kd fitting. + Uses shared along-track bins from relative_AT_dist when available. + + Flowchart step: 15 — Combine data from paired beams if appropriate (not + recommended in optically complex water). + """ + if (not enabled) or dataset.empty: + return dataset + + combined = dataset.copy() + combined["source_beam_id"] = combined["beam_id"] + combined["beam_id"] = ( + combined["beam_id"].map(PAIR_ID_MAP).fillna(combined["beam_id"]) + ) + + if "relative_AT_dist" in combined.columns: + rel_m = pd.to_numeric(combined["relative_AT_dist"], errors="coerce") * 1000.0 + pair_bins = (rel_m / max(horizontal_res, 1)).astype("Int64") + pair_bins = pair_bins.fillna(-1).astype(int) + combined["lat_bins"] = pair_bins + + return combined + + +def run_pipeline(args): + """ + Top-level pipeline orchestrator. Executes all 19 flowchart steps in sequence. + + Flowchart steps directly handled here: + 1 — Load ATL03 data (load_data / Extract_sea_photons). + 2 — Remove afterpulse/impulse-response photons (apply_optional_ir_ap_filter). + 3 — Bin sizes set via args.horizontal_res / args.vertical_res. + 4 — Build histograms (process_sea_photon_binning). + 5 — Solar background filter (apply_optional_solar_background_filter). + 6–14 — Subsurface filtering chain (process_subsurface_photon_filtering). + 15 — Paired beam combine (optionally_combine_paired_beams_for_kd). + 16–17 — Beer's Law Kd fit (process_kd_calculation). + 18 — Save Kd CSV output (subsurface_photon_df_added_kd.to_csv). + 19 — Check data for reasonableness (filter_photon_dataset_by_hull_area + + plot_kd_photons). + """ + atl03_h5_file_path = os.path.join( + args.workspace_path, args.atl03_path, args.atl03_file + ) + shoreline_data_path = os.path.join( + args.workspace_path, args.other_data_path, args.shoreline_data + ) + gebco_full_path = os.path.join( + args.workspace_path, args.other_data_path, args.gebco_path + ) + atl24_file_path = args.atl24_file + if atl24_file_path and (not os.path.isabs(atl24_file_path)): + atl24_file_path = os.path.join(args.workspace_path, atl24_file_path) + output_path = os.path.join(args.workspace_path, args.output_path) + os.makedirs(output_path, exist_ok=True) + + match = re.search(r"_(\d{14})_", atl03_h5_file_path) + timestamp = match.group(1) if match else "unknown" + + version_match = re.search( + r"ATL03_\d{14}_\d{8}_(\d{3})_\d{2}", os.path.basename(atl03_h5_file_path) + ) + atl03_version = int(version_match.group(1)) if version_match else None + logger.info( + "ATL03 version detected: %s", + f"{atl03_version:03d}" if atl03_version else "unknown", + ) + + if args.enable_ir_ap_filter: + if atl03_version is None or atl03_version < 7: + logger.error( + "IR/AP filter requires ATL03 version 007 or later. " + "Detected version: %s. " + "Skipping IR/AP filter for this run. " + "Switch to a version 007 file (e.g., Wax Delta dataset) to enable this step.", + f"{atl03_version:03d}" if atl03_version else "unknown", + ) + args.enable_ir_ap_filter = False + + gebco_pattern = os.path.join(gebco_full_path, "gebco_*.tif") + gebco_file_path_lists = [p for p in glob.glob(gebco_pattern)] + + is2_mds, is2_attrs, is2_beams = load_data(atl03_h5_file_path, False) + target_strong_beams = get_target_beams(is2_beams, is2_attrs, args.target_beams) + if not target_strong_beams: + logger.error( + "No target strong beams found. Check input file and --target_beams." + ) + sys.exit(1) + logger.info("Strong beams selected: %s", target_strong_beams) + plot_target_beam = [target_strong_beams[0]] + + sea_photon_dataset = Extract_sea_photons( + is2_mds, target_strong_beams, shoreline_data_path + ) + + if args.enable_ir_ap_filter and not args.no_plot: + plot_photon_quality_flags( + output_path, timestamp, sea_photon_dataset, plot_target_beam + ) + + sea_photon_dataset = apply_optional_ir_ap_filter( + sea_photon_dataset, + enabled=args.enable_ir_ap_filter, + quality_max=args.ir_ap_quality_max, + min_signal_conf=args.ir_ap_min_signal_conf, + ) + + # Step 5 — Solar background filter is ON by default. + # It self-gates on solar_elevation so has zero effect on nighttime passes. + # Use --disable_solar_background_filter to turn it off. + solar_bg_enabled = not args.disable_solar_background_filter + if not solar_bg_enabled: + if "solar_elevation" in sea_photon_dataset.columns: + max_solar_elev = sea_photon_dataset["solar_elevation"].max() + if max_solar_elev > args.solar_elevation_day_threshold: + logger.warning( + "Solar background filter is DISABLED via --disable_solar_background_filter " + "but this granule is a daytime pass " + "(max solar elevation = %.1f°, daytime threshold = %.1f°). " + "This filter is recommended for daytime passes.", + max_solar_elev, + args.solar_elevation_day_threshold, + ) + + sea_photon_dataset = apply_optional_solar_background_filter( + sea_photon_dataset, + enabled=solar_bg_enabled, + day_threshold=args.solar_elevation_day_threshold, + median_window_deg=args.solar_bg_median_window_deg, + noise_multiplier=args.solar_bg_noise_multiplier, + min_signal_conf=args.solar_background_min_signal_conf, + ) + + binned_dataset_sea_surface = process_sea_photon_binning( + sea_photon_dataset, + horizontal_res=args.horizontal_res, + vertical_res=args.vertical_res, + ) + + post_refraction_refit_enabled = args.enable_post_refraction_refit + if post_refraction_refit_enabled and (not args.enable_refraction_correction): + logger.warning( + "--enable_post_refraction_refit requested without --enable_refraction_correction. " + "The post-refraction refit step will be skipped." + ) + post_refraction_refit_enabled = False + + ( + sea_surface_height, + sea_surface_label, + filtered_seafloor_subsurface_photon_dataset, + ) = process_subsurface_photon_filtering( + binned_dataset_sea_surface, + gebco_file_path_lists, + args.subsurface_thresh, + args.ignore_subsurface_height_thres, + use_atl24_filter=args.enable_atl24_filter, + atl24_file_path=atl24_file_path, + atl24_max_match_distance_deg=args.atl24_max_match_distance_deg, + use_gebco_filter=args.enable_gebco_filter, + apply_histogram_quality_filter=args.enable_histogram_quality_filter, + histogram_quality_min_ratio=args.histogram_quality_min_ratio, + histogram_quality_depth_min=args.histogram_quality_depth_min, + histogram_quality_depth_max=args.histogram_quality_depth_max, + histogram_quality_ref_depth_min=args.histogram_quality_ref_depth_min, + histogram_quality_ref_depth_max=args.histogram_quality_ref_depth_max, + apply_surface_sigma_filter=args.enable_surface_sigma_filter, + surface_sigma_max=args.surface_sigma_max, + apply_refraction_correction=args.enable_refraction_correction, + refraction_water_temp_c=args.refraction_water_temp_c, + refraction_wavelength_nm=args.refraction_wavelength_nm, + apply_post_refraction_refit=post_refraction_refit_enabled, + apply_flattening=args.enable_sea_surface_flattening, + flattening_window_m=args.sea_surface_flattening_window_m, + horizontal_res=args.horizontal_res, + vertical_res=args.vertical_res, + ) + + final_filtered_subsurface_photon_dataset = ( + filtered_seafloor_subsurface_photon_dataset[ + filtered_seafloor_subsurface_photon_dataset["beam_id"].isin( + target_strong_beams + ) + ].copy() + ) + + if args.enable_convex_hull_filter: + final_filtered_subsurface_photon_dataset, convex_hull_areas, convex_hulls = ( + filter_photon_dataset_by_hull_area( + final_filtered_subsurface_photon_dataset, + hull_area_threshold=args.convex_hull_area_threshold, + ) + ) + if plot_target_beam and not args.no_plot: + plot_convex_hulls( + final_filtered_subsurface_photon_dataset, + plot_target_beam, + convex_hulls, + convex_hull_areas, + ) + + subsurface_output_path = os.path.join( + output_path, + f"{timestamp}_strongBeam_{'_'.join(target_strong_beams)}_subsurface_photons.csv", + ) + final_filtered_subsurface_photon_dataset.to_csv(subsurface_output_path, index=False) + + kd_input_dataset = optionally_combine_paired_beams_for_kd( + final_filtered_subsurface_photon_dataset, + horizontal_res=args.horizontal_res, + enabled=args.enable_paired_beam_combine, + ) + wave_mult = args.wave_exclusion_multiplier if args.enable_wave_adaptive_fit else 0.0 + subsurface_photon_df_added_kd = process_kd_calculation( + kd_input_dataset, + decay_zone_threshold=args.decay_zone_threshold, + kd_fit_method=args.kd_fit_method, + wave_exclusion_multiplier=wave_mult, + wave_sigma_calm_threshold=args.wave_sigma_calm_threshold, + ) + kd_output_path = os.path.join( + output_path, f"{timestamp}_AddedKdDataset_strongBeams_Further.csv" + ) + subsurface_photon_df_added_kd.to_csv(kd_output_path, index=False) + + if not args.no_plot and "relative_AT_dist" in kd_input_dataset.columns: + unique_photon_dataset = kd_input_dataset[ + ["relative_AT_dist", "lat_bins", "photon_height"] + ].drop_duplicates() + unique_photon_dataset["relative_AT_dist_center"] = ( + unique_photon_dataset.groupby( + "lat_bins", observed=False + )["relative_AT_dist"].transform("mean") + ) + + _closest_rows = [] + for _lat_bin, _group in unique_photon_dataset.groupby( + "lat_bins", observed=False + ): + if _group.empty: + continue + _center = _group["relative_AT_dist_center"].iloc[0] + _group = _group.copy() + _group["dist_to_center"] = abs(_group["relative_AT_dist"] - _center) + _closest_rows.append(_group.loc[[_group["dist_to_center"].idxmin()]]) + if not _closest_rows: + logger.warning("No along-track bins survived filtering. Skipping plot.") + else: + closest_to_center = pd.concat(_closest_rows).reset_index(drop=True) + kd_df_merged_distance = closest_to_center.merge( + subsurface_photon_df_added_kd, on="lat_bins", how="left" + ).drop( + columns=["relative_AT_dist_center", "dist_to_center"], errors="ignore" + ) + + plot_kd_photons( + output_path, + timestamp, + plot_target_beam, + kd_input_dataset, + kd_df_merged_distance, + ) + + logger.info("SUCCESS! Kd output: %s", kd_output_path) + + +if __name__ == "__main__": + cli_args = get_args() + try: + run_pipeline(cli_args) + except Exception as e: + logger.error("An error occurred: %s", e) + raise diff --git a/pyproject.toml b/pyproject.toml index 22a7696..7bc1d22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,7 @@ py-modules = ["_aok_version"] ##### DEPENDENCIES ##### [tool.setuptools.dynamic] -dependencies = [ - "numpy", -] +dependencies = {file = ["requirements.txt"]} [project.optional-dependencies] dev = [ @@ -69,7 +67,7 @@ exclude = ["*tests"] version_file = "_aok_version.py" version_file_template = 'version = "{version}"' local_scheme = "node-and-date" -fallback_version = "unknown" +fallback_version = "0.0.0" ##### LINTING, FORMATTING, TYPING ##### diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d359c9a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ + +geopandas +h5py + +# Need to confirm which of these remain dependencies +hdbscan +icepyx +matplotlib +netcdf4 +numpy +pandas +pyproj +rasterio +Rtree +scikit_learn +scipy +Shapely +utm