diff --git a/.gitignore b/.gitignore index 4212d5d..c4e413e 100644 --- a/.gitignore +++ b/.gitignore @@ -155,6 +155,9 @@ dmypy.json # Cython debug symbols cython_debug/ +# Notebook data (symlinked datasets, not committed) +docs/notebooks/data/ + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..8a0c5ba --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,135 @@ +# Python Coding & Editing Guidelines + +> **Living document – PRs welcome!** +> Last updated: 2026‑01‑15 + +## Table of Contents + +1. Philosophy +1. Code Style +1. Docstrings & Comments +1. Tools +1. Documentation + +--- + +## Philosophy + +- **Readability, reproducibility, performance – in that order.** +- Prefer explicit over implicit; avoid hidden state and global flags. +- Measure before you optimize (`time.perf_counter`, `line_profiler`). +- Each module holds a **single responsibility**; keep public APIs minimal. +- Before merging pull requests, ensure that + - All tests pass + - All precommit checks pass + - Any new functionality has a new test with it + - Any bug fixes have a new test which fails on the main branch but passes on your PR + +## Code Style + +- Annotate all public functions (PEP 484). +- In general, write code that will raise an exception early if something isn't expected. +- Raise Exceptions/Errors for user-facing problems. Only use asserts to help fix mypy errors, or to show developer expectations. +- Use Pydantic models over `dataclasses.dataclass`. +- Aim for zero "dead" code: do not leave commented code in unless it is part of a very descriptive comment that illustrates something specific. +- Follow the "parse, don't validate" addage: Parse unknown inputs at the serialization boundaries, not scattered everywhere in the code. +- If you need to add an ignore, ignore a specific check like # type: ignore[specific] . Use sparingly. +- Don't write error handing code or smooth over exceptions/errors unless they are expected as part of control flow. +- Prefer `Protocol` over `ABC`s when only an interface is needed. +- Use `from loguru import logger` for logging instead of `print` statements (`logger.info`). + +## Docstrings & Comments + +- Style: NumPyDoc. +- Start with a one‑sentence summary in the imperative mood. +- Sections: Parameters, Returns, Raises, Examples, References. +- Use backticks for code or referring to variables (e.g. `xarray.DataArray`). +- Do not use emojis, or non-unicode characters in comments/print statements. +- Cite peer‑reviewed papers with DOI links when relevant. +- Write code that explains itself rather than needs comments. +- For the inline you do add, explain *why*, not what. For example, *don't* write: + +```python +# open the file +f = open(filename) +``` + +- The comments should be things which are not obvious to a reader with typical background knowledge. Aim to write code that explains itself. + + +## Tools + +- You can run `pre-commit run -a` to run all pre-commit hooks and check for style violations +- ruff is uses for most code maintenance, black for formatting, mypy for type checking, pytest for testing + + + +## Documentation + +- mkdocs + Jupyter. Hosted on ReadTheDocs. +- Auto API from type hints. +- Provide tutorial notebooks covering common workflows. +- Include examples in docstrings. +- Add high-level guides for key functionality. + +--- + +## Codebase Architecture + +SPURT implements Extended Minimum Cost Flow (EMCF) for 3D InSAR phase unwrapping. The algorithm decomposes 3D unwrapping into two sequential 2D MCF problems. + +### Module Structure (`src/spurt/`) + +| Module | Purpose | +|--------|---------| +| `graph/` | Graph representations (Delaunay, Hop3, Regular2D) for spatial/temporal domains | +| `mcf/` | Minimum Cost Flow solver (OR-Tools based) and utilities | +| `links/` | Per-link model estimation (DEM errors, velocities) via grid search | +| `workflows/emcf/` | Main EMCF algorithm orchestration, tiling, merging | +| `io/` | Input/output interfaces for SLC stacks and 3D data | +| `utils/` | Logging, CPU utilities, tiling helpers | + +### Algorithm Flow + +**Stage 1: Temporal Unwrapping** (`EMCFSolver.unwrap_gradients_in_time`) +- For each spatial edge (pixel-to-pixel link), unwrap phase gradients across interferograms +- Uses temporal graph `G_t` (typically Hop3 or Delaunay in time-baseline space) +- Outputs temporally-unwrapped spatial gradients + +**Stage 2: Spatial Unwrapping** (`EMCFSolver.unwrap_gradients_in_space`) +- For each interferogram, unwrap spatial gradients using MCF +- Uses spatial graph `G_s` (typically Delaunay triangulation) +- Outputs unwrapped interferograms + +### Key Files + +- `mcf/_ortools.py` - Core MCF solver using OR-Tools (`ORMCFSolver`) +- `mcf/utils.py` - `phase_diff()`, `flood_fill()`, cost functions +- `links/_grid_search.py` - `GridSearchLinearModel` for velocity/DEM error estimation +- `links/_common.py` - Temporal coherence objective function +- `workflows/emcf/_solver.py` - `EMCFSolver` main algorithm class +- `graph/_delaunay.py` - Delaunay triangulation and regular 2D grid graphs +- `graph/_hop3.py` - Hop-3 temporal graph for narrow-baseline time-series + +### Data Structures + +**Edges/Links**: `(nedges, 2)` integer arrays indexing into points, always ordered `links[i, 0] < links[i, 1]` + +**Dual Graph**: Maps primal edges to cycles they participate in: +- `dual_edges[i]` = `[cycle1_idx, cycle2_idx]` (1-indexed, 0 = boundary) +- `dual_edge_dir[i]` = `[+1/-1, +1/-1]` orientation within each cycle + +**Design Matrix for Link Models**: `amat` of shape `(nifgs, ndim)` where: +- Column 0: temporal sensitivity (rad per velocity unit, e.g., mm/yr) +- Column 1: baseline sensitivity (rad per DEM error unit, e.g., meters) + +### DEM Error / Velocity Estimation + +The `links/` module provides per-link model estimation via `GridSearchLinearModel`: +1. Build design matrix `amat` from temporal baselines and perpendicular baselines +2. Define search ranges as `slice(start, stop, step)` for each parameter +3. For each edge, grid search + Nelder-Mead finds parameters maximizing temporal coherence +4. Model prediction: `model_phase = amat @ [velocity, dem_error]` +5. Use `phase_diff(data0, data1, model=model_phase)` to "flatten" edges before MCF + +Integration point: `EMCFSolver.unwrap_gradients_in_time()` in `_solver.py` (lines 268-299) diff --git a/docs/background-theory.md b/docs/background-theory.md index aab83c4..d3b1a3a 100644 --- a/docs/background-theory.md +++ b/docs/background-theory.md @@ -2,3 +2,4 @@ 1. [Notation](./theory/notation.md) 1. [Residue computation](./theory/residues.md) +1. [DEM error and velocity estimation](./theory/dem-error-velocity.md) diff --git a/docs/notebooks/dem_error_estimation.ipynb b/docs/notebooks/dem_error_estimation.ipynb new file mode 100644 index 0000000..de7bf53 --- /dev/null +++ b/docs/notebooks/dem_error_estimation.ipynb @@ -0,0 +1,360 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# DEM Error Estimation with Spurt\n", + "\n", + "This notebook demonstrates how spurt estimates and applies DEM error corrections during 3D phase unwrapping. We use pre-run results from a **Capella X-band** dataset over **Mexico City** (11 dates, 30 days, perpendicular baselines from -789 to +635 m) to visualize the estimation outputs.\n", + "\n", + "The notebook does **not** re-run the full EMCF pipeline. It shows how to:\n", + "\n", + "1. Build and inspect the design matrix from baseline metadata\n", + "2. Visualize the interferogram network in time-baseline space\n", + "3. Interpret the estimated DEM error and velocity maps\n", + "4. Compare unwrapped phase with and without model guidance\n", + "\n", + "For the theory behind these steps, see the [DEM Error and Velocity Estimation](../theory/dem-error-velocity.md) page.\n", + "\n", + "**Reproducing these results:**\n", + "```bash\n", + "# The results were produced by running dolphin + spurt with DEM error estimation enabled.\n", + "# See the dolphin_config.yaml in the data directory for the exact configuration.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "imports", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import rasterio\n", + "\n", + "from spurt.links import build_design_matrix\n", + "\n", + "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", + "\n", + "# Path to pre-run results (symlinked into docs/notebooks/data/)\n", + "DATA_DIR = Path(\"data/mexico_city\")\n", + "assert DATA_DIR.exists(), f\"Data directory not found: {DATA_DIR}. Create a symlink.\"" + ] + }, + { + "cell_type": "markdown", + "id": "design-matrix-header", + "metadata": {}, + "source": [ + "## Load baselines and build the design matrix\n", + "\n", + "The design matrix $\\mathbf{A}$ relates model parameters (velocity, DEM error) to interferometric phase. Each row corresponds to one interferogram; column 0 encodes velocity sensitivity and column 1 encodes DEM error sensitivity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load-baselines", + "metadata": {}, + "outputs": [], + "source": [ + "# Load per-SLC baselines\n", + "slc_df = pd.read_csv(DATA_DIR / \"baselines_per_slc.csv\", parse_dates=[\"date\"])\n", + "dates = slc_df[\"date\"].values.astype(\"datetime64[D]\")\n", + "bperp_m = slc_df[\"bperp_m\"].values\n", + "\n", + "print(f\"Number of SLCs: {len(dates)}\")\n", + "print(f\"Date range: {dates[0]} to {dates[-1]}\")\n", + "print(f\"Bperp range: {bperp_m.min():.0f} to {bperp_m.max():.0f} m\")\n", + "slc_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load-ifg-baselines", + "metadata": {}, + "outputs": [], + "source": "# Load interferogram baselines to get the network edges\nifg_df = pd.read_csv(DATA_DIR / \"baselines.csv\")\n\n# Extract SLC index pairs from the filenames\n# Map date strings to SLC indices\ndate_to_idx = {str(d): i for i, d in enumerate(dates)}\n\nref_dates = pd.to_datetime(ifg_df[\"reference_time_utc\"]).dt.strftime(\"%Y-%m-%d\")\nsec_dates = pd.to_datetime(ifg_df[\"secondary_time_utc\"]).dt.strftime(\"%Y-%m-%d\")\nifg_edges = np.array(\n [[date_to_idx[r], date_to_idx[s]] for r, s in zip(ref_dates, sec_dates)]\n)\nprint(f\"Number of interferograms: {len(ifg_edges)}\")\nbperp_min = ifg_df[\"bperp_m\"].min()\nbperp_max = ifg_df[\"bperp_m\"].max()\nprint(f\"Bperp range in ifgs: {bperp_min:.0f} to {bperp_max:.0f} m\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "build-design-matrix", + "metadata": {}, + "outputs": [], + "source": "# Capella X-band parameters\nWAVELENGTH_M = 0.031 # 3.1 cm X-band\nSLANT_RANGE_M = 550_000\nLOOK_ANGLE_RAD = 0.61 # ~35 degrees\n\namat = build_design_matrix(\n ifg_edges=ifg_edges,\n dates=dates,\n bperp_m=bperp_m,\n wavelength_m=WAVELENGTH_M,\n slant_range_m=SLANT_RANGE_M,\n look_angle_rad=LOOK_ANGLE_RAD,\n)\n\nprint(f\"Design matrix shape: {amat.shape}\")\nvel_min, vel_max = amat[:, 0].min(), amat[:, 0].max()\ndem_min, dem_max = amat[:, 1].min(), amat[:, 1].max()\nprint(f\"Velocity sensitivity range: {vel_min:.4f} to {vel_max:.4f} rad/(mm/yr)\")\nprint(f\"DEM error sensitivity range: {dem_min:.4f} to {dem_max:.4f} rad/m\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-sensitivities", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n", + "\n", + "# Temporal baselines in days\n", + "delta_t = np.array(\n", + " [(dates[s] - dates[r]).astype(float) for r, s in ifg_edges]\n", + ")\n", + "delta_bperp = np.array(\n", + " [bperp_m[s] - bperp_m[r] for r, s in ifg_edges]\n", + ")\n", + "\n", + "ax = axes[0]\n", + "ax.scatter(delta_t, amat[:, 0], c=\"steelblue\", s=30, edgecolor=\"k\", linewidth=0.5)\n", + "ax.set_xlabel(\"Temporal baseline (days)\")\n", + "ax.set_ylabel(\"Sensitivity (rad per mm/yr)\")\n", + "ax.set_title(\"Velocity sensitivity\")\n", + "ax.axhline(0, color=\"gray\", linewidth=0.5)\n", + "\n", + "ax = axes[1]\n", + "ax.scatter(delta_bperp, amat[:, 1], c=\"firebrick\", s=30, edgecolor=\"k\", linewidth=0.5)\n", + "ax.set_xlabel(\"Perpendicular baseline difference (m)\")\n", + "ax.set_ylabel(\"Sensitivity (rad per m)\")\n", + "ax.set_title(\"DEM error sensitivity\")\n", + "ax.axhline(0, color=\"gray\", linewidth=0.5)\n", + "\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "network-header", + "metadata": {}, + "source": [ + "## Interferogram network\n", + "\n", + "The interferogram network is typically a Hop-3 graph in time-Bperp space: each SLC is connected to its 3 nearest temporal neighbors. The wide spread of perpendicular baselines in this dataset gives strong DEM error sensitivity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-network", + "metadata": {}, + "outputs": [], + "source": "fig, ax = plt.subplots(figsize=(8, 5))\n\n# Plot edges\nfor r, s in ifg_edges:\n ax.plot(\n [dates[r], dates[s]],\n [bperp_m[r], bperp_m[s]],\n \"k-\", linewidth=0.8, alpha=0.5,\n )\n\n# Plot SLC dates as points\nax.scatter(dates, bperp_m, c=\"steelblue\", s=60, zorder=5, edgecolor=\"k\", linewidth=0.5)\nfor _i, (d, b) in enumerate(zip(dates, bperp_m)):\n ax.annotate(\n f\"{b:.0f} m\", (d, b),\n textcoords=\"offset points\", xytext=(5, 5),\n fontsize=7, color=\"gray\",\n )\n\nax.set_xlabel(\"Date\")\nax.set_ylabel(\"Perpendicular baseline (m)\")\nax.set_title(\"Interferogram network in time-Bperp space\")\nfig.autofmt_xdate()\nfig.tight_layout()\nplt.show()" + }, + { + "cell_type": "markdown", + "id": "results-header", + "metadata": {}, + "source": [ + "## Load pre-run results\n", + "\n", + "The pixelwise DEM error and velocity maps were computed by spurt's `write_link_params()` function, which integrates per-edge gradients to per-pixel values via weighted least-squares." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load-results", + "metadata": {}, + "outputs": [], + "source": [ + "# Load pre-computed results\n", + "results = np.load(DATA_DIR / \"pixelwise_dem_error.npz\")\n", + "dem_error = results[\"dem_error_m\"]\n", + "velocity = results[\"velocity_mm_yr\"]\n", + "\n", + "print(f\"DEM error shape: {dem_error.shape}\")\n", + "print(f\"Velocity shape: {velocity.shape}\")\n", + "\n", + "# Load temporal coherence\n", + "tcoh_path = DATA_DIR / \"interferograms\" / \"temporal_coherence_20240626_20240726.tif\"\n", + "with rasterio.open(tcoh_path) as src:\n", + " tcoh = src.read(1)\n", + "\n", + "print(f\"Temporal coherence shape: {tcoh.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "dem-error-map-header", + "metadata": {}, + "source": [ + "## DEM error map\n", + "\n", + "The estimated DEM error shows spatial structure correlated with topography. Positive values indicate the DEM underestimates true elevation; negative values indicate overestimation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-dem-error", + "metadata": {}, + "outputs": [], + "source": [ + "# Mask NaN values for statistics\n", + "valid = np.isfinite(dem_error)\n", + "vmax = np.nanpercentile(np.abs(dem_error[valid]), 95)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 7))\n", + "im = ax.imshow(dem_error, cmap=\"RdBu_r\", vmin=-vmax, vmax=vmax, interpolation=\"nearest\")\n", + "cbar = fig.colorbar(im, ax=ax, shrink=0.8, label=\"DEM error (m)\")\n", + "ax.set_title(\"Estimated DEM error\")\n", + "ax.set_xlabel(\"Column\")\n", + "ax.set_ylabel(\"Row\")\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "histogram-header", + "metadata": {}, + "source": [ + "## DEM error distribution\n", + "\n", + "The DEM error distribution is expected to be roughly centered near zero with standard deviation depending on the DEM quality and scene topography." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-histogram", + "metadata": {}, + "outputs": [], + "source": "from scipy.stats import norm\n\nvalid_dem = dem_error[valid]\nmu, sigma = np.nanmean(valid_dem), np.nanstd(valid_dem)\np5, p50, p95 = np.nanpercentile(valid_dem, [5, 50, 95])\n\nfig, ax = plt.subplots(figsize=(7, 4))\ncounts, bins, _ = ax.hist(\n valid_dem, bins=200, density=True, alpha=0.7, color=\"steelblue\", edgecolor=\"none\",\n range=(-vmax, vmax),\n)\n\n# Gaussian fit overlay\nx_fit = np.linspace(-vmax, vmax, 500)\nax.plot(x_fit, norm.pdf(x_fit, mu, sigma), \"r-\", linewidth=1.5, label=\"Gaussian fit\")\n\nax.set_xlabel(\"DEM error (m)\")\nax.set_ylabel(\"Density\")\nax.set_title(\"DEM error distribution\")\nstats_text = (\n f\"Mean: {mu:.2f} m\\n\"\n f\"Std: {sigma:.2f} m\\n\"\n f\"Median: {p50:.2f} m\\n\"\n f\"5th/95th: {p5:.1f} / {p95:.1f} m\"\n)\nax.text(\n 0.97, 0.95, stats_text, transform=ax.transAxes,\n fontsize=9, verticalalignment=\"top\", horizontalalignment=\"right\",\n bbox={\"boxstyle\": \"round,pad=0.3\", \"facecolor\": \"wheat\", \"alpha\": 0.8},\n)\nax.legend(loc=\"upper left\")\nfig.tight_layout()\nplt.show()" + }, + { + "cell_type": "markdown", + "id": "coherence-header", + "metadata": {}, + "source": [ + "## DEM error vs temporal coherence\n", + "\n", + "High-coherence pixels have well-constrained DEM error estimates (small spread), while low-coherence pixels show a wide range of estimated values. This 2D histogram illustrates the relationship." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-dem-vs-coh", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the same spatial extent for both arrays\n", + "valid_both = valid & np.isfinite(tcoh) & (tcoh > 0)\n", + "\n", + "fig, ax = plt.subplots(figsize=(7, 5))\n", + "h = ax.hist2d(\n", + " tcoh[valid_both].ravel(),\n", + " dem_error[valid_both].ravel(),\n", + " bins=[100, 200],\n", + " range=[[0, 1], [-vmax, vmax]],\n", + " cmap=\"inferno\",\n", + " cmin=1,\n", + ")\n", + "fig.colorbar(h[3], ax=ax, label=\"Count\")\n", + "ax.set_xlabel(\"Temporal coherence\")\n", + "ax.set_ylabel(\"DEM error (m)\")\n", + "ax.set_title(\"DEM error vs temporal coherence\")\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "comparison-header", + "metadata": {}, + "source": [ + "## Unwrapped phase: with vs without model\n", + "\n", + "For a large-Bperp interferogram (2024-06-26 to 2024-07-02, $B_\\perp = -789$ m), the model captures the long-wavelength DEM-correlated signal. Spurt writes three files per interferogram:\n", + "\n", + "- `*.unw.tif` -- full unwrapped phase\n", + "- `*.unw_model.tif` -- model component (velocity + DEM error prediction)\n", + "- `*.unw_diff.tif` -- max difference between tiles (overlap consistency diagnostic)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "plot-comparison", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the large-Bperp pair\n", + "pair = \"20240626_20240702\"\n", + "output_dir = DATA_DIR / \"spurt_output\"\n", + "\n", + "with rasterio.open(output_dir / f\"{pair}.unw.tif\") as src:\n", + " unw = src.read(1)\n", + "with rasterio.open(output_dir / f\"{pair}.unw_model.tif\") as src:\n", + " unw_model = src.read(1)\n", + "\n", + "# Residual = total unwrapped - model\n", + "residual = unw - unw_model\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(14, 5))\n", + "\n", + "for ax, data, title, cmap in zip(\n", + " axes,\n", + " [unw, unw_model, residual],\n", + " [\"Unwrapped phase\", \"Model component\", \"Residual (unw - model)\"],\n", + " [\"RdBu_r\", \"RdBu_r\", \"RdBu_r\"],\n", + "):\n", + " vabs = np.nanpercentile(np.abs(data[np.isfinite(data)]), 95)\n", + " im = ax.imshow(data, cmap=cmap, vmin=-vabs, vmax=vabs, interpolation=\"nearest\")\n", + " fig.colorbar(im, ax=ax, shrink=0.7, label=\"rad\")\n", + " ax.set_title(title)\n", + " ax.set_xlabel(\"Column\")\n", + " ax.set_ylabel(\"Row\")\n", + "\n", + "fig.suptitle(f\"Interferogram {pair} (Bperp = -789 m)\", fontsize=12, y=1.02)\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Key takeaways from this tutorial:\n", + "\n", + "- The design matrix translates physical baselines (time, perpendicular baseline) into phase sensitivities for velocity and DEM error.\n", + "- DEM error estimation is performed per-edge on spatial gradients, naturally canceling common atmospheric signals.\n", + "- The estimated model is used to guide phase wrapping disambiguation, reducing the number of residues MCF must resolve.\n", + "- Temporal coherence serves as both the optimization objective and a quality metric for the final estimates.\n", + "\n", + "For the mathematical derivation behind each step, see the [DEM Error and Velocity Estimation](../theory/dem-error-velocity.md) theory page.\n", + "\n", + "For API details, see the [spurt.links](../reference/spurt/links/) reference documentation." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbformat_minor": 5, + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/references.bib b/docs/references.bib index 20f4154..0abc768 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -1,3 +1,15 @@ +@article{Ferretti2001PermanentScatterers, + title = {Permanent Scatterers in {{SAR}} Interferometry}, + author = {Ferretti, Alessandro and Prati, Claudio and Rocca, Fabio}, + year = {2001}, + journal = {IEEE Transactions on Geoscience and Remote Sensing}, + volume = {39}, + number = {1}, + pages = {8--20}, + publisher = {IEEE}, + doi = {10.1109/36.898661}, +} + @article{Ferretti2011NewAlgorithmProcessing, title = {A New Algorithm for Processing Interferometric Data-Stacks: {{SqueeSAR}}}, author = {Ferretti, Alessandro and Fumagalli, Alfio and Novali, Fabrizio and Prati, Claudio and Rocca, Fabio and Rucci, Alessio}, diff --git a/docs/theory/dem-error-velocity.md b/docs/theory/dem-error-velocity.md new file mode 100644 index 0000000..b0c2073 --- /dev/null +++ b/docs/theory/dem-error-velocity.md @@ -0,0 +1,180 @@ +# DEM Error and Velocity Estimation + +Spurt estimates per-pixel velocity and DEM error as part of 3D phase unwrapping. This page derives the underlying model and explains how each component maps to code in the `spurt.links` module. + +--- + +## The linear phase model + +An interferometric phase observation between two SAR acquisitions can be decomposed as: + +$$ +\phi = \underbrace{\frac{4\pi}{\lambda} \, v \, \Delta t}_{\text{displacement}} \;+\; \underbrace{\frac{4\pi}{\lambda} \, \frac{B_\perp}{R \sin\theta} \, \varepsilon_{\text{DEM}}}_{\text{DEM error}} \;+\; \phi_{\text{atmo}} + \phi_{\text{noise}} +$$ + +where: + +| Symbol | Meaning | Units | +|--------|---------|-------| +| $\lambda$ | Radar wavelength | m | +| $v$ | Line-of-sight velocity | mm/yr | +| $\Delta t$ | Temporal baseline | days | +| $B_\perp$ | Perpendicular baseline (secondary minus reference) | m | +| $R$ | Slant range distance | m | +| $\theta$ | Look angle | rad | +| $\varepsilon_{\text{DEM}}$ | DEM error | m | + +We collect the two deterministic terms into a **design matrix** $\mathbf{A}$ of shape $(N_{\text{ifg}}, 2)$ and a parameter vector $\mathbf{x} = [v, \, \varepsilon_{\text{DEM}}]^T$: + +$$ +\boldsymbol{\phi} = \mathbf{A} \, \mathbf{x} + \boldsymbol{\phi}_{\text{atmo}} + \boldsymbol{\phi}_{\text{noise}} +$$ + +Each row of $\mathbf{A}$ corresponds to one interferogram: + +$$ +A_{k,0} = \frac{4\pi}{\lambda} \cdot \frac{\Delta t_k}{365.25 \cdot 1000} +\qquad +A_{k,1} = \frac{4\pi}{\lambda} \cdot \frac{\Delta B_{\perp,k}}{R \sin\theta} +$$ + +Column 0 gives **radians per mm/yr of velocity**; column 1 gives **radians per meter of DEM error** \[cite:Ferretti2001PermanentScatterers\]. + +--- + +## Why per-edge estimation + +Rather than fitting the model to individual pixel phases, spurt fits to **spatial gradients** (phase differences between neighboring pixels on the Delaunay graph). For an edge connecting pixels $i$ and $j$: + +$$ +\Delta\phi_k^{(ij)} = \phi_k^{(j)} - \phi_k^{(i)} = \mathbf{A}_k \, (\mathbf{x}^{(j)} - \mathbf{x}^{(i)}) + \Delta\phi_{\text{atmo}} + \Delta\phi_{\text{noise}} +$$ + +On short baselines, the atmospheric contribution $\Delta\phi_{\text{atmo}}$ nearly cancels. This is the same data structure that MCF operates on, so per-edge estimation slots directly into the EMCF pipeline without any additional spatial processing. + +--- + +## The design matrix in code + +`spurt.links.build_design_matrix()` in `src/spurt/links/_design_matrix.py` constructs $\mathbf{A}$: + +```python +from spurt.links import build_design_matrix + +amat = build_design_matrix( + ifg_edges=ifg_edges, # (nifgs, 2) array of SLC index pairs + dates=dates, # datetime64[D] per SLC + bperp_m=bperp_m, # perpendicular baseline per SLC (meters) + wavelength_m=0.031, # Capella X-band: 3.1 cm + slant_range_m=550_000, + look_angle_rad=0.61, +) +# amat.shape == (nifgs, 2) +# amat[:, 0] -> rad per mm/yr (velocity sensitivity) +# amat[:, 1] -> rad per meter (DEM error sensitivity) +``` + +The function iterates over interferogram pairs, computing temporal baseline $\Delta t$ and perpendicular baseline difference $\Delta B_\perp$ for each. The common factor $4\pi/\lambda$ is applied once, and the geometric DEM scaling $1/(R \sin\theta)$ is computed once for the scene. + +--- + +## Temporal coherence as objective + +The model parameters are estimated by maximizing **temporal coherence** \[cite:Ferretti2001PermanentScatterers\], defined as: + +$$ +\gamma = \frac{\left|\sum_{k=1}^{N} w_k \, \exp\!\bigl(j\,(\mathbf{A}_k \mathbf{x} - \phi_k)\bigr)\right|}{\sum_{k=1}^{N} w_k} +$$ + +where $w_k$ are per-interferogram weights and $\phi_k$ are the observed (wrapped) phase gradients. When the model perfectly explains the data, all residual phasors align and $\gamma = 1$. When the model is wrong, they scatter and $\gamma \to 0$. + +For optimization with `scipy.minimize`, spurt implements the **negative** temporal coherence in `src/spurt/links/_common.py`: + +```python +def neg_temporal_coherence(x, amat, b, wts): + res = amat.dot(x) - b + return -np.abs(np.sum(wts * np.exp(1j * res))) +``` + +A variant with analytic Jacobian (`neg_temporal_coherence_with_jacobian`) is also provided for gradient-based refinement. + +--- + +## Grid search + Nelder-Mead + +The temporal coherence surface is **non-convex** because the observed phases are wrapped. A local optimizer starting from a poor initial guess will find the wrong basin. Spurt uses a two-stage strategy via `GridSearchLinearModel` in `src/spurt/links/_grid_search.py`: + +1. **Brute-force grid search** (`scipy.optimize.brute`): evaluate $\gamma$ on a coarse grid over user-defined parameter ranges. The ranges are specified as Python `slice` objects, e.g. `slice(-50, 50, 0.5)` for DEM error in meters. This finds the global basin. + +2. **Nelder-Mead refinement**: starting from the grid optimum, `scipy.optimize.fmin` refines to sub-grid precision. The result is clipped to the search bounds to prevent aliasing. + +This is conceptually different from \[cite:Pepe2006ExtensionMinimumCost\], which solves a full MCF problem for each parameter hypothesis. The per-link approach is simpler, parallelizes trivially, and avoids the need to define a spatial cost function over parameter space. + +The solver processes links in batches via `estimate_model_many()`, distributing work across CPU cores with Python multiprocessing. + +--- + +## Model-guided wrapping + +After estimating per-edge parameters, the model prediction is used to **guide phase wrapping disambiguation**. This is the key integration point between the `links` module and the EMCF solver. + +The function `phase_diff()` in `src/spurt/mcf/utils.py` computes: + +$$ +d = \phi_1 - \phi_0 - m +$$ + +$$ +\text{phase\_diff} = m + d - 2\pi \left\lfloor \frac{d}{2\pi} \right\rceil +$$ + +where $m$ is the model prediction. Without a model ($m = 0$), this is standard wrapped phase differencing: the result lies in $(-\pi, \pi]$. With a model, the result lies in $(m - \pi, \, m + \pi]$. The gradient is wrapped **around the model prediction** rather than around zero. + +In code: + +```python +def phase_diff(z0, z1, model=0.0): + d = z1 - z0 - model + return model + d - np.round(d / (2 * np.pi)) * 2 * np.pi +``` + +This "flattening" means MCF only needs to resolve **residual** ambiguities---the small integer cycles left after accounting for velocity and DEM error. For interferograms with large perpendicular baselines, this can dramatically reduce the number of residues MCF must handle. + +The integration point is `EMCFSolver.unwrap_gradients_in_time()` in `src/spurt/workflows/emcf/_solver.py` (lines 271--308). After estimating model parameters for a batch of links, it recomputes the spatial gradients with the model prediction: + +```python +# Recompute gradients using model to guide wrapping +grad_space[:, i_start:i_end] = utils.phase_diff( + wrap_data[:, inds[:, 0]], + wrap_data[:, inds[:, 1]], + model=model_pred, +) +``` + +--- + +## From edges to pixels + +The per-edge estimated parameters (velocity gradient, DEM error gradient) are integrated to per-point values via **weighted least-squares** (WLS). Given the incidence matrix $\mathbf{D}$ of the spatial graph where each row represents an edge $i \to j$: + +$$ +\mathbf{D} \, \mathbf{v} \approx \mathbf{g} +$$ + +where $\mathbf{g}$ is the vector of per-edge gradients. This is solved by `scipy.sparse.linalg.lsqr` with weights $\sqrt{\gamma}$ (square root of temporal coherence), so that well-fitting edges contribute more to the solution. + +The implementation in `src/spurt/workflows/emcf/_output.py` (function `_integrate_tile_link_params`, lines 323--395): + +1. Builds a weighted incidence matrix: multiply each row by $\sqrt{\gamma_{\text{edge}}}$ +2. Solves via LSQR for each parameter dimension (velocity, DEM error) +3. Centers the result by removing the median (the integration constant is arbitrary) +4. Computes mean coherence per point for quality assessment + +The result is a per-pixel map of velocity (mm/yr) and DEM error (m), plus a coherence map indicating estimation reliability. + +--- + +## References + +- \[cite:Ferretti2001PermanentScatterers\] +- \[cite:Pepe2006ExtensionMinimumCost\] diff --git a/docs/tutorials.md b/docs/tutorials.md index ceb04ec..5949be7 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -3,6 +3,7 @@ 1. [Planar graphs](./tutorials/planar-graph.md) 2. [2D Minimum Cost Flow](./tutorials/mcf-2d.md) 3. [3D Extended MCF](./tutorials/emcf-3d.md) +4. [DEM error estimation (notebook)](./notebooks/dem_error_estimation.ipynb) diff --git a/mkdocs.yml b/mkdocs.yml index b6c64e4..b8d6385 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,12 +65,12 @@ nav: - index.md - getting-started.md - tutorials.md -# - Tutorials: -# - Notebook page: notebooks/walkthrough-basic.ipynb -# # - Notebook page2: notebooks/walkthrough-basic.html -# - how-to-guides.md +- Notebooks: + - DEM Error Estimation: notebooks/dem_error_estimation.ipynb +- Background Theory: + - background-theory.md + - theory/dem-error-velocity.md # https://mkdocstrings.github.io/recipes/#generate-a-literate-navigation-file # trailing slash: that mkdocs-literate-nav knows a summary.md file is in that folder. - developer-setup.md - Code Reference: reference/ -# - background-theory.md diff --git a/pyproject.toml b/pyproject.toml index ce0a0e2..03e69e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ Issues = "https://github.com/isce-framework/spurt/issues" # Entry points for the command line interface [project.scripts] -spurt-emcf = "spurt.workflows.emcf:__main__" +spurt-emcf = "spurt.workflows.emcf._cli:main" [tool.black] preview = true @@ -120,6 +120,7 @@ ignore = [ "D105", # Missing docstring in magic method "PLR", # Pylint Refactor "PLC0415", # `import` should be at the top-level of a file + "PTH123", # Path.open ] [tool.ruff.lint.per-file-ignores] diff --git a/src/spurt/io/__init__.py b/src/spurt/io/__init__.py index b1077fe..943c496 100644 --- a/src/spurt/io/__init__.py +++ b/src/spurt/io/__init__.py @@ -1,5 +1,6 @@ from typing import Any +from ._baseline import BaselineData, load_baseline_csv from ._interface import ( InputInterface, InputStackInterface, @@ -9,12 +10,14 @@ from ._three_d import Irreg3DInput, Reg3DInput __all__ = [ + "BaselineData", "InputInterface", "InputStackInterface", "Irreg3DInput", "OutputInterface", "OutputStackInterface", "Reg3DInput", + "load_baseline_csv", ] diff --git a/src/spurt/io/_baseline.py b/src/spurt/io/_baseline.py new file mode 100644 index 0000000..dfd64b9 --- /dev/null +++ b/src/spurt/io/_baseline.py @@ -0,0 +1,250 @@ +"""Load perpendicular baseline data from CSV files.""" + +from __future__ import annotations + +import csv +import re +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +__all__ = [ + "BaselineData", + "load_baseline_csv", +] + + +@dataclass +class BaselineData: + """Container for per-SLC baseline information. + + Parameters + ---------- + dates : np.ndarray + Array of dates as datetime64[D], shape (n_slc,). + bperp_m : np.ndarray + Perpendicular baseline in meters for each SLC, shape (n_slc,). + Values are relative to first SLC (reference). + """ + + dates: np.ndarray + bperp_m: np.ndarray + + def __post_init__(self): + if self.dates.shape != self.bperp_m.shape: + errmsg = ( + f"Shape mismatch: dates {self.dates.shape}" + f" vs bperp_m {self.bperp_m.shape}" + ) + raise ValueError(errmsg) + if self.dates.ndim != 1: + errmsg = f"Expected 1D arrays, got dates.ndim={self.dates.ndim}" + raise ValueError(errmsg) + + +def load_baseline_csv(filepath: str | Path) -> BaselineData: + """Load per-SLC baselines from CSV file. + + Supports two CSV formats: + + 1. Per-SLC format with columns: date,bperp_m + Each row represents one SLC acquisition. + + 2. Per-IFG format with columns: reference,secondary,...,bperp_m + Each row represents one interferogram. This format is automatically + detected and converted to per-SLC baselines via least-squares, + with the first SLC as reference (bperp=0). + + Parameters + ---------- + filepath : str | Path + Path to CSV file. + + Returns + ------- + BaselineData + Container with dates and perpendicular baselines per SLC. + """ + filepath = Path(filepath) + + with open(filepath) as f: + reader = csv.reader(f) + header = next(reader) + + columns = [c.strip().lower() for c in header] + + if "date" in columns and "bperp_m" in columns: + return _load_per_slc_csv(filepath, columns) + if "reference" in columns and "secondary" in columns and "bperp_m" in columns: + return _load_per_ifg_csv(filepath, columns) + errmsg = ( + f"Unrecognized CSV format in {filepath}. " + "Expected either 'date,bperp_m' (per-SLC) or " + "'reference,secondary,...,bperp_m' (per-IFG) columns." + ) + raise ValueError(errmsg) + + +def _load_per_slc_csv(filepath: Path, columns: list[str]) -> BaselineData: + """Load per-SLC format CSV.""" + date_col = columns.index("date") + bperp_col = columns.index("bperp_m") + + dates_list = [] + bperp_list = [] + + with open(filepath) as f: + reader = csv.reader(f) + next(reader) # skip header + for row in reader: + if not row or not row[0].strip(): + continue + dates_list.append(_parse_date(row[date_col].strip())) + bperp_list.append(float(row[bperp_col].strip())) + + dates = np.array(dates_list, dtype="datetime64[D]") + bperp_m = np.array(bperp_list, dtype=np.float64) + + # Sort by date + sort_idx = np.argsort(dates) + dates = dates[sort_idx] + bperp_m = bperp_m[sort_idx] + + return BaselineData(dates=dates, bperp_m=bperp_m) + + +# Regex to pull the first YYYYMMDD from a string (works on filenames, paths, etc.) +_DATE8_RE = re.compile(r"(\d{8})") + + +def _parse_date(date_str: str) -> str: + """Parse date string to ISO format (YYYY-MM-DD). + + Supports formats: + - YYYYMMDD + - YYYY-MM-DD + - ISO 8601 timestamps (e.g. 2026-01-31T04:30:30.757Z) + - File paths containing a YYYYMMDD substring + (e.g. slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif) + """ + date_str = date_str.strip() + + # Already ISO date + if len(date_str) == 10 and date_str[4] == "-" and date_str[7] == "-": + return date_str + + # Compact YYYYMMDD + if len(date_str) == 8 and date_str.isdigit(): + return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" + + # ISO 8601 timestamp — take the date part + # Handles both 'T' separator (2024-07-20T05:26:48Z) and space separator + # (2024-07-20 05:26:48.445709+00:00, Python's default datetime.__str__()). + # Require dashes so we don't accidentally match compact formats like + # "20200101T120000" (those fall through to the regex fallback). + if ( + len(date_str) > 10 + and date_str[:4].isdigit() + and date_str[4] == "-" + and date_str[7] == "-" + and ("T" in date_str or " " in date_str) + ): + return date_str[:10] + + # Fallback: extract first 8-digit sequence from the string (e.g. filename) + m = _DATE8_RE.search(date_str) + if m: + d = m.group(1) + return f"{d[:4]}-{d[4:6]}-{d[6:8]}" + + errmsg = f"Cannot parse date from: {date_str!r}" + raise ValueError(errmsg) + + +def _load_per_ifg_csv(filepath: Path, columns: list[str]) -> BaselineData: + """Load per-IFG format CSV and convert to per-SLC baselines. + + The interferometric baseline relationship is: + bperp_ifg[i,j] = bperp[j] - bperp[i] + + If all interferograms share a single reference, the bperp values are + already per-SLC baselines and are used directly. Otherwise, per-SLC + baselines are recovered via least-squares with the first SLC as + reference (bperp=0). + + If 'reference_time_utc' and 'secondary_time_utc' columns are present, + dates are taken from those (more reliable). Otherwise dates are extracted + from the 'reference' and 'secondary' columns (which may be file paths). + """ + ref_col = columns.index("reference") + sec_col = columns.index("secondary") + bperp_col = columns.index("bperp_m") + + # Prefer the explicit UTC time columns when available + has_time_cols = "reference_time_utc" in columns and "secondary_time_utc" in columns + if has_time_cols: + ref_time_col = columns.index("reference_time_utc") + sec_time_col = columns.index("secondary_time_utc") + + ref_dates = [] + sec_dates = [] + bperp_ifg = [] + + with open(filepath) as f: + reader = csv.reader(f) + next(reader) # skip header + for row in reader: + if not row or not row[0].strip(): + continue + if has_time_cols: + ref_str = row[ref_time_col].strip() + sec_str = row[sec_time_col].strip() + else: + ref_str = row[ref_col].strip() + sec_str = row[sec_col].strip() + ref_dates.append(_parse_date(ref_str)) + sec_dates.append(_parse_date(sec_str)) + bperp_ifg.append(float(row[bperp_col].strip())) + + # Get unique dates sorted + all_dates = sorted(set(ref_dates) | set(sec_dates)) + date_to_idx = {d: i for i, d in enumerate(all_dates)} + n_slc = len(all_dates) + n_ifg = len(bperp_ifg) + + # Single-reference shortcut: if all IFGs share the same reference, + # bperp values are already per-SLC baselines relative to that reference. + unique_refs = set(ref_dates) + if len(unique_refs) == 1: + bperp_m = np.zeros(n_slc, dtype=np.float64) + for sec, bp in zip(sec_dates, bperp_ifg): + bperp_m[date_to_idx[sec]] = bp + dates = np.array(all_dates, dtype="datetime64[D]") + return BaselineData(dates=dates, bperp_m=bperp_m) + + # Build design matrix: A[ifg, :] has -1 at reference, +1 at secondary + # bperp_ifg = A @ bperp_slc + # First column is reference (bperp=0), so we solve for columns 1: + amat = np.zeros((n_ifg, n_slc - 1), dtype=np.float64) + bvec = np.array(bperp_ifg, dtype=np.float64) + + for i, (ref, sec) in enumerate(zip(ref_dates, sec_dates)): + ref_idx = date_to_idx[ref] + sec_idx = date_to_idx[sec] + # bperp_ifg[i] = bperp[sec] - bperp[ref] + if ref_idx > 0: + amat[i, ref_idx - 1] = -1.0 + if sec_idx > 0: + amat[i, sec_idx - 1] = 1.0 + + # Solve least-squares + result, *_ = np.linalg.lstsq(amat, bvec, rcond=None) + + # Prepend zero for reference SLC + bperp_m = np.zeros(n_slc, dtype=np.float64) + bperp_m[1:] = result + + dates = np.array(all_dates, dtype="datetime64[D]") + + return BaselineData(dates=dates, bperp_m=bperp_m) diff --git a/src/spurt/io/_slc_stack.py b/src/spurt/io/_slc_stack.py index a2f6465..2c00c89 100644 --- a/src/spurt/io/_slc_stack.py +++ b/src/spurt/io/_slc_stack.py @@ -4,6 +4,7 @@ import os from collections.abc import Mapping +from datetime import datetime from pathlib import Path import numpy as np @@ -15,6 +16,30 @@ "SLCStackReader", ] +DEFAULT_DATE_FMT = "%Y%m%d" + + +def _date_str_length(fmt: str) -> int: + """Return the length of a date string formatted with `fmt`. + + Used to slice the date portion out of an SLC filename when ``fmt`` + includes a time-of-day component (e.g. ``"%Y%m%d%H%M%S"``). + """ + return len(datetime(2020, 1, 1, 12, 30, 45).strftime(fmt)) + + +def _extract_date_str(token: str, date_fmt: str, date_len: int) -> str: + """Slice a date prefix from `token` and validate it against `date_fmt`. + + Raises ``ValueError`` (from :func:`datetime.datetime.strptime`) if the + sliced substring does not match the format. This catches a mismatched + ``--date-fmt`` at scan time rather than letting a garbage prefix flow + through into output filenames. + """ + s = token[:date_len] + datetime.strptime(s, date_fmt) + return s + class SLCStackReader: """Read a SLC stack from single files. @@ -65,6 +90,7 @@ def from_phase_linked_directory( cls, folder: str | os.PathLike[str], temp_coh_threshold: float = 0.6, + date_fmt: str = DEFAULT_DATE_FMT, ) -> SLCStackReader: """Initialize stack by scanning a folder. @@ -72,6 +98,19 @@ def from_phase_linked_directory( for now. metadata and spatial coherence to be dealt with later. This folder structure corresponds to current test data for `spurt` and will likely evolve. + + Parameters + ---------- + folder : str or PathLike + Directory containing the phase-linked SLC stack. + temp_coh_threshold : float, optional + Minimum temporal coherence to consider a pixel stable. + date_fmt : str, optional + ``strftime``-compatible format used to extract acquisition dates + from SLC filenames. Default ``"%Y%m%d"``. Use a longer format such + as ``"%Y%m%d%H%M%S"`` to preserve a time-of-day component (e.g. + for non-Sentinel cadences with same-day repeats); the unwrapped + output filenames will then carry the same component. """ p = Path(folder) @@ -83,7 +122,10 @@ def from_phase_linked_directory( # Then list individual SLCs slclist = sorted(p.glob("*.int.tif")) - first_date = slclist[0].name.split("_")[0][:8] + date_len = _date_str_length(date_fmt) + first_date = _extract_date_str( + slclist[0].name.split("_")[0], date_fmt, date_len + ) # Start with first date - set to None # None is special case for reference epoch @@ -100,7 +142,7 @@ def from_phase_linked_directory( ) raise ValueError(errmsg) - acq_date = slc.name.split("_")[1][:8] + acq_date = _extract_date_str(slc.name.split("_")[1], date_fmt, date_len) if acq_date in slc_files: errmsg = ( f"Error scanning {folder}." @@ -120,6 +162,7 @@ def from_slc_directory( cls, folder: str | os.PathLike[str], temp_coh_threshold: float = 0.6, + date_fmt: str = DEFAULT_DATE_FMT, ) -> SLCStackReader: """Initialize stack by scanning a folder. @@ -129,6 +172,18 @@ def from_slc_directory( and will likely evolve. This is a totally made up directory structure to demonstrate use of same data structure. The temporal coherence file could just be a mask file as well for good pixels here. + + Parameters + ---------- + folder : str or PathLike + Directory containing the SLC stack. + temp_coh_threshold : float, optional + Minimum temporal coherence (or quality value) to consider a pixel + stable. + date_fmt : str, optional + ``strftime``-compatible format used to extract acquisition dates + from SLC filenames. Default ``"%Y%m%d"``. See + :meth:`from_phase_linked_directory` for details. """ p = Path(folder) @@ -140,10 +195,11 @@ def from_slc_directory( # Then list individual SLCs slclist = sorted(p.glob("*.slc.tif")) + date_len = _date_str_length(date_fmt) slc_files = {} for slc in slclist: - acq_date = slc.name.split("_")[0][:8] + acq_date = _extract_date_str(slc.name.split("_")[0], date_fmt, date_len) if acq_date in slc_files: errmsg = ( f"Error scanning {folder}." diff --git a/src/spurt/links/__init__.py b/src/spurt/links/__init__.py index 4a43a70..403e6b0 100644 --- a/src/spurt/links/__init__.py +++ b/src/spurt/links/__init__.py @@ -1,7 +1,9 @@ +from ._design_matrix import build_design_matrix from ._grid_search import GridSearchLinearModel from ._interface import LinkModelInterface __all__ = [ "GridSearchLinearModel", "LinkModelInterface", + "build_design_matrix", ] diff --git a/src/spurt/links/_design_matrix.py b/src/spurt/links/_design_matrix.py new file mode 100644 index 0000000..c3ff2bf --- /dev/null +++ b/src/spurt/links/_design_matrix.py @@ -0,0 +1,93 @@ +"""Build design matrices for link model estimation.""" + +from __future__ import annotations + +import numpy as np + +__all__ = [ + "build_design_matrix", +] + + +def build_design_matrix( + ifg_edges: np.ndarray, + dates: np.ndarray, + bperp_m: np.ndarray, + wavelength_m: float, + slant_range_m: float, + look_angle_rad: float, +) -> np.ndarray: + """Build design matrix for velocity and DEM error estimation. + + The design matrix relates model parameters (velocity, DEM error) to + interferometric phase via the linear model: + + phase = amat @ [velocity, dem_error] + + where: + - Column 0: velocity sensitivity (rad per mm/yr) + - Column 1: DEM error sensitivity (rad per meter) + + Parameters + ---------- + ifg_edges : np.ndarray + Interferogram edges as (nifgs, 2) array of SLC indices. + Each row [i, j] represents an interferogram from SLC i to SLC j. + dates : np.ndarray + SLC acquisition dates as datetime64[D], shape (n_slc,). + bperp_m : np.ndarray + Perpendicular baseline in meters per SLC, shape (n_slc,). + wavelength_m : float + Radar wavelength in meters (e.g., 0.055465 for Sentinel-1 C-band). + slant_range_m : float + Slant range distance in meters (e.g., 900000 for Sentinel-1). + look_angle_rad : float + Look angle in radians (e.g., 0.68 rad = 39 deg for Sentinel-1). + + Returns + ------- + amat : np.ndarray + Design matrix of shape (nifgs, 2). + Column 0: temporal sensitivity for velocity (rad per mm/yr). + Column 1: baseline sensitivity for DEM error (rad per meter). + + Notes + ----- + The phase model is: + + phi = (4 * pi / wavelength) * velocity * delta_t / 1000 / 365.25 + + (4 * pi / wavelength) * (bperp / (slant_range * sin(look))) * dem_error + + where velocity is in mm/yr and dem_error is in meters. + + References + ---------- + .. [1] Ferretti, A., Prati, C. and Rocca, F., 2001. Permanent scatterers + in SAR interferometry. IEEE Transactions on geoscience and remote + sensing, 39(1), pp.8-20. + """ + nifgs = len(ifg_edges) + amat = np.zeros((nifgs, 2), dtype=np.float64) + + # Common factor: 4 * pi / wavelength + phase_factor = 4.0 * np.pi / wavelength_m + + # DEM error factor: bperp / (slant_range * sin(incidence_angle)) + dem_scale = 1.0 / (slant_range_m * np.sin(look_angle_rad)) + + for ii, (ref_idx, sec_idx) in enumerate(ifg_edges): + # Temporal baseline in days + delta_t_days = (dates[sec_idx] - dates[ref_idx]).astype("timedelta64[D]") + delta_t_days = float(delta_t_days.astype(np.float64)) + + # Perpendicular baseline difference in meters + delta_bperp = bperp_m[sec_idx] - bperp_m[ref_idx] + + # Column 0: velocity sensitivity (rad per mm/yr) + # Convert days to years, mm to m: delta_t_days / 365.25 * 0.001 + amat[ii, 0] = phase_factor * delta_t_days / 365.25 * 0.001 + + # Column 1: DEM error sensitivity (rad per meter) + amat[ii, 1] = phase_factor * delta_bperp * dem_scale + + return amat diff --git a/src/spurt/links/_grid_search.py b/src/spurt/links/_grid_search.py index ab2e839..2776df0 100644 --- a/src/spurt/links/_grid_search.py +++ b/src/spurt/links/_grid_search.py @@ -1,13 +1,9 @@ from __future__ import annotations from dataclasses import dataclass -from multiprocessing import get_context import numpy as np -from scipy import optimize -from ..utils import get_cpu_count, logger -from ._common import neg_temporal_coherence from ._interface import LinkModelInterface @@ -41,6 +37,57 @@ class GridSearchLinearModel(Parameters, LinkModelInterface): s.t: ranges[i][0] <= x_i <= ranges[i][1] """ + def __post_init__(self): + super().__post_init__() + self._precompute() + + def _precompute(self) -> None: + """Precompute grid, forward model, and complex exponentials. + + These are constant for a given design matrix and search ranges, + so computing them once avoids redundant work per link. + """ + axes = [np.arange(s.start, s.stop, s.step) for s in self.ranges] + self._grid_shape = tuple(len(a) for a in axes) + grids = np.meshgrid(*axes, indexing="ij") + self._grid_flat = np.column_stack([g.ravel() for g in grids]) + self._grid_steps = np.array([s.step for s in self.ranges]) + + # Clip bounds from actual grid range + self._param_lo = self._grid_flat.min(axis=0) + self._param_hi = self._grid_flat.max(axis=0) + + # Forward model for all grid points: (nobs, ngrid) + self._pred = self.matrix @ self._grid_flat.T + + # Complex exponential of forward model: (nobs, ngrid) + self._E = np.exp(1j * self._pred) + + # Quadratic refinement setup (2D only) + if self.ndim == 2: + self._init_quadratic_refinement() + + def _init_quadratic_refinement(self) -> None: + """Precompute pseudoinverse for 3x3 quadratic surface fit.""" + offsets = np.array( + [ + [-1, -1], + [-1, 0], + [-1, 1], + [0, -1], + [0, 0], + [0, 1], + [1, -1], + [1, 0], + [1, 1], + ] + ) + self._stencil_offsets = offsets + dx = offsets[:, 0].astype(np.float64) + dy = offsets[:, 1].astype(np.float64) + design = np.column_stack([np.ones(9), dx, dy, dx**2, dx * dy, dy**2]) + self._refine_pinv = np.linalg.pinv(design) # (6, 9) + @property def nobs(self) -> int: return self.matrix.shape[0] @@ -49,6 +96,11 @@ def nobs(self) -> int: def ndim(self) -> int: return self.matrix.shape[1] + @property + def ngrid(self) -> int: + """Total number of grid points in the search space.""" + return len(self._grid_flat) + def fwd_model(self, x: np.ndarray) -> np.ndarray: return np.dot(self.matrix, x) @@ -57,21 +109,21 @@ def estimate_model( wrapdata: np.ndarray, weights: np.ndarray | float | None = None, ) -> tuple[np.ndarray, float]: - """Fit model parameters via grid search followed by Nelder-Mead optimization. + """Fit model parameters via grid search + quadratic refinement. Parameters ---------- - wrapdata: np.ndarray - Real-valued array of wrapped phase gradient - weights: np.ndarray | None - Real-valued weights - assumed to be normalized to 1. + wrapdata : np.ndarray + Real-valued array of wrapped phase gradient, shape ``(nobs,)``. + weights : np.ndarray | float | None + Real-valued weights, assumed normalized to 1. Returns ------- - params: np.ndarray - 1D array of length ndim - coh: float - Temporal coherence + params : np.ndarray + 1D array of length ``ndim``. + coh : float + Temporal coherence. """ if wrapdata.ndim != 1: errmsg = f"Input data must be a 1D array. Got {wrapdata.shape}." @@ -84,18 +136,40 @@ def estimate_model( errmsg = f"Weights shape mismatch: {weights.shape} vs {wrapdata.shape}" raise ValueError(errmsg) - return solve(self.matrix, self.ranges, wrapdata, weights) + params, coh = self.estimate_model_many(wrapdata[:, np.newaxis], weights=weights) + return params[:, 0], float(coh[0]) def estimate_model_many( self, wrapdata: np.ndarray, weights: np.ndarray | float | None = None, - worker_count: int | None = None, + worker_count: int | None = None, # noqa: ARG002 ) -> tuple[np.ndarray, np.ndarray]: - """Grid search followed by fmin in parallel.""" - if (worker_count is None) or (worker_count <= 0): - worker_count = max(1, get_cpu_count() - 1) + """Batched grid search + quadratic refinement for many links. + + Evaluates all grid points for all links via matrix multiply, then + refines with quadratic interpolation. Links are processed in chunks + to keep the intermediate ``(ngrid, nlinks)`` coherence matrix within + a reasonable memory budget. + + Parameters + ---------- + wrapdata : np.ndarray + Wrapped phase, shape ``(nobs, nlinks)``. + weights : np.ndarray | float | None + Weights. Scalar for uniform, 1D for per-observation, + 2D for per-observation-per-link. + worker_count : int | None + Accepted for ``LinkModelInterface`` compatibility but not used. + The batched approach uses NumPy's internal BLAS threading. + Returns + ------- + params : np.ndarray + Shape ``(ndim, nlinks)``. + coh : np.ndarray + Shape ``(nlinks,)``. + """ if wrapdata.ndim != 2: errmsg = f"Input data must be a 2D array. Got {wrapdata.shape}." raise ValueError(errmsg) @@ -104,83 +178,212 @@ def estimate_model_many( errmsg = f"Input shape mismatch. Got {wrapdata.shape} vs {self.nobs}" raise ValueError(errmsg) - const_weights: bool = True + wts: np.ndarray | float if weights is None: - weights = 1.0 / self.nobs - elif isinstance(weights, np.ndarray): - if weights.ndim == 2 and (weights.shape != wrapdata.shape): - errmsg = ( - f"Weights shape mismatch. Got {weights.shape} vs {wrapdata.shape}" - ) - raise ValueError(errmsg) - const_weights = False - arr_weights: np.ndarray = weights - - if weights.shape[0] != self.nobs: - errmsg = f"Weights shape mismatch. Got {weights.shape} vs {self.nobs}" - raise ValueError(errmsg) - - # Return arrays - nruns: int = wrapdata.shape[1] - params: np.ndarray = np.zeros((self.ndim, nruns)) - tcoh: np.ndarray = np.zeros(nruns) - - # Run sequentially when only 1 worker available - if worker_count == 1: - for ii in range(nruns): - wts = weights if const_weights else arr_weights[:, ii] - res = self.estimate_model( - wrapdata[:, ii], - wts, - ) - params[:, ii] = res[0] - tcoh[ii] = res[1] + wts = 1.0 / self.nobs else: - logger.info(f"Modeling batch of {nruns} with {worker_count} threads") - - def inv_inputs(idxs): - for ii in idxs: - wts = weights if const_weights else arr_weights[:, ii] - yield ( - ii, - self.matrix, - self.ranges, - wrapdata[:, ii], - wts, + wts = weights + if isinstance(wts, np.ndarray): + if wts.ndim == 2 and wts.shape != wrapdata.shape: + errmsg = ( + f"Weights shape mismatch." + f" Got {wts.shape} vs {wrapdata.shape}" ) + raise ValueError(errmsg) + if wts.ndim <= 1 and wts.shape[0] != self.nobs: + errmsg = f"Weights shape mismatch. Got {wts.shape} vs {self.nobs}" + raise ValueError(errmsg) + + nlinks = wrapdata.shape[1] + + # Weighted conjugate of data: v[k, l] = wts_k * exp(-1j * wdata[k, l]) + d_conj = np.exp(-1j * wrapdata) + if isinstance(wts, np.ndarray) and wts.ndim == 2: + weighted_conj = wts * d_conj + elif isinstance(wts, np.ndarray): + weighted_conj = wts[:, np.newaxis] * d_conj + else: + weighted_conj = wts * d_conj + del d_conj + + # Chunk links so the intermediate (ngrid, chunk) coherence matrix + # stays under ~512 MB. Each link needs ngrid * 24 bytes + # (16 for complex128 matmul result + 8 for float64 abs). + bytes_per_link = self.ngrid * 24 + chunk_size = max(1, int(512e6 / bytes_per_link)) + + params = np.zeros((self.ndim, nlinks), dtype=np.float64) + coh = np.zeros(nlinks, dtype=np.float64) + + for start in range(0, nlinks, chunk_size): + end = min(start + chunk_size, nlinks) + sl = slice(start, end) + + # Grid search for this chunk + coh_grid = self._E.T @ weighted_conj[:, sl] # (ngrid, chunk) + coherence = np.abs(coh_grid) # (ngrid, chunk) + del coh_grid + + best_idx = np.argmax(coherence, axis=0) + chunk_n = end - start + + # Quadratic refinement (2D case) + if self.ndim == 2 and chunk_n > 0: + chunk_wts = ( + wts[:, sl] if isinstance(wts, np.ndarray) and wts.ndim == 2 else wts + ) + params[:, sl], coh[sl] = self._quadratic_refine_batch( + coherence, best_idx, wrapdata[:, sl], chunk_wts + ) + else: + params[:, sl] = self._grid_flat[best_idx].T + coh[sl] = coherence[best_idx, np.arange(chunk_n)] + + return params, coh + + def _quadratic_refine_batch( + self, + coherence: np.ndarray, + best_idx: np.ndarray, + wrapdata: np.ndarray, + wts: np.ndarray | float, + ) -> tuple[np.ndarray, np.ndarray]: + """Refine grid search via 2D quadratic interpolation. + + Fits a quadratic surface to the 3x3 coherence stencil around + each best grid point and solves for the analytic peak. + """ + nlinks = len(best_idx) + + # Convert flat grid indices to 2D + i0, i1 = np.unravel_index(best_idx, self._grid_shape) + + # Build stencil grid indices: (nlinks, 9) + di = self._stencil_offsets[:, 0] # (9,) + dj = self._stencil_offsets[:, 1] # (9,) + si = np.clip(i0[:, None] + di[None, :], 0, self._grid_shape[0] - 1) + sj = np.clip(i1[:, None] + dj[None, :], 0, self._grid_shape[1] - 1) + + # Flat stencil indices and extract coherence + flat_idx = np.ravel_multi_index((si, sj), self._grid_shape) + link_col = np.broadcast_to(np.arange(nlinks)[:, None], flat_idx.shape) + stencil_coh = coherence[flat_idx, link_col] # (nlinks, 9) + + # Fit quadratic: f(dx,dy) = a + b*dx + c*dy + d*dx^2 + e*dx*dy + f*dy^2 + coeffs = stencil_coh @ self._refine_pinv.T # (nlinks, 6) + + b_c = coeffs[:, 1] + c_c = coeffs[:, 2] + d_c = coeffs[:, 3] + e_c = coeffs[:, 4] + f_c = coeffs[:, 5] + + # Solve for quadratic peak: H @ [dx, dy] = -[b, c] + # where H = [[2d, e], [e, 2f]] + det = 4.0 * d_c * f_c - e_c**2 + is_max = (d_c < 0) & (det > 0) & (np.abs(det) > 1e-12) + + dx = np.zeros(nlinks) + dy = np.zeros(nlinks) + m = is_max + dx[m] = -(2.0 * f_c[m] * b_c[m] - e_c[m] * c_c[m]) / det[m] + dy[m] = -(2.0 * d_c[m] * c_c[m] - e_c[m] * b_c[m]) / det[m] + + # Clip refinement to one grid step + dx = np.clip(dx, -1.0, 1.0) + dy = np.clip(dy, -1.0, 1.0) + + # Build refined parameters + params = self._grid_flat[best_idx].T.copy() # (ndim, nlinks) + params[0] += dx * self._grid_steps[0] + params[1] += dy * self._grid_steps[1] - # Create a pool and dispatch - with get_context("fork").Pool(processes=worker_count) as p: - mp_tasks = p.imap_unordered(wrap_solve, inv_inputs(range(nruns))) + # Clip to search bounds + for d in range(self.ndim): + params[d] = np.clip(params[d], self._param_lo[d], self._param_hi[d]) - # Gather results - for res in mp_tasks: # type: ignore[assignment] - params[:, res[0]] = res[1] - tcoh[res[0]] = res[2] # type: ignore[misc] + # Evaluate actual coherence at refined parameters + coh = self._eval_coherence(params, wrapdata, wts) - return params, tcoh + return params, coh + + def _eval_coherence( + self, + params: np.ndarray, + wrapdata: np.ndarray, + wts: np.ndarray | float, + ) -> np.ndarray: + """Evaluate temporal coherence at given parameters. + + Parameters + ---------- + params : np.ndarray + Shape ``(ndim, nlinks)``. + wrapdata : np.ndarray + Shape ``(nobs, nlinks)``. + wts : np.ndarray | float + Weights. + + Returns + ------- + np.ndarray + Coherence values, shape ``(nlinks,)``. + """ + residuals = self.matrix @ params - wrapdata # (nobs, nlinks) + weighted_exp = np.exp(1j * residuals) + if isinstance(wts, np.ndarray) and wts.ndim == 2: + weighted_exp *= wts + elif isinstance(wts, np.ndarray): + weighted_exp *= wts[:, np.newaxis] + else: + weighted_exp *= wts + return np.abs(weighted_exp.sum(axis=0)) -def solve( +def _vectorized_grid_search( matrix: np.ndarray, rngs: tuple[slice, ...], wdata: np.ndarray, wts: np.ndarray | float, -) -> tuple[np.ndarray, float]: - """Actual call to the solver.""" - resbrute = optimize.brute( - neg_temporal_coherence, - rngs, - args=(matrix, wdata, wts), - full_output=True, - finish=optimize.fmin, - ) - return (resbrute[0], -resbrute[1]) - - -def wrap_solve( - args: tuple[int, np.ndarray, tuple[slice], np.ndarray, np.ndarray | float], -) -> tuple[int, np.ndarray, float]: - ind, ma, rg, wd, wt = args - out = solve(ma, rg, wd, wt) - return (ind, out[0], out[1]) +) -> np.ndarray: + """Find parameters maximizing temporal coherence over a regular grid. + + Evaluates all grid points in a single batched numpy operation + instead of calling the objective function once per grid point. + + Parameters + ---------- + matrix : np.ndarray + Design matrix of shape ``(nifgs, ndim)``. + rngs : tuple[slice, ...] + One ``slice(start, stop, step)`` per parameter dimension. + wdata : np.ndarray + Wrapped phase data of shape ``(nifgs,)``. + wts : np.ndarray | float + Weights, either scalar or array of shape ``(nifgs,)``. + + Returns + ------- + np.ndarray + Best-fit parameter vector of shape ``(ndim,)``. + """ + # Build 1D coordinate arrays for each parameter dimension + axes = [np.arange(s.start, s.stop, s.step) for s in rngs] + + # Flattened grid of all parameter combinations: (ngrid, ndim) + grids = np.meshgrid(*axes, indexing="ij") + grid_flat = np.column_stack([g.ravel() for g in grids]) + + # Evaluate all grid points at once + # matrix @ grid_flat.T: (nifgs, ndim) @ (ndim, ngrid) -> (nifgs, ngrid) + residuals = matrix @ grid_flat.T - wdata[:, np.newaxis] + + # Temporal coherence for all grid points + weighted_exp = np.exp(1j * residuals) + if isinstance(wts, np.ndarray): + weighted_exp *= wts[:, np.newaxis] + else: + weighted_exp *= wts + coherence = np.abs(weighted_exp.sum(axis=0)) + + return grid_flat[np.argmax(coherence)] diff --git a/src/spurt/mcf/_ortools.py b/src/spurt/mcf/_ortools.py index 8baca6e..0f1c763 100644 --- a/src/spurt/mcf/_ortools.py +++ b/src/spurt/mcf/_ortools.py @@ -9,7 +9,7 @@ from ortools.graph.python import min_cost_flow from ..graph import PlanarGraphInterface, order_points -from ..utils import get_cpu_count +from ..utils import get_cpu_count, logger from ._interface import MCFSolverInterface from .utils import flood_fill, phase_diff, sign_nonzero @@ -253,32 +253,28 @@ def residues_to_flows_many( flows[ii, :] = self.residues_to_flows(res, cost, revcost=revcost) else: - print(f"Processing batch of {nruns} with {worker_count} threads") + logger.info(f"Processing batch of {nruns} with {worker_count} workers") def uw_inputs(idxs): for ii in idxs: - # Only solve if needed if not np.any(residues[ii] != 0): continue - - yield ( - ii, - self._dual_edges, - self._dual_edge_dir, - residues[ii], - cost, - revcost, - ) - - # Create a pool and dispatch - # We explicitly use fork here as osx has switched to using spawn - # and that really slows down the use of multiprocessing - with get_context("fork").Pool(processes=worker_count) as p: + yield (ii, residues[ii]) + + # Use forkserver to avoid inheriting the parent's full memory. + # With fork, each worker gets the parent's entire address space + # (SLC data, gradients, etc.) causing memory explosion on macOS + # where COW page sharing breaks down due to reference counting. + # Constant arrays are shared once per worker via initializer. + with get_context("forkserver").Pool( + processes=worker_count, + initializer=_init_mcf_worker, + initargs=(self._dual_edges, self._dual_edge_dir, cost, revcost), + ) as p: mp_tasks = p.imap_unordered( - wrap_solve_mcf, uw_inputs(range(nruns)), chunksize=chunksize + _worker_solve_mcf, uw_inputs(range(nruns)), chunksize=chunksize ) - # Gather results for res in mp_tasks: flows[res[0], :] = res[1] @@ -352,14 +348,40 @@ def solve_mcf( flows = np.zeros(num_edges, dtype=int) for ii in range(num_edges): # Sign accounts for orientation of edge in cycles - flows[ii] = first_cycle_dir[ii] * (smcf.flow(ii + num_edges) - smcf.flow(ii)) + # Cast to int to avoid int8 overflow when multiplying + flows[ii] = int(first_cycle_dir[ii]) * ( + smcf.flow(ii + num_edges) - smcf.flow(ii) + ) return flows -def wrap_solve_mcf( - args: tuple[int, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None], -) -> tuple[int, np.ndarray]: - """Parallel version of solve_mcf.""" - ind, es, ed, rr, cc, rc = args - return (ind, solve_mcf(es, ed, rr, cc, rc)) +_mcf_worker_state: dict = {} + + +def _init_mcf_worker( + dual_edges: np.ndarray, + dual_edge_dir: np.ndarray, + cost: np.ndarray, + revcost: np.ndarray, +) -> None: + """Initialize MCF worker with constant solver data.""" + _mcf_worker_state["dual_edges"] = dual_edges + _mcf_worker_state["dual_edge_dir"] = dual_edge_dir + _mcf_worker_state["cost"] = cost + _mcf_worker_state["revcost"] = revcost + + +def _worker_solve_mcf(args: tuple[int, np.ndarray]) -> tuple[int, np.ndarray]: + """Solve a single MCF problem using shared worker state.""" + ii, residues = args + return ( + ii, + solve_mcf( + _mcf_worker_state["dual_edges"], + _mcf_worker_state["dual_edge_dir"], + residues, + _mcf_worker_state["cost"], + _mcf_worker_state["revcost"], + ), + ) diff --git a/src/spurt/utils/merge.py b/src/spurt/utils/merge.py index 74a8613..2eee214 100644 --- a/src/spurt/utils/merge.py +++ b/src/spurt/utils/merge.py @@ -344,7 +344,7 @@ def dirichlet_graph( corrections[kk, ~mask] = 0 continue - x, info = cg( + x, _info = cg( mat, b, rtol=1e-7, diff --git a/src/spurt/workflows/emcf/__init__.py b/src/spurt/workflows/emcf/__init__.py index 330d4e9..666829a 100644 --- a/src/spurt/workflows/emcf/__init__.py +++ b/src/spurt/workflows/emcf/__init__.py @@ -1,13 +1,20 @@ from ._bulk_offset import get_bulk_offsets from ._merge import merge_tiles from ._overlap import compute_phasediff_deciles -from ._settings import GeneralSettings, MergerSettings, SolverSettings, TilerSettings +from ._settings import ( + GeneralSettings, + LinkModelSettings, + MergerSettings, + SolverSettings, + TilerSettings, +) from ._solver import EMCFSolver as Solver from ._tiling import get_tiles from ._unwrap import unwrap_tiles __all__ = [ "GeneralSettings", + "LinkModelSettings", "MergerSettings", "Solver", "SolverSettings", diff --git a/src/spurt/workflows/emcf/__main__.py b/src/spurt/workflows/emcf/__main__.py index 9fcc018..95c155e 100644 --- a/src/spurt/workflows/emcf/__main__.py +++ b/src/spurt/workflows/emcf/__main__.py @@ -2,4 +2,5 @@ from ._cli import main -sys.exit(main()) +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/spurt/workflows/emcf/_cli.py b/src/spurt/workflows/emcf/_cli.py index 8f4cad3..fa49745 100644 --- a/src/spurt/workflows/emcf/_cli.py +++ b/src/spurt/workflows/emcf/_cli.py @@ -5,8 +5,15 @@ from ._bulk_offset import get_bulk_offsets from ._merge import merge_tiles +from ._output import write_link_params from ._overlap import compute_phasediff_deciles -from ._settings import GeneralSettings, MergerSettings, SolverSettings, TilerSettings +from ._settings import ( + GeneralSettings, + LinkModelSettings, + MergerSettings, + SolverSettings, + TilerSettings, +) from ._tiling import get_tiles from ._unwrap import unwrap_tiles @@ -105,6 +112,64 @@ def main(args=None): "--log-file", help="Path to save the log file (in addition to printing to stderr).", ) + parser.add_argument( + "--date-fmt", + default="%Y%m%d", + help=( + "strftime format used to extract acquisition dates from SLC" + " filenames and to write the date portion of unwrapped output" + " filenames. Use a longer format such as '%%Y%%m%%d%%H%%M%%S' to" + " preserve a time-of-day component (e.g. for non-Sentinel cadences" + " with same-day repeats)." + ), + ) + + # Link model / velocity estimation arguments + parser.add_argument( + "--baseline-csv", + type=str, + default=None, + help="Path to CSV with perpendicular baselines. Enables velocity estimation.", + ) + parser.add_argument( + "--no-velocity-estimation", + action="store_true", + help="Disable velocity/DEM error estimation even if baseline CSV is provided.", + ) + parser.add_argument( + "--wavelength", + type=float, + default=0.055465, + help="Radar wavelength in meters.", + ) + parser.add_argument( + "--slant-range", + type=float, + default=900000.0, + help="Slant range distance in meters.", + ) + parser.add_argument( + "--look-angle-deg", + type=float, + default=39.0, + help="Look angle in degrees.", + ) + parser.add_argument( + "--velocity-range", + type=float, + nargs=3, + default=[-100.0, 100.0, 5.0], + metavar=("MIN", "MAX", "STEP"), + help="Velocity search range in mm/yr: min max step.", + ) + parser.add_argument( + "--dem-error-range", + type=float, + nargs=3, + default=[-50.0, 50.0, 2.5], + metavar=("MIN", "MAX", "STEP"), + help="DEM error search range in meters: min max step.", + ) # Parse arguments parsed_args = parser.parse_args(args=args) @@ -118,6 +183,7 @@ def main(args=None): stack = spurt.io.SLCStackReader.from_phase_linked_directory( parsed_args.inputdir, temp_coh_threshold=parsed_args.coh, + date_fmt=parsed_args.date_fmt, ) # Create general settings @@ -148,6 +214,20 @@ def main(args=None): num_parallel_ifgs=parsed_args.merge_parallel_ifgs, ) + # Create link model settings if baseline CSV is provided and not disabled + link_model_settings: LinkModelSettings | None = None + if parsed_args.baseline_csv and not parsed_args.no_velocity_estimation: + link_model_settings = LinkModelSettings( + enabled=True, + wavelength_m=parsed_args.wavelength, + slant_range_m=parsed_args.slant_range, + look_angle_deg=parsed_args.look_angle_deg, + velocity_range=tuple(parsed_args.velocity_range), + dem_error_range=tuple(parsed_args.dem_error_range), + baseline_csv=parsed_args.baseline_csv, + ) + logger.info(f"Link model enabled with baselines: {parsed_args.baseline_csv}") + # Using default Hop3Graph logger.info(f"Using Hop3 Graph in time with {len(stack.slc_files)} epochs.") g_time = spurt.graph.Hop3Graph(len(stack.slc_files)) @@ -157,7 +237,7 @@ def main(args=None): get_tiles(stack, gen_settings, tile_settings) # Unwrap tiles - unwrap_tiles(stack, g_time, gen_settings, slv_settings) + unwrap_tiles(stack, g_time, gen_settings, slv_settings, link_model_settings) # Compute overlap stats compute_phasediff_deciles(gen_settings, mrg_settings) @@ -168,4 +248,9 @@ def main(args=None): # Merge tiles and write output merge_tiles(stack, g_time, gen_settings, mrg_settings) + # Write link model parameters (velocity, DEM error) if available + if link_model_settings is not None: + like_slc_file = stack.slc_files[stack.dates[-1]] + write_link_params(gen_settings, stack.raster_shape, like=like_slc_file) + logger.info("Completed EMCF workflow.") diff --git a/src/spurt/workflows/emcf/_output.py b/src/spurt/workflows/emcf/_output.py index ce8bb9f..70c483d 100644 --- a/src/spurt/workflows/emcf/_output.py +++ b/src/spurt/workflows/emcf/_output.py @@ -296,3 +296,210 @@ def write_merged_band( like=like_raster, ) as raster: raster[sidx] = model + + +def _write_raster( + fname: Path, + arr: np.ndarray, + like_raster: Any, +) -> None: + """Write a 2D float32 array to a GeoTIFF.""" + with spurt.io.Raster.create( + str(fname), + width=arr.shape[1], + height=arr.shape[0], + dtype=np.float32, + nodata=np.nan, + driver="GTiff", + tiled=True, + blockxsize=512, + blockysize=512, + compress="DEFLATE", + like=like_raster, + ) as raster: + raster[np.s_[:, :]] = arr + + +def _integrate_tile_link_params( + tile_file: str, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """Integrate per-link gradients to per-point values for one tile. + + Parameters + ---------- + tile_file : str + Path to tile HDF5 file. + + Returns + ------- + coords : np.ndarray + Global (row, col) coordinates of shape (npoints, 2). + point_params : np.ndarray + Integrated parameter values, shape (ndim, npoints). + point_coh : np.ndarray + Mean link coherence per point, shape (npoints,). + ndim : int + Number of parameter dimensions. + """ + from scipy.sparse import csr_matrix + from scipy.sparse.linalg import lsqr + + with h5py.File(tile_file, "r") as fid: + link_params = fid["link_params"][...] # (ndim, nlinks) + link_coherence = fid["link_coherence"][...] # (nlinks,) + points = fid["points"][...] # (npoints, 2) + tile_offset = fid["tile"][...] # (4,) -> [row_start, col_start, ...] + + coords = points + tile_offset[None, :2] + + # Build Delaunay graph for this tile to get edges + g_space = spurt.graph.DelaunayGraph(points) + edges = g_space.links + nlinks = len(edges) + npoints = g_space.npoints + ndim = link_params.shape[0] + + # Build weighted incidence matrix for integration. + # Use sqrt(coherence) as weights for proper WLS: + # minimize sum_i w_i * (A_i @ x - b_i)^2 + # Transforming to standard LS: sqrt(W) * A @ x = sqrt(W) * b + w = np.sqrt(np.clip(link_coherence, 0, 1)).astype(np.float64) + data = np.empty(2 * nlinks, dtype=np.float64) + data[0::2] = -w + data[1::2] = w + row_indices = np.repeat(np.arange(nlinks), 2) + col_indices = edges.flatten() + w_incidence = csr_matrix( + (data, (row_indices, col_indices)), shape=(nlinks, npoints) + ) + + # Integrate each parameter dimension via WLS + point_params = np.zeros((ndim, npoints), dtype=np.float32) + for dd in range(ndim): + rhs = link_params[dd, :] * w + result = lsqr(w_incidence[:, 1:], rhs) + point_params[dd, 1:] = result[0].astype(np.float32) + # Remove median: integration is relative to an arbitrary + # reference point, so center the result. + point_params[dd] -= np.median(point_params[dd]) + + # Compute mean coherence per point + coh_sum = np.zeros(npoints, dtype=np.float64) + coh_cnt = np.zeros(npoints, dtype=np.int32) + np.add.at(coh_sum, edges[:, 0], link_coherence) + np.add.at(coh_sum, edges[:, 1], link_coherence) + np.add.at(coh_cnt, edges[:, 0], 1) + np.add.at(coh_cnt, edges[:, 1], 1) + point_coh = np.where(coh_cnt > 0, coh_sum / coh_cnt, 0).astype(np.float32) + + return coords, point_params, point_coh, ndim + + +def write_link_params( + gen_settings: Any, + shape: tuple[int, int], + param_names: list[str] | None = None, + like: str | os.PathLike[str] | None = None, +) -> list[Path]: + """Integrate per-link model parameters to per-point rasters. + + Reads link_params (spatial gradients) and link_coherence from tile HDF5 + files, integrates per-link gradients to per-point values via weighted + least-squares (WLS), and writes GeoTIFFs for each parameter plus the + model coherence. + + Parameters + ---------- + gen_settings : GeneralSettings + General settings with tile filenames and output folder. + shape : tuple[int, int] + Output raster shape (rows, cols). + param_names : list[str] | None + Names for each parameter dimension. + Default: ["velocity_mm_yr", "dem_error_m"]. + like : str | os.PathLike | None + Reference raster for georeferencing. + + Returns + ------- + list[Path] + Paths to written GeoTIFFs. + """ + if param_names is None: + param_names = ["velocity_mm_yr", "dem_error_m"] + + tile_json = gen_settings.tiles_jsonname + tiledata = spurt.utils.TileSet.from_json(tile_json) + + output_dir = Path(gen_settings.output_folder) + like_raster = None if like is None else spurt.io.Raster(like) + written: list[Path] = [] + + # Check which output files already exist so we can skip work + # We need to peek at ndim from the first tile to build the full file list + first_tile = str(gen_settings.tile_filename(0)) + with h5py.File(first_tile, "r") as fid: + if "link_params" not in fid: + logger.info("No link_params in tile files. Skipping link param output.") + return written + ndim_check = fid["link_params"].shape[0] + + expected_names = [ + param_names[dd] if dd < len(param_names) else f"param_{dd}" + for dd in range(ndim_check) + ] + expected_files = [output_dir / f"{name}.tif" for name in expected_names] + expected_files.append(output_dir / "link_model_coherence.tif") + + if all(f.is_file() for f in expected_files): + for f in expected_files: + logger.info(f"{f!s} already exists. Skipping writing ...") + written.append(f) + return written + + # Accumulate all tiles into full-size rasters before writing. + # Each tile is integrated independently, then placed into global arrays. + param_arrays: list[np.ndarray] | None = None + coh_array: np.ndarray | None = None + ndim = 0 + + for tt in range(tiledata.ntiles): + tile_file = str(gen_settings.tile_filename(tt)) + coords, point_params, point_coh, ndim = _integrate_tile_link_params(tile_file) + + # Initialize output arrays on first tile + if param_arrays is None: + param_arrays = [ + np.full(shape, np.nan, dtype=np.float32) for _ in range(ndim) + ] + coh_array = np.full(shape, np.nan, dtype=np.float32) + + # Place tile results into global arrays (overlapping regions + # get overwritten — last tile wins, same as unwrapped phase) + assert param_arrays is not None + assert coh_array is not None + r, c = coords[:, 0], coords[:, 1] + for dd in range(ndim): + param_arrays[dd][r, c] = point_params[dd] + coh_array[r, c] = point_coh + + if param_arrays is None: + return written + + assert coh_array is not None + + # Write accumulated parameter rasters + for dd in range(ndim): + name = param_names[dd] if dd < len(param_names) else f"param_{dd}" + fname = output_dir / f"{name}.tif" + logger.info(f"Writing {name} to {fname}") + _write_raster(fname, param_arrays[dd], like_raster) + written.append(fname) + + # Write model coherence + coh_fname = output_dir / "link_model_coherence.tif" + logger.info(f"Writing link model coherence to {coh_fname}") + _write_raster(coh_fname, coh_array, like_raster) + written.append(coh_fname) + + return written diff --git a/src/spurt/workflows/emcf/_settings.py b/src/spurt/workflows/emcf/_settings.py index 45b4754..13e6d81 100644 --- a/src/spurt/workflows/emcf/_settings.py +++ b/src/spurt/workflows/emcf/_settings.py @@ -182,6 +182,61 @@ def __post_init__(self): raise ValueError(errmsg) +@dataclass +class LinkModelSettings: + """Settings for per-link velocity and DEM error estimation. + + Parameters + ---------- + enabled : bool + Whether link model estimation is enabled. Default True when + baseline_csv is provided via CLI. + wavelength_m : float + Radar wavelength in meters. Default 0.055465 (Sentinel-1 C-band). + slant_range_m : float + Slant range distance in meters. Default 900000.0. + look_angle_deg : float + Look angle in degrees. Default 39.0. + velocity_range : tuple[float, float, float] + Grid search range for velocity in mm/yr as (min, max, step). + Default (-100.0, 100.0, 5.0). + dem_error_range : tuple[float, float, float] + Grid search range for DEM error in meters as (min, max, step). + Default (-50.0, 50.0, 2.5). + baseline_csv : str | None + Path to CSV file with perpendicular baselines. Required when enabled. + """ + + enabled: bool = True + wavelength_m: float = 0.055465 # Sentinel-1 C-band + slant_range_m: float = 900000.0 + look_angle_deg: float = 39.0 + velocity_range: tuple[float, float, float] = (-100.0, 100.0, 5.0) + dem_error_range: tuple[float, float, float] = (-50.0, 50.0, 2.5) + baseline_csv: str | None = None + + @property + def look_angle_rad(self) -> float: + """Return look angle in radians.""" + import math + + return math.radians(self.look_angle_deg) + + def __post_init__(self): + if self.enabled and self.baseline_csv is None: + errmsg = "baseline_csv is required when link model is enabled" + raise ValueError(errmsg) + if self.wavelength_m <= 0: + errmsg = f"wavelength_m must be > 0, got {self.wavelength_m}" + raise ValueError(errmsg) + if self.slant_range_m <= 0: + errmsg = f"slant_range_m must be > 0, got {self.slant_range_m}" + raise ValueError(errmsg) + if not (0 < self.look_angle_deg < 90): + errmsg = f"look_angle_deg must be in (0, 90), got {self.look_angle_deg}" + raise ValueError(errmsg) + + @dataclass class MergerSettings: """Class for holding tile merging settings. diff --git a/src/spurt/workflows/emcf/_solver.py b/src/spurt/workflows/emcf/_solver.py index bb39585..215d7c4 100644 --- a/src/spurt/workflows/emcf/_solver.py +++ b/src/spurt/workflows/emcf/_solver.py @@ -54,9 +54,10 @@ def __init__( self._settings = settings self._link_model = link_model - if link_model is not None: - errmsg = "Not implemented yet." - raise NotImplementedError(errmsg) + # Estimated link parameters (velocity, DEM error, etc.) + # Populated during unwrap_gradients_in_time when link_model is provided + self.link_params: np.ndarray | None = None + self.link_coherence: np.ndarray | None = None @property def npoints(self) -> int: @@ -88,6 +89,53 @@ def link_model(self) -> LinkModelInterface | None: """Retrieve the link model for the workflow.""" return self._link_model + def integrate_link_params(self, param_idx: int = 0) -> np.ndarray: + """Integrate link parameters to get point values via least-squares. + + Converts per-link parameters (e.g., velocity gradients) to per-point + values by solving the least-squares problem: find v such that + v[j] - v[i] ~= grad[edge] for all edges. The first point is used + as reference (value = 0). + + Parameters + ---------- + param_idx: int + Index of the parameter to integrate (default 0, typically velocity). + + Returns + ------- + point_values: np.ndarray + 1D array of shape (npoints,) with integrated parameter values. + """ + if self.link_params is None: + errmsg = "No link parameters available. Run unwrap with link_model first." + raise RuntimeError(errmsg) + + from scipy.sparse import csr_matrix + from scipy.sparse.linalg import lsqr + + link_gradients = self.link_params[param_idx, :] + edges = self._solver_space.edges + + # Build incidence matrix: A[edge, :] has -1 at source, +1 at dest + nlinks = len(edges) + data = np.ones(2 * nlinks, dtype=np.float64) + data[0::2] = -1.0 # -1 for source node + data[1::2] = 1.0 # +1 for dest node + row_indices = np.repeat(np.arange(nlinks), 2) + col_indices = edges.flatten() + incidence = csr_matrix( + (data, (row_indices, col_indices)), shape=(nlinks, self.npoints) + ) + + # Solve least-squares: incidence @ point_values = link_gradients + # Add constraint: point_values[0] = 0 by dropping first column + result = lsqr(incidence[:, 1:], link_gradients.astype(np.float64)) + point_values = np.zeros(self.npoints, dtype=np.float64) + point_values[1:] = result[0] + + return point_values + def unwrap_cube(self, wrap_data: Irreg3DInput) -> np.ndarray: """Unwrap a 3D cube of data. @@ -126,7 +174,10 @@ def unwrap_cube(self, wrap_data: Irreg3DInput) -> np.ndarray: wrap_data.data, input_is_ifg=input_is_ifg ) - # Then unwrap spatial gradients + # Then unwrap spatial gradients. + # Note: phase_diff(z0, z1, model=m) returns the FULL gradient + # (z1-z0) wrapped around the model — not the residual. The model + # guides wrapping disambiguation but does not need to be restored. return self.unwrap_gradients_in_space(grad_space) def unwrap_gradients_in_time( @@ -160,6 +211,13 @@ def unwrap_gradients_in_time( # Create output array grad_space: np.ndarray = np.zeros((self.nifgs, self.nlinks), dtype=np.float32) + # Initialize storage for estimated parameters if link_model is provided + if self._link_model is not None: + self.link_params = np.zeros( + (self._link_model.ndim, self.nlinks), dtype=np.float32 + ) + self.link_coherence = np.zeros(self.nlinks, dtype=np.float32) + logger.info(f"Temporal: Number of interferograms: {self.nifgs}") logger.info(f"Temporal: Number of links: {self.nlinks}") logger.info(f"Temporal: Number of cycles: {self._solver_time.ncycles}") @@ -178,9 +236,6 @@ def unwrap_gradients_in_time( # Get indices of points forming links from spatial graph inds = self._solver_space.edges[i_start:i_end, :] - # TODO: Incorporate link_model here when ready - # Add self._modeled_phase_diff to replace phase_diff - # Compute spatial gradients for each link # If input data is already interferograms if input_is_ifg: @@ -193,6 +248,45 @@ def unwrap_gradients_in_time( wrap_data, inds, grad_space, np.s_[i_start:i_end] ) + # If link_model is provided, estimate parameters and flatten gradients + if self._link_model is not None: + assert self.link_params is not None + assert self.link_coherence is not None + + logger.info(f"Temporal: Estimating model for batch {bb + 1}/{nbatches}") + + # Estimate model parameters for this batch of links + batch_params, batch_coh = self._link_model.estimate_model_many( + grad_space[:, i_start:i_end], + worker_count=self.settings.t_worker_count, + ) + + # Store estimated parameters and coherence + self.link_params[:, i_start:i_end] = batch_params + self.link_coherence[i_start:i_end] = batch_coh + + # Compute model prediction for each interferogram and link + # fwd_model: (nifgs, ndim) @ (ndim, nlinks) -> (nifgs, nlinks) + model_pred = self._link_model.fwd_model(batch_params) + assert model_pred.shape == (self.nifgs, links_in_batch) + + # Recompute gradients using model to guide wrapping + # phase_diff with model returns gradient wrapped around model_pred + if input_is_ifg: + grad_space[:, i_start:i_end] = utils.phase_diff( + wrap_data[:, inds[:, 0]], + wrap_data[:, inds[:, 1]], + model=model_pred, + ) + else: + self._ifg_spatial_gradients_from_slc( + wrap_data, + inds, + grad_space, + np.s_[i_start:i_end], + model=model_pred, + ) + # Compute residues for each cycle in temporal graph # Easier to loop over interferograms here ncycles: int = len(self._solver_time.cycles) @@ -258,18 +352,17 @@ def unwrap_gradients_in_space(self, grad_space: np.ndarray) -> np.ndarray: if nworkers < 1: nworkers = get_cpu_count() - 1 - mp_context = mp.get_context("fork") + # Use forkserver to avoid inheriting the parent's full memory. + # Constant data (solver, cost) is shared once per worker via initializer. + mp_context = mp.get_context("forkserver") with ProcessPoolExecutor( - max_workers=nworkers, mp_context=mp_context + max_workers=nworkers, + mp_context=mp_context, + initializer=_init_spatial_worker, + initargs=(self._solver_space, cost), ) as executor: futures = { - executor.submit( - _unwrap_ifg_in_space, - grad_space[ii, :], - self._solver_space, - cost, - ii, - ): ii + executor.submit(_unwrap_ifg_in_space, ii, grad_space[ii, :]): ii for ii in range(self.nifgs) } for fut in as_completed(futures): @@ -284,6 +377,7 @@ def _ifg_spatial_gradients_from_slc( edges: np.ndarray, grad_space: np.ndarray, link_slice: slice, + model: float | np.ndarray = 0.0, ) -> None: """Compute interferometric spatial gradients from slc data. @@ -299,6 +393,10 @@ def _ifg_spatial_gradients_from_slc( This array gets updated in place. link_slice: slice Slice corresponding to edges within the array of all links. + model: float | np.ndarray + Model prediction for spatial gradients. When provided, wraps + the gradient around the model to guide disambiguation. + Shape (nifg, nlinks_in_batch) or scalar 0.0 (default). """ # Interferogram edges ifg_inds = self._solver_time.edges @@ -315,18 +413,37 @@ def _ifg_spatial_gradients_from_slc( slc_data1[ifg_inds[:, 0], :], slc_data1[ifg_inds[:, 1], :] ) - # Update gradient in place - grad_space[:, link_slice] = utils.phase_diff(ifg_data0, ifg_data1) + # Update gradient in place, using model to guide wrapping + grad_space[:, link_slice] = utils.phase_diff(ifg_data0, ifg_data1, model=model) -def _unwrap_ifg_in_space(ifg_grad, solver_space, cost, ii): +_spatial_worker_state: dict = {} + + +def _init_spatial_worker(solver_space, cost): + """Initialize spatial unwrapping worker with solver and cost data.""" + _spatial_worker_state["solver"] = solver_space + _spatial_worker_state["cost"] = cost + + +def _unwrap_ifg_in_space(ii, ifg_grad): + solver_space = _spatial_worker_state["solver"] + cost = _spatial_worker_state["cost"] + # Compute residues residues = solver_space.compute_residues_from_gradients(ifg_grad) # Unwrap the interferogram - sequential flows = solver_space.residues_to_flows(residues, cost) - # Flood fill - out = utils.flood_fill(ifg_grad, solver_space.edges, flows, mode="gradients") + # Flood fill - tolerate closure errors by filling with NaN + try: + out = utils.flood_fill(ifg_grad, solver_space.edges, flows, mode="gradients") + except ValueError as e: + if "closure errors" in str(e): + logger.warning(f"Spatial unwrapping {ii + 1}: {e}. Filling with NaN.") + out = np.full(solver_space.npoints, np.nan, dtype=np.float32) + else: + raise logger.info(f"Completed spatial unwrapping {ii + 1}") return ii, out diff --git a/src/spurt/workflows/emcf/_unwrap.py b/src/spurt/workflows/emcf/_unwrap.py index f7651eb..c8de615 100644 --- a/src/spurt/workflows/emcf/_unwrap.py +++ b/src/spurt/workflows/emcf/_unwrap.py @@ -7,7 +7,7 @@ import spurt -from ._settings import GeneralSettings, SolverSettings +from ._settings import GeneralSettings, LinkModelSettings, SolverSettings from ._solver import EMCFSolver logger = spurt.utils.logger @@ -20,13 +20,14 @@ def unwrap_tiles( g_time: spurt.graph.PlanarGraphInterface, gen_settings: GeneralSettings, solv_settings: SolverSettings, + link_model_settings: LinkModelSettings | None = None, ) -> None: """Unwrap each tile and save to h5.""" # Load tile set tile_json = gen_settings.tiles_jsonname tiledata = spurt.utils.TileSet.from_json(tile_json) - mp_context = mp.get_context("fork") + mp_context = mp.get_context("forkserver") with ProcessPoolExecutor( max_workers=solv_settings.num_parallel_tiles, mp_context=mp_context ) as executor: @@ -48,6 +49,7 @@ def unwrap_tiles( g_time, solv_settings, tt, + link_model_settings, ) ] = tt @@ -63,6 +65,7 @@ def _unwrap_one_tile( g_time: spurt.graph.PlanarGraphInterface, solv_settings: SolverSettings, tile_num: int, + link_model_settings: LinkModelSettings | None = None, ) -> None: """Unwrap tile-by-tile.""" # Get tile information @@ -83,11 +86,16 @@ def _unwrap_one_tile( ) s_space = spurt.mcf.ORMCFSolver(g_space) # type: ignore[abstract] + # Build link model if settings provided + link_model = None + if link_model_settings is not None and link_model_settings.enabled: + link_model = _build_link_model(g_time, stack.dates, link_model_settings) + # EMCF solver - solver = EMCFSolver(s_space, s_time, solv_settings) + solver = EMCFSolver(s_space, s_time, solv_settings, link_model) wrap_data = stack.read_tile(tile.space) assert wrap_data.shape[1] == g_space.npoints - logger.info(f"Time steps: {solver.nifgs}") + logger.info(f"Interferograms: {solver.nifgs}") logger.info(f"Number of points: {solver.npoints}") uw_data = solver.unwrap_cube(wrap_data) @@ -103,19 +111,109 @@ def _unwrap_one_tile( wrap_data.data[ifgs[:, 0], 0], wrap_data.data[ifgs[:, 1], 0] ) - _dump_tile_to_h5(tile_output, uw_data, phase_offset, g_space, tile) + _dump_tile_to_h5( + tile_output, + uw_data, + phase_offset, + g_space, + tile, + solver.link_params, + solver.link_coherence, + ) logger.info(f"Wrote tile {tt + 1} to {tile_output}") +def _build_link_model( + g_time: spurt.graph.PlanarGraphInterface, + dates: list[str], + settings: LinkModelSettings, +) -> spurt.links.GridSearchLinearModel: + """Build link model from settings and baseline data.""" + from spurt.io import load_baseline_csv + from spurt.io._baseline import _parse_date + from spurt.links import GridSearchLinearModel, build_design_matrix + + assert settings.baseline_csv is not None + baseline_data = load_baseline_csv(settings.baseline_csv) + + # Convert stack dates to datetime64 (dates may be YYYYMMDD strings, + # which numpy misparses; normalize to ISO YYYY-MM-DD first) + stack_dates = np.array([_parse_date(d) for d in dates], dtype="datetime64[D]") + + # Build design matrix + amat = build_design_matrix( + ifg_edges=g_time.links, + dates=stack_dates, + bperp_m=_interpolate_baselines(stack_dates, baseline_data), + wavelength_m=settings.wavelength_m, + slant_range_m=settings.slant_range_m, + look_angle_rad=settings.look_angle_rad, + ) + + # Create grid search model + return GridSearchLinearModel( + matrix=amat, + ranges=(slice(*settings.velocity_range), slice(*settings.dem_error_range)), + ) + + +def _interpolate_baselines( + stack_dates: np.ndarray, + baseline_data: spurt.io.BaselineData, +) -> np.ndarray: + """Interpolate baselines to match stack dates. + + If stack dates match baseline dates exactly, returns baselines directly. + Otherwise, linearly interpolates baselines for missing dates. + + Parameters + ---------- + stack_dates : np.ndarray + SLC dates from the stack as datetime64[D]. + baseline_data : spurt.io.BaselineData + Baseline data loaded from CSV. + + Returns + ------- + np.ndarray + Perpendicular baselines matched to stack dates. + """ + # Check for exact match + if len(stack_dates) == len(baseline_data.dates) and np.all( + stack_dates == baseline_data.dates + ): + return baseline_data.bperp_m + + # Perpendicular baselines depend on orbital geometry, not time, so + # linear interpolation is only a rough approximation. + n_missing = np.sum(~np.isin(stack_dates, baseline_data.dates)) + logger.warning( + f"Baseline dates do not match stack dates ({n_missing} dates missing" + f" from baseline CSV). Linearly interpolating baselines; this is only" + f" approximate since Bperp depends on orbital geometry, not time." + ) + + stack_days = stack_dates.astype("datetime64[D]").astype(np.float64) + baseline_days = baseline_data.dates.astype("datetime64[D]").astype(np.float64) + return np.interp(stack_days, baseline_days, baseline_data.bperp_m) + + def _dump_tile_to_h5( fname: str, uw: np.ndarray, off: np.ndarray, gspace: spurt.graph.PlanarGraphInterface, tile: spurt.utils.BBox, + link_params: np.ndarray | None = None, + link_coherence: np.ndarray | None = None, ) -> None: with h5py.File(fname, "w") as fid: fid["uw_data"] = uw fid["points"] = gspace.points.astype(np.int32) fid["tile"] = np.array(tile.tolist()).astype(np.int32) fid["phase_offset"] = off.astype(np.float32) + + if link_params is not None: + fid["link_params"] = link_params.astype(np.float32) + if link_coherence is not None: + fid["link_coherence"] = link_coherence.astype(np.float32) diff --git a/test/io/test_baseline.py b/test/io/test_baseline.py new file mode 100644 index 0000000..1a68314 --- /dev/null +++ b/test/io/test_baseline.py @@ -0,0 +1,459 @@ +"""Tests for baseline CSV loading functionality.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from spurt.io import BaselineData, load_baseline_csv + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_tmp_csv(content: str) -> Path: + """Write content to a temporary CSV and return the path.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(content) + return Path(f.name) + + +# --------------------------------------------------------------------------- +# BaselineData validation +# --------------------------------------------------------------------------- + + +def test_baseline_data_shape_mismatch(): + """Test that BaselineData validates shape consistency.""" + dates = np.array(["2020-01-01", "2020-01-13"], dtype="datetime64[D]") + bperp = np.array([0.0, 100.0, 200.0]) + + with pytest.raises(ValueError, match="Shape mismatch"): + BaselineData(dates=dates, bperp_m=bperp) + + +def test_baseline_data_not_1d(): + """Test that BaselineData rejects non-1D arrays.""" + dates = np.array([["2020-01-01"], ["2020-01-13"]], dtype="datetime64[D]") + bperp = np.array([[0.0], [100.0]]) + + with pytest.raises(ValueError, match="Expected 1D arrays"): + BaselineData(dates=dates, bperp_m=bperp) + + +# --------------------------------------------------------------------------- +# Invalid / unrecognized format +# --------------------------------------------------------------------------- + + +def test_invalid_csv_raises(): + """Test that invalid CSV format raises ValueError.""" + csv_path = _write_tmp_csv("foo,bar,baz\n1,2,3\n4,5,6\n") + try: + with pytest.raises(ValueError, match="Unrecognized CSV format"): + load_baseline_csv(csv_path) + finally: + csv_path.unlink() + + +# --------------------------------------------------------------------------- +# Per-SLC format +# --------------------------------------------------------------------------- + + +def test_load_per_slc_csv(): + """Test loading per-SLC format CSV with YYYYMMDD dates.""" + csv_content = ( + "date,bperp_m\n20200101,0.0\n20200113,150.5\n20200125,-200.3\n20200206,50.0\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + assert isinstance(bd, BaselineData) + assert len(bd.dates) == 4 + + expected_dates = np.array( + ["2020-01-01", "2020-01-13", "2020-01-25", "2020-02-06"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + np.testing.assert_array_almost_equal(bd.bperp_m, [0.0, 150.5, -200.3, 50.0]) + finally: + csv_path.unlink() + + +def test_load_per_slc_csv_unsorted(): + """Test that per-SLC CSV with unsorted dates gets sorted.""" + csv_content = ( + "date,bperp_m\n20200125,-200.3\n20200101,0.0\n20200206,50.0\n20200113,150.5\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + expected_dates = np.array( + ["2020-01-01", "2020-01-13", "2020-01-25", "2020-02-06"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + np.testing.assert_array_almost_equal(bd.bperp_m, [0.0, 150.5, -200.3, 50.0]) + finally: + csv_path.unlink() + + +# --------------------------------------------------------------------------- +# Per-IFG format - simple YYYYMMDD dates +# --------------------------------------------------------------------------- + + +def test_load_per_ifg_csv(): + """Test loading per-IFG format CSV and conversion to per-SLC.""" + csv_content = ( + "reference,secondary,bperp_m\n" + "20200101,20200113,100.0\n" + "20200113,20200125,100.0\n" + "20200101,20200125,200.0\n" + "20200125,20200206,100.0\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 4 + assert bd.bperp_m[0] == 0.0 + np.testing.assert_allclose( + np.diff(bd.bperp_m), [100.0, 100.0, 100.0], atol=1e-6 + ) + finally: + csv_path.unlink() + + +# --------------------------------------------------------------------------- +# Per-IFG format +# --------------------------------------------------------------------------- + +_CAPELLA_CSV = """\ +reference,secondary,reference_time_utc,secondary_time_utc,btemp_days,bperp_m,bpar_m,btotal_m,bradial_m,bnormal_m,look_angle_diff_rad +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260206022355_20260206022404.tif,2026-01-31T04:30:30.757865503Z,2026-02-06T02:23:59.736190396Z,5.912,79.512,65.021,106.861,18.250,105.291,0.00012 +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260209012040_20260209012049.tif,2026-01-31T04:30:30.757865503Z,2026-02-09T01:20:44.804243296Z,8.868,369.187,285.877,469.717,70.814,464.348,0.00053 +slcs/CAPELLA_C13_SP_SLC_HH_20260206022355_20260206022404.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260209012040_20260209012049.tif,2026-02-06T02:23:59.736190396Z,2026-02-09T01:20:44.804243296Z,2.956,289.675,220.856,362.856,52.564,359.057,0.00041 +""" + + +def test_load_per_ifg_capella_filenames(): + """Test per-IFG loading with Capella-style filenames and UTC time columns.""" + csv_path = _write_tmp_csv(_CAPELLA_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 3 + expected_dates = np.array( + ["2026-01-31", "2026-02-06", "2026-02-09"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + + # Reference SLC should be zero + assert bd.bperp_m[0] == 0.0 + + # IFG baselines should be self-consistent (redundant network) + np.testing.assert_allclose(bd.bperp_m[1] - bd.bperp_m[0], 79.512, atol=0.1) + np.testing.assert_allclose(bd.bperp_m[2] - bd.bperp_m[0], 369.187, atol=0.1) + np.testing.assert_allclose(bd.bperp_m[2] - bd.bperp_m[1], 289.675, atol=0.1) + finally: + csv_path.unlink() + + +def test_load_per_ifg_capella_without_time_columns(): + """Test that Capella filenames work even without *_time_utc columns. + + Dates should be extracted from the filename's first YYYYMMDD substring. + """ + csv_content = ( + "reference,secondary,bperp_m\n" + "slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif," + "slcs/CAPELLA_C13_SP_SLC_HH_20260206022355_20260206022404.tif,79.512\n" + "slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif," + "slcs/CAPELLA_C13_SP_SLC_HH_20260209012040_20260209012049.tif,369.187\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 3 + expected_dates = np.array( + ["2026-01-31", "2026-02-06", "2026-02-09"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + assert bd.bperp_m[0] == 0.0 + finally: + csv_path.unlink() + + +# --------------------------------------------------------------------------- +# Single-reference IFG format (per-SLC baselines in IFG CSV) +# --------------------------------------------------------------------------- + +_SINGLE_REF_CSV = """\ +reference,secondary,bperp_m +20230105,20230117,95.3 +20230105,20230129,210.7 +20230105,20230210,305.1 +""" + + +def test_load_single_reference_ifg(): + """Test single-reference IFG CSV uses direct assignment (no lstsq).""" + csv_path = _write_tmp_csv(_SINGLE_REF_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 4 + assert bd.bperp_m[0] == 0.0 + # Values should be exact — no least-squares involved + np.testing.assert_array_equal(bd.bperp_m, [0.0, 95.3, 210.7, 305.1]) + finally: + csv_path.unlink() + + +_SINGLE_REF_CAPELLA_CSV = """\ +reference,secondary,reference_time_utc,secondary_time_utc,btemp_days,bperp_m,bpar_m,btotal_m,bradial_m,bnormal_m,look_angle_diff_rad +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260206022355_20260206022404.tif,2026-01-31T04:30:30.757865503Z,2026-02-06T02:23:59.736190396Z,5.912,79.512,65.021,106.861,18.250,105.291,0.00012 +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260209012040_20260209012049.tif,2026-01-31T04:30:30.757865503Z,2026-02-09T01:20:44.804243296Z,8.868,369.187,285.877,469.717,70.814,464.348,0.00053 +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260212001725_20260212001734.tif,2026-01-31T04:30:30.757865503Z,2026-02-12T00:17:29.220234759Z,11.824,441.233,335.754,557.799,79.465,552.110,0.00063 +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260214231409_20260214231418.tif,2026-01-31T04:30:30.757865503Z,2026-02-14T23:14:13.365825911Z,14.780,443.027,325.736,554.367,69.824,549.952,0.00064 +slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif,slcs/CAPELLA_C13_SP_SLC_HH_20260217221050_20260217221059.tif,2026-01-31T04:30:30.757865503Z,2026-02-17T22:10:55.094275436Z,17.736,4.588,16.477,38.075,12.147,36.086,4.86e-05 +""" + + +def test_load_single_reference_capella_full(): + """Test the actual Capella single-reference CSV format end-to-end. + + This mirrors the real output from the Capella baseline processor: + all rows share one reference, with UTC time columns and extra baseline cols. + """ + csv_path = _write_tmp_csv(_SINGLE_REF_CAPELLA_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 6 + expected_dates = np.array( + [ + "2026-01-31", + "2026-02-06", + "2026-02-09", + "2026-02-12", + "2026-02-14", + "2026-02-17", + ], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + + # Reference should be zero, values should be exact (single-ref path) + assert bd.bperp_m[0] == 0.0 + np.testing.assert_allclose(bd.bperp_m[1], 79.512, atol=1e-3) + np.testing.assert_allclose(bd.bperp_m[2], 369.187, atol=1e-3) + np.testing.assert_allclose(bd.bperp_m[3], 441.233, atol=1e-3) + np.testing.assert_allclose(bd.bperp_m[4], 443.027, atol=1e-3) + np.testing.assert_allclose(bd.bperp_m[5], 4.588, atol=1e-3) + finally: + csv_path.unlink() + + +def test_single_ref_gives_same_as_lstsq(): + """Verify single-reference shortcut matches what lstsq would give.""" + # Build a small single-reference CSV that also works as a valid network + csv_content = ( + "reference,secondary,bperp_m\n" + "20230101,20230113,100.0\n" + "20230101,20230125,250.0\n" + "20230101,20230206,400.0\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + # Should be exact (no lstsq noise) + np.testing.assert_array_equal(bd.bperp_m, [0.0, 100.0, 250.0, 400.0]) + finally: + csv_path.unlink() + + +# --------------------------------------------------------------------------- +# _parse_date edge cases +# --------------------------------------------------------------------------- + + +def test_parse_date_iso_timestamp(): + """Test parsing ISO 8601 timestamps.""" + from spurt.io._baseline import _parse_date + + assert _parse_date("2026-01-31T04:30:30.757865503Z") == "2026-01-31" + assert _parse_date("2026-02-06T02:23:59.736190396Z") == "2026-02-06" + + +def test_parse_date_filename(): + """Test date extraction from Capella-style filenames.""" + from spurt.io._baseline import _parse_date + + fname = "slcs/CAPELLA_C13_SP_SLC_HH_20260131043025_20260131043034.tif" + assert _parse_date(fname) == "2026-01-31" + + +def test_parse_date_unparseable(): + """Test that unparseable strings raise ValueError.""" + from spurt.io._baseline import _parse_date + + with pytest.raises(ValueError, match="Cannot parse date"): + _parse_date("no-date-here") + + +def test_parse_date_compact_with_time(): + """Test parsing compact YYYYMMDD[T]HHMMSS strings used as stack date keys. + + The bare ``T`` separator without dashes (e.g. ``20200101T120000``) used + to be misclassified as an ISO 8601 timestamp and truncated to + ``20200101T1`` before regex fallback could see it. + """ + from spurt.io._baseline import _parse_date + + assert _parse_date("20200101120000") == "2020-01-01" + assert _parse_date("20200101T120000") == "2020-01-01" + + +# --------------------------------------------------------------------------- +# OPERA CSLC-S1 filenames +# --------------------------------------------------------------------------- + +_OPERA_CSLC_CSV = """\ +reference,secondary,bperp_m +OPERA_L2_CSLC-S1_T078-165495-IW2_20230105T120000Z_20230120T000000Z_S1A_VV_v1.0.h5,OPERA_L2_CSLC-S1_T078-165495-IW2_20230117T120000Z_20230201T000000Z_S1A_VV_v1.0.h5,95.3 +OPERA_L2_CSLC-S1_T078-165495-IW2_20230105T120000Z_20230120T000000Z_S1A_VV_v1.0.h5,OPERA_L2_CSLC-S1_T078-165495-IW2_20230129T120000Z_20230212T000000Z_S1A_VV_v1.0.h5,210.7 +OPERA_L2_CSLC-S1_T078-165495-IW2_20230117T120000Z_20230201T000000Z_S1A_VV_v1.0.h5,OPERA_L2_CSLC-S1_T078-165495-IW2_20230129T120000Z_20230212T000000Z_S1A_VV_v1.0.h5,115.4 +""" + + +def test_load_per_ifg_opera_cslc(): + """Test per-IFG loading with OPERA CSLC-S1 filenames (no time columns).""" + csv_path = _write_tmp_csv(_OPERA_CSLC_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 3 + expected_dates = np.array( + ["2023-01-05", "2023-01-17", "2023-01-29"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + assert bd.bperp_m[0] == 0.0 + np.testing.assert_allclose(bd.bperp_m[1] - bd.bperp_m[0], 95.3, atol=0.1) + np.testing.assert_allclose(bd.bperp_m[2] - bd.bperp_m[0], 210.7, atol=0.1) + finally: + csv_path.unlink() + + +def test_parse_date_opera_cslc(): + """Test date extraction from OPERA CSLC-S1 filenames.""" + from spurt.io._baseline import _parse_date + + f = "OPERA_L2_CSLC-S1_T078-165495-IW2_20230105T120000Z_20230120T000000Z_S1A_VV_v1.0.h5" # noqa: E501 + assert _parse_date(f) == "2023-01-05" + + +_OPERA_COMPRESSED_CSLC_CSV = """\ +reference,secondary,bperp_m +OPERA_L2_COMPRESSED-CSLC-S1_F23148_T078-165495-IW2_20230105T120000Z_20221001T000000Z_20230101T000000Z_20230115T000000Z_VV_v1.1.h5,OPERA_L2_COMPRESSED-CSLC-S1_F23148_T078-165495-IW2_20230117T120000Z_20230101T000000Z_20230201T000000Z_20230210T000000Z_VV_v1.1.h5,88.2 +OPERA_L2_COMPRESSED-CSLC-S1_F23148_T078-165495-IW2_20230105T120000Z_20221001T000000Z_20230101T000000Z_20230115T000000Z_VV_v1.1.h5,OPERA_L2_COMPRESSED-CSLC-S1_F23148_T078-165495-IW2_20230129T120000Z_20230101T000000Z_20230301T000000Z_20230310T000000Z_VV_v1.1.h5,195.0 +""" + + +def test_load_per_ifg_opera_compressed_cslc(): + """Test per-IFG loading with OPERA COMPRESSED-CSLC-S1 filenames.""" + csv_path = _write_tmp_csv(_OPERA_COMPRESSED_CSLC_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 3 + expected_dates = np.array( + ["2023-01-05", "2023-01-17", "2023-01-29"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + assert bd.bperp_m[0] == 0.0 + np.testing.assert_allclose(bd.bperp_m[1] - bd.bperp_m[0], 88.2, atol=0.1) + finally: + csv_path.unlink() + + +def test_parse_date_opera_compressed_cslc(): + """Test date extraction from OPERA COMPRESSED-CSLC-S1 filenames.""" + from spurt.io._baseline import _parse_date + + f = "OPERA_L2_COMPRESSED-CSLC-S1_F23148_T078-165495-IW2_20230105T120000Z_20221001T000000Z_20230101T000000Z_20230115T000000Z_VV_v1.1.h5" # noqa: E501 + assert _parse_date(f) == "2023-01-05" + + +# --------------------------------------------------------------------------- +# Sentinel-1 IW SAFE filenames +# --------------------------------------------------------------------------- + +_S1_SAFE_CSV = """\ +reference,secondary,bperp_m +S1A_IW_SLC__1SDV_20230105T120000_20230105T120030_046801_059E44_A1B2.SAFE,S1A_IW_SLC__1SDV_20230117T120000_20230117T120030_046976_05A3F1_C3D4.SAFE,102.5 +S1A_IW_SLC__1SDV_20230105T120000_20230105T120030_046801_059E44_A1B2.SAFE,S1B_IW_SLC__1SDV_20230129T120000_20230129T120030_047151_05A9BE_E5F6.SAFE,215.8 +S1A_IW_SLC__1SDV_20230117T120000_20230117T120030_046976_05A3F1_C3D4.SAFE,S1B_IW_SLC__1SDV_20230129T120000_20230129T120030_047151_05A9BE_E5F6.SAFE,113.3 +""" + + +def test_load_per_ifg_sentinel1_safe(): + """Test per-IFG loading with Sentinel-1 SAFE directory names.""" + csv_path = _write_tmp_csv(_S1_SAFE_CSV) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 3 + expected_dates = np.array( + ["2023-01-05", "2023-01-17", "2023-01-29"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + assert bd.bperp_m[0] == 0.0 + np.testing.assert_allclose(bd.bperp_m[1] - bd.bperp_m[0], 102.5, atol=0.1) + np.testing.assert_allclose(bd.bperp_m[2] - bd.bperp_m[0], 215.8, atol=0.1) + np.testing.assert_allclose(bd.bperp_m[2] - bd.bperp_m[1], 113.3, atol=0.1) + finally: + csv_path.unlink() + + +def test_parse_date_sentinel1_safe(): + """Test date extraction from Sentinel-1 SAFE names.""" + from spurt.io._baseline import _parse_date + + f = "S1A_IW_SLC__1SDV_20230105T120000_20230105T120030_046801_059E44_A1B2.SAFE" + assert _parse_date(f) == "2023-01-05" + + +def test_load_per_ifg_sentinel1_safe_paths(): + """Test with full path-like Sentinel-1 SAFE references.""" + csv_content = ( + "reference,secondary,bperp_m\n" + "/data/slcs/S1A_IW_SLC__1SDV_20230105T120000_20230105T120030_046801_059E44_A1B2.SAFE/measurement/s1a-iw1-slc-vv.tiff," + "/data/slcs/S1A_IW_SLC__1SDV_20230117T120000_20230117T120030_046976_05A3F1_C3D4.SAFE/measurement/s1a-iw1-slc-vv.tiff,102.5\n" + ) + csv_path = _write_tmp_csv(csv_content) + try: + bd = load_baseline_csv(csv_path) + + assert len(bd.dates) == 2 + expected_dates = np.array( + ["2023-01-05", "2023-01-17"], + dtype="datetime64[D]", + ) + np.testing.assert_array_equal(bd.dates, expected_dates) + finally: + csv_path.unlink() diff --git a/test/io/test_slcstack.py b/test/io/test_slcstack.py index fd54a80..0a9249a 100644 --- a/test/io/test_slcstack.py +++ b/test/io/test_slcstack.py @@ -27,6 +27,40 @@ def has_testdata() -> bool: return p.is_dir() +class TestExtractDateStr: + """Tests for the date-string extractor used during stack scanning.""" + + def test_default_format_extracts_eight_chars(self): + from spurt.io._slc_stack import _date_str_length, _extract_date_str + + date_len = _date_str_length("%Y%m%d") + assert _extract_date_str("20240709_extra", "%Y%m%d", date_len) == "20240709" + + def test_with_time_of_day(self): + from spurt.io._slc_stack import _date_str_length, _extract_date_str + + fmt = "%Y%m%d%H%M%S" + date_len = _date_str_length(fmt) + token = "20240709040329_secondary" + assert _extract_date_str(token, fmt, date_len) == "20240709040329" + + def test_invalid_date_raises(self): + from spurt.io._slc_stack import _date_str_length, _extract_date_str + + date_len = _date_str_length("%Y%m%d") + with pytest.raises(ValueError, match="does not match format"): + _extract_date_str("notadate_xx", "%Y%m%d", date_len) + + def test_format_mismatch_raises(self): + """A format that doesn't match the filename token should fail loudly.""" + from spurt.io._slc_stack import _date_str_length, _extract_date_str + + # Token has compact YYYYMMDD but user passed dashed format of same length + date_len = _date_str_length("%Y-%m-%d") + with pytest.raises(ValueError, match="does not match format"): + _extract_date_str("20240709aa", "%Y-%m-%d", date_len) + + @pytest.mark.skipif( (not has_rasterio()) or (not has_testdata()), reason="Either rasterio or test data not available", diff --git a/test/links/test_design_matrix.py b/test/links/test_design_matrix.py new file mode 100644 index 0000000..444a9a7 --- /dev/null +++ b/test/links/test_design_matrix.py @@ -0,0 +1,146 @@ +"""Tests for design matrix construction.""" + +import numpy as np + +from spurt.links import build_design_matrix + + +def test_design_matrix_shape(): + """Test that design matrix has correct shape.""" + # 4 SLCs -> 5 IFGs (hop-3 style: 0-1, 0-2, 0-3, 1-2, 1-3, 2-3) + ifg_edges = np.array([[0, 1], [0, 2], [1, 2], [1, 3], [2, 3]]) + dates = np.array( + ["2020-01-01", "2020-01-13", "2020-01-25", "2020-02-06"], + dtype="datetime64[D]", + ) + bperp_m = np.array([0.0, 100.0, 200.0, 150.0]) + + amat = build_design_matrix( + ifg_edges=ifg_edges, + dates=dates, + bperp_m=bperp_m, + wavelength_m=0.055465, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + assert amat.shape == (5, 2) + + +def test_velocity_sensitivity_calculation(): + """Test velocity sensitivity column calculation. + + For velocity in mm/yr, the phase sensitivity is: + dphi/dv = 4*pi/wavelength * dt_days/365.25 * 0.001 + """ + ifg_edges = np.array([[0, 1]]) + dates = np.array(["2020-01-01", "2021-01-01"], dtype="datetime64[D]") # 366 days + bperp_m = np.array([0.0, 0.0]) # No baseline -> no DEM error sensitivity + + wavelength_m = 0.055465 + + amat = build_design_matrix( + ifg_edges=ifg_edges, + dates=dates, + bperp_m=bperp_m, + wavelength_m=wavelength_m, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + # Expected: 4*pi/0.055465 * 366/365.25 * 0.001 = 0.2273 rad/(mm/yr) + expected_vel_sens = 4.0 * np.pi / wavelength_m * 366.0 / 365.25 * 0.001 + np.testing.assert_almost_equal(amat[0, 0], expected_vel_sens, decimal=6) + + +def test_dem_error_sensitivity_calculation(): + """Test DEM error sensitivity column calculation. + + For DEM error in meters, the phase sensitivity is: + dphi/dh = 4*pi/wavelength * bperp / (slant_range * sin(look)) + """ + ifg_edges = np.array([[0, 1]]) + dates = np.array(["2020-01-01", "2020-01-01"], dtype="datetime64[D]") # Same date + bperp_m = np.array([0.0, 100.0]) # 100m baseline difference + + wavelength_m = 0.055465 + slant_range_m = 900000.0 + look_angle_rad = np.radians(39.0) + + amat = build_design_matrix( + ifg_edges=ifg_edges, + dates=dates, + bperp_m=bperp_m, + wavelength_m=wavelength_m, + slant_range_m=slant_range_m, + look_angle_rad=look_angle_rad, + ) + + # Velocity sensitivity should be 0 (same date) + np.testing.assert_almost_equal(amat[0, 0], 0.0, decimal=10) + + # Expected DEM error sensitivity: 4*pi/wavelength * 100 / (900000 * sin(39deg)) + expected_dem_sens = ( + 4.0 * np.pi / wavelength_m * 100.0 / (slant_range_m * np.sin(look_angle_rad)) + ) + np.testing.assert_almost_equal(amat[0, 1], expected_dem_sens, decimal=10) + + +def test_negative_baseline_difference(): + """Test that negative baseline differences are handled correctly.""" + ifg_edges = np.array([[0, 1]]) + dates = np.array(["2020-01-01", "2020-01-13"], dtype="datetime64[D]") + bperp_m = np.array([100.0, 0.0]) # Secondary has smaller bperp + + amat = build_design_matrix( + ifg_edges=ifg_edges, + dates=dates, + bperp_m=bperp_m, + wavelength_m=0.055465, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + # DEM error sensitivity should be negative + assert amat[0, 1] < 0 + + +def test_multiple_ifgs(): + """Test design matrix for multiple interferograms.""" + # Hop3-style edges for 5 SLCs + ifg_edges = np.array( + [ + [0, 1], + [0, 2], + [0, 3], + [1, 2], + [1, 3], + [1, 4], + [2, 3], + [2, 4], + [3, 4], + ] + ) + dates = np.array( + ["2020-01-01", "2020-01-13", "2020-01-25", "2020-02-06", "2020-02-18"], + dtype="datetime64[D]", + ) + bperp_m = np.array([0.0, 50.0, 100.0, -50.0, 25.0]) + + amat = build_design_matrix( + ifg_edges=ifg_edges, + dates=dates, + bperp_m=bperp_m, + wavelength_m=0.055465, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + assert amat.shape == (9, 2) + + # All velocity sensitivities should be positive (time always increases) + assert np.all(amat[:, 0] > 0) + + # Longer temporal baselines should have larger velocity sensitivity + # IFG 0-3 (36 days) should have larger sensitivity than 0-1 (12 days) + assert amat[2, 0] > amat[0, 0] diff --git a/test/links/test_grid_search.py b/test/links/test_grid_search.py index a13cafd..11f2aec 100644 --- a/test/links/test_grid_search.py +++ b/test/links/test_grid_search.py @@ -1,6 +1,7 @@ import numpy as np import spurt +from spurt.links._grid_search import _vectorized_grid_search # Fix the seed for repeatability np.random.seed(32) @@ -71,6 +72,35 @@ def test_grid_estimate(): assert coh > 0.99 +def test_vectorized_grid_search(): + """Test that vectorized grid search finds the correct grid point.""" + amat, vel_range, demerr_range = gen_data() + + # Noiseless data with known parameters on grid points + vel_true = 10.0 + dem_true = 2.0 + fwd_phase = amat @ np.array([vel_true, dem_true]) + wts = 1.0 / amat.shape[0] + + x0 = _vectorized_grid_search(amat, (vel_range, demerr_range), fwd_phase, wts) + np.testing.assert_allclose(x0[0], vel_true, atol=1e-10) + np.testing.assert_allclose(x0[1], dem_true, atol=1e-10) + + +def test_vectorized_grid_search_with_weights(): + """Test vectorized grid search with per-observation weight arrays.""" + amat, vel_range, demerr_range = gen_data() + + vel_true = -15.0 + dem_true = 1.5 + fwd_phase = amat @ np.array([vel_true, dem_true]) + wts = np.ones(amat.shape[0]) / amat.shape[0] + + x0 = _vectorized_grid_search(amat, (vel_range, demerr_range), fwd_phase, wts) + np.testing.assert_allclose(x0[0], vel_true, atol=1e-10) + np.testing.assert_allclose(x0[1], dem_true, atol=1e-10) + + def test_grid_estimate_many(): """Test estimate_many.""" amat, vel_range, demerr_range = gen_data() diff --git a/test/workflows/test_emcf_baseline_integration.py b/test/workflows/test_emcf_baseline_integration.py new file mode 100644 index 0000000..0e3f4dd --- /dev/null +++ b/test/workflows/test_emcf_baseline_integration.py @@ -0,0 +1,286 @@ +"""Integration tests for EMCF workflow with baseline/velocity estimation.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +import spurt +from spurt.io import load_baseline_csv +from spurt.links import build_design_matrix +from spurt.workflows.emcf import LinkModelSettings + + +def test_link_model_settings_validation(): + """Test LinkModelSettings validation.""" + # Should work with baseline_csv provided + settings = LinkModelSettings( + enabled=True, + baseline_csv="/path/to/file.csv", + ) + assert settings.enabled is True + + # Should fail without baseline_csv when enabled + with pytest.raises(ValueError, match="baseline_csv is required"): + LinkModelSettings(enabled=True, baseline_csv=None) + + # Should work when disabled without baseline_csv + settings = LinkModelSettings(enabled=False, baseline_csv=None) + assert settings.enabled is False + + +def test_link_model_settings_look_angle_conversion(): + """Test look angle conversion from degrees to radians.""" + settings = LinkModelSettings( + enabled=True, + baseline_csv="/path/to/file.csv", + look_angle_deg=45.0, + ) + + expected_rad = np.radians(45.0) + np.testing.assert_almost_equal(settings.look_angle_rad, expected_rad) + + +def test_baseline_and_design_matrix_integration(): + """Test that baseline loading and design matrix construction work together.""" + # Create a per-SLC baseline CSV + csv_content = """date,bperp_m +20200101,0.0 +20200113,100.0 +20200125,200.0 +20200206,150.0 +20200218,250.0 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(csv_content) + f.flush() + csv_path = Path(f.name) + + try: + # Load baselines + baseline_data = load_baseline_csv(csv_path) + assert len(baseline_data.dates) == 5 + + # Create temporal graph + g_time = spurt.graph.Hop3Graph(5) + + # Build design matrix + amat = build_design_matrix( + ifg_edges=g_time.links, + dates=baseline_data.dates, + bperp_m=baseline_data.bperp_m, + wavelength_m=0.055465, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + # Should have correct shape + assert amat.shape == (len(g_time.links), 2) + + # All velocity sensitivities should be positive + assert np.all(amat[:, 0] > 0) + + finally: + csv_path.unlink() + + +def test_emcf_with_baseline_csv_full_workflow(): + """Test EMCF solver with baseline CSV and link model estimation. + + This is a full workflow test that: + 1. Creates synthetic phase data with known velocity + 2. Creates a baseline CSV file + 3. Configures LinkModelSettings + 4. Runs the solver with link model + 5. Verifies results + """ + n_sar = 10 + y, x = np.ogrid[-3:3:32j, -3:3:32j] + + # Velocity field in mm/yr (converted to radians for phase) + vel_mm_yr = -20.0 * np.exp(-(x**2 + y**2) / 5) + vel_mm_yr -= vel_mm_yr.max() + + # Create dates (12 day repeat cycle) + base_date = np.datetime64("2020-01-01") + dates = [base_date + np.timedelta64(i * 12, "D") for i in range(n_sar)] + dates_str = [str(d) for d in dates] + + # Create baselines (no DEM error in this test) + bperp_m = np.zeros(n_sar) + + # Convert velocity to phase + wavelength_m = 0.055465 + times_days = np.array([i * 12 for i in range(n_sar)], dtype=np.float64) + times_years = times_days / 365.25 + + # Phase = 4*pi/wavelength * velocity_m * time_years + # velocity_m = velocity_mm_yr * 0.001 + phase_rate = 4.0 * np.pi / wavelength_m * 0.001 # rad per (mm/yr * year) + phase = times_years[:, None, None] * vel_mm_yr[None, :, :] * phase_rate + + igram = np.exp(1j * phase) + + # Create baseline CSV + csv_content = "date,bperp_m\n" + for d, b in zip(dates_str, bperp_m): + date_str = str(d).replace("-", "") + csv_content += f"{date_str},{b}\n" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(csv_content) + f.flush() + csv_path = Path(f.name) + + try: + # Set up time processing + g_time = spurt.graph.Hop3Graph(n_sar) + s_time = spurt.mcf.ORMCFSolver(g_time) + + # Set up spatial processing + g_space = spurt.graph.Reg2DGraph(igram.shape[1:]) + s_space = spurt.mcf.ORMCFSolver(g_space) + + # Load baselines and build design matrix + baseline_data = load_baseline_csv(csv_path) + stack_dates = np.array(dates, dtype="datetime64[D]") + + amat = build_design_matrix( + ifg_edges=g_time.links, + dates=stack_dates, + bperp_m=baseline_data.bperp_m, + wavelength_m=wavelength_m, + slant_range_m=900000.0, + look_angle_rad=np.radians(39.0), + ) + + # Create link model + vel_range = slice(-50.0, 10.0, 2.0) + dem_range = slice(-5.0, 5.0, 1.0) + link_model = spurt.links.GridSearchLinearModel( + matrix=amat, + ranges=(vel_range, dem_range), + ) + + # Create EMCF solver with link model + settings = spurt.workflows.emcf.SolverSettings( + s_worker_count=1, + t_worker_count=1, + links_per_batch=500, + ) + solver = spurt.workflows.emcf.Solver(s_space, s_time, settings, link_model) + + w_data = spurt.io.Irreg3DInput( + igram.reshape((n_sar, g_space.npoints)), g_space.points + ) + uw_data = solver.unwrap_cube(w_data) + + # Verify link parameters were estimated + assert solver.link_params is not None + assert solver.link_coherence is not None + + # Verify unwrapping succeeded + for ii, edge in enumerate(g_time.links): + orig = phase[edge[1]] - phase[edge[0]] + recon = uw_data[ii].reshape(phase.shape[1:]) + assert np.allclose(orig - orig[0, 0], recon - recon[0, 0], atol=1.0e-3) + + # High coherence expected for clean synthetic data + assert np.mean(solver.link_coherence) > 0.95 + + finally: + csv_path.unlink() + + +def test_build_link_model_yyyymmdd_dates(): + """Test that _build_link_model handles YYYYMMDD stack dates. + + The SLCStackReader stores dates as YYYYMMDD strings, but numpy's + datetime64 misparses these (treating '20240626' as year 20240626). + This test verifies the dates are normalized before comparison. + """ + from spurt.workflows.emcf._unwrap import _build_link_model + + n_slc = 5 + g_time = spurt.graph.Hop3Graph(n_slc) + + # Stack dates in YYYYMMDD format (as SLCStackReader provides) + dates_yyyymmdd = ["20200101", "20200113", "20200125", "20200206", "20200218"] + + csv_content = ( + "date,bperp_m\n" + "20200101,0.0\n20200113,100.0\n20200125,200.0\n" + "20200206,150.0\n20200218,250.0\n" + ) + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(csv_content) + f.flush() + csv_path = Path(f.name) + + try: + settings = LinkModelSettings( + enabled=True, + baseline_csv=str(csv_path), + ) + model = _build_link_model(g_time, dates_yyyymmdd, settings) + + # Model should have correct shape: (nifgs, 2) + assert model.matrix.shape == (len(g_time.links), 2) + + # All velocity sensitivities should be positive (time moves forward) + assert np.all(model.matrix[:, 0] > 0) + + # DEM error sensitivities should be nonzero (baselines vary) + assert not np.allclose(model.matrix[:, 1], 0.0) + finally: + csv_path.unlink() + + +def test_cli_argument_parsing(): + """Test that CLI argument parsing works for baseline arguments.""" + import argparse + + test_args = [ + "-i", + "/nonexistent/path", + "--baseline-csv", + "/path/to/baselines.csv", + "--wavelength", + "0.031", + "--slant-range", + "800000", + "--look-angle-deg", + "35.0", + "--velocity-range", + "-80", + "80", + "4", + "--dem-error-range", + "-40", + "40", + "2", + ] + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--inputdir", required=True) + parser.add_argument("--baseline-csv", type=str, default=None) + parser.add_argument("--no-velocity-estimation", action="store_true") + parser.add_argument("--wavelength", type=float, default=0.055465) + parser.add_argument("--slant-range", type=float, default=900000.0) + parser.add_argument("--look-angle-deg", type=float, default=39.0) + parser.add_argument( + "--velocity-range", type=float, nargs=3, default=[-100.0, 100.0, 5.0] + ) + parser.add_argument( + "--dem-error-range", type=float, nargs=3, default=[-50.0, 50.0, 2.5] + ) + + parsed = parser.parse_args(test_args) + + assert parsed.baseline_csv == "/path/to/baselines.csv" + assert parsed.wavelength == 0.031 + assert parsed.slant_range == 800000 + assert parsed.look_angle_deg == 35.0 + assert parsed.velocity_range == [-80.0, 80.0, 4.0] + assert parsed.dem_error_range == [-40.0, 40.0, 2.0] diff --git a/test/workflows/test_emcf_link_model.py b/test/workflows/test_emcf_link_model.py new file mode 100644 index 0000000..e3e4028 --- /dev/null +++ b/test/workflows/test_emcf_link_model.py @@ -0,0 +1,191 @@ +"""Test EMCF solver with link model for DEM error and velocity estimation.""" + +import numpy as np +import pytest + +import spurt + + +def gen_data_with_velocity(): + """Generate a regular 3D dataset with known velocity field.""" + n_collects = 20 + y, x = np.ogrid[-3:3:32j, -3:3:32j] + + # Velocity field in radians per time unit + vel = -np.pi * np.exp(-(x**2 + y**2) / 5) / 12 + vel -= vel.max() + + times = np.arange(n_collects) * 12 + phase = times[:, None, None] * vel[None, :, :] + + return n_collects, times, phase, vel + + +def test_emcf_with_link_model(): + """Test EMCF unwrapping with link model for velocity estimation.""" + n_sar, times, phase, true_vel = gen_data_with_velocity() + igram = np.exp(1j * phase) + + # Set up time processing + g_time = spurt.graph.Hop3Graph(n_sar) + s_time = spurt.mcf.ORMCFSolver(g_time) + + # Set up spatial processing + g_space = spurt.graph.Reg2DGraph(igram.shape[1:]) + s_space = spurt.mcf.ORMCFSolver(g_space) + + # Build design matrix for velocity estimation + # For each interferogram, the temporal sensitivity is delta_time + nifgs = len(g_time.links) + amat = np.zeros((nifgs, 1)) + for ii, edge in enumerate(g_time.links): + amat[ii, 0] = times[edge[1]] - times[edge[0]] + + # Create link model with velocity search range + vel_range = slice(-0.5, 0.1, 0.02) + link_model = spurt.links.GridSearchLinearModel(matrix=amat, ranges=(vel_range,)) + + # Create EMCF solver with link model + settings = spurt.workflows.emcf.SolverSettings( + s_worker_count=1, + t_worker_count=1, + links_per_batch=10000, + ) + solver = spurt.workflows.emcf.Solver(s_space, s_time, settings, link_model) + + w_data = spurt.io.Irreg3DInput( + igram.reshape((n_sar, g_space.npoints)), g_space.points + ) + uw_data = solver.unwrap_cube(w_data) + + # Verify unwrapping succeeded + for ii, edge in enumerate(g_time.links): + orig = phase[edge[1]] - phase[edge[0]] + recon = uw_data[ii].reshape(phase.shape[1:]) + assert np.allclose(orig - orig[0, 0], recon - recon[0, 0], atol=1.0e-3) + + # Verify link parameters were estimated + assert solver.link_params is not None + assert solver.link_coherence is not None + assert solver.link_params.shape == (1, solver.nlinks) + assert solver.link_coherence.shape == (solver.nlinks,) + + # Coherence should be high for clean synthetic data (no noise added). + # 0.95 threshold ensures model fits well; lower values indicate estimation issues. + assert np.mean(solver.link_coherence) > 0.95 + + # Test integrate_link_params to get point velocities + point_vel = solver.integrate_link_params(param_idx=0) + assert point_vel.shape == (solver.npoints,) + + # Point velocities should correlate well with true velocity field. + # Exact values may differ due to grid search resolution (0.02) and integration. + # 0.95 correlation threshold ensures spatial pattern is recovered correctly. + true_vel_flat = true_vel.flatten() + point_vel_ref = point_vel - point_vel[0] + true_vel_ref = true_vel_flat - true_vel_flat[0] + correlation = np.corrcoef(point_vel_ref, true_vel_ref)[0, 1] + assert correlation > 0.95 + + +def test_emcf_with_link_model_ifg_input(): + """Test EMCF with link model using interferogram input.""" + n_sar, times, phase, _ = gen_data_with_velocity() + + # Set up time processing + g_time = spurt.graph.Hop3Graph(n_sar) + s_time = spurt.mcf.ORMCFSolver(g_time) + + # Set up spatial processing + g_space = spurt.graph.Reg2DGraph(phase.shape[1:]) + s_space = spurt.mcf.ORMCFSolver(g_space) + + # Create interferograms from SLC phases + nifgs = len(g_time.links) + ifg_phase = np.zeros((nifgs, *phase.shape[1:])) + for ii, edge in enumerate(g_time.links): + ifg_phase[ii] = phase[edge[1]] - phase[edge[0]] + ifg = np.exp(1j * ifg_phase) + + # Build design matrix for velocity estimation + amat = np.zeros((nifgs, 1)) + for ii, edge in enumerate(g_time.links): + amat[ii, 0] = times[edge[1]] - times[edge[0]] + + # Create link model + vel_range = slice(-0.5, 0.1, 0.02) + link_model = spurt.links.GridSearchLinearModel(matrix=amat, ranges=(vel_range,)) + + # Create EMCF solver with link model + settings = spurt.workflows.emcf.SolverSettings( + s_worker_count=1, + t_worker_count=1, + ) + solver = spurt.workflows.emcf.Solver(s_space, s_time, settings, link_model) + + w_data = spurt.io.Irreg3DInput( + ifg.reshape((nifgs, g_space.npoints)), g_space.points + ) + uw_data = solver.unwrap_cube(w_data) + + # Verify link parameters were estimated + assert solver.link_params is not None + assert solver.link_coherence is not None + + # Verify unwrapping succeeded + for ii in range(nifgs): + orig = ifg_phase[ii] + recon = uw_data[ii].reshape(phase.shape[1:]) + assert np.allclose(orig - orig[0, 0], recon - recon[0, 0], atol=1.0e-3) + + +def test_link_model_not_provided(): + """Test that link_params is None when no link model is provided.""" + n_sar, _, phase, _ = gen_data_with_velocity() + igram = np.exp(1j * phase) + + g_time = spurt.graph.Hop3Graph(n_sar) + s_time = spurt.mcf.ORMCFSolver(g_time) + + g_space = spurt.graph.Reg2DGraph(igram.shape[1:]) + s_space = spurt.mcf.ORMCFSolver(g_space) + + settings = spurt.workflows.emcf.SolverSettings( + s_worker_count=1, + t_worker_count=1, + ) + solver = spurt.workflows.emcf.Solver(s_space, s_time, settings) + + w_data = spurt.io.Irreg3DInput( + igram.reshape((n_sar, g_space.npoints)), g_space.points + ) + solver.unwrap_cube(w_data) + + assert solver.link_params is None + assert solver.link_coherence is None + + +def test_integrate_link_params_without_model(): + """Test that integrate_link_params raises error when no link model was used.""" + n_sar, _, phase, _ = gen_data_with_velocity() + igram = np.exp(1j * phase) + + g_time = spurt.graph.Hop3Graph(n_sar) + s_time = spurt.mcf.ORMCFSolver(g_time) + + g_space = spurt.graph.Reg2DGraph(igram.shape[1:]) + s_space = spurt.mcf.ORMCFSolver(g_space) + + settings = spurt.workflows.emcf.SolverSettings( + s_worker_count=1, + t_worker_count=1, + ) + solver = spurt.workflows.emcf.Solver(s_space, s_time, settings) + + w_data = spurt.io.Irreg3DInput( + igram.reshape((n_sar, g_space.npoints)), g_space.points + ) + solver.unwrap_cube(w_data) + + with pytest.raises(RuntimeError, match="No link parameters available"): + solver.integrate_link_params()