diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f998f324ce..bc3ca79eac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,7 +64,7 @@ jobs: run: ./bin/run-doc-codeblocks --ci --no-cache tests: - timeout-minutes: 20 + timeout-minutes: 25 strategy: matrix: pyver: ['3.10', '3.11', '3.12', '3.13', '3.14'] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea8769d946..3afc56dec3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -117,7 +117,7 @@ repos: git ls-files "*Cargo.toml" | while read -r m; do cargo locate-project --manifest-path "$m" --workspace --message-format plain done | sort -u | while read -r root; do - cargo fmt --manifest-path "$root" --all --check + cargo fmt --manifest-path "$root" --all done ' language: system diff --git a/dimos/conftest.py b/dimos/conftest.py index 6d7d1f509a..8f50ee8ec4 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -17,6 +17,7 @@ import hashlib import os import platform +import shutil import tempfile import threading @@ -74,6 +75,10 @@ def _has_ros() -> bool: return False +def _has_nix() -> bool: + return shutil.which("nix") is not None + + def _is_macos() -> bool: return platform.system() == "Darwin" @@ -90,6 +95,7 @@ def pytest_configure(config): config.addinivalue_line("markers", "skipif_no_openai: skip when OPENAI_API_KEY is not set") config.addinivalue_line("markers", "skipif_no_alibaba: skip when ALIBABA_API_KEY is not set") config.addinivalue_line("markers", "skipif_no_ros: skip when ROS dependencies are not present") + config.addinivalue_line("markers", "skipif_no_nix: skip when the `nix` binary is not on PATH") config.addinivalue_line("markers", "skipif_macos_bug: skip known-buggy tests on macOS") config.addinivalue_line("markers", "skipif_macos: skip tests not intended to run on macOS") @@ -122,6 +128,7 @@ def pytest_collection_modifyitems(config, items): "skipif_no_openai": (not os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY not set"), "skipif_no_alibaba": (not os.getenv("ALIBABA_API_KEY"), "ALIBABA_API_KEY not set"), "skipif_no_ros": (not _has_ros(), "ROS dependencies are not present"), + "skipif_no_nix": (not _has_nix(), "nix binary is not on PATH"), "skipif_macos_bug": (_is_macos(), "Some tests are buggy on Mac OS"), "skipif_macos": (_is_macos(), "Not intended to run on macOS"), } diff --git a/dimos/control/blueprints/mobile.py b/dimos/control/blueprints/mobile.py index c5065ea8d4..81fb085617 100644 --- a/dimos/control/blueprints/mobile.py +++ b/dimos/control/blueprints/mobile.py @@ -198,8 +198,7 @@ def _flowbase_twist_base( ) .remappings( [ - (FastLio2, "lidar", "registered_scan"), - (FastLio2, "global_map", "global_map_fastlio"), + (FastLio2, "global_map", "_global_map_fastlio"), # SimplePlanner / FarPlanner owns way_point — disconnect MovementManager's # redundant pass-through copy (matches unitree-g1-nav-onboard). (MovementManager, "way_point", "_mgr_way_point_unused"), diff --git a/dimos/core/module.py b/dimos/core/module.py index 259118098f..939d5de3f1 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -247,11 +247,13 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self._tools = {} self._tools_lock = threading.Lock() + _tf_lock: threading.Lock = threading.Lock() + @property def tf(self): # type: ignore[no-untyped-def] - if self._tf is None: - # self._tf = self.config.tf_transport() - self._tf = LCMTF() + with self._tf_lock: + if self._tf is None: + self._tf = LCMTF() return self._tf @tf.setter diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index e24e425460..23eda3156e 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -179,6 +179,15 @@ def __init__(self, **kwargs: Any) -> None: if not Path(self.config.executable).is_absolute() and self.config.cwd is not None: self.config.executable = str(Path(self.config.cwd) / self.config.executable) + @rpc + def build(self) -> None: + # Heavy one-time work (cargo/cmake/nix builds, LFS) belongs in build(), + # not start(). Running it in start() blocks Popen and lets upstream + # publishers pump messages before the subprocess's LCM subscriptions + # are live, which causes flaky data loss in tests. + super().build() + self._maybe_build() + @rpc def start(self) -> None: super().start() @@ -190,8 +199,6 @@ def start(self) -> None: ) return - self._maybe_build() - topics = self._collect_topics() cmd = [self.config.executable] diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3d4ade88d4..22cbca31b3 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -77,7 +77,7 @@ def test_classmethods() -> None: # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 7 + assert len(class_rpcs) == 8 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" diff --git a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp index 5c53381aa3..d91f447133 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp +++ b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp @@ -510,6 +510,11 @@ int main(int argc, char** argv) { if (debug) printf("[fastlio2] SDK started, waiting for device...\n"); + // NativeModule.start() in Python reads stderr for this marker and only + // returns once it sees it. + fprintf(stderr, "[DIMOS_NATIVE_READY]\n"); + fflush(stderr); + // Main loop auto frame_interval = std::chrono::microseconds( static_cast(1e6 / g_frequency)); diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index 89f3e82ab8..c28c6ea8f8 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -12,20 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from typing import TYPE_CHECKING + +from reactivex.disposable import Disposable from dimos.core.coordination.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.stream import In, Out from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.mapping.ray_tracing.module import RayTracingVoxelMap from dimos.mapping.voxels import VoxelGridMapper +from dimos.memory2.module import MemoryModule, MemoryModuleConfig, Recorder, RecorderConfig +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.testing.replay import timed_playback from dimos.visualization.vis_module import vis_module +if TYPE_CHECKING: + from rerun._baseclasses import Archetype + + +class FastlioMemoryConfig(RecorderConfig): + db_path: str | Path = "recording_fastlio.db" + default_frame_id: str = "base_link" + + voxel_size = 0.05 +class FastlioMemory(Recorder): + config: FastlioMemoryConfig + lidar: In[PointCloud2] + odometry: In[Odometry] + + @rpc + def start(self) -> None: + super().start() + + def _on_odom(msg: Odometry) -> None: + self.tf.publish(Transform.from_odometry(msg)) + + self.register_disposable(Disposable(self.odometry.subscribe(_on_odom))) + + +class FastlioReplayConfig(MemoryModuleConfig): + db_path: str | Path = "recording_fastlio.db" + speed: float = 1.0 + + +class FastlioReplay(MemoryModule): + """Replays a FastLIO2 recording (lidar + odometry) at real-time speed. + + Drop-in replacement for ``FastLio2`` when feeding rerun off a recorded session. + Publishes odometry to tf so downstream visualizers see robot pose. + """ + + config: FastlioReplayConfig + lidar: Out[PointCloud2] + odometry: Out[Odometry] + + @rpc + def start(self) -> None: + super().start() + + lidar_stream = self.store.stream("lidar", PointCloud2) + odom_stream = self.store.stream("odometry", Odometry) + + def _publish_odom(msg: Odometry) -> None: + self.tf.publish(Transform.from_odometry(msg)) + self.odometry.publish(msg) + + speed = self.config.speed + + self.register_disposable( + timed_playback( + lambda: ((obs.ts, obs.data) for obs in lidar_stream), + speed=speed, + ).subscribe(self.lidar.publish) + ) + self.register_disposable( + timed_playback( + lambda: ((obs.ts, obs.data) for obs in odom_stream), + speed=speed, + ).subscribe(_publish_odom) + ) + + +def _convert_global_map(msg: PointCloud2) -> "Archetype": + return msg.to_rerun(voxel_size=voxel_size) + + mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360_fastlio2") +mid360_fastlio_memory = autoconnect( + FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), + vis_module("rerun"), + FastlioMemory.blueprint(), +).global_config(n_workers=3, robot_model="mid360_fastlio2_memory") + mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), @@ -39,6 +128,31 @@ ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") +mid360_fastlio_replay = autoconnect( + FastlioReplay.blueprint(), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": _convert_global_map, + }, + }, + ), +).global_config(n_workers=2, robot_model="mid360_fastlio2_replay") + +mid360_fastlio_replay_voxels = autoconnect( + FastlioReplay.blueprint(), + VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=True), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": _convert_global_map, + }, + }, + ), +).global_config(n_workers=2, robot_model="mid360_fastlio2_replay") + mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), vis_module( @@ -50,3 +164,31 @@ }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") + + +mid360_fastlio_ray_trace_replay = autoconnect( + FastlioReplay.blueprint(), + RayTracingVoxelMap.blueprint(voxel_size=voxel_size), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + }, + }, + ), +).global_config(n_workers=3, robot_model="mid360_fastlio2_ray_trace_replay") + + +mid360_fastlio_ray_trace = autoconnect( + FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), + RayTracingVoxelMap.blueprint(voxel_size=voxel_size), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + }, + }, + ), +).global_config(n_workers=3, robot_model="mid360_fastlio2_ray_trace") diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index 2c8ab22dfc..8b1d03db87 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -61,7 +61,6 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.nav_stack.frames import FRAME_BODY, FRAME_ODOM from dimos.spec import mapping, perception from dimos.utils.generic import get_local_ips from dimos.utils.logging_config import setup_logger @@ -83,11 +82,9 @@ class FastLio2Config(NativeModuleConfig): # Converted to init_pose CLI arg [x, y, z, qx, qy, qz, qw] in model_post_init. mount: Pose = Pose() - # Frame IDs for output messages. "odom" reflects that FastLio2 provides - # locally-smooth, continuous odometry (no loop-closure jumps). PGO - # publishes the map→odom correction via TF. - frame_id: str = FRAME_ODOM - child_frame_id: str = FRAME_BODY + frame_id: str = "start_point" + child_frame_id: str = "current_point" + sensor_frame: str = "mid360_link" # FAST-LIO internal processing rates msr_freq: float = 50.0 @@ -127,7 +124,7 @@ class FastLio2Config(NativeModuleConfig): host_imu_data_port: int = SDK_HOST_IMU_DATA_PORT host_log_data_port: int = SDK_HOST_LOG_DATA_PORT - # Resolved in __post_init__, passed as --config_path to the binary + # Resolved from `config` in model_post_init, passed as --config_path to the binary config_path: str | None = None # init_pose is computed from mount; config is resolved to config_path @@ -169,10 +166,11 @@ def start(self) -> None: ) def _on_odom_for_tf(self, msg: Odometry) -> None: + ts = msg.ts or time.time() self.tf.publish( Transform( - frame_id=FRAME_ODOM, - child_frame_id=FRAME_BODY, + frame_id=self.config.frame_id, + child_frame_id=self.config.child_frame_id, translation=Vector3( msg.pose.position.x, msg.pose.position.y, @@ -184,7 +182,23 @@ def _on_odom_for_tf(self, msg: Odometry) -> None: msg.pose.orientation.z, msg.pose.orientation.w, ), - ts=msg.ts or time.time(), + ts=ts, + ) + ) + # Static sensor mount + mount = self.config.mount + self.tf.publish( + Transform( + frame_id=self.config.child_frame_id, + child_frame_id=self.config.sensor_frame, + translation=Vector3(mount.x, mount.y, mount.z), + rotation=Quaternion( + mount.orientation.x, + mount.orientation.y, + mount.orientation.z, + mount.orientation.w, + ), + ts=ts, ) ) diff --git a/dimos/hardware/sensors/lidar/livox/cpp/main.cpp b/dimos/hardware/sensors/lidar/livox/cpp/main.cpp index cdf083ef3b..fd1d47d64f 100644 --- a/dimos/hardware/sensors/lidar/livox/cpp/main.cpp +++ b/dimos/hardware/sensors/lidar/livox/cpp/main.cpp @@ -297,6 +297,11 @@ int main(int argc, char** argv) { printf("[mid360] SDK started, waiting for device...\n"); + // NativeModule.start() in Python reads stderr for this marker and only + // returns once it sees it. + fprintf(stderr, "[DIMOS_NATIVE_READY]\n"); + fflush(stderr); + // Main loop: periodically emit accumulated point clouds auto frame_interval = std::chrono::microseconds( static_cast(1e6 / g_frequency)); diff --git a/dimos/mapping/ray_tracing/demo_clearing_scene.py b/dimos/mapping/ray_tracing/demo_clearing_scene.py new file mode 100644 index 0000000000..c22e91b048 --- /dev/null +++ b/dimos/mapping/ray_tracing/demo_clearing_scene.py @@ -0,0 +1,419 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 Dimensional Inc. +# SPDX-License-Identifier: Apache-2.0 +"""Synthetic test scene for the ray-tracing voxel module — Rerun preview. + +Scene layout (all in world frame, meters): + + Sensor at (0, 0, 1.0), looking down +x. + + Floor : z = 0, x ∈ [0.5, 8], y ∈ [-3, 3] + Wall : x = 6, y ∈ [-3, 3], z ∈ [0, 2.5] + Box : axis-aligned, x ∈ [4.0, 4.5], y ∈ [0.3, 1.1], z ∈ [0, 0.5] + — a static obstacle on the floor, sitting between the person + and the back wall. The person walks past it and partially + occludes it for several frames; the ray tracer must NOT + erase it during occlusion. + Person : a thin vertical "wall of points" at x = 3, + y ∈ [person_y ± 0.3], z ∈ [0, 1.8] + +Per frame: + * The person walks across the field of view in the +y direction. + * The "lidar return" is the union of all voxel-grid surface points + that are not occluded by a closer surface from the sensor. + +This is a no-dimos sanity check — once the input geometry looks right +in Rerun, the same generator will be wrapped to feed PointCloud2 + +Odometry into RayTracingVoxelMap. +""" + +from __future__ import annotations + +import argparse +from collections.abc import Iterator +from dataclasses import dataclass +import time + +import numpy as np +import rerun as rr + +VOXEL_SIZE = 0.1 # meters per voxel edge + +SENSOR_ORIGIN = np.array([0.0, 0.0, 1.0], dtype=np.float32) + +WALL_X = 6.0 +WALL_Y = (-3.0, 3.0) +WALL_Z = (0.0, 2.5) + +# Floor stops at the wall — we never observe floor behind a wall from a +# sensor in front of it, so simulating it would just create ghost lidar +# returns past the wall. +FLOOR_X = (0.5, WALL_X) +FLOOR_Y = (-3.0, 3.0) +FLOOR_Z = 0.0 + +BOX_X = (4.0, 4.5) +BOX_Y = (0.3, 1.1) +BOX_Z = (0.0, 0.5) + +PERSON_X = 3.0 +PERSON_HALF_WIDTH = 0.3 +PERSON_HEIGHT = 1.8 +PERSON_Y_START = -2.0 +PERSON_Y_END = 2.0 + + +@dataclass +class Frame: + index: int + timestamp_s: float + sensor_origin: np.ndarray # (3,) float32 + points: np.ndarray # (N, 3) float32, world-frame + person_y: float | None # None on frame 0 (no person yet) + + +def _grid_axis(lo: float, hi: float, step: float) -> np.ndarray: + """Voxel-center positions covering [lo, hi).""" + return np.arange(lo, hi, step, dtype=np.float32) + np.float32(step / 2) + + +def _floor_points() -> np.ndarray: + xs = _grid_axis(FLOOR_X[0], FLOOR_X[1], VOXEL_SIZE) + ys = _grid_axis(FLOOR_Y[0], FLOOR_Y[1], VOXEL_SIZE) + grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij") + z = np.full_like(grid_x, FLOOR_Z) + return np.stack([grid_x.ravel(), grid_y.ravel(), z.ravel()], axis=-1) + + +def _wall_points() -> np.ndarray: + ys = _grid_axis(WALL_Y[0], WALL_Y[1], VOXEL_SIZE) + zs = _grid_axis(WALL_Z[0], WALL_Z[1], VOXEL_SIZE) + grid_y, grid_z = np.meshgrid(ys, zs, indexing="ij") + x = np.full_like(grid_y, WALL_X) + return np.stack([x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=-1) + + +def _person_points(person_y: float) -> np.ndarray: + ys = _grid_axis(person_y - PERSON_HALF_WIDTH, person_y + PERSON_HALF_WIDTH, VOXEL_SIZE) + zs = _grid_axis(0.0, PERSON_HEIGHT, VOXEL_SIZE) + grid_y, grid_z = np.meshgrid(ys, zs, indexing="ij") + x = np.full_like(grid_y, PERSON_X) + return np.stack([x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=-1) + + +def _box_visible_face_points() -> np.ndarray: + """Three sensor-facing faces of the box (front, top, near side). + + From a sensor at (0, 0, +z), only the -x face, the +z face, and the + -y face are visible. The other three faces are hidden behind the box + itself, so we don't generate them — that way no self-occlusion check + is needed for the box. + """ + # Front face: x = BOX_X[0], spans y × z + ys = _grid_axis(BOX_Y[0], BOX_Y[1], VOXEL_SIZE) + zs = _grid_axis(BOX_Z[0], BOX_Z[1], VOXEL_SIZE) + gy, gz = np.meshgrid(ys, zs, indexing="ij") + front = np.stack([np.full_like(gy, BOX_X[0]).ravel(), gy.ravel(), gz.ravel()], axis=-1) + + # Top face: z = BOX_Z[1], spans x × y + xs = _grid_axis(BOX_X[0], BOX_X[1], VOXEL_SIZE) + ys = _grid_axis(BOX_Y[0], BOX_Y[1], VOXEL_SIZE) + gx, gy = np.meshgrid(xs, ys, indexing="ij") + top = np.stack([gx.ravel(), gy.ravel(), np.full_like(gx, BOX_Z[1]).ravel()], axis=-1) + + # Near side: y = BOX_Y[0], spans x × z + xs = _grid_axis(BOX_X[0], BOX_X[1], VOXEL_SIZE) + zs = _grid_axis(BOX_Z[0], BOX_Z[1], VOXEL_SIZE) + gx, gz = np.meshgrid(xs, zs, indexing="ij") + near = np.stack([gx.ravel(), np.full_like(gx, BOX_Y[0]).ravel(), gz.ravel()], axis=-1) + + return np.concatenate([front, top, near], axis=0) + + +def _occluded_by_box(origin: np.ndarray, targets: np.ndarray) -> np.ndarray: + """Boolean mask, True for targets whose ray from `origin` is blocked + by the box AABB. Vectorized AABB slab test. + + Tolerance bands at t≈0 and t≈1 mean points sitting exactly on the box + surface (or behind it past the back face) aren't flagged as + self-occluding. + """ + deltas = targets - origin # (N, 3) + box_min = np.array([BOX_X[0], BOX_Y[0], BOX_Z[0]], dtype=np.float32) + box_max = np.array([BOX_X[1], BOX_Y[1], BOX_Z[1]], dtype=np.float32) + + eps = 1e-9 + safe_d = np.where(np.abs(deltas) < eps, eps, deltas) + t1 = (box_min - origin) / safe_d + t2 = (box_max - origin) / safe_d + t_min = np.minimum(t1, t2) + t_max = np.maximum(t1, t2) + t_enter = t_min.max(axis=1) + t_exit = t_max.min(axis=1) + + hits_box = t_enter <= t_exit + # The box is in front of the target if it enters at t < 1, after the + # origin (t > 0). A 1e-4 margin keeps box-surface points themselves + # from being flagged as occluded. + return hits_box & (t_enter > 1e-4) & (t_enter < 1.0 - 1e-4) # type: ignore[no-any-return] + + +def _occluded_by_person(origin: np.ndarray, targets: np.ndarray, person_y: float) -> np.ndarray: + """Boolean mask, True for targets whose ray from `origin` is blocked + by the person standing at `person_y`. Vectorized over targets. + + A target is occluded if the ray origin→target crosses the person's + front plane (x = PERSON_X) at a (y, z) point inside the person's + rectangle, AND the crossing happens before the target itself. + """ + deltas = targets - origin # (N, 3) + dx = deltas[:, 0] + # Rays moving in -x or staying still can't be blocked by a +x plane. + forward = dx > 0 + safe_dx = np.where(forward, dx, 1.0) + t_p = (PERSON_X - origin[0]) / safe_dx # parametric distance to person plane + crosses_in_front = forward & (t_p > 0.0) & (t_p < 1.0) + y_at = origin[1] + t_p * deltas[:, 1] + z_at = origin[2] + t_p * deltas[:, 2] + inside_person = ( + (np.abs(y_at - person_y) < PERSON_HALF_WIDTH) & (z_at >= 0.0) & (z_at < PERSON_HEIGHT) + ) + return crosses_in_front & inside_person # type: ignore[no-any-return] + + +def _visible_points(person_y: float | None) -> np.ndarray: + """Lidar return from SENSOR_ORIGIN: floor + wall + box + (optional) + person. Floor / wall / box points are dropped if a closer surface + (the box itself, for floor/wall; or the person, for any of them) + blocks their ray from the sensor. + """ + floor = _floor_points() + wall = _wall_points() + box = _box_visible_face_points() + + box_occ_floor = _occluded_by_box(SENSOR_ORIGIN, floor) + box_occ_wall = _occluded_by_box(SENSOR_ORIGIN, wall) + + if person_y is None: + return np.concatenate( # type: ignore[no-any-return] + [floor[~box_occ_floor], wall[~box_occ_wall], box], axis=0 + ).astype(np.float32) + + person_occ_floor = _occluded_by_person(SENSOR_ORIGIN, floor, person_y) + person_occ_wall = _occluded_by_person(SENSOR_ORIGIN, wall, person_y) + person_occ_box = _occluded_by_person(SENSOR_ORIGIN, box, person_y) + + person = _person_points(person_y) + return np.concatenate( # type: ignore[no-any-return] + [ + floor[~(box_occ_floor | person_occ_floor)], + wall[~(box_occ_wall | person_occ_wall)], + box[~person_occ_box], + person, + ], + axis=0, + ).astype(np.float32) + + +def synthetic_scene(num_frames: int = 60, frame_dt: float = 0.1) -> Iterator[Frame]: + """Yield frames one at a time. + + Frame 0: empty scene (floor + back wall only, no person). + Frames 1..num_frames-1: person walks from PERSON_Y_START to PERSON_Y_END. + """ + yield Frame( + index=0, + timestamp_s=0.0, + sensor_origin=SENSOR_ORIGIN.copy(), + points=_visible_points(person_y=None), + person_y=None, + ) + + if num_frames < 2: + return + walking_frames = num_frames - 1 + for i in range(walking_frames): + t = i / max(walking_frames - 1, 1) + person_y = PERSON_Y_START + t * (PERSON_Y_END - PERSON_Y_START) + frame_idx = i + 1 + yield Frame( + index=frame_idx, + timestamp_s=frame_idx * frame_dt, + sensor_origin=SENSOR_ORIGIN.copy(), + points=_visible_points(person_y=person_y), + person_y=person_y, + ) + + +def _classify_points(points: np.ndarray, person_y: float | None) -> np.ndarray: + """Per-point class id: 0=floor, 1=wall, 2=person, 3=box. Coloring only. + + Classification by which surface generated the point — floor/wall both + have voxels at z=0.05 (lowest row) so we can't disambiguate by z alone. + """ + is_wall = np.abs(points[:, 0] - WALL_X) < 1e-3 + is_floor = np.abs(points[:, 2] - FLOOR_Z) < 1e-3 + in_box = ( + (points[:, 0] >= BOX_X[0] - 1e-3) + & (points[:, 0] <= BOX_X[1] + 1e-3) + & (points[:, 1] >= BOX_Y[0] - 1e-3) + & (points[:, 1] <= BOX_Y[1] + 1e-3) + & (points[:, 2] >= BOX_Z[0] - 1e-3) + & (points[:, 2] <= BOX_Z[1] + 1e-3) + ) + + classes = np.empty(len(points), dtype=np.uint8) + classes[:] = 1 # default to wall + classes[is_floor & ~is_wall] = 0 # floor (but a wall point at z=0.05 stays wall) + classes[in_box] = 3 # box overrides + + if person_y is not None: + is_person = (np.abs(points[:, 0] - PERSON_X) < 1e-3) & ( + np.abs(points[:, 1] - person_y) < PERSON_HALF_WIDTH + 1e-3 + ) + classes[is_person] = 2 # person overrides everything else + + return classes + + +CLASS_COLORS = np.array( + [ + [120, 120, 120], # floor — gray + [80, 160, 255], # wall — blue + [255, 80, 80], # person — red + [255, 180, 60], # box — orange + ], + dtype=np.uint8, +) + + +def log_to_rerun(num_frames: int, frame_dt: float, realtime: bool = True) -> None: + rr.log( + "world/sensor", + rr.Points3D( + positions=SENSOR_ORIGIN.reshape(1, 3), + colors=np.array([[0, 255, 0]], dtype=np.uint8), + radii=0.08, + labels=["sensor"], + ), + static=True, + ) + # Show the un-occluded reference scene as a faint backdrop so the user + # can see what the back-wall "should" look like behind the person. + reference_floor = _floor_points().astype(np.float32) + reference_wall = _wall_points().astype(np.float32) + rr.log( + "world/reference/floor", + rr.Points3D( + positions=reference_floor, + colors=np.tile(np.array([[60, 60, 60]], dtype=np.uint8), (len(reference_floor), 1)), + radii=VOXEL_SIZE / 2 * 0.6, + ), + static=True, + ) + rr.log( + "world/reference/wall", + rr.Points3D( + positions=reference_wall, + colors=np.tile(np.array([[40, 60, 90]], dtype=np.uint8), (len(reference_wall), 1)), + radii=VOXEL_SIZE / 2 * 0.6, + ), + static=True, + ) + reference_box = _box_visible_face_points().astype(np.float32) + rr.log( + "world/reference/box", + rr.Points3D( + positions=reference_box, + colors=np.tile(np.array([[90, 70, 40]], dtype=np.uint8), (len(reference_box), 1)), + radii=VOXEL_SIZE / 2 * 0.6, + ), + static=True, + ) + + for frame in synthetic_scene(num_frames=num_frames, frame_dt=frame_dt): + # Single timeline in seconds — viewer plays it at 1× wall-clock by + # default, so 60 frames @ dt=0.1s plays as a 6-second video. + rr.set_time("time", duration=frame.timestamp_s) + + classes = _classify_points(frame.points, frame.person_y) + colors = CLASS_COLORS[classes] + + rr.log( + "world/lidar_return", + rr.Points3D( + positions=frame.points, + colors=colors, + radii=VOXEL_SIZE / 2, + ), + ) + + # Visualize a few sample rays from the sensor toward the back wall + # to make occlusion easy to read at a glance. + sample_ys = np.linspace(WALL_Y[0] + 0.2, WALL_Y[1] - 0.2, 9, dtype=np.float32) + sample_targets = np.stack( + [ + np.full_like(sample_ys, WALL_X), + sample_ys, + np.full_like(sample_ys, 1.2), + ], + axis=-1, + ) + ray_origins = np.tile(SENSOR_ORIGIN, (len(sample_targets), 1)) + ray_strips = np.stack([ray_origins, sample_targets], axis=1) # (R, 2, 3) + rr.log( + "world/sample_rays", + rr.LineStrips3D( + strips=list(ray_strips), + colors=np.tile( + np.array([[200, 200, 80, 60]], dtype=np.uint8), (len(ray_strips), 1) + ), + radii=0.005, + ), + ) + + # Stream live: sleep between frames so the viewer renders them + # in order at real-time pace. Skipped for offline .rrd captures. + if realtime: + time.sleep(frame_dt) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument("--frames", type=int, default=60, help="total frames (incl. frame 0)") + parser.add_argument("--dt", type=float, default=0.1, help="seconds per frame") + parser.add_argument( + "--save", + type=str, + default=None, + help="if set, save to this .rrd path instead of spawning the viewer", + ) + args = parser.parse_args() + + if args.save: + rr.init("ray_tracing_clearing_scene", spawn=False) + rr.save(args.save) + log_to_rerun(num_frames=args.frames, frame_dt=args.dt, realtime=False) + else: + rr.init("ray_tracing_clearing_scene", spawn=True) + # Give the spawned viewer a moment to connect before we start + # streaming, otherwise the first few frames can be missed. + time.sleep(1.0) + log_to_rerun(num_frames=args.frames, frame_dt=args.dt, realtime=True) + + +if __name__ == "__main__": + main() diff --git a/dimos/mapping/ray_tracing/module.py b/dimos/mapping/ray_tracing/module.py new file mode 100644 index 0000000000..c1b07e90b2 --- /dev/null +++ b/dimos/mapping/ray_tracing/module.py @@ -0,0 +1,90 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Native Rust voxel-map module with raycast clearing. + +Subscribes to a world-frame ``PointCloud2`` (e.g. from FastLio2's +``lidar`` output) and matching ``Odometry``, maintains a global +voxel hash set, and publishes the accumulated map on ``global_map`` +as a :class:`DynamicCloud` (per-voxel health + slow-clock sequence +stamp). + +Algorithm (v1): + * Insert the voxel of every point into the global hash set. + * For every point, walk the 3-D DDA ray from the latest + odometry position to the point and remove every intermediate + voxel from the map. The endpoint voxel is preserved. + * A "slow clock" sequence counter increments every + ``sequence_period_secs`` (default 1.0s). Any voxel touched + while still uncertain (health <= 0) is stamped with the + current sequence value; once health > 0 the stamp freezes, + capturing "when did this voxel become confirmed." + +Map override: + Publishing to ``map_override`` with a :class:`DynamicCloud` + fully replaces the internal voxel state with the override's + contents. The slow-clock counter snaps to + ``max(override.sequence)``, even if that's less than the + current value — the override is authoritative. + +The Rust binary at ``rust/`` does the heavy lifting. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class RayTracingVoxelMapConfig(NativeModuleConfig): + cwd: str | None = "rust" + executable: str = "target/release/voxel_ray_tracing" + build_command: str | None = "cargo build --release" + stdin_config: bool = True + + voxel_size: float = 0.1 + # Skip rays longer than this (meters); 0 disables the limit. + max_range: float = 30.0 + # Controls what portion of rays we perform ray tracing on. + # Honestly we probably should always have this at 1 unless you don't care about a clean map. + # Higher num means less ray tracing. + ray_subsample: int = 1 + # Extend rays past the end point to clear shadows + shadow_depth: float = 0.2 + # Bounds for the health of voxels. Positive health means voxel is occupied. + min_health: int = -1 + max_health: int = 1 + # Seconds between sequence-counter increments ("slow clock"). + sequence_period_secs: float = 1.0 + + +class RayTracingVoxelMap(NativeModule): + """Rust voxel-map module with raycast clearing of dynamic objects.""" + + config: RayTracingVoxelMapConfig + + lidar: In[PointCloud2] + odometry: In[Odometry] + map_override: In[DynamicCloud] + global_map: Out[DynamicCloud] + + +# Verify protocol port compliance (mypy will flag missing ports) +if TYPE_CHECKING: + RayTracingVoxelMap() diff --git a/dimos/mapping/ray_tracing/rust/.gitignore b/dimos/mapping/ray_tracing/rust/.gitignore new file mode 100644 index 0000000000..2f7896d1d1 --- /dev/null +++ b/dimos/mapping/ray_tracing/rust/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/dimos/mapping/ray_tracing/rust/Cargo.lock b/dimos/mapping/ray_tracing/rust/Cargo.lock new file mode 100644 index 0000000000..d760804d1e --- /dev/null +++ b/dimos/mapping/ray_tracing/rust/Cargo.lock @@ -0,0 +1,428 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "dimos-lcm" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" +dependencies = [ + "byteorder", + "socket2 0.5.10", + "tokio", +] + +[[package]] +name = "dimos-module" +version = "0.1.0" +dependencies = [ + "dimos-lcm", + "dimos-module-macros", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "dimos-module-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dimos-voxel-ray-tracing" +version = "0.1.0" +dependencies = [ + "ahash", + "dimos-module", + "lcm-msgs", + "serde", + "tokio", +] + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lcm-msgs" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#e7c9428b7201cdfeadecd181c77c9e2d60a14503" +dependencies = [ + "byteorder", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tokio" +version = "1.52.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2 0.6.3", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/dimos/mapping/ray_tracing/rust/Cargo.toml b/dimos/mapping/ray_tracing/rust/Cargo.toml new file mode 100644 index 0000000000..7dc2ca4c53 --- /dev/null +++ b/dimos/mapping/ray_tracing/rust/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "dimos-voxel-ray-tracing" +version = "0.1.0" +edition = "2021" +description = "Native Rust voxel-map module with raycast clearing for dimos" +license = "Apache-2.0" + +[[bin]] +name = "voxel_ray_tracing" +path = "src/main.rs" + +[dependencies] +dimos-module = { path = "../../../../native/rust/dimos-module" } +lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal"] } +serde = { version = "1", features = ["derive"] } +ahash = "0.8" + +[profile.release] +lto = "thin" +codegen-units = 1 diff --git a/dimos/mapping/ray_tracing/rust/src/dynamic_cloud.rs b/dimos/mapping/ray_tracing/rust/src/dynamic_cloud.rs new file mode 100644 index 0000000000..0c960bd3b2 --- /dev/null +++ b/dimos/mapping/ray_tracing/rust/src/dynamic_cloud.rs @@ -0,0 +1,387 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// DynamicCloud: a per-voxel point cloud + sparse timestamped event log. +// +// Mirrors `dimos/msgs/nav_msgs/DynamicCloud.py`. Wire format +// (little-endian, packed): +// +// u64 timestamp_nanos // overall message timestamp +// f32 voxel_size // meters per voxel edge +// u16 frame_id_len +// bytes frame_id // utf-8, frame_id_len bytes +// u32 num_points +// i32[N*3] voxels // (x, y, z) interleaved +// u32[N] quantity // per-point quantity +// u32 num_events +// u32[M] event_indices // indices into voxels (0 ≤ idx < N) +// u64[M] event_timestamps // nanoseconds +// +// `num_events` is independent of `num_points`; events can be empty, +// can reference the same point multiple times, and don't need to cover +// every point. The python test at `test_dynamic_cloud.py::test_known_bytes` +// pins the byte fixture this file's `tests::known_bytes_matches_python` +// also asserts against — drift on either side breaks both tests. + +use std::convert::TryInto; +use std::fmt; + +#[derive(Debug, Clone, PartialEq)] +pub struct DynamicCloud { + pub timestamp_nanos: u64, + pub voxel_size: f32, + pub frame_id: String, + /// Voxel keys (signed integer coords in voxel-grid space). + pub voxels: Vec<(i32, i32, i32)>, + /// Per-point unsigned integer (e.g. voxel health/hit count). + pub quantity: Vec, + /// Sparse event log: indices into `voxels`. + pub event_indices: Vec, + /// Sparse event log: timestamp (nanoseconds) for each event. + pub event_timestamps: Vec, +} + +#[derive(Debug)] +pub enum DecodeError { + Truncated { needed: usize, got: usize }, + InvalidUtf8(std::str::Utf8Error), + PayloadSizeMismatch { expected: usize, got: usize }, + EventIndexOutOfRange { index: u32, num_points: u32 }, +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DecodeError::Truncated { needed, got } => { + write!(f, "DynamicCloud: truncated (needed {needed}, got {got})") + } + DecodeError::InvalidUtf8(e) => write!(f, "DynamicCloud: invalid utf-8: {e}"), + DecodeError::PayloadSizeMismatch { expected, got } => write!( + f, + "DynamicCloud: payload size mismatch (expected {expected} tail bytes, got {got})" + ), + DecodeError::EventIndexOutOfRange { index, num_points } => write!( + f, + "DynamicCloud: event index {index} out of range for {num_points} points" + ), + } + } +} + +impl std::error::Error for DecodeError {} + +#[derive(Debug)] +pub enum EncodeError { + FrameIdTooLong(usize), + EventLengthMismatch { indices: usize, timestamps: usize }, +} + +impl fmt::Display for EncodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + EncodeError::FrameIdTooLong(n) => { + write!(f, "DynamicCloud: frame_id too long ({n} > 65535 bytes)") + } + EncodeError::EventLengthMismatch { indices, timestamps } => write!( + f, + "DynamicCloud: event arrays length mismatch (indices={indices}, timestamps={timestamps})" + ), + } + } +} + +impl std::error::Error for EncodeError {} + +const HEADER_SIZE: usize = 8 + 4 + 2; +const U32_SIZE: usize = 4; + +impl DynamicCloud { + #[allow(dead_code)] // public API, used by tests + pub fn new(voxel_size: f32, frame_id: impl Into) -> Self { + Self { + timestamp_nanos: 0, + voxel_size, + frame_id: frame_id.into(), + voxels: Vec::new(), + quantity: Vec::new(), + event_indices: Vec::new(), + event_timestamps: Vec::new(), + } + } + + #[allow(dead_code)] // public API, used by tests + pub fn len(&self) -> usize { + self.voxels.len() + } + + pub fn encode(&self) -> Result, EncodeError> { + let frame_bytes = self.frame_id.as_bytes(); + if frame_bytes.len() > u16::MAX as usize { + return Err(EncodeError::FrameIdTooLong(frame_bytes.len())); + } + if self.event_indices.len() != self.event_timestamps.len() { + return Err(EncodeError::EventLengthMismatch { + indices: self.event_indices.len(), + timestamps: self.event_timestamps.len(), + }); + } + + let num_points = self.voxels.len().min(self.quantity.len()); + let num_events = self.event_indices.len(); + + let voxels_bytes = num_points * 3 * 4; + let quantity_bytes = num_points * 4; + let events_idx_bytes = num_events * 4; + let events_ts_bytes = num_events * 8; + let total = HEADER_SIZE + + frame_bytes.len() + + U32_SIZE + + voxels_bytes + + quantity_bytes + + U32_SIZE + + events_idx_bytes + + events_ts_bytes; + + let mut buf = Vec::with_capacity(total); + buf.extend_from_slice(&self.timestamp_nanos.to_le_bytes()); + buf.extend_from_slice(&self.voxel_size.to_le_bytes()); + buf.extend_from_slice(&(frame_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(frame_bytes); + buf.extend_from_slice(&(num_points as u32).to_le_bytes()); + + for &(x, y, z) in &self.voxels[..num_points] { + buf.extend_from_slice(&x.to_le_bytes()); + buf.extend_from_slice(&y.to_le_bytes()); + buf.extend_from_slice(&z.to_le_bytes()); + } + for &q in &self.quantity[..num_points] { + buf.extend_from_slice(&q.to_le_bytes()); + } + + buf.extend_from_slice(&(num_events as u32).to_le_bytes()); + for &idx in &self.event_indices { + buf.extend_from_slice(&idx.to_le_bytes()); + } + for &t in &self.event_timestamps { + buf.extend_from_slice(&t.to_le_bytes()); + } + + Ok(buf) + } + + pub fn decode(data: &[u8]) -> Result { + if data.len() < HEADER_SIZE { + return Err(DecodeError::Truncated { + needed: HEADER_SIZE, + got: data.len(), + }); + } + + let timestamp_nanos = u64::from_le_bytes(data[0..8].try_into().unwrap()); + let voxel_size = f32::from_le_bytes(data[8..12].try_into().unwrap()); + let frame_id_len = u16::from_le_bytes(data[12..14].try_into().unwrap()) as usize; + let mut offset = HEADER_SIZE; + + let needed = offset + frame_id_len + U32_SIZE; + if data.len() < needed { + return Err(DecodeError::Truncated { + needed, + got: data.len(), + }); + } + + let frame_id = std::str::from_utf8(&data[offset..offset + frame_id_len]) + .map_err(DecodeError::InvalidUtf8)? + .to_string(); + offset += frame_id_len; + + let num_points = + u32::from_le_bytes(data[offset..offset + U32_SIZE].try_into().unwrap()) as usize; + offset += U32_SIZE; + + let voxels_bytes = num_points * 3 * 4; + let quantity_bytes = num_points * 4; + let needed_after_points = offset + voxels_bytes + quantity_bytes + U32_SIZE; + if data.len() < needed_after_points { + return Err(DecodeError::Truncated { + needed: needed_after_points, + got: data.len(), + }); + } + + let mut voxels = Vec::with_capacity(num_points); + for i in 0..num_points { + let base = offset + i * 12; + let x = i32::from_le_bytes(data[base..base + 4].try_into().unwrap()); + let y = i32::from_le_bytes(data[base + 4..base + 8].try_into().unwrap()); + let z = i32::from_le_bytes(data[base + 8..base + 12].try_into().unwrap()); + voxels.push((x, y, z)); + } + offset += voxels_bytes; + + let mut quantity = Vec::with_capacity(num_points); + for i in 0..num_points { + let base = offset + i * 4; + quantity.push(u32::from_le_bytes(data[base..base + 4].try_into().unwrap())); + } + offset += quantity_bytes; + + let num_events = + u32::from_le_bytes(data[offset..offset + U32_SIZE].try_into().unwrap()) as usize; + offset += U32_SIZE; + + let events_idx_bytes = num_events * 4; + let events_ts_bytes = num_events * 8; + let expected_tail = events_idx_bytes + events_ts_bytes; + if data.len() - offset != expected_tail { + return Err(DecodeError::PayloadSizeMismatch { + expected: expected_tail, + got: data.len() - offset, + }); + } + + let mut event_indices = Vec::with_capacity(num_events); + for i in 0..num_events { + let base = offset + i * 4; + let idx = u32::from_le_bytes(data[base..base + 4].try_into().unwrap()); + if num_points == 0 || idx as usize >= num_points { + return Err(DecodeError::EventIndexOutOfRange { + index: idx, + num_points: num_points as u32, + }); + } + event_indices.push(idx); + } + offset += events_idx_bytes; + + let mut event_timestamps = Vec::with_capacity(num_events); + for i in 0..num_events { + let base = offset + i * 8; + event_timestamps.push(u64::from_le_bytes(data[base..base + 8].try_into().unwrap())); + } + + Ok(Self { + timestamp_nanos, + voxel_size, + frame_id, + voxels, + quantity, + event_indices, + event_timestamps, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fixture() -> DynamicCloud { + DynamicCloud { + timestamp_nanos: 1_500_000_000, + voxel_size: 0.25, + frame_id: "map".to_string(), + voxels: vec![(1, -2, 3), (4, 5, -6)], + quantity: vec![7, 8], + event_indices: vec![0, 1, 0], + event_timestamps: vec![1_000_000_000, 2_000_000_000, 1_500_000_000], + } + } + + // Same hex string as the Python KNOWN_BYTES fixture in + // test_dynamic_cloud.py — keep them in sync. + const KNOWN_BYTES_HEX: &str = concat!( + "002f685900000000", // ts_ns = 1_500_000_000 LE + "0000803e", // voxel_size = 0.25 f32 LE + "0300", // frame_id_len = 3 + "6d6170", // "map" + "02000000", // num_points = 2 + "01000000feffffff03000000", + "0400000005000000faffffff", + "0700000008000000", + "03000000", // num_events = 3 + "000000000100000000000000", + "00ca9a3b00000000", + "0094357700000000", + "002f685900000000", + ); + + fn hex_to_bytes(s: &str) -> Vec { + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) + .collect() + } + + #[test] + fn roundtrip() { + let cloud = fixture(); + let bytes = cloud.encode().expect("encode"); + let decoded = DynamicCloud::decode(&bytes).expect("decode"); + assert_eq!(cloud, decoded); + } + + #[test] + fn known_bytes_matches_python() { + let cloud = fixture(); + let bytes = cloud.encode().expect("encode"); + let expected = hex_to_bytes(KNOWN_BYTES_HEX); + assert_eq!(bytes, expected, "encoded bytes drift from python fixture"); + } + + #[test] + fn decode_known_bytes() { + let bytes = hex_to_bytes(KNOWN_BYTES_HEX); + let decoded = DynamicCloud::decode(&bytes).expect("decode"); + assert_eq!(decoded, fixture()); + } + + #[test] + fn empty_cloud_roundtrip() { + let cloud = DynamicCloud::new(0.1, "world"); + let bytes = cloud.encode().expect("encode"); + let decoded = DynamicCloud::decode(&bytes).expect("decode"); + assert_eq!(cloud, decoded); + assert_eq!(decoded.len(), 0); + assert!(decoded.event_indices.is_empty()); + } + + #[test] + fn truncated_returns_err() { + assert!(matches!( + DynamicCloud::decode(&[0u8; 4]), + Err(DecodeError::Truncated { .. }) + )); + } + + #[test] + fn payload_size_mismatch_returns_err() { + let mut bytes = fixture().encode().unwrap(); + bytes.pop(); // chop a byte off the tail + assert!(matches!( + DynamicCloud::decode(&bytes), + Err(DecodeError::PayloadSizeMismatch { .. }) + )); + } + + #[test] + fn event_index_out_of_range_returns_err() { + let cloud = DynamicCloud { + timestamp_nanos: 0, + voxel_size: 0.1, + frame_id: "x".to_string(), + voxels: vec![(0, 0, 0)], + quantity: vec![1], + event_indices: vec![5], + event_timestamps: vec![123], + }; + let bytes = cloud.encode().unwrap(); + assert!(matches!( + DynamicCloud::decode(&bytes), + Err(DecodeError::EventIndexOutOfRange { + index: 5, + num_points: 1 + }) + )); + } +} diff --git a/dimos/mapping/ray_tracing/rust/src/main.rs b/dimos/mapping/ray_tracing/rust/src/main.rs new file mode 100644 index 0000000000..74f3247500 --- /dev/null +++ b/dimos/mapping/ray_tracing/rust/src/main.rs @@ -0,0 +1,862 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Native Rust voxel-map module with raycast clearing. +// +// Algorithm (v1): +// * Insert the voxel of every point into the global hash set. +// * For every point, walk the 3D-DDA ray from the sensor origin +// (latest odometry pose) toward the point, removing every +// intermediate voxel from the map. The endpoint voxel itself +// is kept (it just got inserted as a hit). +// +// Inputs (LCM topics, set by the dimos NativeModule coordinator): +// * `lidar` : sensor_msgs::PointCloud2 (world frame) +// * `odometry` : nav_msgs::Odometry (world frame) +// +// Output: +// * `global_map` : sensor_msgs::PointCloud2 (world frame) +// +// PointCloud2 input is expected in the standard FastLio2 layout +// (xyz at offsets 0/4/8 as little-endian f32, point_step >= 12). + +mod dynamic_cloud; + +use ahash::{AHashMap, AHashSet}; +use dimos_module::{run, Input, LcmTransport, Module, Output}; +use lcm_msgs::nav_msgs::Odometry; +use lcm_msgs::sensor_msgs::{PointCloud2, PointField}; +use lcm_msgs::std_msgs::Time; +use serde::Deserialize; + +use dynamic_cloud::DynamicCloud; + +type VoxelKey = (i32, i32, i32); + +#[derive(Debug, Default, Clone, Copy)] +struct VoxelState { + health: i32, + timestamp_nanos: u64, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct Config { + voxel_size: f32, + max_range: f32, + ray_subsample: u32, + shadow_depth: f32, + min_health: i32, + max_health: i32, + /// Seconds between sequence-counter increments. Defaults to 1.0 + /// (one tick per second — "slow clock"). + #[serde(default = "default_sequence_period_secs")] + sequence_period_secs: f32, +} + +fn default_sequence_period_secs() -> f32 { + 1.0 +} + +#[derive(Default)] +struct VoxelMap { + voxels: AHashMap, +} + +#[derive(Default)] +struct SlowClock { + /// Quantized nanosecond timestamp of the current slow tick. + current_nanos: u64, + /// Wall-clock seconds (from msg timestamps) when the next tick fires. + /// `None` until the first lidar message arrives. + next_tick_secs: Option, +} + +impl SlowClock { + /// Advance the clock to the given time, snapping `current_nanos` + /// forward whenever a `period_secs` boundary is crossed. + fn advance(&mut self, now_secs: f64, period_secs: f32) { + let period = period_secs.max(f32::EPSILON) as f64; + match self.next_tick_secs { + None => { + // First sample primes the schedule. Stamp anything observed + // before the first boundary crossing with this initial time. + self.current_nanos = secs_to_nanos(now_secs); + self.next_tick_secs = Some(now_secs + period); + } + Some(mut t) => { + while now_secs >= t { + self.current_nanos = secs_to_nanos(t); + t += period; + } + self.next_tick_secs = Some(t); + } + } + } + + fn reset_to(&mut self, timestamp_nanos: u64, now_secs: f64, period_secs: f32) { + self.current_nanos = timestamp_nanos; + self.next_tick_secs = Some(now_secs + period_secs.max(f32::EPSILON) as f64); + } +} + +fn secs_to_nanos(s: f64) -> u64 { + if s <= 0.0 { + 0 + } else { + (s * 1e9) as u64 + } +} + +fn time_to_secs(t: &Time) -> f64 { + t.sec as f64 + t.nsec as f64 * 1e-9 +} + +#[derive(Module)] +struct RayTracingVoxelMap { + #[input(decode = PointCloud2::decode, handler = on_lidar)] + lidar: Input, + + #[input(decode = Odometry::decode, handler = on_odometry)] + odometry: Input, + + #[input(decode = decode_dynamic_cloud, handler = on_map_override)] + map_override: Input, + + #[output(encode = encode_dynamic_cloud)] + global_map: Output, + + #[config] + config: Config, + + map: VoxelMap, + last_origin: Option<(f32, f32, f32)>, + last_lidar_secs: Option, + clock: SlowClock, +} + +impl RayTracingVoxelMap { + async fn on_odometry(&mut self, msg: Odometry) { + self.last_origin = Some(( + msg.pose.pose.position.x as f32, + msg.pose.pose.position.y as f32, + msg.pose.pose.position.z as f32, + )); + } + + async fn on_lidar(&mut self, msg: PointCloud2) { + let Some(origin) = self.last_origin else { + // Need at least one odometry sample before we can raycast. + return; + }; + + let voxel_size = self.config.voxel_size; + if voxel_size <= 0.0 { + eprintln!("voxel_ray_tracing: voxel_size must be > 0, got {voxel_size}"); + return; + } + + let points = match extract_xyz(&msg) { + Ok(p) => p, + Err(e) => { + eprintln!("voxel_ray_tracing: bad cloud, dropped: {e}"); + return; + } + }; + if points.is_empty() { + return; + } + + let now_secs = time_to_secs(&msg.header.stamp); + self.last_lidar_secs = Some(now_secs); + self.clock + .advance(now_secs, self.config.sequence_period_secs); + let timestamp_nanos = self.clock.current_nanos; + + let inv = 1.0_f32 / voxel_size; + let mut live: AHashSet = AHashSet::with_capacity(points.len()); + for &(x, y, z) in &points { + live.insert(world_to_voxel(x, y, z, inv)); + } + + update_map( + &mut self.map, + origin, + &points, + &self.config, + timestamp_nanos, + ); + + // Echo the input cloud's frame; the global map lives in the same + // world frame as the upstream lidar/odometry. + let cloud = build_dynamic_cloud( + &self.map, + &live, + voxel_size, + &msg.header.frame_id, + msg.header.stamp, + ); + if let Err(e) = self.global_map.publish(&cloud).await { + eprintln!("voxel_ray_tracing: publish failed: {e}"); + } + } + + async fn on_map_override(&mut self, msg: DynamicCloud) { + self.map.voxels.clear(); + self.map.voxels.reserve(msg.voxels.len()); + + let mut per_voxel_ts = vec![0u64; msg.voxels.len()]; + let mut max_ts: u64 = 0; + for (i, &idx) in msg.event_indices.iter().enumerate() { + let idx = idx as usize; + let t = msg.event_timestamps[i]; + if idx < per_voxel_ts.len() && t > per_voxel_ts[idx] { + per_voxel_ts[idx] = t; + } + if t > max_ts { + max_ts = t; + } + } + + for (i, &(x, y, z)) in msg.voxels.iter().enumerate() { + let health = (msg.quantity[i] as i32).min(self.config.max_health); + let timestamp_nanos = per_voxel_ts[i]; + self.map.voxels.insert( + (x, y, z), + VoxelState { + health, + timestamp_nanos, + }, + ); + } + + // Reset the slow clock. Prefer the last lidar timestamp; fall back + // to the override's own message timestamp if no lidar has been seen yet. + let now_secs = self + .last_lidar_secs + .unwrap_or(msg.timestamp_nanos as f64 * 1e-9); + self.clock + .reset_to(max_ts, now_secs, self.config.sequence_period_secs); + } +} + +fn decode_dynamic_cloud(buf: &[u8]) -> std::io::Result { + DynamicCloud::decode(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e)) +} + +fn encode_dynamic_cloud(msg: &DynamicCloud) -> Vec { + msg.encode() + .expect("DynamicCloud::encode: frame_id exceeds 65535 bytes") +} + +fn update_map( + map: &mut VoxelMap, + origin: (f32, f32, f32), + points: &[(f32, f32, f32)], + cfg: &Config, + timestamp_nanos: u64, +) { + let inv = 1.0_f32 / cfg.voxel_size; + let max_range_sq = if cfg.max_range > 0.0 { + cfg.max_range * cfg.max_range + } else { + f32::INFINITY + }; + + let mut hits: AHashSet = AHashSet::with_capacity(points.len()); + for &(x, y, z) in points { + hits.insert(world_to_voxel(x, y, z, inv)); + } + + let mut misses: AHashSet = AHashSet::new(); + let origin_voxel = world_to_voxel(origin.0, origin.1, origin.2, inv); + let step = cfg.ray_subsample.max(1) as usize; + for (i, &p) in points.iter().enumerate() { + if i % step != 0 { + continue; + } + let dx = p.0 - origin.0; + let dy = p.1 - origin.1; + let dz = p.2 - origin.2; + if dx * dx + dy * dy + dz * dz > max_range_sq { + continue; + } + let endpoint = world_to_voxel(p.0, p.1, p.2, inv); + walk_ray( + &mut misses, + origin, + p, + cfg.voxel_size, + cfg.shadow_depth, + origin_voxel, + endpoint, + ); + } + + // Apply hits first: a voxel that is both a hit and a miss this scan + // counts as a hit (the lidar return is the stronger signal). + // + // Sequence stamping: when an observation lands on a voxel whose + // current health <= 0 (i.e. still uncertain), stamp it with the + // current slow-clock value. The check is against PRE-update health, + // so the confirmation event itself (uncertain -> confirmed) also + // gets stamped — its sequence captures the moment of confirmation. + // Subsequent hits on already-confirmed voxels (pre-health > 0) leave + // the stamp frozen. + for v in &hits { + let state = map.voxels.entry(*v).or_insert(VoxelState { + health: cfg.min_health, + timestamp_nanos, + }); + let was_uncertain = state.health <= 0; + state.health = (state.health + 1).min(cfg.max_health); + if was_uncertain { + state.timestamp_nanos = timestamp_nanos; + } + } + for v in misses.difference(&hits) { + if let Some(state) = map.voxels.get_mut(v) { + let was_uncertain = state.health <= 0; + state.health -= 1; + if state.health <= cfg.min_health { + map.voxels.remove(v); + } else if was_uncertain { + state.timestamp_nanos = timestamp_nanos; + } + } + } +} + +#[inline] +fn world_to_voxel(x: f32, y: f32, z: f32, inv: f32) -> VoxelKey { + ( + (x * inv).floor() as i32, + (y * inv).floor() as i32, + (z * inv).floor() as i32, + ) +} + +/// Amanatides & Woo 3-D DDA. Records every voxel strictly between +/// `origin_voxel` and `endpoint` into `misses`, then continues past +/// `endpoint` for `shadow_depth` meters and records those voxels too. +/// The endpoint voxel itself is not added (it is a hit, handled by the +/// caller). +fn walk_ray( + misses: &mut AHashSet, + origin: (f32, f32, f32), + end: (f32, f32, f32), + voxel_size: f32, + shadow_depth: f32, + origin_voxel: VoxelKey, + endpoint: VoxelKey, +) { + if origin_voxel == endpoint { + return; + } + + let (ox, oy, oz) = origin; + let dx = end.0 - ox; + let dy = end.1 - oy; + let dz = end.2 - oz; + + let (mut x, mut y, mut z) = origin_voxel; + + let step_x = dx.signum() as i32; + let step_y = dy.signum() as i32; + let step_z = dz.signum() as i32; + + let t_max_init = |p: f32, d: f32, vox: i32, step: i32| -> f32 { + if step == 0 { + return f32::INFINITY; + } + let next_boundary = if step > 0 { + (vox + 1) as f32 * voxel_size + } else { + vox as f32 * voxel_size + }; + (next_boundary - p) / d + }; + + let mut tx = t_max_init(ox, dx, x, step_x); + let mut ty = t_max_init(oy, dy, y, step_y); + let mut tz = t_max_init(oz, dz, z, step_z); + + let dt_x = if step_x == 0 { + f32::INFINITY + } else { + voxel_size / dx.abs() + }; + let dt_y = if step_y == 0 { + f32::INFINITY + } else { + voxel_size / dy.abs() + }; + let dt_z = if step_z == 0 { + f32::INFINITY + } else { + voxel_size / dz.abs() + }; + + let half = voxel_size * 0.5; + let endpoint_center = ( + endpoint.0 as f32 * voxel_size + half, + endpoint.1 as f32 * voxel_size + half, + endpoint.2 as f32 * voxel_size + half, + ); + let shadow_sq = shadow_depth.max(0.0).powi(2); + + // FIXME: I don't know if we really need this + let max_iter = 4096; + let mut past_endpoint = false; + for _ in 0..max_iter { + if tx < ty { + if tx < tz { + x += step_x; + tx += dt_x; + } else { + z += step_z; + tz += dt_z; + } + } else if ty < tz { + y += step_y; + ty += dt_y; + } else { + z += step_z; + tz += dt_z; + } + + // FIXME: I don't like how this is written, come back and change this. + // It would be more clear to do this in two loops, one for the normal tracing + // and a second for the shadow clearing + if (x, y, z) == endpoint { + past_endpoint = true; + continue; + } + + if past_endpoint { + let cx = x as f32 * voxel_size + half; + let cy = y as f32 * voxel_size + half; + let cz = z as f32 * voxel_size + half; + let ddx = cx - endpoint_center.0; + let ddy = cy - endpoint_center.1; + let ddz = cz - endpoint_center.2; + if ddx * ddx + ddy * ddy + ddz * ddz > shadow_sq { + return; + } + } + + misses.insert((x, y, z)); + } +} + +struct ExtractError(&'static str); +impl std::fmt::Display for ExtractError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +fn extract_xyz(msg: &PointCloud2) -> Result, ExtractError> { + let mut x_off: Option = None; + let mut y_off: Option = None; + let mut z_off: Option = None; + for f in &msg.fields { + if f.datatype != PointField::FLOAT32 as u8 { + continue; + } + match f.name.as_str() { + "x" => x_off = Some(f.offset as usize), + "y" => y_off = Some(f.offset as usize), + "z" => z_off = Some(f.offset as usize), + _ => {} + } + } + let xo = x_off.ok_or(ExtractError("missing float32 x field"))?; + let yo = y_off.ok_or(ExtractError("missing float32 y field"))?; + let zo = z_off.ok_or(ExtractError("missing float32 z field"))?; + + let n = (msg.width as usize) * (msg.height as usize); + let step = msg.point_step as usize; + if step == 0 { + return Err(ExtractError("point_step is 0")); + } + if msg.data.len() < n * step { + return Err(ExtractError( + "data buffer shorter than width*height*point_step", + )); + } + if msg.is_bigendian { + return Err(ExtractError("big-endian point data not supported")); + } + + let mut out = Vec::with_capacity(n); + for i in 0..n { + let base = i * step; + let x = read_f32_le(&msg.data, base + xo); + let y = read_f32_le(&msg.data, base + yo); + let z = read_f32_le(&msg.data, base + zo); + if x.is_finite() && y.is_finite() && z.is_finite() { + out.push((x, y, z)); + } + } + Ok(out) +} + +#[inline] +fn read_f32_le(buf: &[u8], off: usize) -> f32 { + let bytes: [u8; 4] = buf[off..off + 4] + .try_into() + .expect("bounds checked by caller"); + f32::from_le_bytes(bytes) +} + +fn build_dynamic_cloud( + map: &VoxelMap, + live: &AHashSet, + voxel_size: f32, + frame_id: &str, + stamp: Time, +) -> DynamicCloud { + // Include all voxels currently considered "live": those with health > 0 + // (confirmed) plus any voxel hit this scan (even if still uncertain). + // Live voxels were just inserted by update_map, so they're guaranteed + // in the map. + // + // Emit one event per published voxel — the sparse format permits + // tighter packings (e.g. only timestamping changed voxels) but this + // dense layout is simplest and lets downstream consumers always look + // up a per-voxel timestamp directly. + let mut voxels = Vec::with_capacity(map.voxels.len()); + let mut quantity = Vec::with_capacity(map.voxels.len()); + let mut event_indices = Vec::with_capacity(map.voxels.len()); + let mut event_timestamps = Vec::with_capacity(map.voxels.len()); + + for (&key, &state) in &map.voxels { + if state.health > 0 || live.contains(&key) { + let idx = voxels.len() as u32; + voxels.push(key); + quantity.push(state.health.max(0) as u32); + event_indices.push(idx); + event_timestamps.push(state.timestamp_nanos); + } + } + + let timestamp_nanos = (stamp.sec as i64 as u64) + .wrapping_mul(1_000_000_000) + .wrapping_add(stamp.nsec.max(0) as u64); + + DynamicCloud { + timestamp_nanos, + voxel_size, + frame_id: frame_id.to_string(), + voxels, + quantity, + event_indices, + event_timestamps, + } +} + +#[tokio::main] +async fn main() { + let transport = LcmTransport::new() + .await + .expect("failed to create LCM transport"); + run::(transport) + .await + .expect("voxel_ray_tracing run failed"); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn basic_config() -> Config { + Config { + voxel_size: 1.0, + max_range: 100.0, + ray_subsample: 1, + shadow_depth: 2.0, + min_health: 0, + max_health: 1, + sequence_period_secs: 1.0, + } + } + + fn health_of(map: &VoxelMap, key: VoxelKey) -> Option { + map.voxels.get(&key).map(|s| s.health) + } + + fn insert_health(map: &mut VoxelMap, key: VoxelKey, health: i32) { + map.voxels.insert( + key, + VoxelState { + health, + timestamp_nanos: 0, + }, + ); + } + + #[test] + fn walk_ray_hits_correct_voxels_1() { + let voxel_size = 1.0; + let shadow_depth = 2.0; + let origin = (0.5, 0.5, 0.5); + let end = (5.5, 0.5, 0.5); + let inv = 1.0 / voxel_size; + let origin_voxel = world_to_voxel(origin.0, origin.1, origin.2, inv); + let endpoint = world_to_voxel(end.0, end.1, end.2, inv); + + let mut misses: AHashSet = AHashSet::new(); + walk_ray( + &mut misses, + origin, + end, + voxel_size, + shadow_depth, + origin_voxel, + endpoint, + ); + + let expected: AHashSet = [ + (1, 0, 0), + (2, 0, 0), + (3, 0, 0), + (4, 0, 0), + (6, 0, 0), + (7, 0, 0), + ] + .into_iter() + .collect(); + assert_eq!(misses, expected); + } + + #[test] + fn walk_ray_hits_correct_voxels_2() { + let voxel_size = 1.0; + let shadow_depth = 2.0; + let origin = (0.5, 0.5, 0.5); + let end = (3.5, 2.5, 1.5); + let inv = 1.0 / voxel_size; + let origin_voxel = world_to_voxel(origin.0, origin.1, origin.2, inv); + let endpoint = world_to_voxel(end.0, end.1, end.2, inv); + + let mut misses: AHashSet = AHashSet::new(); + walk_ray( + &mut misses, + origin, + end, + voxel_size, + shadow_depth, + origin_voxel, + endpoint, + ); + + let expected: AHashSet = [ + (1, 0, 0), + (1, 1, 0), + (1, 1, 1), + (2, 1, 1), + (2, 2, 1), + (4, 2, 1), + (4, 3, 1), + (4, 3, 2), + ] + .into_iter() + .collect(); + assert_eq!(misses, expected); + } + + #[test] + fn hits_insert_voxels() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + update_map( + &mut map, + (0.0, 0.0, 0.0), + &[(5.5, 0.5, 0.5), (0.5, 5.5, 0.5)], + &cfg, + 0, + ); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + assert_eq!(health_of(&map, (0, 5, 0)), Some(1)); + assert_eq!(map.voxels.len(), 2); + } + + #[test] + fn voxels_on_ray_are_removed() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + insert_health(&mut map, (3, 0, 0), 1); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + // make sure the initial point got cleared by the new update + assert!(!map.voxels.contains_key(&(3, 0, 0))); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn voxels_not_on_ray_survive() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + insert_health(&mut map, (3, 5, 0), 1); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (3, 5, 0)), Some(1)); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn voxels_within_shadow_region_are_removed() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + insert_health(&mut map, (6, 0, 0), 1); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + // point within the shadow is no longer included, new point is included + assert!(!map.voxels.contains_key(&(6, 0, 0))); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn voxels_beyond_shadow_region_survive() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + insert_health(&mut map, (8, 0, 0), 1); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (8, 0, 0)), Some(1)); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn hit_caught_by_other_ray_is_not_removed() { + let cfg = basic_config(); + let mut map = VoxelMap::default(); + update_map( + &mut map, + (0.0, 0.0, 0.0), + &[(3.5, 0.5, 0.5), (5.5, 0.5, 0.5)], + &cfg, + 0, + ); + assert_eq!(health_of(&map, (3, 0, 0)), Some(1)); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn point_beyond_max_range_does_not_clear() { + let cfg = Config { + max_range: 3.0, + ..basic_config() + }; + let mut map = VoxelMap::default(); + insert_health(&mut map, (3, 0, 0), 1); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (3, 0, 0)), Some(1)); + } + + #[test] + fn two_hits_needed_when_min_health_is_negative() { + let cfg = Config { + min_health: -1, + ..basic_config() + }; + let mut map = VoxelMap::default(); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (5, 0, 0)), Some(0)); + + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (5, 0, 0)), Some(1)); + } + + #[test] + fn two_misses_needed_when_max_health_is_two() { + let cfg = Config { + max_health: 2, + ..basic_config() + }; + let mut map = VoxelMap::default(); + update_map(&mut map, (0.0, 0.0, 0.0), &[(3.5, 0.5, 0.5)], &cfg, 0); + update_map(&mut map, (0.0, 0.0, 0.0), &[(3.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (3, 0, 0)), Some(2)); + + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert_eq!(health_of(&map, (3, 0, 0)), Some(1)); + + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 0); + assert!(!map.voxels.contains_key(&(3, 0, 0))); + } + + #[test] + fn unconfirmed_voxels_get_timestamp_stamp() { + // With min_health=-1, a fresh hit lands at health=0 — still + // uncertain — and so must be stamped with the supplied timestamp. + let cfg = Config { + min_health: -1, + ..basic_config() + }; + let mut map = VoxelMap::default(); + update_map( + &mut map, + (0.0, 0.0, 0.0), + &[(5.5, 0.5, 0.5)], + &cfg, + 42_000_000_000, + ); + let state = map.voxels.get(&(5, 0, 0)).copied().unwrap(); + assert_eq!(state.health, 0); + assert_eq!(state.timestamp_nanos, 42_000_000_000); + } + + #[test] + fn confirmed_voxels_freeze_their_timestamp() { + // With min_health=-1, max_health=1: two hits to reach health=1. + // hit #1 (ts=10): -1 -> 0, pre-health was -1 (≤0), stamp 10 + // hit #2 (ts=99): 0 -> 1, pre-health was 0 (≤0), stamp 99 + // -- voxel now confirmed -- + // hit #3 (ts=1000): pre-health is 1 (>0), no stamp + let cfg = Config { + min_health: -1, + ..basic_config() + }; + let mut map = VoxelMap::default(); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 10); + assert_eq!(map.voxels[&(5, 0, 0)].timestamp_nanos, 10); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 99); + let confirming = map.voxels[&(5, 0, 0)]; + assert_eq!(confirming.health, 1); + assert_eq!(confirming.timestamp_nanos, 99); + update_map(&mut map, (0.0, 0.0, 0.0), &[(5.5, 0.5, 0.5)], &cfg, 1000); + let frozen = map.voxels[&(5, 0, 0)]; + assert_eq!(frozen.health, 1); + assert_eq!(frozen.timestamp_nanos, 99); + } + + #[test] + fn slow_clock_ticks_at_period() { + let mut clock = SlowClock::default(); + clock.advance(100.0, 1.0); + // First call primes the schedule and stamps with `now` itself. + assert_eq!(clock.current_nanos, secs_to_nanos(100.0)); + clock.advance(100.5, 1.0); + // Not yet a period elapsed. + assert_eq!(clock.current_nanos, secs_to_nanos(100.0)); + clock.advance(101.0, 1.0); + // Crossed first scheduled boundary (101.0). + assert_eq!(clock.current_nanos, secs_to_nanos(101.0)); + clock.advance(103.5, 1.0); + // Crossed boundaries at 102.0 and 103.0 — most recent wins. + assert_eq!(clock.current_nanos, secs_to_nanos(103.0)); + } + + #[test] + fn slow_clock_reset_snaps_backwards() { + let mut clock = SlowClock::default(); + clock.advance(100.0, 1.0); + clock.advance(110.0, 1.0); + let big = clock.current_nanos; + // Override is authoritative even if smaller. + clock.reset_to(42, 110.0, 1.0); + assert_eq!(clock.current_nanos, 42); + assert!(clock.current_nanos < big); + // Next tick still fires at the scheduled time. + clock.advance(111.0, 1.0); + assert_eq!(clock.current_nanos, secs_to_nanos(111.0)); + } +} diff --git a/dimos/mapping/ray_tracing/test_clearing.py b/dimos/mapping/ray_tracing/test_clearing.py new file mode 100644 index 0000000000..7ff90b2f7e --- /dev/null +++ b/dimos/mapping/ray_tracing/test_clearing.py @@ -0,0 +1,440 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 Dimensional Inc. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end test for RayTracingVoxelMap using the synthetic clearing scene. + +Spins up three modules — a synthetic lidar source, the Rust ray tracer, +and a global-map collector — feeds the floor/wall/box/person sequence +through, then scores the published DynamicCloud frames against two +penalties: + + forget_box : per missing box voxel per frame (the static obstacle + should never disappear from the published map after + it's been confirmed) + ghost_person : per stale voxel sitting in the person plane (x=PERSON_X, + z above the floor zone where wall returns sweep through) + that doesn't belong to the current person position — + i.e. the ray tracer didn't clear the person's previous + footprint when the person moved. + +Lower is better. Always prints the score; --rerun adds a live visualization +of both the input lidar and the published global map side by side. +""" + +from __future__ import annotations + +import argparse +import threading +import time +from typing import Any + +import numpy as np +import pytest +import rerun as rr + +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport +from dimos.mapping.ray_tracing.demo_clearing_scene import ( + CLASS_COLORS, + PERSON_HALF_WIDTH, + PERSON_X, + VOXEL_SIZE, + Frame, + _box_visible_face_points, + _classify_points, + synthetic_scene, +) +from dimos.mapping.ray_tracing.module import RayTracingVoxelMap +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +# A ray from sensor (0,0,1) to a wall point at (WALL_X=6, y_w, z_w) crosses +# the person plane (x=PERSON_X=3) at y = y_w/2 and z = 0.5*z_w + 0.5. +# The wall y range is [-3, 3] and z range is [0, 2.5], so the cells at +# the person column that the wall actually sweeps are bounded: +# +# y_at_person_plane ∈ [-1.5, 1.5] (voxel index [-15, 14]) +# z_at_person_plane ∈ [ 0.5, 1.75] (voxel index [ 5, 17]) +# +# Cells outside that bounding box can't be cleared by wall returns — +# scene limitation, not a ray-tracer bug — so we don't count them as +# ghosts. +_GHOST_CHECK_MIN_Z_VOXEL = 5 +_GHOST_CHECK_MAX_Z_VOXEL = 17 +_GHOST_CHECK_MIN_Y_VOXEL = -15 +_GHOST_CHECK_MAX_Y_VOXEL = 14 + +# A voxel needs `1 - min_health + 1` hits to become confirmed and survive +# occlusion. Default config is min_health=-1, max_health=1 → 2 hits, so +# the first two frames are warmup for the box-presence check. +_BOX_WARMUP_FRAMES = 2 + + +def _voxel_key(x: float, y: float, z: float, voxel_size: float) -> tuple[int, int, int]: + return ( + int(np.floor(x / voxel_size)), + int(np.floor(y / voxel_size)), + int(np.floor(z / voxel_size)), + ) + + +def _expected_box_voxel_keys(voxel_size: float) -> set[tuple[int, int, int]]: + return {_voxel_key(x, y, z, voxel_size) for x, y, z in _box_visible_face_points()} + + +def _person_voxel_y_range(person_y: float, voxel_size: float) -> tuple[int, int]: + return ( + int(np.floor((person_y - PERSON_HALF_WIDTH) / voxel_size)), + int(np.floor((person_y + PERSON_HALF_WIDTH - 1e-9) / voxel_size)), + ) + + +class SyntheticLidarSource(Module): + """Publishes the synthetic-scene PointCloud2 + Odometry pair per frame.""" + + lidar: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, num_frames: int = 30, frame_dt: float = 0.1, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._num_frames = num_frames + self._frame_dt = frame_dt + self._stop = threading.Event() + self._done = threading.Event() + self._thread: threading.Thread | None = None + + @rpc + def start(self) -> None: + super().start() + self._thread = threading.Thread( + target=self._publish_loop, daemon=True, name="synthetic-lidar" + ) + self._thread.start() + + @rpc + def stop(self) -> None: + self._stop.set() + thread = self._thread + if thread is not None and thread.is_alive(): + thread.join(timeout=2.0) + super().stop() + + @rpc + def wait_done(self, timeout: float = 60.0) -> bool: + return self._done.wait(timeout) + + def _publish_loop(self) -> None: + for frame in synthetic_scene(num_frames=self._num_frames, frame_dt=self._frame_dt): + if self._stop.is_set(): + return + cloud = PointCloud2.from_numpy( + points=frame.points, + frame_id="world", + timestamp=frame.timestamp_s, + ) + ox, oy, oz = (float(v) for v in frame.sensor_origin) + odom = Odometry( + ts=frame.timestamp_s, + frame_id="world", + child_frame_id="sensor", + pose=Pose(ox, oy, oz), + ) + self.lidar.publish(cloud) + self.odometry.publish(odom) + if not self._stop.wait(self._frame_dt): + continue + return + self._done.set() + + +class GlobalMapCollector(Module): + """Subscribes to global_map and stores every frame for later inspection.""" + + global_map: In[DynamicCloud] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lock = threading.Lock() + self._collected: list[DynamicCloud] = [] + self._unsub = None + + @rpc + def build(self) -> None: + super().build() + self._unsub = self.global_map.subscribe(self._on_msg) + + @rpc + def stop(self) -> None: + if self._unsub is not None: + self._unsub() + self._unsub = None + super().stop() + + def _on_msg(self, msg: DynamicCloud) -> None: + with self._lock: + self._collected.append(msg) + + @rpc + def get_frame_count(self) -> int: + with self._lock: + return len(self._collected) + + @rpc + def get_all_frames(self) -> list[DynamicCloud]: + with self._lock: + return list(self._collected) + + +def compute_loss( + collected: list[DynamicCloud], + expected_frames: list[Frame], + voxel_size: float, +) -> dict: + box_keys = _expected_box_voxel_keys(voxel_size) + person_x_voxel = int(np.floor(PERSON_X / voxel_size)) + + # Match by rounded ts (microsecond resolution); the Rust side writes + # the input PointCloud2 stamp through to the published DynamicCloud, + # and the Python `DynamicCloud.ts` is that same value re-decoded. + by_ts: dict[int, DynamicCloud] = {round((out.ts or 0.0) * 1_000_000): out for out in collected} + + forget_box = 0 + ghost_person = 0 + matched = 0 + + for frame in expected_frames: + ts_us = round(frame.timestamp_s * 1_000_000) + out = by_ts.get(ts_us) + if out is None: + continue + matched += 1 + out_keys = {(int(v[0]), int(v[1]), int(v[2])) for v in out.voxels} + + if frame.index >= _BOX_WARMUP_FRAMES: + forget_box += len(box_keys - out_keys) + + if frame.person_y is not None: + y_lo, y_hi = _person_voxel_y_range(frame.person_y, voxel_size) + for vx, vy, vz in out_keys: + if vx != person_x_voxel: + continue + if vz < _GHOST_CHECK_MIN_Z_VOXEL or vz > _GHOST_CHECK_MAX_Z_VOXEL: + continue + if vy < _GHOST_CHECK_MIN_Y_VOXEL or vy > _GHOST_CHECK_MAX_Y_VOXEL: + continue + if vy < y_lo or vy > y_hi: + ghost_person += 1 + + return { + "score": float(forget_box + ghost_person), + "forget_box": forget_box, + "ghost_person": ghost_person, + "matched_frames": matched, + "expected_frames": len(expected_frames), + "received_frames": len(collected), + "box_voxel_count": len(box_keys), + } + + +def run( + num_frames: int = 30, + frame_dt: float = 0.1, + use_rerun: bool = False, + voxel_size: float = VOXEL_SIZE, + settle_secs: float = 1.5, +) -> dict: + """Spin up the modules, feed the scene, score the output.""" + coord = ModuleCoordinator() + coord.start() + collected: list[DynamicCloud] = [] + try: + source = coord.deploy( + SyntheticLidarSource, + num_frames=num_frames, + frame_dt=frame_dt, + ) + ray_tracer = coord.deploy( + RayTracingVoxelMap, + voxel_size=voxel_size, + auto_build=True, # always rebuild — picks up source changes between runs + ) + collector = coord.deploy(GlobalMapCollector) + + # Wire ports to LCM topics explicitly — `.connect()` doesn't always + # propagate transports through to In ports in deployed worker + # modules, and the NativeModule binary needs the topic names anyway. + source.lidar.transport = LCMTransport("/test_lidar", PointCloud2) + source.odometry.transport = LCMTransport("/test_odometry", Odometry) + ray_tracer.lidar.transport = LCMTransport("/test_lidar", PointCloud2) + ray_tracer.odometry.transport = LCMTransport("/test_odometry", Odometry) + ray_tracer.global_map.transport = LCMTransport("/test_global_map", DynamicCloud) + collector.global_map.transport = LCMTransport("/test_global_map", DynamicCloud) + + ray_tracer.build() + collector.build() + + ray_tracer.start() + collector.start() + # Give the Rust binary a moment to bind LCM subscriptions, otherwise + # the first lidar frames are sent into the void. + time.sleep(0.5) + source.start() + + source.wait_done(timeout=num_frames * frame_dt + 30.0) + # Let the ray tracer finish processing trailing frames. + time.sleep(settle_secs) + collected = collector.get_all_frames() + + source.stop() + collector.stop() + ray_tracer.stop() + finally: + coord.stop() + + expected = list(synthetic_scene(num_frames=num_frames, frame_dt=frame_dt)) + loss = compute_loss(collected, expected, voxel_size) + + print() + print(f"score : {loss['score']:.0f} (lower is better)") + print( + f" forget_box : {loss['forget_box']} / target box voxels = {loss['box_voxel_count']}" + ) + print(f" ghost_person : {loss['ghost_person']}") + print(f" matched frames : {loss['matched_frames']} / {loss['expected_frames']} expected") + print(f" received frames : {loss['received_frames']}") + print() + + if use_rerun: + _visualize(collected, expected, voxel_size) + + return loss + + +def _visualize( + collected: list[DynamicCloud], + expected_frames: list[Frame], + voxel_size: float, +) -> None: + """Stream input + output side by side, color-coded so the ray tracer's + state is visually distinct from the sensor returns. + + Color scheme (intentionally non-overlapping): + input/by_class : floor=gray, wall=blue, person=red, box=orange + — what the synthetic lidar emits this frame. + output/map : bright magenta — every voxel the ray tracer + currently holds. Sits on top of the input, + slightly smaller radius so the input class + colors stay visible beneath. + output/box : bright green — published voxels that fall inside + the box AABB. If the green stays solid through + the whole walk, the ray tracer is preserving + the static obstacle correctly. + """ + rr.init("ray_tracing_clearing_test", spawn=True) + time.sleep(1.0) + + box_min = np.array([4.0, 0.3, 0.0], dtype=np.float32) # mirrors BOX_X/Y/Z + box_max = np.array([4.5, 1.1, 0.5], dtype=np.float32) + + by_ts = {round((out.ts or 0.0) * 1_000_000): out for out in collected} + for frame in expected_frames: + rr.set_time("time", duration=frame.timestamp_s) + + # ---- input, colored by surface class + classes = _classify_points(frame.points, frame.person_y) + rr.log( + "input/by_class", + rr.Points3D( + positions=frame.points, + colors=CLASS_COLORS[classes], + radii=voxel_size / 2, + ), + ) + + # ---- output, two layers in distinct solid colors + ts_us = round(frame.timestamp_s * 1_000_000) + out = by_ts.get(ts_us) + if out is None or len(out) == 0: + # Clear stale points so the entity disappears on frames where + # we didn't receive output. (Logging an empty Points3D works.) + rr.log("output/map", rr.Points3D([])) + rr.log("output/box", rr.Points3D([])) + time.sleep(0.05) + continue + + world = out.world_positions() + in_box = np.all((world >= box_min) & (world <= box_max), axis=1) + + # Everything the tracer publishes — bright magenta. Includes box + # voxels too, so the pink consistently represents "is this in the + # published map" regardless of class. + rr.log( + "output/map", + rr.Points3D( + positions=world, + colors=np.array([[255, 0, 200]], dtype=np.uint8), + radii=voxel_size / 2 * 0.55, + ), + ) + # Box subset — small bright-green dot sitting inside the pink + # output voxel, so you can see at a glance whether the static + # obstacle is being preserved across occlusion. Drawn smaller + # than the pink so the pink shows around it. + rr.log( + "output/box", + rr.Points3D( + positions=world[in_box], + colors=np.array([[60, 255, 80]], dtype=np.uint8), + radii=voxel_size / 2 * 0.25, + ), + ) + time.sleep(0.05) + + +@pytest.mark.slow +def test_ray_tracing_clearing(): + loss = run(num_frames=20, frame_dt=0.1, use_rerun=False) + # Observed on a clean run: forget_box ≈ 15, ghost_person ≈ 78 over + # 19 matched frames. Thresholds are 3-4× the observed values — meant + # to flag outright regressions (ray tracer eats the box, never clears, + # etc.) without being flaky on timing jitter. + assert loss["matched_frames"] >= 15, f"too few matched frames: {loss}" + assert loss["forget_box"] < 80, f"too many missing box voxels: {loss}" + assert loss["ghost_person"] < 300, f"too many ghost person voxels: {loss}" + + +def main(): + parser = argparse.ArgumentParser(description="End-to-end test for RayTracingVoxelMap") + parser.add_argument("--rerun", action="store_true", help="visualize input + output in Rerun") + parser.add_argument("--frames", type=int, default=30) + parser.add_argument("--dt", type=float, default=0.1) + parser.add_argument("--voxel-size", type=float, default=VOXEL_SIZE) + args = parser.parse_args() + run( + num_frames=args.frames, + frame_dt=args.dt, + use_rerun=args.rerun, + voxel_size=args.voxel_size, + ) + + +if __name__ == "__main__": + main() diff --git a/dimos/memory2/module.py b/dimos/memory2/module.py index b584553bae..c467b68ead 100644 --- a/dimos/memory2/module.py +++ b/dimos/memory2/module.py @@ -305,7 +305,22 @@ def _port_to_stream(self, name: str, input_topic: In[Any], stream: Stream[Any]) def on_msg(msg: Any) -> None: ts = getattr(msg, "ts", None) or time.time() - frame_id = getattr(msg, "frame_id", None) or default_frame_id + # For msgs that carry a parent→child transform (Odometry, + # TransformStamped), child_frame_id is the body whose pose we + # want to anchor; frame_id is just the parent (often world). + # Plain stamped msgs only have frame_id (the frame the data is in); + # if that's already 'world' the data carries no robot-pose info, + # so fall back to default_frame_id to still anchor it to the + # robot's world pose at this timestamp. + frame_id = ( + getattr(msg, "child_frame_id", None) + or getattr(msg, "frame_id", None) + or default_frame_id + ) + + if frame_id == "world": + frame_id = default_frame_id + transform = self.tf.get("world", frame_id, time_point=ts, time_tolerance=tf_tolerance) pose = transform.to_pose() if transform is not None else None diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 9b08c8dadd..14d5b0c742 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -21,6 +21,7 @@ import rerun as rr from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.nav_msgs.Odometry import Odometry from dimos_lcm.geometry_msgs import ( Transform as LCMTransform, @@ -161,6 +162,17 @@ def __neg__(self) -> Transform: """Unary minus operator returns the inverse transform.""" return self.inverse() + @classmethod + def from_odometry(cls, odom: Odometry) -> Transform: # type: ignore[name-defined] + """Create a Transform from an Odometry message using its own frame names.""" + return cls( + translation=odom.pose.position, + rotation=odom.pose.orientation, + frame_id=odom.frame_id, + child_frame_id=odom.child_frame_id, + ts=odom.ts, + ) + @classmethod def from_pose(cls, frame_id: str, pose: Pose | PoseStamped) -> Transform: # type: ignore[name-defined] """Create a Transform from a Pose or PoseStamped. diff --git a/dimos/msgs/nav_msgs/DynamicCloud.py b/dimos/msgs/nav_msgs/DynamicCloud.py new file mode 100644 index 0000000000..598d863892 --- /dev/null +++ b/dimos/msgs/nav_msgs/DynamicCloud.py @@ -0,0 +1,249 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DynamicCloud: a per-voxel point cloud with a separate sparse event log. + +Designed for voxel-grid maps where each "point" is a voxel cell carrying +a quantity (occupancy / health / hit count), and a sparse second array +records timestamped events that reference points by index — useful for +expressing "this voxel was last seen at time T" without paying the +per-point cost when most voxels have no event. + +Wire format (little-endian, packed):: + + u64 timestamp_nanos # overall message timestamp + f32 voxel_size # meters per voxel edge + u16 frame_id_len + bytes frame_id # utf-8, frame_id_len bytes + u32 num_points + i32[N*3] voxels # (x, y, z) interleaved + u32[N] quantity # per-point quantity + u32 num_events + u32[M] event_indices # indices into voxels (0 ≤ idx < N) + u64[M] event_timestamps # nanoseconds + +`num_events` is independent of `num_points`: events can be empty, can +reference the same point multiple times, and don't need to cover every +point. The Rust mirror lives at +``dimos/mapping/ray_tracing/rust/src/dynamic_cloud.rs`` and must stay +in sync with this format. ``test_dynamic_cloud.py`` pins a known-bytes +fixture that both sides assert against. +""" + +from __future__ import annotations + +import struct +from typing import TYPE_CHECKING + +import numpy as np + +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from rerun._baseclasses import Archetype + + +_HEADER_FMT = " None: + self.ts = ts if ts is not None else 0.0 # type: ignore[assignment] + self.frame_id = frame_id + self.voxel_size = float(voxel_size) + + if voxels is None: + voxels = np.zeros((0, 3), dtype=np.int32) + if quantity is None: + quantity = np.zeros(0, dtype=np.uint32) + if event_indices is None: + event_indices = np.zeros(0, dtype=np.uint32) + if event_timestamps is None: + event_timestamps = np.zeros(0, dtype=np.uint64) + + voxels = np.ascontiguousarray(voxels, dtype=np.int32) + if voxels.ndim != 2 or voxels.shape[1] != 3: + raise ValueError(f"voxels must have shape (N, 3), got {voxels.shape}") + + quantity = np.ascontiguousarray(quantity, dtype=np.uint32).reshape(-1) + event_indices = np.ascontiguousarray(event_indices, dtype=np.uint32).reshape(-1) + event_timestamps = np.ascontiguousarray(event_timestamps, dtype=np.uint64).reshape(-1) + + num_points = voxels.shape[0] + if quantity.shape[0] != num_points: + raise ValueError( + f"voxels/quantity length mismatch: {num_points} vs {quantity.shape[0]}" + ) + if event_indices.shape[0] != event_timestamps.shape[0]: + raise ValueError( + f"event_indices/event_timestamps length mismatch: " + f"{event_indices.shape[0]} vs {event_timestamps.shape[0]}" + ) + if num_points == 0 and event_indices.shape[0] > 0: + raise ValueError("event_indices nonempty but voxels is empty") + if num_points > 0 and event_indices.shape[0] > 0: + max_idx = int(event_indices.max()) + if max_idx >= num_points: + raise ValueError(f"event index {max_idx} out of range for {num_points} points") + + self.voxels = voxels + self.quantity = quantity + self.event_indices = event_indices + self.event_timestamps = event_timestamps + + def __len__(self) -> int: + return int(self.voxels.shape[0]) + + def world_positions(self) -> np.ndarray: + """Return points reprojected to world space as `(N, 3) float32`.""" + return self.voxels.astype(np.float32) * np.float32(self.voxel_size) + + def per_point_latest_timestamp(self) -> np.ndarray: + """Return the latest event timestamp per point, 0 if no events touch a point. + + Useful for visualization or "freshness" coloring. Shape ``(N,) uint64``. + """ + result = np.zeros(len(self), dtype=np.uint64) + if self.event_indices.size == 0: + return result + # For each event, keep the max timestamp per index. + np.maximum.at(result, self.event_indices, self.event_timestamps) + return result + + def lcm_encode(self) -> bytes: + frame_bytes = self.frame_id.encode("utf-8") + if len(frame_bytes) > 0xFFFF: + raise ValueError(f"frame_id too long: {len(frame_bytes)} > 65535 bytes") + timestamp_nanos = int(self.ts * 1_000_000_000) if self.ts else 0 + if timestamp_nanos < 0: + timestamp_nanos = 0 + + header = struct.pack(_HEADER_FMT, timestamp_nanos, self.voxel_size, len(frame_bytes)) + num_points_bytes = struct.pack(_U32_FMT, len(self)) + num_events_bytes = struct.pack(_U32_FMT, int(self.event_indices.shape[0])) + return b"".join( + [ + header, + frame_bytes, + num_points_bytes, + self.voxels.tobytes(), + self.quantity.tobytes(), + num_events_bytes, + self.event_indices.tobytes(), + self.event_timestamps.tobytes(), + ] + ) + + @classmethod + def lcm_decode(cls, data: bytes) -> DynamicCloud: + if len(data) < _HEADER_SIZE: + raise ValueError(f"DynamicCloud: data too short for header ({len(data)} bytes)") + timestamp_nanos, voxel_size, frame_id_len = struct.unpack_from(_HEADER_FMT, data, 0) + offset = _HEADER_SIZE + + if len(data) < offset + frame_id_len + _U32_SIZE: + raise ValueError("DynamicCloud: data too short for frame_id + num_points") + frame_id = data[offset : offset + frame_id_len].decode("utf-8") + offset += frame_id_len + + (num_points,) = struct.unpack_from(_U32_FMT, data, offset) + offset += _U32_SIZE + + voxels_size = num_points * 3 * 4 + quantity_size = num_points * 4 + if len(data) < offset + voxels_size + quantity_size + _U32_SIZE: + raise ValueError("DynamicCloud: data too short for voxels + quantity + num_events") + + voxels = np.frombuffer(data, dtype=np.int32, count=num_points * 3, offset=offset).reshape( + num_points, 3 + ) + offset += voxels_size + quantity = np.frombuffer(data, dtype=np.uint32, count=num_points, offset=offset) + offset += quantity_size + + (num_events,) = struct.unpack_from(_U32_FMT, data, offset) + offset += _U32_SIZE + + events_idx_size = num_events * 4 + events_ts_size = num_events * 8 + expected_tail = events_idx_size + events_ts_size + if len(data) - offset != expected_tail: + raise ValueError( + f"DynamicCloud: payload size mismatch " + f"(expected {expected_tail} tail bytes, got {len(data) - offset})" + ) + + event_indices = np.frombuffer(data, dtype=np.uint32, count=num_events, offset=offset) + offset += events_idx_size + event_timestamps = np.frombuffer(data, dtype=np.uint64, count=num_events, offset=offset) + + return cls( + voxels=voxels.copy(), + quantity=quantity.copy(), + event_indices=event_indices.copy(), + event_timestamps=event_timestamps.copy(), + voxel_size=voxel_size, + frame_id=frame_id, + ts=timestamp_nanos / 1_000_000_000 if timestamp_nanos > 0 else None, + ) + + def to_rerun( + self, + colormap: str = "turbo", + radii: float | None = None, + normalize_quantity: bool = True, + ) -> Archetype: + """Return an `rr.Points3D` archetype colored by `quantity`. + + Events are not visualized by default (use `per_point_latest_timestamp()` + if you need to derive a freshness-based visualization). + """ + import rerun as rr + + positions = self.world_positions() + if len(positions) == 0: + return rr.Points3D([]) + + colors = self._quantity_colors(colormap, normalize=normalize_quantity) + radius = self.voxel_size / 2 if radii is None else radii + return rr.Points3D(positions=positions, colors=colors, radii=radius) + + def _quantity_colors(self, colormap: str, normalize: bool) -> np.ndarray: + import matplotlib.pyplot as plt + + quantity = self.quantity.astype(np.float32) + if normalize and quantity.size > 0: + lo, hi = float(quantity.min()), float(quantity.max()) + spread = hi - lo + t = (quantity - lo) / spread if spread > 0 else np.zeros_like(quantity) + else: + t = np.clip(quantity / 255.0, 0.0, 1.0) + rgba = plt.get_cmap(colormap)(t) + return np.asarray(rgba[:, :3] * 255, dtype=np.uint8) diff --git a/dimos/msgs/nav_msgs/Graph3D.hpp b/dimos/msgs/nav_msgs/Graph3D.hpp new file mode 100644 index 0000000000..a440127ae2 --- /dev/null +++ b/dimos/msgs/nav_msgs/Graph3D.hpp @@ -0,0 +1,175 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Typed C++ helper mirroring the Python `dimos.msgs.nav_msgs.Graph3D`. +// Canonical schema lives in `dimos/msgs/nav_msgs/Graph3D.ksy` — keep +// encode() in sync with that file (and with Graph3D.py.lcm_decode). +// +// Wire format (big-endian): +// +// uint64 edge_count +// uint64 node_count +// double timestamp // seconds since epoch +// per node (node_count): +// pose_stamped: +// double ts +// uint32 frame_id_len +// bytes frame_id (utf-8, no terminator) +// 7×double pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w +// uint64 id +// uint64 metadata_id +// per edge (edge_count): +// uint64 start_id +// uint64 end_id +// double timestamp +// uint64 metadata_id +// +// Edges reference nodes by `id`, not by index. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace dimos { + +namespace graph3d_detail { + +// Host-order → big-endian byte writers. Avoid for portability +// (macOS uses different names) — write byte-by-byte from the top. + +inline void write_u32_be(std::vector& out, uint32_t v) { + out.push_back(static_cast((v >> 24) & 0xFF)); + out.push_back(static_cast((v >> 16) & 0xFF)); + out.push_back(static_cast((v >> 8) & 0xFF)); + out.push_back(static_cast( v & 0xFF)); +} + +inline void write_u64_be(std::vector& out, uint64_t v) { + for (int shift = 56; shift >= 0; shift -= 8) { + out.push_back(static_cast((v >> shift) & 0xFF)); + } +} + +inline void write_double_be(std::vector& out, double v) { + uint64_t bits; + std::memcpy(&bits, &v, sizeof(bits)); + write_u64_be(out, bits); +} + +inline void write_bytes(std::vector& out, const std::string& s) { + out.insert(out.end(), s.begin(), s.end()); +} + +} // namespace graph3d_detail + +class Graph3D { +public: + struct PoseStamped { + double ts = 0.0; + std::string frame_id; + double pos_x = 0.0, pos_y = 0.0, pos_z = 0.0; + double quat_x = 0.0, quat_y = 0.0, quat_z = 0.0, quat_w = 1.0; + }; + + struct Node3D { + PoseStamped pose; + uint64_t id = 0; + uint64_t metadata_id = 0; + }; + + struct Edge { + uint64_t start_id = 0; + uint64_t end_id = 0; + double timestamp = 0.0; + uint64_t metadata_id = 0; + }; + + Graph3D(std::string frame_id, double timestamp) + : frame_id_(std::move(frame_id)), timestamp_(timestamp) {} + + void reserve_nodes(size_t capacity) { nodes_.reserve(capacity); } + void reserve_edges(size_t capacity) { edges_.reserve(capacity); } + + // Add a node. The pose's frame_id defaults to the graph's frame_id — + // override per-node only if a node lives in a different frame. + void add_node(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z, + double quat_x, double quat_y, double quat_z, double quat_w, + std::string node_frame_id = "") { + PoseStamped pose; + pose.ts = pose_ts; + pose.frame_id = node_frame_id.empty() ? frame_id_ : std::move(node_frame_id); + pose.pos_x = pos_x; pose.pos_y = pos_y; pose.pos_z = pos_z; + pose.quat_x = quat_x; pose.quat_y = quat_y; pose.quat_z = quat_z; pose.quat_w = quat_w; + nodes_.push_back({pose, id, metadata_id}); + } + + // Position-only convenience (orientation defaults to identity). + void add_node_xyz(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z) { + add_node(id, metadata_id, pose_ts, pos_x, pos_y, pos_z, 0.0, 0.0, 0.0, 1.0); + } + + void add_edge(uint64_t start_id, uint64_t end_id, double edge_ts, + uint64_t metadata_id = 0) { + edges_.push_back({start_id, end_id, edge_ts, metadata_id}); + } + + size_t node_count() const { return nodes_.size(); } + size_t edge_count() const { return edges_.size(); } + const std::string& frame_id() const { return frame_id_; } + + std::vector encode() const { + using namespace graph3d_detail; + std::vector out; + // Conservative reservation: header + per-node fixed bytes + per-edge. + // frame_id strings add variable length on top — that just causes a + // realloc, not correctness issues. + out.reserve(24 + nodes_.size() * 84 + edges_.size() * 32); + write_u64_be(out, static_cast(edges_.size())); + write_u64_be(out, static_cast(nodes_.size())); + write_double_be(out, timestamp_); + for (const auto& n : nodes_) { + // pose_stamped first (per Graph3D.ksy) + write_double_be(out, n.pose.ts); + write_u32_be(out, static_cast(n.pose.frame_id.size())); + write_bytes(out, n.pose.frame_id); + write_double_be(out, n.pose.pos_x); + write_double_be(out, n.pose.pos_y); + write_double_be(out, n.pose.pos_z); + write_double_be(out, n.pose.quat_x); + write_double_be(out, n.pose.quat_y); + write_double_be(out, n.pose.quat_z); + write_double_be(out, n.pose.quat_w); + // then id, metadata_id + write_u64_be(out, n.id); + write_u64_be(out, n.metadata_id); + } + for (const auto& e : edges_) { + write_u64_be(out, e.start_id); + write_u64_be(out, e.end_id); + write_double_be(out, e.timestamp); + write_u64_be(out, e.metadata_id); + } + return out; + } + + int publish(lcm::LCM& lcm, const std::string& channel) const { + std::vector bytes = encode(); + return lcm.publish(channel, bytes.data(), static_cast(bytes.size())); + } + +private: + std::string frame_id_; + double timestamp_; + std::vector nodes_; + std::vector edges_; +}; + +} // namespace dimos diff --git a/dimos/msgs/nav_msgs/Graph3D.py b/dimos/msgs/nav_msgs/Graph3D.py new file mode 100644 index 0000000000..4e342e022a --- /dev/null +++ b/dimos/msgs/nav_msgs/Graph3D.py @@ -0,0 +1,237 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph3D: pose-graph / visibility-graph message with typed nodes and edges. + +Edges reference nodes by ``id`` (not list index), so producers are free +to reorder or re-emit nodes between snapshots. ``metadata_id`` is a +caller-defined enum — ex: for far_planner: 0=normal, 1=odom, 2=goal +""" + +from __future__ import annotations + +from dataclasses import dataclass +import struct +import time +from typing import TYPE_CHECKING, BinaryIO + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from rerun._baseclasses import Archetype + + +# Default node metadata_id → RGBA. Callers can override via the +# `node_colors` kwarg on to_rerun*; these defaults match the far_planner +# node-type enum (0=normal, 1=odom, 2=goal, 3=frontier, 4=navpoint). +_DEFAULT_NODE_COLORS: dict[int, tuple[int, int, int, int]] = { + 0: (180, 180, 180, 200), + 1: (0, 255, 0, 255), + 2: (255, 0, 0, 255), + 3: (255, 165, 0, 200), + 4: (0, 200, 255, 200), +} +_DEFAULT_NODE_COLOR = (200, 200, 200, 180) + +# Edge-type → RGBA, soft default (caller can override via to_rerun args). +_DEFAULT_EDGE_COLORS: dict[int, tuple[int, int, int, int]] = { + 0: (0, 220, 100, 200), # odom / traversable — green + 1: (255, 180, 0, 220), # loop_closure / partial — yellow + 2: (255, 50, 50, 150), # blocked — red +} +_DEFAULT_EDGE_COLOR = (180, 180, 180, 180) + + +class Graph3D(Timestamped): + msg_name = "nav_msgs.Graph3D" + + @dataclass + class Node3D: + pose: PoseStamped + id: int = 0 + metadata_id: int = 0 + + @dataclass + class Edge: + start_id: int + end_id: int + timestamp: float = 0.0 + metadata_id: int = 0 + + ts: float + nodes: list[Node3D] + edges: list[Edge] + + def __init__( + self, + ts: float = 0.0, + nodes: list[Graph3D.Node3D] | None = None, + edges: list[Graph3D.Edge] | None = None, + ) -> None: + self.ts = ts if ts != 0 else time.time() + self.nodes = nodes if nodes is not None else [] + self.edges = edges if edges is not None else [] + + def lcm_encode(self) -> bytes: + # Field order matches Graph3D.ksy: edge_count, node_count, ts, + # nodes[] (pose, id, metadata_id), edges[]. + parts: list[bytes] = [] + parts.append(struct.pack(">QQd", len(self.edges), len(self.nodes), self.ts)) + for node in self.nodes: + frame_id_bytes = node.pose.frame_id.encode("utf-8") + parts.append(struct.pack(">d", node.pose.ts)) + parts.append(struct.pack(">I", len(frame_id_bytes))) + parts.append(frame_id_bytes) + parts.append( + struct.pack( + ">7d", + node.pose.position.x, + node.pose.position.y, + node.pose.position.z, + node.pose.orientation.x, + node.pose.orientation.y, + node.pose.orientation.z, + node.pose.orientation.w, + ) + ) + parts.append(struct.pack(">QQ", node.id, node.metadata_id)) + for edge in self.edges: + parts.append( + struct.pack(">QQdQ", edge.start_id, edge.end_id, edge.timestamp, edge.metadata_id) + ) + return b"".join(parts) + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> Graph3D: + buf = data if isinstance(data, (bytes, bytearray)) else data.read() + offset = 0 + edge_count, node_count, graph_ts = struct.unpack_from(">QQd", buf, offset) + offset += 24 + + nodes: list[Graph3D.Node3D] = [] + for _ in range(node_count): + (pose_ts,) = struct.unpack_from(">d", buf, offset) + offset += 8 + (frame_id_len,) = struct.unpack_from(">I", buf, offset) + offset += 4 + frame_id = buf[offset : offset + frame_id_len].decode("utf-8") + offset += frame_id_len + px, py, pz, qx, qy, qz, qw = struct.unpack_from(">7d", buf, offset) + offset += 56 + node_id, metadata_id = struct.unpack_from(">QQ", buf, offset) + offset += 16 + pose = PoseStamped( + ts=pose_ts, + frame_id=frame_id, + position=Vector3(px, py, pz), + orientation=Quaternion(qx, qy, qz, qw), + ) + nodes.append(cls.Node3D(pose=pose, id=node_id, metadata_id=metadata_id)) + + edges: list[Graph3D.Edge] = [] + for _ in range(edge_count): + start_id, end_id, edge_ts, edge_metadata_id = struct.unpack_from(">QQdQ", buf, offset) + offset += 32 + edges.append( + cls.Edge( + start_id=start_id, + end_id=end_id, + timestamp=edge_ts, + metadata_id=edge_metadata_id, + ) + ) + + return cls(ts=graph_ts, nodes=nodes, edges=edges) + + def to_rerun( + self, + z_offset: float = 0.0, + radii: float = 0.12, + node_colors: dict[int, tuple[int, int, int, int]] | None = None, + ) -> Archetype: + """Default visualization: ``rr.Points3D`` of just the nodes. + + For nodes + edges in separate entity sub-paths, use + ``to_rerun_multi`` from a ``visual_override`` callback. + """ + import rerun as rr + + nc = node_colors if node_colors is not None else _DEFAULT_NODE_COLORS + positions = [ + [n.pose.position.x, n.pose.position.y, n.pose.position.z + z_offset] for n in self.nodes + ] + colors = [nc.get(n.metadata_id, _DEFAULT_NODE_COLOR) for n in self.nodes] + node_radii = [radii * 2.0 if n.metadata_id in (1, 2) else radii for n in self.nodes] + return rr.Points3D(positions, colors=colors, radii=node_radii) + + def to_rerun_multi( + self, + base_path: str, + z_offset: float = 0.0, + node_radius: float = 0.12, + edge_radius: float = 0.04, + node_colors: dict[int, tuple[int, int, int, int]] | None = None, + edge_colors: dict[int, tuple[int, int, int, int]] | None = None, + ) -> list[tuple[str, Archetype]]: + """Return ``[(base_path/nodes, Points3D), (base_path/edges, LineStrips3D)]``. + + Intended for use from ``visual_override`` callbacks where the + bridge supports the ``RerunMulti`` list-of-tuples form. + """ + import rerun as rr + + nc = node_colors if node_colors is not None else _DEFAULT_NODE_COLORS + ec = edge_colors if edge_colors is not None else _DEFAULT_EDGE_COLORS + + node_positions = [ + [n.pose.position.x, n.pose.position.y, n.pose.position.z + z_offset] for n in self.nodes + ] + node_colors_list = [nc.get(n.metadata_id, _DEFAULT_NODE_COLOR) for n in self.nodes] + node_radii = [ + node_radius * 2.0 if n.metadata_id in (1, 2) else node_radius for n in self.nodes + ] + nodes_archetype = rr.Points3D(node_positions, colors=node_colors_list, radii=node_radii) + + id_to_pose: dict[int, PoseStamped] = {n.id: n.pose for n in self.nodes} + strips: list[list[list[float]]] = [] + edge_colors_list: list[tuple[int, int, int, int]] = [] + for edge in self.edges: + start = id_to_pose.get(edge.start_id) + end = id_to_pose.get(edge.end_id) + if start is None or end is None: + continue + strips.append( + [ + [start.position.x, start.position.y, start.position.z + z_offset], + [end.position.x, end.position.y, end.position.z + z_offset], + ] + ) + edge_colors_list.append(ec.get(edge.metadata_id, _DEFAULT_EDGE_COLOR)) + edges_archetype = rr.LineStrips3D( + strips, colors=edge_colors_list, radii=[edge_radius] * len(strips) + ) + + return [ + (f"{base_path}/nodes", nodes_archetype), + (f"{base_path}/edges", edges_archetype), + ] + + def __len__(self) -> int: + return len(self.nodes) + + def __str__(self) -> str: + return f"Graph3D(nodes={len(self.nodes)}, edges={len(self.edges)})" diff --git a/dimos/msgs/nav_msgs/GraphDelta3D.hpp b/dimos/msgs/nav_msgs/GraphDelta3D.hpp new file mode 100644 index 0000000000..4db42eb71c --- /dev/null +++ b/dimos/msgs/nav_msgs/GraphDelta3D.hpp @@ -0,0 +1,168 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Typed C++ helper mirroring the Python `dimos.msgs.nav_msgs.GraphDelta3D`. +// +// Wire format (big-endian): +// +// uint64 node_count +// double timestamp // seconds since epoch +// per node (node_count): +// pose_stamped: // (same as Graph3D's node3d pose) +// double ts +// uint32 frame_id_len +// bytes frame_id (utf-8, no terminator) +// 7×double pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w +// uint64 id +// uint64 metadata_id +// per transform (node_count): +// 7×double translation_x, translation_y, translation_z, +// rotation_x, rotation_y, rotation_z, rotation_w +// +// Two aligned arrays: ``transforms[i]`` is the SE(3) delta about to +// be applied to ``nodes[i]``. ``post_pose = transforms[i] * nodes[i].pose`` +// is the convention (left-multiply). +// +// `GraphDelta3D.py.lcm_decode` reads exactly this layout — keep in sync. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace dimos { + +namespace graph_delta3d_detail { + +inline void write_u32_be(std::vector& out, uint32_t v) { + out.push_back(static_cast((v >> 24) & 0xFF)); + out.push_back(static_cast((v >> 16) & 0xFF)); + out.push_back(static_cast((v >> 8) & 0xFF)); + out.push_back(static_cast( v & 0xFF)); +} + +inline void write_u64_be(std::vector& out, uint64_t v) { + for (int shift = 56; shift >= 0; shift -= 8) { + out.push_back(static_cast((v >> shift) & 0xFF)); + } +} + +inline void write_double_be(std::vector& out, double v) { + uint64_t bits; + std::memcpy(&bits, &v, sizeof(bits)); + write_u64_be(out, bits); +} + +inline void write_bytes(std::vector& out, const std::string& s) { + out.insert(out.end(), s.begin(), s.end()); +} + +} // namespace graph_delta3d_detail + +class GraphDelta3D { +public: + struct PoseStamped { + double ts = 0.0; + std::string frame_id; + double pos_x = 0.0, pos_y = 0.0, pos_z = 0.0; + double quat_x = 0.0, quat_y = 0.0, quat_z = 0.0, quat_w = 1.0; + }; + + struct Node3D { + PoseStamped pose; + uint64_t id = 0; + uint64_t metadata_id = 0; + }; + + struct Transform { + double translation_x = 0.0, translation_y = 0.0, translation_z = 0.0; + double rotation_x = 0.0, rotation_y = 0.0, rotation_z = 0.0, rotation_w = 1.0; + }; + + GraphDelta3D(std::string frame_id, double timestamp) + : frame_id_(std::move(frame_id)), timestamp_(timestamp) {} + + void reserve(size_t capacity) { + nodes_.reserve(capacity); + transforms_.reserve(capacity); + } + + // Add a node + its SE(3) delta. Pass empty `node_frame_id` to inherit + // the graph's frame_id. + void add(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z, + double quat_x, double quat_y, double quat_z, double quat_w, + double translation_x, double translation_y, double translation_z, + double rotation_x, double rotation_y, double rotation_z, double rotation_w, + std::string node_frame_id = "") { + Node3D node; + node.id = id; + node.metadata_id = metadata_id; + node.pose.ts = pose_ts; + node.pose.frame_id = node_frame_id.empty() ? frame_id_ : std::move(node_frame_id); + node.pose.pos_x = pos_x; node.pose.pos_y = pos_y; node.pose.pos_z = pos_z; + node.pose.quat_x = quat_x; node.pose.quat_y = quat_y; + node.pose.quat_z = quat_z; node.pose.quat_w = quat_w; + nodes_.push_back(node); + + Transform tf; + tf.translation_x = translation_x; tf.translation_y = translation_y; tf.translation_z = translation_z; + tf.rotation_x = rotation_x; tf.rotation_y = rotation_y; + tf.rotation_z = rotation_z; tf.rotation_w = rotation_w; + transforms_.push_back(tf); + } + + size_t size() const { return nodes_.size(); } + bool empty() const { return nodes_.empty(); } + const std::string& frame_id() const { return frame_id_; } + + std::vector encode() const { + using namespace graph_delta3d_detail; + std::vector out; + out.reserve(16 + nodes_.size() * (84 + 56)); + write_u64_be(out, static_cast(nodes_.size())); + write_double_be(out, timestamp_); + for (const auto& n : nodes_) { + write_double_be(out, n.pose.ts); + write_u32_be(out, static_cast(n.pose.frame_id.size())); + write_bytes(out, n.pose.frame_id); + write_double_be(out, n.pose.pos_x); + write_double_be(out, n.pose.pos_y); + write_double_be(out, n.pose.pos_z); + write_double_be(out, n.pose.quat_x); + write_double_be(out, n.pose.quat_y); + write_double_be(out, n.pose.quat_z); + write_double_be(out, n.pose.quat_w); + write_u64_be(out, n.id); + write_u64_be(out, n.metadata_id); + } + for (const auto& t : transforms_) { + write_double_be(out, t.translation_x); + write_double_be(out, t.translation_y); + write_double_be(out, t.translation_z); + write_double_be(out, t.rotation_x); + write_double_be(out, t.rotation_y); + write_double_be(out, t.rotation_z); + write_double_be(out, t.rotation_w); + } + return out; + } + + int publish(lcm::LCM& lcm, const std::string& channel) const { + std::vector bytes = encode(); + return lcm.publish(channel, bytes.data(), static_cast(bytes.size())); + } + +private: + std::string frame_id_; + double timestamp_; + std::vector nodes_; + std::vector transforms_; +}; + +} // namespace dimos diff --git a/dimos/msgs/nav_msgs/GraphDelta3D.py b/dimos/msgs/nav_msgs/GraphDelta3D.py new file mode 100644 index 0000000000..a61ba67604 --- /dev/null +++ b/dimos/msgs/nav_msgs/GraphDelta3D.py @@ -0,0 +1,198 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GraphDelta3D: per-node SE(3) transforms about to be applied to a list of nodes. + +Two aligned arrays: ``nodes[i]`` is the node, ``transforms[i]`` is the +SE(3) delta about to be applied to it. ``post_pose = transforms[i] * +nodes[i].pose`` is the convention (left-multiply). + +Use case: PGO publishes this on ``loop_closure_event`` when iSAM2 +smooths the pose graph — ``nodes[i]`` is the keyframe pre-smooth, +``transforms[i]`` is the delta iSAM2 just applied to it. Consumers can +re-derive post-poses or filter to large deltas. + +Wire format mirrors ``Graph3D`` conventions: big-endian, ``Node3D`` +serialization shared, ``Transform`` is just 7 f8s (translation + +quaternion). Custom binary, dispatched by the ``#nav_msgs.GraphDelta3D`` +channel-name suffix. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import struct +import time +from typing import TYPE_CHECKING, BinaryIO + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from rerun._baseclasses import Archetype + + +class GraphDelta3D(Timestamped): + msg_name = "nav_msgs.GraphDelta3D" + + # Reuse Graph3D's nested Node3D for wire-format consistency. A + # GraphDelta3D[i].node is byte-identical to a Graph3D.nodes[i]. + Node3D = Graph3D.Node3D + + @dataclass + class Transform: + """SE(3) transform — translation + rotation quaternion (xyzw).""" + + translation: Vector3 + rotation: Quaternion + + ts: float + nodes: list[Graph3D.Node3D] + transforms: list[Transform] + + def __init__( + self, + ts: float = 0.0, + nodes: list[Graph3D.Node3D] | None = None, + transforms: list[Transform] | None = None, + ) -> None: + self.ts = ts if ts != 0 else time.time() + self.nodes = nodes if nodes is not None else [] + self.transforms = transforms if transforms is not None else [] + if len(self.nodes) != len(self.transforms): + raise ValueError( + f"nodes ({len(self.nodes)}) and transforms ({len(self.transforms)}) " + "must be the same length — they're aligned arrays" + ) + + def lcm_encode(self) -> bytes: + parts: list[bytes] = [] + parts.append(struct.pack(">Qd", len(self.nodes), self.ts)) + for node in self.nodes: + frame_id_bytes = node.pose.frame_id.encode("utf-8") + parts.append(struct.pack(">d", node.pose.ts)) + parts.append(struct.pack(">I", len(frame_id_bytes))) + parts.append(frame_id_bytes) + parts.append( + struct.pack( + ">7d", + node.pose.position.x, + node.pose.position.y, + node.pose.position.z, + node.pose.orientation.x, + node.pose.orientation.y, + node.pose.orientation.z, + node.pose.orientation.w, + ) + ) + parts.append(struct.pack(">QQ", node.id, node.metadata_id)) + for transform in self.transforms: + parts.append( + struct.pack( + ">7d", + transform.translation.x, + transform.translation.y, + transform.translation.z, + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ) + ) + return b"".join(parts) + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> GraphDelta3D: + buf = data if isinstance(data, (bytes, bytearray)) else data.read() + offset = 0 + node_count, graph_ts = struct.unpack_from(">Qd", buf, offset) + offset += 16 + + nodes: list[Graph3D.Node3D] = [] + for _ in range(node_count): + (pose_ts,) = struct.unpack_from(">d", buf, offset) + offset += 8 + (frame_id_len,) = struct.unpack_from(">I", buf, offset) + offset += 4 + frame_id = buf[offset : offset + frame_id_len].decode("utf-8") + offset += frame_id_len + px, py, pz, qx, qy, qz, qw = struct.unpack_from(">7d", buf, offset) + offset += 56 + node_id, metadata_id = struct.unpack_from(">QQ", buf, offset) + offset += 16 + pose = PoseStamped( + ts=pose_ts, + frame_id=frame_id, + position=Vector3(px, py, pz), + orientation=Quaternion(qx, qy, qz, qw), + ) + nodes.append(Graph3D.Node3D(pose=pose, id=node_id, metadata_id=metadata_id)) + + transforms: list[GraphDelta3D.Transform] = [] + for _ in range(node_count): + tx, ty, tz, qx, qy, qz, qw = struct.unpack_from(">7d", buf, offset) + offset += 56 + transforms.append( + cls.Transform( + translation=Vector3(tx, ty, tz), + rotation=Quaternion(qx, qy, qz, qw), + ) + ) + + return cls(ts=graph_ts, nodes=nodes, transforms=transforms) + + def to_rerun( + self, + z_offset: float = 0.0, + arrow_scale: float = 1.0, + ) -> Archetype: + """Render each (node, transform) pair as an arrow from node.pose to post_pose. + + The arrow origin is the node's current position; the vector is + the translation component of the transform (scaled by + ``arrow_scale``). Rotation deltas aren't visualized by default — + callers wanting to see those can subclass. + """ + import rerun as rr + + if not self.nodes: + return rr.Arrows3D(origins=[], vectors=[]) + + origins = [] + vectors = [] + for node, transform in zip(self.nodes, self.transforms, strict=True): + origins.append( + [ + node.pose.position.x, + node.pose.position.y, + node.pose.position.z + z_offset, + ] + ) + vectors.append( + [ + transform.translation.x * arrow_scale, + transform.translation.y * arrow_scale, + transform.translation.z * arrow_scale, + ] + ) + return rr.Arrows3D(origins=origins, vectors=vectors) + + def __len__(self) -> int: + return len(self.nodes) + + def __str__(self) -> str: + return f"GraphDelta3D(nodes={len(self.nodes)})" diff --git a/dimos/msgs/nav_msgs/GraphNodes3D.py b/dimos/msgs/nav_msgs/GraphNodes3D.py deleted file mode 100644 index 95e74b8d14..0000000000 --- a/dimos/msgs/nav_msgs/GraphNodes3D.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GraphNodes3D: visibility-graph nodes for debug visualization. - -On the wire this reuses ``nav_msgs/Path``. Each pose is a node; the -``orientation.w`` field encodes the node type: - - 0 = normal nav node - 1 = odom (robot) node - 2 = goal node - 3 = frontier node - 4 = navpoint (trajectory) node - -Rerun visualization renders as ``rr.Points3D`` with type-based coloring. -""" - -from __future__ import annotations - -import time -from typing import TYPE_CHECKING, BinaryIO - -from dimos_lcm.geometry_msgs import ( - Point as LCMPoint, - Pose as LCMPose, - PoseStamped as LCMPoseStamped, - Quaternion as LCMQuaternion, -) -from dimos_lcm.nav_msgs import Path as LCMPath -from dimos_lcm.std_msgs import Header as LCMHeader, Time as LCMTime - -from dimos.types.timestamped import Timestamped - -if TYPE_CHECKING: - from rerun._baseclasses import Archetype - - -# Node type → RGBA color -TYPE_COLORS: dict[int, tuple[int, int, int, int]] = { - 0: (180, 180, 180, 200), # normal — grey - 1: (0, 255, 0, 255), # odom — green - 2: (255, 0, 0, 255), # goal — red - 3: (255, 165, 0, 200), # frontier — orange - 4: (0, 200, 255, 200), # navpoint — cyan -} -DEFAULT_COLOR = (200, 200, 200, 180) - - -class GraphNode: - """A single graph node with position and type.""" - - __slots__ = ("node_type", "x", "y", "z") - - def __init__(self, x: float, y: float, z: float, node_type: int = 0) -> None: - self.x = x - self.y = y - self.z = z - self.node_type = node_type - - -def _sec_nsec(ts: float) -> list[int]: - s = int(ts) - return [s, int((ts - s) * 1_000_000_000)] - - -class GraphNodes3D(Timestamped): - """Visibility-graph node positions for debug visualization.""" - - msg_name = "nav_msgs.GraphNodes3D" - ts: float - frame_id: str - nodes: list[GraphNode] - - def __init__( - self, - ts: float = 0.0, - frame_id: str = "map", - nodes: list[GraphNode] | None = None, - ) -> None: - self.frame_id = frame_id - self.ts = ts if ts != 0 else time.time() - self.nodes = nodes if nodes is not None else [] - - def lcm_encode(self) -> bytes: - lcm_msg = LCMPath() - lcm_msg.poses_length = len(self.nodes) - lcm_msg.poses = [] - - for node in self.nodes: - pose = LCMPoseStamped() - pose.header = LCMHeader() - pose.header.stamp = LCMTime() - [pose.header.stamp.sec, pose.header.stamp.nsec] = _sec_nsec(self.ts) - pose.header.frame_id = self.frame_id - pose.pose = LCMPose() - pose.pose.position = LCMPoint() - pose.pose.position.x = node.x - pose.pose.position.y = node.y - pose.pose.position.z = node.z - pose.pose.orientation = LCMQuaternion() - pose.pose.orientation.w = float(node.node_type) - lcm_msg.poses.append(pose) - - lcm_msg.header = LCMHeader() - lcm_msg.header.stamp = LCMTime() - [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = _sec_nsec(self.ts) - lcm_msg.header.frame_id = self.frame_id - return lcm_msg.lcm_encode() # type: ignore[no-any-return] - - @classmethod - def lcm_decode(cls, data: bytes | BinaryIO) -> GraphNodes3D: - lcm_msg = LCMPath.lcm_decode(data) - header_ts = lcm_msg.header.stamp.sec + lcm_msg.header.stamp.nsec / 1e9 - frame_id = lcm_msg.header.frame_id - - nodes: list[GraphNode] = [] - for pose in lcm_msg.poses: - nodes.append( - GraphNode( - x=pose.pose.position.x, - y=pose.pose.position.y, - z=pose.pose.position.z, - node_type=int(pose.pose.orientation.w), - ) - ) - return cls(ts=header_ts, frame_id=frame_id, nodes=nodes) - - def to_rerun( - self, - z_offset: float = 1.7, - radii: float = 0.12, - ) -> Archetype: - """Render as ``rr.Points3D`` with type-based coloring.""" - import rerun as rr - - if not self.nodes: - return rr.Points3D([]) - - positions = [[n.x, n.y, n.z + z_offset] for n in self.nodes] - colors = [TYPE_COLORS.get(n.node_type, DEFAULT_COLOR) for n in self.nodes] - node_radii = [radii * 2.0 if n.node_type in (1, 2) else radii for n in self.nodes] - - return rr.Points3D(positions, colors=colors, radii=node_radii) - - def __len__(self) -> int: - return len(self.nodes) - - def __str__(self) -> str: - return f"GraphNodes3D(frame_id='{self.frame_id}', nodes={len(self.nodes)})" diff --git a/dimos/msgs/nav_msgs/LineSegments3D.hpp b/dimos/msgs/nav_msgs/LineSegments3D.hpp new file mode 100644 index 0000000000..8748388bdd --- /dev/null +++ b/dimos/msgs/nav_msgs/LineSegments3D.hpp @@ -0,0 +1,95 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Typed C++ helper mirroring the Python `dimos.msgs.nav_msgs.LineSegments3D` +// wrapper. Wire format is `nav_msgs::Path` where consecutive `PoseStamped` +// pairs form line segments; `orientation.w` on the first pose of each +// pair carries the segment's `traversability`. The Python +// `LineSegments3D.lcm_decode` reads exactly this layout — keep the two +// in sync. +// +// This type is for *standalone* line segments (e.g., collision-boundary +// polylines). Graph-structured edges with node-id references live in +// `Graph3D` instead. + +#pragma once + +#include +#include +#include +#include + +#include + +#include "geometry_msgs/PoseStamped.hpp" +#include "nav_msgs/Path.hpp" + +#include "dimos_native_module.hpp" + +namespace dimos { + +class LineSegments3D { +public: + LineSegments3D(std::string frame_id, double ts) + : frame_id_(std::move(frame_id)), ts_(ts) {} + + void reserve(size_t capacity) { segments_.reserve(capacity); } + + void add(float x1, float y1, float z1, + float x2, float y2, float z2, + float traversability = 1.0f) { + segments_.push_back({x1, y1, z1, x2, y2, z2, traversability}); + } + + size_t size() const { return segments_.size(); } + bool empty() const { return segments_.empty(); } + + nav_msgs::Path to_lcm_path() const { + nav_msgs::Path msg; + msg.header = make_header(frame_id_, ts_); + msg.poses_length = static_cast(segments_.size() * 2); + msg.poses.resize(segments_.size() * 2); + for (size_t i = 0; i < segments_.size(); ++i) { + const auto& s = segments_[i]; + auto& p1 = msg.poses[i * 2]; + auto& p2 = msg.poses[i * 2 + 1]; + p1.header = msg.header; + p2.header = msg.header; + p1.pose.position.x = s.x1; + p1.pose.position.y = s.y1; + p1.pose.position.z = s.z1; + p1.pose.orientation.x = 0.0; + p1.pose.orientation.y = 0.0; + p1.pose.orientation.z = 0.0; + // orientation.w on the first endpoint carries traversability + // (see LineSegments3D.py). + p1.pose.orientation.w = s.traversability; + p2.pose.position.x = s.x2; + p2.pose.position.y = s.y2; + p2.pose.position.z = s.z2; + p2.pose.orientation.x = 0.0; + p2.pose.orientation.y = 0.0; + p2.pose.orientation.z = 0.0; + p2.pose.orientation.w = s.traversability; + } + return msg; + } + + int publish(lcm::LCM& lcm, const std::string& channel) const { + nav_msgs::Path msg = to_lcm_path(); + return lcm.publish(channel, &msg); + } + +private: + struct Segment { + float x1, y1, z1; + float x2, y2, z2; + float traversability; + }; + + std::string frame_id_; + double ts_; + std::vector segments_; +}; + +} // namespace dimos diff --git a/dimos/msgs/nav_msgs/LineSegments3D.py b/dimos/msgs/nav_msgs/LineSegments3D.py index 26e0c515ed..313be71978 100644 --- a/dimos/msgs/nav_msgs/LineSegments3D.py +++ b/dimos/msgs/nav_msgs/LineSegments3D.py @@ -31,30 +31,39 @@ from rerun._baseclasses import Archetype +Segment = tuple[tuple[float, float, float], tuple[float, float, float]] + + class LineSegments3D(Timestamped): """Line segments for graph edge visualization. Wire format: ``nav_msgs/Path`` — consecutive pose pairs are segments. ``orientation.w`` encodes traversability: 1.0=traversable, 0.5=partial, 0.0=unreachable. + Each endpoint's ``header.stamp`` is preserved into ``segment_timestamps`` so + consumers can correlate endpoints back to source events (e.g. keyframe + creation time for a pose-graph SLAM producer). """ msg_name = "nav_msgs.LineSegments3D" ts: float frame_id: str - _segments: list[tuple[tuple[float, float, float], tuple[float, float, float]]] - _traversability: list[float] + segments: list[Segment] + traversability: list[float] + segment_timestamps: list[tuple[float, float]] def __init__( self, ts: float = 0.0, frame_id: str = "map", - segments: list[tuple[tuple[float, float, float], tuple[float, float, float]]] | None = None, + segments: list[Segment] | None = None, traversability: list[float] | None = None, + segment_timestamps: list[tuple[float, float]] | None = None, ) -> None: self.frame_id = frame_id self.ts = ts if ts != 0 else time.time() - self._segments = segments or [] - self._traversability = traversability or [1.0] * len(self._segments) + self.segments = segments or [] + self.traversability = traversability or [1.0] * len(self.segments) + self.segment_timestamps = segment_timestamps or [(self.ts, self.ts)] * len(self.segments) def lcm_encode(self) -> bytes: raise NotImplementedError("Encoded on C++ side") @@ -65,8 +74,9 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> LineSegments3D: header_ts = lcm_msg.header.stamp.sec + lcm_msg.header.stamp.nsec / 1e9 frame_id = lcm_msg.header.frame_id - segments = [] - traversability = [] + segments: list[Segment] = [] + traversability: list[float] = [] + segment_timestamps: list[tuple[float, float]] = [] poses = lcm_msg.poses for i in range(0, len(poses) - 1, 2): p1, p2 = poses[i], poses[i + 1] @@ -77,8 +87,15 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> LineSegments3D: ) ) traversability.append(p1.pose.orientation.w) + start_ts = p1.header.stamp.sec + p1.header.stamp.nsec / 1e9 + end_ts = p2.header.stamp.sec + p2.header.stamp.nsec / 1e9 + segment_timestamps.append((start_ts, end_ts)) return cls( - ts=header_ts, frame_id=frame_id, segments=segments, traversability=traversability + ts=header_ts, + frame_id=frame_id, + segments=segments, + traversability=traversability, + segment_timestamps=segment_timestamps, ) def to_rerun( @@ -87,40 +104,33 @@ def to_rerun( color: tuple[int, int, int, int] = (0, 255, 150, 255), radii: float = 0.04, ) -> Archetype: - """Render as ``rr.LineStrips3D`` — color-coded by traversability. - - Green = traversable (reachable from robot), red = non-traversable. - """ + """Render as ``rr.LineStrips3D`` — color-coded by traversability.""" import rerun as rr - if not self._segments: + if not self.segments: return rr.LineStrips3D([]) strips = [] colors = [] - for idx, (p1, p2) in enumerate(self._segments): + for idx, (p1, p2) in enumerate(self.segments): strips.append( [ [p1[0], p1[1], p1[2] + z_offset], [p2[0], p2[1], p2[2] + z_offset], ] ) - trav = self._traversability[idx] if idx < len(self._traversability) else 1.0 + trav = self.traversability[idx] if idx < len(self.traversability) else 1.0 if trav >= 0.9: - colors.append((0, 220, 100, 200)) # green = fully traversable + colors.append((0, 220, 100, 200)) elif trav >= 0.4: - colors.append((255, 180, 0, 200)) # yellow = partially traversable + colors.append((255, 180, 0, 200)) else: - colors.append((255, 50, 50, 150)) # red = non-traversable + colors.append((255, 50, 50, 150)) - return rr.LineStrips3D( - strips, - colors=colors, - radii=[radii] * len(strips), - ) + return rr.LineStrips3D(strips, colors=colors, radii=[radii] * len(strips)) def __len__(self) -> int: - return len(self._segments) + return len(self.segments) def __str__(self) -> str: - return f"LineSegments3D(frame_id='{self.frame_id}', segments={len(self._segments)})" + return f"LineSegments3D(frame_id='{self.frame_id}', segments={len(self.segments)})" diff --git a/dimos/msgs/nav_msgs/test_Graph3D.py b/dimos/msgs/nav_msgs/test_Graph3D.py new file mode 100644 index 0000000000..79de37376c --- /dev/null +++ b/dimos/msgs/nav_msgs/test_Graph3D.py @@ -0,0 +1,128 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Graph3D message type. + +These pin the wire layout (per ``Graph3D.ksy``) so the hand-written +Python encoder/decoder and the matching C++ encoder in +``Graph3D.hpp`` (same directory) don't drift. +""" + +from __future__ import annotations + +import struct + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Graph3D import Graph3D + + +def _make_graph() -> Graph3D: + return Graph3D( + ts=1234.5, + nodes=[ + Graph3D.Node3D( + pose=PoseStamped( + ts=10.5, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + id=100, + metadata_id=1, + ), + Graph3D.Node3D( + pose=PoseStamped( + ts=11.0, + frame_id="odom", + position=Vector3(4.0, 5.0, 6.0), + orientation=Quaternion(0.1, 0.2, 0.3, 0.9273618), + ), + id=200, + metadata_id=2, + ), + ], + edges=[ + Graph3D.Edge(start_id=100, end_id=200, timestamp=10.7, metadata_id=0), + ], + ) + + +def test_round_trip() -> None: + original = _make_graph() + decoded = Graph3D.lcm_decode(original.lcm_encode()) + assert decoded.ts == original.ts + assert len(decoded.nodes) == len(original.nodes) + assert len(decoded.edges) == len(original.edges) + for got, want in zip(decoded.nodes, original.nodes, strict=True): + assert got.id == want.id + assert got.metadata_id == want.metadata_id + assert got.pose.ts == want.pose.ts + assert got.pose.frame_id == want.pose.frame_id + assert got.pose.position.x == want.pose.position.x + assert got.pose.position.y == want.pose.position.y + assert got.pose.position.z == want.pose.position.z + assert got.pose.orientation.w == want.pose.orientation.w + for got, want in zip(decoded.edges, original.edges, strict=True): + assert got.start_id == want.start_id + assert got.end_id == want.end_id + assert got.timestamp == want.timestamp + assert got.metadata_id == want.metadata_id + + +def test_wire_layout_header() -> None: + """Header is `[edge_count u8][node_count u8][timestamp f8]` (big-endian).""" + graph = _make_graph() + encoded = graph.lcm_encode() + edge_count, node_count, timestamp = struct.unpack_from(">QQd", encoded, 0) + assert edge_count == 1 + assert node_count == 2 + assert timestamp == 1234.5 + + +def test_wire_layout_node_starts_with_pose() -> None: + """A node's first bytes are its pose, NOT id — matches Graph3D.ksy spec.""" + graph = _make_graph() + encoded = graph.lcm_encode() + # Header is 24 bytes; node starts at offset 24 with pose.ts (f8). + (pose_ts,) = struct.unpack_from(">d", encoded, 24) + assert pose_ts == 10.5, "first node's bytes must be pose.ts, not id" + # After ts comes a uint32 frame_id_len = 3 (utf-8 "map"). + (frame_id_len,) = struct.unpack_from(">I", encoded, 24 + 8) + assert frame_id_len == 3 + assert encoded[24 + 12 : 24 + 12 + 3] == b"map" + + +def test_empty_graph() -> None: + empty = Graph3D(ts=0.0) + decoded = Graph3D.lcm_decode(empty.lcm_encode()) + assert decoded.nodes == [] + assert decoded.edges == [] + + +def test_edge_references_unknown_node_id_decodes_fine() -> None: + """Decoder shouldn't validate id-references — that's a consumer concern.""" + graph = Graph3D( + ts=1.0, + nodes=[ + Graph3D.Node3D(pose=PoseStamped(ts=0.0, frame_id="map"), id=1, metadata_id=0), + ], + edges=[ + Graph3D.Edge(start_id=1, end_id=999, timestamp=0.5, metadata_id=0), # 999 doesn't exist + ], + ) + decoded = Graph3D.lcm_decode(graph.lcm_encode()) + assert len(decoded.edges) == 1 + assert decoded.edges[0].end_id == 999 diff --git a/dimos/msgs/nav_msgs/test_GraphDelta3D.py b/dimos/msgs/nav_msgs/test_GraphDelta3D.py new file mode 100644 index 0000000000..440bbac74a --- /dev/null +++ b/dimos/msgs/nav_msgs/test_GraphDelta3D.py @@ -0,0 +1,146 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for GraphDelta3D — pins wire layout vs the C++ encoder.""" + +from __future__ import annotations + +import struct + +import pytest + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D + + +def _sample() -> GraphDelta3D: + return GraphDelta3D( + ts=1234.5, + nodes=[ + Graph3D.Node3D( + pose=PoseStamped( + ts=10.5, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + id=100, + metadata_id=1, + ), + Graph3D.Node3D( + pose=PoseStamped( + ts=11.0, + frame_id="odom", + position=Vector3(4.0, 5.0, 6.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + id=200, + metadata_id=0, + ), + ], + transforms=[ + GraphDelta3D.Transform( + translation=Vector3(0.1, 0.2, 0.3), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + GraphDelta3D.Transform( + translation=Vector3(0.4, 0.5, 0.6), + rotation=Quaternion(0.1, 0.2, 0.3, 0.9273618), + ), + ], + ) + + +def test_round_trip() -> None: + original = _sample() + decoded = GraphDelta3D.lcm_decode(original.lcm_encode()) + assert decoded.ts == original.ts + assert len(decoded.nodes) == len(original.nodes) + assert len(decoded.transforms) == len(original.transforms) + for got, want in zip(decoded.nodes, original.nodes, strict=True): + assert got.id == want.id + assert got.metadata_id == want.metadata_id + assert got.pose.frame_id == want.pose.frame_id + assert got.pose.position.x == want.pose.position.x + for got, want in zip(decoded.transforms, original.transforms, strict=True): + assert got.translation.x == want.translation.x + assert got.translation.y == want.translation.y + assert got.translation.z == want.translation.z + assert got.rotation.x == want.rotation.x + assert got.rotation.w == want.rotation.w + + +def test_wire_layout_header() -> None: + """Header is ``[node_count u8][timestamp f8]`` (big-endian).""" + encoded = _sample().lcm_encode() + node_count, timestamp = struct.unpack_from(">Qd", encoded, 0) + assert node_count == 2 + assert timestamp == 1234.5 + + +def test_empty() -> None: + empty = GraphDelta3D(ts=0.0) + decoded = GraphDelta3D.lcm_decode(empty.lcm_encode()) + assert decoded.nodes == [] + assert decoded.transforms == [] + + +def test_misaligned_lengths_rejected() -> None: + """nodes and transforms must be the same length — aligned arrays.""" + with pytest.raises(ValueError, match="aligned arrays"): + GraphDelta3D( + ts=0.0, + nodes=[ + Graph3D.Node3D(pose=PoseStamped(ts=0.0, frame_id="map"), id=1, metadata_id=0), + ], + transforms=[], + ) + + +def test_node_layout_matches_graph3d() -> None: + """A GraphDelta3D node's wire bytes should be identical to a Graph3D node.""" + node = Graph3D.Node3D( + pose=PoseStamped( + ts=42.0, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + id=7, + metadata_id=1, + ) + delta = GraphDelta3D( + ts=0.0, + nodes=[node], + transforms=[ + GraphDelta3D.Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + ], + ) + graph = Graph3D(ts=0.0, nodes=[node]) + + delta_bytes = delta.lcm_encode() + graph_bytes = graph.lcm_encode() + + # The Node3D body inside each is 8 (ts) + 4 (frame_id_len) + 3 (frame_id 'map') + # + 56 (7 doubles for pos/quat) + 8 (id) + 8 (metadata_id) = 87 bytes. + # In GraphDelta3D: header is 16 bytes (u8 node_count + f8 ts). + # In Graph3D: header is 24 bytes (u8 edge_count + u8 node_count + f8 ts). + NODE_BYTES = 87 + assert delta_bytes[16 : 16 + NODE_BYTES] == graph_bytes[24 : 24 + NODE_BYTES] diff --git a/dimos/msgs/nav_msgs/test_dynamic_cloud.py b/dimos/msgs/nav_msgs/test_dynamic_cloud.py new file mode 100644 index 0000000000..959a686b39 --- /dev/null +++ b/dimos/msgs/nav_msgs/test_dynamic_cloud.py @@ -0,0 +1,158 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DynamicCloud. + +The KNOWN_BYTES fixture below is the same fixture the Rust mirror's unit +test asserts against — keep both sides in sync. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud + + +def _make_fixture(): + """A small fixed-content cloud used for cross-language byte-equality.""" + voxels = np.array([[1, -2, 3], [4, 5, -6]], dtype=np.int32) + quantity = np.array([7, 8], dtype=np.uint32) + event_indices = np.array([0, 1, 0], dtype=np.uint32) + event_timestamps = np.array([1_000_000_000, 2_000_000_000, 1_500_000_000], dtype=np.uint64) + return DynamicCloud( + voxels=voxels, + quantity=quantity, + event_indices=event_indices, + event_timestamps=event_timestamps, + voxel_size=0.25, + frame_id="map", + ts=1.5, # 1_500_000_000 ns + ) + + +# Hand-computed expected encoding of _make_fixture(); the Rust unit test +# (dimos/mapping/ray_tracing/rust/src/dynamic_cloud.rs::tests) reproduces +# the exact same bytes. Any drift on either side fails both tests. +KNOWN_BYTES = bytes.fromhex( + "002f685900000000" # ts_ns = 1_500_000_000 LE (0x5968_2F00) + "0000803e" # voxel_size = 0.25 f32 LE + "0300" # frame_id_len = 3 + "6d6170" # frame_id "map" + "02000000" # num_points = 2 + "01000000feffffff03000000" # voxels: (1,-2,3) + "0400000005000000faffffff" # voxels: (4,5,-6) + "0700000008000000" # quantity: 7, 8 + "03000000" # num_events = 3 + "000000000100000000000000" # event_indices: 0, 1, 0 + "00ca9a3b00000000" # event_timestamps[0] = 1_000_000_000 LE (0x3B9A_CA00) + "0094357700000000" # event_timestamps[1] = 2_000_000_000 LE (0x7735_9400) + "002f685900000000" # event_timestamps[2] = 1_500_000_000 LE (0x5968_2F00) +) + + +def test_roundtrip(): + cloud = _make_fixture() + encoded = cloud.lcm_encode() + decoded = DynamicCloud.lcm_decode(encoded) + + assert decoded.frame_id == cloud.frame_id + assert decoded.voxel_size == cloud.voxel_size + assert decoded.ts == cloud.ts + np.testing.assert_array_equal(decoded.voxels, cloud.voxels) + np.testing.assert_array_equal(decoded.quantity, cloud.quantity) + np.testing.assert_array_equal(decoded.event_indices, cloud.event_indices) + np.testing.assert_array_equal(decoded.event_timestamps, cloud.event_timestamps) + + +def test_known_bytes(): + """Pinned wire format; mirrors the Rust unit test fixture exactly.""" + encoded = _make_fixture().lcm_encode() + assert encoded == KNOWN_BYTES, f"encoded:\n{encoded.hex()}\nexpected:\n{KNOWN_BYTES.hex()}" + + +def test_decode_known_bytes(): + decoded = DynamicCloud.lcm_decode(KNOWN_BYTES) + expected = _make_fixture() + assert decoded.frame_id == expected.frame_id + assert decoded.voxel_size == expected.voxel_size + np.testing.assert_array_equal(decoded.voxels, expected.voxels) + np.testing.assert_array_equal(decoded.quantity, expected.quantity) + np.testing.assert_array_equal(decoded.event_indices, expected.event_indices) + np.testing.assert_array_equal(decoded.event_timestamps, expected.event_timestamps) + + +def test_empty_cloud(): + # 0.125 is exactly representable in f32; 0.1 would round-trip with f32 drift. + cloud = DynamicCloud(voxel_size=0.125, frame_id="world", ts=0.0) + encoded = cloud.lcm_encode() + decoded = DynamicCloud.lcm_decode(encoded) + assert len(decoded) == 0 + assert decoded.event_indices.shape[0] == 0 + assert decoded.frame_id == "world" + assert decoded.voxel_size == 0.125 + + +def test_world_positions(): + cloud = DynamicCloud( + voxels=np.array([[2, 0, -1]], dtype=np.int32), + quantity=np.array([1], dtype=np.uint32), + voxel_size=0.5, + ) + world = cloud.world_positions() + np.testing.assert_array_almost_equal(world, [[1.0, 0.0, -0.5]]) + + +def test_per_point_latest_timestamp(): + # event_indices: [0, 1, 0] with timestamps [1, 2, 5] + # point 0 has events at t=1 and t=5 → latest is 5 + # point 1 has one event at t=2 → latest is 2 + # point 2 has no events → 0 + cloud = DynamicCloud( + voxels=np.zeros((3, 3), dtype=np.int32), + quantity=np.zeros(3, dtype=np.uint32), + event_indices=np.array([0, 1, 0], dtype=np.uint32), + event_timestamps=np.array([1, 2, 5], dtype=np.uint64), + ) + latest = cloud.per_point_latest_timestamp() + np.testing.assert_array_equal(latest, [5, 2, 0]) + + +def test_voxels_quantity_length_mismatch_raises(): + with pytest.raises(ValueError, match="voxels/quantity length mismatch"): + DynamicCloud( + voxels=np.zeros((3, 3), dtype=np.int32), + quantity=np.zeros(2, dtype=np.uint32), + ) + + +def test_event_arrays_length_mismatch_raises(): + with pytest.raises(ValueError, match="event_indices/event_timestamps length mismatch"): + DynamicCloud( + voxels=np.zeros((2, 3), dtype=np.int32), + quantity=np.zeros(2, dtype=np.uint32), + event_indices=np.array([0], dtype=np.uint32), + event_timestamps=np.array([1, 2], dtype=np.uint64), + ) + + +def test_event_index_out_of_range_raises(): + with pytest.raises(ValueError, match="event index 5 out of range"): + DynamicCloud( + voxels=np.zeros((2, 3), dtype=np.int32), + quantity=np.zeros(2, dtype=np.uint32), + event_indices=np.array([5], dtype=np.uint32), + event_timestamps=np.array([1], dtype=np.uint64), + ) diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 338d10d9b0..34fbbf98df 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -90,6 +90,7 @@ class WavefrontConfig(ModuleConfig): info_gain_threshold: float = 0.03 num_no_gain_attempts: int = 2 goal_timeout: float = 15.0 + frame_id: str = "world" class WavefrontFrontierExplorer(Module): @@ -751,7 +752,7 @@ def stop_exploration(self) -> bool: goal = PoseStamped( position=self.latest_odometry.position, orientation=self.latest_odometry.orientation, - frame_id="world", + frame_id=self.config.frame_id, ts=self.latest_odometry.ts, ) self.goal_request.publish(goal) @@ -792,7 +793,7 @@ def _exploration_loop(self) -> None: goal_msg.position.y = goal.y goal_msg.position.z = 0.0 goal_msg.orientation.w = 1.0 # No rotation - goal_msg.frame_id = "world" + goal_msg.frame_id = self.config.frame_id goal_msg.ts = self.latest_costmap.ts self.goal_request.publish(goal_msg) diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/README.md b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/README.md new file mode 100644 index 0000000000..8d4f1cbf17 --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/README.md @@ -0,0 +1,114 @@ +# Pose-graph SLAM benchmark on KITTI-360 + +Generic loop-closure benchmark. Drop in any pose-graph SLAM module that +exposes the standard interface and the runner will replay a KITTI-360 +sequence at it, watch its loop-closure output, and score precision / +recall / F1 against KITTI's ground-truth pose trajectory. + +The module under test never sees KITTI — it only sees streams. The +runner provides two helper modules: + +| Module | Role | +|------------------------------|-----------------------------------------------------------| +| `Kitti360PlaybackModule` | Publishes `registered_scan` + `odometry` from disk | +| Your pose-graph SLAM module | Consumes those, publishes `pose_graph` + `loop_correction_delta` | +| `PoseGraphScoringModule` | Subscribes to the outputs, accumulates metrics | + +`autoconnect` wires the three together by stream name. + +## Required interface for the module under test + +```python +class YourPoseGraphModule(Module): + registered_scan: In[PointCloud2] + odometry: In[Odometry] + + pose_graph: Out[Graph3D] # nodes (keyframes) + edges (odom & loop-closure) + loop_correction_delta: Out[NavPath] # one message per loop-closure update +``` + +Edge convention on `pose_graph`: loop-closure edges have +``metadata_id == 1`` (odometry edges use ``0``). Each node carries the +keyframe's *creation* timestamp in ``pose.ts``, and edges reference +nodes by ``id``; the scorer looks up ``edge.start_id`` and +``edge.end_id`` against the node table to recover endpoint frame_ids. + +## Dataset + +Download from (Test SLAM 3D +split is enough). Expected layout: + +``` +/ +├── calibration/ +├── data_3d_raw/ +│ └── 2013_05_28_drive__sync/velodyne_points/data/*.bin +└── data_poses/ + └── 2013_05_28_drive__sync/cam0_to_world.txt +``` + +Sequence IDs map onto the drive numbers: `2 → drive_0002`, `4 → drive_0004`, +`8 → drive_0008`, etc. + +## Quickstart + +```python +from pathlib import Path +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.runner import ( + run_benchmark, +) +from dimos.navigation.nav_stack.modules.pgo.pgo_module import PgoModule # your module + +results = run_benchmark( + module_under_test=PgoModule.blueprint(), + kitti360_root=Path("~/datasets/kitti360").expanduser(), + sequence_id=2, + max_scans=None, # None = full sequence (~3k frames for seq 2) + publish_interval_sec=0.02, +) + +print(results) +# { +# "true_positive": ..., "false_positive": ..., "false_negative": ..., +# "precision": ..., "recall": ..., "f1": ..., +# "detected_loop_edges": ..., "loop_correction_delta_events": ..., +# "wallclock_seconds": ..., "sequence_id": 2, +# } +``` + +## Ground-truth definition + +A loop pair `(i, j)` counts as ground truth if: + +* frame gap `|i − j| ≥ DEFAULT_MIN_FRAME_GAP` (default 50), AND +* lidar-position distance ≤ `DEFAULT_MAX_LOOP_DISTANCE_M` (default 4.0 m). + +Tune via `min_frame_gap=` and `max_loop_distance_m=` on `run_benchmark`. + +A detected edge `(i, j)` is a **true positive** if `j` is in the +ground-truth valid-loop set for `i` (or vice-versa). Otherwise it's a +false positive. Ground-truth queries with no detection in their valid +set become false negatives. + +## Files + +| File | What it does | +|------|--------------| +| `runner.py` | `run_benchmark()` — builds the blueprint, polls playback, returns scores | +| `playback.py` | `Kitti360PlaybackModule` — streams scan + odom messages from disk | +| `scoring.py` | `PoseGraphScoringModule`, `LoopMetrics` — accumulates TP/FP/FN | +| `kitti360_loader.py` | `load_kitti360_sequence()` — reads velodyne `.bin` + `cam0_to_world.txt` | +| `loop_groundtruth.py` | `compute_loop_groundtruth()` + thresholds | + +## Tips + +- Start with `max_scans=200` for a smoke test; you should see playback + hit ~95% and a couple of GT pairs before paying for the full 3000-scan + run (~2.5 min wall on a Mac). +- Recall is bounded by your module's loop-search radius. KITTI ground + truth uses 4 m; if your module searches a 1 m radius, recall floors + near zero by construction even on a perfect descriptor. +- The scorer maps edge endpoints back to frame_ids via timestamps. If + your module rewrites pose timestamps after iSAM2 optimization, keep + the **creation** timestamp on the `PoseStamped` header so the lookup + still works. diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/kitti360_loader.py b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/kitti360_loader.py new file mode 100644 index 0000000000..2dd78ec5f2 --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/kitti360_loader.py @@ -0,0 +1,193 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Read a KITTI-360 sequence from disk. + +Layout (from cvlibs.net/datasets/kitti-360): + + / + data_3d_raw/2013_05_28_drive__sync/velodyne_points/ + data/.bin + timestamps.txt + data_poses/2013_05_28_drive__sync/poses.txt + calibration/calib_cam_to_velo.txt + +poses.txt rows are cam0→world; we left-multiply by inv(cam0→lidar) to get +lidar→world. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import numpy as np + +KITTI360_DRIVE_TEMPLATE = "2013_05_28_drive_{seq:04d}_sync" + + +@dataclass(frozen=True) +class Kitti360Frame: + frame_id: int + timestamp: float + pose_world: np.ndarray + scan_path: Path + + +def _parse_kitti_calib_matrix(path: Path, key_prefix: str = "") -> np.ndarray: + """Parse a KITTI calibration matrix file. + + Two on-disk formats are accepted: a labelled ``: f1 f2 ... f12`` + line (picks the first row matching ``key_prefix``, or the first row + if no prefix) or a bare 16-float matrix block. + """ + text = path.read_text().strip() + if not text: + raise ValueError(f"Empty calibration file: {path}") + + for line in text.splitlines(): + line = line.strip() + if not line: + continue + if ":" in line: + key, _, rest = line.partition(":") + if key_prefix and not key.strip().startswith(key_prefix): + continue + values = [float(token) for token in rest.split()] + else: + values = [float(token) for token in line.split()] + if len(values) == 12: + matrix = np.eye(4, dtype=np.float64) + matrix[:3, :4] = np.array(values, dtype=np.float64).reshape(3, 4) + return matrix + if len(values) == 16: + return np.array(values, dtype=np.float64).reshape(4, 4) + + raise ValueError(f"No 12/16-float row found in {path}") + + +def _load_poses_file(path: Path) -> dict[int, np.ndarray]: + poses: dict[int, np.ndarray] = {} + for line in path.read_text().splitlines(): + tokens = line.split() + if len(tokens) < 13: + continue + frame_id = int(tokens[0]) + values = np.array([float(token) for token in tokens[1:13]], dtype=np.float64) + matrix = np.eye(4, dtype=np.float64) + matrix[:3, :4] = values.reshape(3, 4) + poses[frame_id] = matrix + return poses + + +def _load_timestamps_file(path: Path) -> dict[int, float]: + """Read ``timestamps.txt`` as line_index → seconds-since-first-sample. + + Returned dict is line-keyed (not frame-keyed) so the caller can decide + how to align with the actual frame ids in the split — see the + rekeying in ``load_kitti360_sequence``. + """ + timestamps: dict[int, float] = {} + base: float | None = None + for index, line in enumerate(path.read_text().splitlines()): + line = line.strip() + if not line: + continue + try: + parsed = datetime.fromisoformat(line) + except ValueError: + # Some KITTI-360 files use a space instead of 'T'. + parsed = datetime.fromisoformat(line.replace(" ", "T", 1)) + unix_seconds = parsed.timestamp() + if base is None: + base = unix_seconds + timestamps[index] = unix_seconds - base + return timestamps + + +@dataclass +class Kitti360Sequence: + sequence_id: int + velodyne_dir: Path + timestamps: dict[int, float] + poses_world: dict[int, np.ndarray] + velo_to_cam: np.ndarray + + @property + def frame_ids(self) -> list[int]: + scan_ids = {int(scan.stem) for scan in self.velodyne_dir.glob("*.bin")} + return sorted(scan_ids & self.poses_world.keys()) + + def lidar_pose(self, frame_id: int) -> np.ndarray: + cam0_to_world = self.poses_world[frame_id] + return cam0_to_world @ np.linalg.inv(self.velo_to_cam) + + def scan_xyz(self, frame_id: int) -> np.ndarray: + scan_path = self.velodyne_dir / f"{frame_id:010d}.bin" + data = np.fromfile(str(scan_path), dtype=np.float32) + if data.size % 4 != 0: + raise ValueError(f"Scan {scan_path} has unexpected length {data.size}") + return data.reshape(-1, 4) + + def frames(self, frame_ids: list[int] | None = None) -> Iterator[Kitti360Frame]: + selected = frame_ids if frame_ids is not None else self.frame_ids + for frame_id in selected: + yield Kitti360Frame( + frame_id=frame_id, + timestamp=self.timestamps.get(frame_id, float(frame_id)), + pose_world=self.lidar_pose(frame_id), + scan_path=self.velodyne_dir / f"{frame_id:010d}.bin", + ) + + +def load_kitti360_sequence(root: Path, sequence_id: int) -> Kitti360Sequence: + drive = KITTI360_DRIVE_TEMPLATE.format(seq=sequence_id) + velodyne_dir = root / "data_3d_raw" / drive / "velodyne_points" / "data" + poses_path = root / "data_poses" / drive / "poses.txt" + timestamps_path = root / "data_3d_raw" / drive / "velodyne_points" / "timestamps.txt" + calib_path = root / "calibration" / "calib_cam_to_velo.txt" + + for required in (velodyne_dir, poses_path, calib_path): + if not required.exists(): + raise FileNotFoundError(f"KITTI-360 layout missing under {root}: {required}") + + velo_to_cam = _parse_kitti_calib_matrix(calib_path) + poses_world = _load_poses_file(poses_path) + + # Rekey timestamps by actual frame_id (the on-disk file is line-indexed + # but the Test SLAM split's frame_ids don't start at 0) + timestamps: dict[int, float] = {} + if timestamps_path.exists(): + sorted_scan_ids = sorted(int(scan.stem) for scan in velodyne_dir.glob("*.bin")) + line_indexed = _load_timestamps_file(timestamps_path) + if len(sorted_scan_ids) != len(line_indexed): + raise ValueError( + f"KITTI-360 timestamp count mismatch under {root}: " + f"{len(sorted_scan_ids)} .bin files in {velodyne_dir} but " + f"{len(line_indexed)} lines in {timestamps_path}." + ) + timestamps = { + frame_id: line_indexed[line_index] + for line_index, frame_id in enumerate(sorted_scan_ids) + } + + return Kitti360Sequence( + sequence_id=sequence_id, + velodyne_dir=velodyne_dir, + timestamps=timestamps, + poses_world=poses_world, + velo_to_cam=velo_to_cam, + ) diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/loop_groundtruth.py b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/loop_groundtruth.py new file mode 100644 index 0000000000..0cae8e70ae --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/loop_groundtruth.py @@ -0,0 +1,145 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for helping compute groundtruth loop closures from a trajectory.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +DEFAULT_MIN_FRAME_GAP = 50 +DEFAULT_MAX_LOOP_DISTANCE_M = 4.0 + + +@dataclass +class LoopGroundtruth: + min_frame_gap: int + max_distance_m: float + valid_loops_per_query: dict[int, set[int]] + + @property + def queries_with_loop(self) -> int: + return sum(1 for v in self.valid_loops_per_query.values() if v) + + @property + def total_loop_pairs(self) -> int: + return sum(len(v) for v in self.valid_loops_per_query.values()) + + +def compute_loop_groundtruth( + frame_ids: list[int], + positions_xyz: np.ndarray, + min_frame_gap: int = DEFAULT_MIN_FRAME_GAP, + max_distance_m: float = DEFAULT_MAX_LOOP_DISTANCE_M, +) -> LoopGroundtruth: + """ + Args: + frame_ids: ordered list of frame IDs (e.g. KITTI frame indices). + positions_xyz: (N, 3) world-frame translation of each frame. + min_frame_gap: minimum index distance (in this list) to count. + max_distance_m: spatial radius for a positive loop. + + Returns: + ``LoopGroundtruth`` with ``valid_loops_per_query``: query frame_id + → set of earlier frame_ids that satisfy both thresholds. + """ + if positions_xyz.shape != (len(frame_ids), 3): + raise ValueError( + f"positions_xyz shape {positions_xyz.shape} doesn't match " + f"len(frame_ids)={len(frame_ids)}" + ) + + valid: dict[int, set[int]] = {frame_id: set() for frame_id in frame_ids} + for query_index in range(len(frame_ids)): + if query_index < min_frame_gap: + continue + # Bound the search: any candidate with |query - candidate| >= min_frame_gap. + upper_candidate_index = query_index - min_frame_gap + if upper_candidate_index < 0: + continue + deltas = positions_xyz[: upper_candidate_index + 1] - positions_xyz[query_index] + distances = np.linalg.norm(deltas, axis=1) + matches = np.where(distances <= max_distance_m)[0] + for candidate_index in matches: + valid[frame_ids[query_index]].add(frame_ids[int(candidate_index)]) + + return LoopGroundtruth( + min_frame_gap=min_frame_gap, + max_distance_m=max_distance_m, + valid_loops_per_query=valid, + ) + + +@dataclass +class LoopMetrics: + true_positive: int + false_positive: int + false_negative: int + + @property + def precision(self) -> float: + denom = self.true_positive + self.false_positive + return self.true_positive / denom if denom > 0 else float("nan") + + @property + def recall(self) -> float: + denom = self.true_positive + self.false_negative + return self.true_positive / denom if denom > 0 else float("nan") + + @property + def f1(self) -> float: + precision, recall = self.precision, self.recall + if not (precision > 0 and recall > 0): + return 0.0 + return 2.0 * precision * recall / (precision + recall) + + +def score_detected_loops( + detected_pairs: list[tuple[int, int]], + groundtruth: LoopGroundtruth, +) -> LoopMetrics: + """Score detected (query_id, candidate_id) pairs against groundtruth. + + All three counts are query-level so precision/recall stay + dimensionally consistent. The "query" of a detected pair is the + later frame_id. A query contributes 1 TP if any of its detected + edges matched groundtruth, otherwise 1 FP. Duplicate detections + for the same query collapse. Match is order-agnostic — PGO may + report (target, source) or (source, target). + """ + seen_queries_with_hit: set[int] = set() + seen_queries_without_hit: set[int] = set() + queries_with_any_groundtruth = { + query_frame_id + for query_frame_id, valid in groundtruth.valid_loops_per_query.items() + if valid + } + + for source_frame_id, target_frame_id in detected_pairs: + source_valid = groundtruth.valid_loops_per_query.get(source_frame_id, set()) + target_valid = groundtruth.valid_loops_per_query.get(target_frame_id, set()) + query_frame_id = max(source_frame_id, target_frame_id) + if target_frame_id in source_valid or source_frame_id in target_valid: + seen_queries_with_hit.add(query_frame_id) + else: + seen_queries_without_hit.add(query_frame_id) + seen_queries_without_hit -= seen_queries_with_hit + + return LoopMetrics( + true_positive=len(seen_queries_with_hit), + false_positive=len(seen_queries_without_hit), + false_negative=len(queries_with_any_groundtruth - seen_queries_with_hit), + ) diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/playback.py b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/playback.py new file mode 100644 index 0000000000..62033f0770 --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/playback.py @@ -0,0 +1,170 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module that streams a KITTI-360 sequence as scan + odometry messages. + +Pairs with any module satisfying ``LoopClosure`` +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from pathlib import Path + +import numpy as np +from scipy.spatial.transform import Rotation + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.kitti360_loader import ( + load_kitti360_sequence, +) +from dimos.navigation.nav_stack.tests.rosbag_fixtures import ( + make_odometry_msg, + make_pointcloud_msg, +) + +FIRST_RESPONSE_TIMEOUT_SEC = 120.0 + + +class Kitti360PlaybackConfig(ModuleConfig): + kitti360_root: str + sequence_id: int = 2 + max_scans: int | None = None + publish_interval_sec: float = 0.02 + first_response_timeout_sec: float = FIRST_RESPONSE_TIMEOUT_SEC + + +class Kitti360PlaybackModule(Module): + """Replays a KITTI-360 sequence at a controlled rate.""" + + config: Kitti360PlaybackConfig + + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + corrected_odometry: In[Odometry] + + def __init__(self, **kwargs: object) -> None: + super().__init__(**kwargs) + self._frame_ids: list[int] = [] + self._send_timestamps: list[float] = [] + self._frames_published: int = 0 + self._playback_finished: bool = False + self._playback_error: str | None = None + self._first_response_event: asyncio.Event | None = None + + async def handle_corrected_odometry(self, value: Odometry) -> None: + if self._first_response_event is not None: + self._first_response_event.set() + + async def main(self) -> AsyncIterator[None]: + self._sequence = load_kitti360_sequence( + Path(self.config.kitti360_root), self.config.sequence_id + ) + frame_ids = self._sequence.frame_ids + if self.config.max_scans is not None: + frame_ids = frame_ids[: self.config.max_scans] + self._frame_ids = frame_ids + self._send_timestamps = compute_send_timestamps(self._sequence.timestamps, frame_ids) + # Event lives on self._loop, the same loop _run_playback and + # handle_corrected_odometry run on. + self._first_response_event = asyncio.Event() + self._playback_task = asyncio.create_task(self._run_playback()) + yield + self._playback_task.cancel() + + async def _run_playback(self) -> None: + try: + assert self._first_response_event is not None + for index, frame_id in enumerate(self._frame_ids): + # scan_xyz is a blocking np.fromfile — push it to a thread so + # the event loop (and any concurrent RPC) keeps spinning. + scan_xyz = await asyncio.to_thread(self._sequence.scan_xyz, frame_id) + pose = self._sequence.lidar_pose(frame_id) + position = pose[:3, 3] + quaternion = Rotation.from_matrix(pose[:3, :3]).as_quat() + timestamp = self._send_timestamps[index] + + odometry_message = make_odometry_msg(position, quaternion, ts=timestamp) + world_xyz = (pose[:3, :3] @ scan_xyz[:, :3].T).T + position + cloud_array = np.column_stack([world_xyz, scan_xyz[:, 3:4]]).astype(np.float32) + cloud_message = make_pointcloud_msg(cloud_array, ts=timestamp) + + # Odometry first so the receiver can stash the latest pose + # before the matching scan arrives. + self.odometry.publish(odometry_message) + self.registered_scan.publish(cloud_message) + + self._frames_published = index + 1 + if index == 0: + try: + await asyncio.wait_for( + self._first_response_event.wait(), + timeout=self.config.first_response_timeout_sec, + ) + except asyncio.TimeoutError: + raise RuntimeError( + "No corrected_odometry within " + f"{self.config.first_response_timeout_sec:.1f}s of " + "the first scan — playback aborted" + ) from None + if self.config.publish_interval_sec > 0: + await asyncio.sleep(self.config.publish_interval_sec) + except Exception as exc: + self._playback_error = f"{type(exc).__name__}: {exc}" + raise + finally: + self._playback_finished = True + + @rpc + def frames_published(self) -> int: + return self._frames_published + + @rpc + def is_finished(self) -> bool: + return self._playback_finished + + @rpc + def playback_error(self) -> str | None: + return self._playback_error + + @rpc + def send_timestamps(self) -> list[float]: + return list(self._send_timestamps) + + @rpc + def frame_ids(self) -> list[int]: + return list(self._frame_ids) + + +def compute_send_timestamps( + raw_timestamps: dict[int, float], frame_ids_in_order: list[int] +) -> list[float]: + """Compute strictly-monotonic publish timestamps from raw KITTI ones. + + PGO's Odometry constructor treats ``ts==0`` as "now", so clamp the first + ts away from zero; subsequent values inherit at least a 1 ms floor. + """ + if not frame_ids_in_order: + return [] + first_timestamp = max(raw_timestamps.get(frame_ids_in_order[0], 1.0), 1.0) + send_timestamps: list[float] = [] + for index, frame_id in enumerate(frame_ids_in_order): + raw_timestamp = raw_timestamps.get(frame_id, float(index)) + send_timestamps.append(max(raw_timestamp, first_timestamp + index * 0.001)) + return send_timestamps diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/runner.py b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/runner.py new file mode 100644 index 0000000000..7e1c4ab19c --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/runner.py @@ -0,0 +1,145 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic KITTI-360 loop-closure benchmark for any module satisfying +``LoopClosure`` (see ``dimos/navigation/nav_stack/specs.py``). + +The playback + scoring modules wire into the producer via ``autoconnect``; +the runner doesn't care which implementation it is. +""" + +from __future__ import annotations + +from pathlib import Path +import time +from typing import Any + +import numpy as np + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.kitti360_loader import ( + load_kitti360_sequence, +) +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.loop_groundtruth import ( + DEFAULT_MAX_LOOP_DISTANCE_M, + DEFAULT_MIN_FRAME_GAP, + compute_loop_groundtruth, +) +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.playback import ( + Kitti360PlaybackModule, + compute_send_timestamps, +) +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.scoring import ( + PoseGraphScoringModule, +) +from dimos.navigation.nav_stack.specs import LoopClosure +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def run_benchmark( + module_under_test: type[LoopClosure], + kitti360_root: Path, + module_kwargs: dict[str, Any] | None = None, + sequence_id: int = 2, + max_scans: int | None = None, + publish_interval_sec: float = 0.02, + min_frame_gap: int = DEFAULT_MIN_FRAME_GAP, + max_loop_distance_m: float = DEFAULT_MAX_LOOP_DISTANCE_M, + drain_sec: float = 10.0, + poll_interval_sec: float = 0.5, +) -> dict[str, Any]: + """Run a pose-graph SLAM blueprint against KITTI-360 and return scores. + + ``module_under_test`` is any Blueprint whose module exposes the + pose-graph interface (in: ``registered_scan``, ``odometry``; + out: ``pose_graph``, ``loop_closure_event``). The runner adds a + ``Kitti360PlaybackModule`` (publishes the inputs) and a + ``PoseGraphScoringModule`` (subscribes to the outputs + scores), + then auto-connects everything into one blueprint. + + Returns a dict with TP/FP/FN, precision, recall, F1, and the raw + detected-edge / loop-event counts. + """ + sequence = load_kitti360_sequence(kitti360_root, sequence_id) + frame_ids = sequence.frame_ids + if max_scans is not None: + frame_ids = frame_ids[:max_scans] + positions = np.array([sequence.lidar_pose(frame_id)[:3, 3] for frame_id in frame_ids]) + groundtruth = compute_loop_groundtruth( + frame_ids, + positions, + min_frame_gap=min_frame_gap, + max_distance_m=max_loop_distance_m, + ) + send_timestamps = compute_send_timestamps(sequence.timestamps, frame_ids) + + logger.info( + f"KITTI-360 seq {sequence_id}: {len(frame_ids)} frames, " + f"{groundtruth.queries_with_loop} GT queries with loops, " + f"{groundtruth.total_loop_pairs} GT pairs." + ) + + playback_blueprint = Kitti360PlaybackModule.blueprint( + kitti360_root=str(kitti360_root), + sequence_id=sequence_id, + max_scans=max_scans, + publish_interval_sec=publish_interval_sec, + ) + scoring_blueprint = PoseGraphScoringModule.blueprint( + frame_ids=frame_ids, + send_timestamps=send_timestamps, + valid_loops_per_query={ + frame_id: list(valid) for frame_id, valid in groundtruth.valid_loops_per_query.items() + }, + ) + + sut_blueprint = module_under_test.blueprint(**(module_kwargs or {})) + blueprint = autoconnect(playback_blueprint, scoring_blueprint, sut_blueprint) + + wallclock_start = time.monotonic() + coordinator = ModuleCoordinator.build(blueprint) + try: + playback = coordinator.get_instance(Kitti360PlaybackModule) + scoring = coordinator.get_instance(PoseGraphScoringModule) + + # Wait for the playback module to finish publishing all scans. + while not playback.is_finished(): + published = playback.frames_published() + logger.info( + f" playback {published}/{len(frame_ids)} " + f"({published / max(len(frame_ids), 1) * 100:.0f}%)" + ) + time.sleep(poll_interval_sec) + + playback_error = playback.playback_error() + if playback_error is not None: + raise RuntimeError( + f"Kitti360PlaybackModule aborted at frame " + f"{playback.frames_published()}/{len(frame_ids)}: {playback_error}" + ) + + # Drain remaining loop-closure / edge messages from PGO's backlog. + logger.info(f"playback done, draining for {drain_sec:.1f}s") + time.sleep(drain_sec) + + results: dict[str, Any] = scoring.get_results() + finally: + coordinator.stop() + + results["wallclock_seconds"] = time.monotonic() - wallclock_start + results["sequence_id"] = sequence_id + return results diff --git a/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/scoring.py b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/scoring.py new file mode 100644 index 0000000000..217182b9e4 --- /dev/null +++ b/dimos/navigation/nav_stack/benchmarks/pose_graph_kitti360/scoring.py @@ -0,0 +1,182 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Score a pose-graph SLAM module's loop closures against KITTI groundtruth. + +Subscribes to two outputs that any pose-graph SLAM module exposes: + +* ``pose_graph: In[Graph3D]`` — full pose-graph snapshot. Loop-closure + edges are identified by ``metadata_id == EDGE_LOOP_CLOSURE``; each + node carries the keyframe creation time in ``pose.ts``, which we map + back to the input scan that produced it. +* ``loop_closure_event: In[GraphDelta3D]`` — one event per loop-closure + update, carrying per-keyframe (pre-pose, SE(3) delta) pairs. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import math +from typing import Any + +from pydantic import Field +from reactivex.disposable import Disposable + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D + +# edge-type enum (matches build_pose_graph in pgo/cpp/main.cpp). +EDGE_LOOP_CLOSURE = 1 + + +@dataclass +class LoopMetrics: + true_positive: int + false_positive: int + false_negative: int + + @property + def precision(self) -> float: + denom = self.true_positive + self.false_positive + return self.true_positive / denom if denom > 0 else float("nan") + + @property + def recall(self) -> float: + denom = self.true_positive + self.false_negative + return self.true_positive / denom if denom > 0 else float("nan") + + @property + def f1(self) -> float: + precision, recall = self.precision, self.recall + if not (precision > 0 and recall > 0): + return 0.0 + return 2.0 * precision * recall / (precision + recall) + + +class PoseGraphScoringConfig(ModuleConfig): + frame_ids: list[int] = Field(default_factory=list) + send_timestamps: list[float] = Field(default_factory=list) + valid_loops_per_query: dict[int, list[int]] = Field(default_factory=dict) + + +class PoseGraphScoringModule(Module): + """Accumulates loop-closure detections and scores them against KITTI groundtruth.""" + + config: PoseGraphScoringConfig + + pose_graph: In[Graph3D] + loop_closure_event: In[GraphDelta3D] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._detected_pairs: list[tuple[int, int]] = [] + self._loop_closure_events: int = 0 + self._timestamp_ms_to_frame_id: dict[int, int] = { + round(send_timestamp * 1e3): frame_id + for frame_id, send_timestamp in zip( + self.config.frame_ids, self.config.send_timestamps, strict=True + ) + } + + @rpc + def start(self) -> None: + super().start() + self.register_disposable( + Disposable(self.loop_closure_event.subscribe(self._on_loop_closure_event)) + ) + self.register_disposable(Disposable(self.pose_graph.subscribe(self._on_pose_graph))) + + def _on_loop_closure_event(self, message: GraphDelta3D) -> None: + del message + self._loop_closure_events += 1 + + def _on_pose_graph(self, message: Graph3D) -> None: + id_to_node_ts: dict[int, float] = {n.id: n.pose.ts for n in message.nodes} + for edge in message.edges: + if edge.metadata_id != EDGE_LOOP_CLOSURE: + continue + start_ts = id_to_node_ts.get(edge.start_id) + end_ts = id_to_node_ts.get(edge.end_id) + if start_ts is None or end_ts is None: + continue + start_frame_id = self._timestamp_to_frame(start_ts) + end_frame_id = self._timestamp_to_frame(end_ts) + if start_frame_id is None or end_frame_id is None: + continue + pair = (start_frame_id, end_frame_id) + if pair not in self._detected_pairs: + self._detected_pairs.append(pair) + + def _timestamp_to_frame(self, timestamp_sec: float) -> int | None: + timestamp_ms = round(timestamp_sec * 1e3) + # ±1 ms slop: pose.ts round-trips through (int32 sec, uint32 nsec). + for slop_ms in (0, -1, 1): + frame_id = self._timestamp_ms_to_frame_id.get(timestamp_ms + slop_ms) + if frame_id is not None: + return frame_id + return None + + @rpc + def get_results(self) -> dict[str, Any]: + valid_loops_per_query: dict[int, set[int]] = { + frame_id: set(loops) for frame_id, loops in self.config.valid_loops_per_query.items() + } + metrics = _score_pairs(self._detected_pairs, valid_loops_per_query) + queries_with_loop = sum(1 for valid in valid_loops_per_query.values() if valid) + total_pairs = sum(len(valid) for valid in valid_loops_per_query.values()) + return { + "scans_played": len(self.config.frame_ids), + "groundtruth_queries_with_loop": queries_with_loop, + "groundtruth_total_loop_pairs": total_pairs, + "detected_loop_edges": len(self._detected_pairs), + "loop_closure_events": self._loop_closure_events, + "true_positive": metrics.true_positive, + "false_positive": metrics.false_positive, + "false_negative": metrics.false_negative, + "precision": (metrics.precision if math.isfinite(metrics.precision) else None), + "recall": metrics.recall if math.isfinite(metrics.recall) else None, + "f1": metrics.f1, + } + + +def _score_pairs( + detected_pairs: list[tuple[int, int]], + valid_loops_per_query: dict[int, set[int]], +) -> LoopMetrics: + # A query contributes 1 TP if any of its edges matched groundtruth, + # otherwise 1 FP. Duplicate detections for the same query collapse. + seen_queries_with_hit: set[int] = set() + seen_queries_without_hit: set[int] = set() + queries_with_any_groundtruth = { + frame_id for frame_id, valid in valid_loops_per_query.items() if valid + } + for source_frame_id, target_frame_id in detected_pairs: + source_valid = valid_loops_per_query.get(source_frame_id, set()) + target_valid = valid_loops_per_query.get(target_frame_id, set()) + query_frame_id = max(source_frame_id, target_frame_id) + if target_frame_id in source_valid or source_frame_id in target_valid: + seen_queries_with_hit.add(query_frame_id) + else: + seen_queries_without_hit.add(query_frame_id) + # A query that fires both a TP and a FP edge is counted as TP only + # (one good detection is enough to say LoopClosure recognised the place). + seen_queries_without_hit -= seen_queries_with_hit + return LoopMetrics( + true_positive=len(seen_queries_with_hit), + false_positive=len(seen_queries_without_hit), + false_negative=len(queries_with_any_groundtruth - seen_queries_with_hit), + ) diff --git a/dimos/navigation/nav_stack/frames.py b/dimos/navigation/nav_stack/frames.py deleted file mode 100644 index b8c13f4fb9..0000000000 --- a/dimos/navigation/nav_stack/frames.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# NOTE: this will be deleted shortly - do not rely on - -FRAME_MAP = "map" -FRAME_ODOM = "odom" -FRAME_BODY = "body" -FRAME_SENSOR = "sensor" diff --git a/dimos/navigation/nav_stack/main.py b/dimos/navigation/nav_stack/main.py index 4729233507..7cd8034bf7 100644 --- a/dimos/navigation/nav_stack/main.py +++ b/dimos/navigation/nav_stack/main.py @@ -21,7 +21,8 @@ import numpy as np from dimos.core.coordination.blueprints import Blueprint, autoconnect -from dimos.core.module import ModuleBase +from dimos.mapping.ray_tracing.module import RayTracingVoxelMap +from dimos.navigation.nav_stack.modules.apply_closure.apply_closure import ApplyClosure from dimos.navigation.nav_stack.modules.far_planner.far_planner import FarPlanner from dimos.navigation.nav_stack.modules.local_planner.local_planner import LocalPlanner from dimos.navigation.nav_stack.modules.path_follower.path_follower import PathFollower @@ -31,7 +32,6 @@ from dimos.navigation.nav_stack.modules.terrain_analysis.terrain_analysis import TerrainAnalysis from dimos.navigation.nav_stack.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.spec.utils import Spec from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -41,6 +41,8 @@ def create_nav_stack( *, use_tare: bool = False, use_terrain_map_ext: bool = True, + use_ray_tracing: bool = True, + use_apply_closure: bool = True, planner: str = "far", vehicle_height: float | None = None, max_speed: float | None = None, @@ -48,6 +50,10 @@ def create_nav_stack( terrain_voxel_size: float = 0.2, replan_rate: float = 0.5, record: bool = False, + world_frame: str = "world", + map_frame: str = "map", + start_point_frame: str = "start_point", + current_point_frame: str = "current_point", terrain_analysis: dict[str, Any] | None = None, terrain_map_ext: dict[str, Any] | None = None, local_planner: dict[str, Any] | None = None, @@ -57,19 +63,23 @@ def create_nav_stack( pgo: dict[str, Any] | None = None, tare_planner: dict[str, Any] | None = None, nav_record: dict[str, Any] | None = None, + ray_tracing: dict[str, Any] | None = None, + apply_closure: dict[str, Any] | None = None, ) -> Blueprint: """Compose a nav stack Blueprint. Per-module config dicts (``terrain_analysis``, ``local_planner``, etc.) - override defaults. ``vehicle_height`` and ``max_speed`` propagate to - the relevant modules automatically. + override defaults. ``vehicle_height``, ``max_speed`` and the ``*_frame`` + parameters propagate to the relevant modules automatically. """ far_planner_config = {**(far_planner or {})} far_planner_config.setdefault("is_static_env", False) + far_planner_config.setdefault("frame_id", map_frame) if vehicle_height is not None: far_planner_config.setdefault("vehicle_height", vehicle_height) local_planner_config = {**(local_planner or {})} + local_planner_config.setdefault("body_frame", current_point_frame) path_follower_config = {**(path_follower or {})} simple_planner_config = {**(simple_planner or {})} if waypoint_threshold is not None: @@ -77,8 +87,6 @@ def create_nav_stack( path_follower_config.setdefault("goal_tolerance", waypoint_threshold) simple_planner_config.setdefault("goal_reached_threshold", waypoint_threshold) - pgo_module: Blueprint = PGO.blueprint(**(pgo or {})) - modules: list[Blueprint] = [ TerrainAnalysis.blueprint( **{ @@ -113,6 +121,12 @@ def create_nav_stack( "vehicle_height": 1.5 if vehicle_height is None else vehicle_height, **(terrain_analysis or {}), } + ).remappings( + [ + # Inputs + (TerrainAnalysis, "registered_scan", "lidar"), + (TerrainAnalysis, "odometry", "corrected_odometry"), + ] ), LocalPlanner.blueprint( **{ @@ -127,6 +141,11 @@ def create_nav_stack( "publish_free_paths": False, **local_planner_config, } + ).remappings( + [ + (LocalPlanner, "registered_scan", "lidar"), + (LocalPlanner, "cancel_goal", "stop_movement"), + ] ), PathFollower.blueprint( **{ @@ -140,17 +159,39 @@ def create_nav_stack( "max_acceleration": 2.0, # important for smooth movement **path_follower_config, } + ).remappings( + [ + (PathFollower, "cmd_vel", "nav_cmd_vel"), + ] ), - pgo_module, + PGO.blueprint( + **{ + "parent_frame": world_frame, + "frame_id": map_frame, + "child_frame_id": start_point_frame, + "body_frame": current_point_frame, + **(pgo or {}), + } + ).remappings([(PGO, "registered_scan", "lidar"), (PGO, "global_map", "_pgo_global_map")]), ] if planner == "simple": - merged_simple_planner_config: dict[str, Any] = {"replan_rate": replan_rate} + merged_simple_planner_config: dict[str, Any] = { + "replan_rate": replan_rate, + "frame_id": map_frame, + "body_frame": current_point_frame, + } if vehicle_height is not None: merged_simple_planner_config["ground_offset_below_robot"] = vehicle_height merged_simple_planner_config.update(simple_planner_config) modules.append(SimplePlanner.blueprint(**merged_simple_planner_config)) elif planner == "far": - modules.append(FarPlanner.blueprint(**far_planner_config)) + modules.append( + FarPlanner.blueprint(**far_planner_config).remappings( + [ + (FarPlanner, "odometry", "corrected_odometry"), + ] + ) + ) else: raise Exception(f"invalid planner: {planner}") @@ -158,6 +199,7 @@ def create_nav_stack( modules.append( TerrainMapExt.blueprint( **{ + "frame_id": map_frame, "scan_voxel_size": 0.1, "decay_time": 4.0, "use_sorting": True, @@ -166,29 +208,39 @@ def create_nav_stack( "vehicle_height": 1.5 if vehicle_height is None else vehicle_height, **(terrain_map_ext or {}), } + ).remappings( + [ + (TerrainMapExt, "registered_scan", "lidar"), + (TerrainMapExt, "odometry", "corrected_odometry"), + ] ) ) if use_tare: modules.append(TarePlanner.blueprint(**(tare_planner or {}))) - record_remappings: list[tuple[type[ModuleBase], str, str | type[ModuleBase] | type[Spec]]] = [] + if use_ray_tracing: + modules.append( + RayTracingVoxelMap.blueprint(**(ray_tracing or {})).remappings( + [ + (RayTracingVoxelMap, "odometry", "corrected_odometry"), + ] + ) + ) + if use_apply_closure: + modules.append(ApplyClosure.blueprint(**(apply_closure or {}))) if record: # Lazy: breaks on G1 onboard (linux-aarch64 TLS allocation failure) from dimos.navigation.nav_stack.modules.nav_record.nav_record import NavRecord - modules.append(NavRecord.blueprint(**(nav_record or {}))) - record_remappings.append((NavRecord, "global_map", "global_map_pgo")) - - remappings: list[tuple[type[ModuleBase], str, str | type[ModuleBase] | type[Spec]]] = [ - (PathFollower, "cmd_vel", "nav_cmd_vel"), - (TerrainAnalysis, "odometry", "corrected_odometry"), - (TerrainMapExt, "odometry", "corrected_odometry"), - (PGO, "global_map", "global_map_pgo"), - *record_remappings, - ] - if planner == "far": - remappings.append((FarPlanner, "odometry", "corrected_odometry")) + modules.append( + NavRecord.blueprint( + **{ + "default_frame_id": current_point_frame, + **(nav_record or {}), + } + ).remappings([(NavRecord, "global_map", "_pgo_global_map")]) + ) - return autoconnect(*modules).remappings(remappings) + return autoconnect(*modules) def nav_stack_rerun_config( @@ -218,6 +270,7 @@ def nav_stack_rerun_config( visual_override.setdefault("world/global_map", _global_map_colors) visual_override.setdefault("world/global_map_pgo", _global_map_colors) visual_override.setdefault("world/global_map_fastlio", _global_map_colors) + visual_override.setdefault("world/corrected_global_map", _global_map_colors) visual_override.setdefault( "world/registered_scan", _registered_scan_colors if show_registered_scan else _hide ) @@ -231,16 +284,15 @@ def nav_stack_rerun_config( visual_override.setdefault("world/goal_path", _goal_path_colors_debug) visual_override.setdefault("world/nav_boundary", _nav_boundary_colors_debug) visual_override.setdefault("world/contour_polygons", _contour_polygons_colors_debug) - visual_override.setdefault("world/graph_nodes", _graph_nodes_colors_debug) - visual_override.setdefault("world/graph_edges", _graph_edges_colors_debug) + visual_override.setdefault("world/graph", _graph_colors_debug) + visual_override.setdefault("world/pose_graph", _pose_graph_colors_debug) else: visual_override.setdefault("world/way_point", _waypoint_colors) visual_override.setdefault("world/goal", _goal_colors) visual_override.setdefault("world/goal_path", _goal_path_colors) visual_override.setdefault("world/nav_boundary", _nav_boundary_colors) visual_override.setdefault("world/contour_polygons", _contour_polygons_colors) - visual_override.setdefault("world/graph_nodes", _hide) - visual_override.setdefault("world/graph_edges", _hide) + visual_override.setdefault("world/graph", _hide) visual_override.setdefault("world/obstacle_cloud", _obstacle_cloud_colors) visual_override.setdefault("world/costmap_cloud", _costmap_cloud_colors) visual_override.setdefault("world/free_paths", _free_paths_colors) @@ -285,7 +337,13 @@ def _sensor_scan_colors(cloud: Any) -> Any: def _global_map_colors(cloud: Any) -> Any: import rerun as rr - points, _ = cloud.as_numpy() + # Polymorphic over PointCloud2 (``as_numpy``) and DynamicCloud + # (``world_positions``) so the same coloring rule applies to whichever + # stream is feeding ``world/global_map``. + if hasattr(cloud, "as_numpy"): + points, _ = cloud.as_numpy() + else: + points = cloud.world_positions() if len(points) == 0: return None @@ -553,9 +611,17 @@ def _contour_polygons_colors_debug(polygons: Any) -> Any: ) -def _graph_nodes_colors_debug(graph_nodes: Any) -> Any: - return graph_nodes.to_rerun(z_offset=_AGENTIC_DEBUG_BOUNDARY_LIFT) +def _graph_colors_debug(graph: Any) -> Any: + return graph.to_rerun_multi( + base_path="world/graph", + z_offset=_AGENTIC_DEBUG_BOUNDARY_LIFT, + ) -def _graph_edges_colors_debug(graph_edges: Any) -> Any: - return graph_edges.to_rerun(z_offset=_AGENTIC_DEBUG_BOUNDARY_LIFT) +def _pose_graph_colors_debug(pose_graph: Any) -> Any: + return pose_graph.to_rerun_multi( + base_path="world/pose_graph", + z_offset=_AGENTIC_DEBUG_LIFT, + node_radius=0.15, + edge_radius=0.06, + ) diff --git a/dimos/navigation/nav_stack/modules/apply_closure/apply_closure.py b/dimos/navigation/nav_stack/modules/apply_closure/apply_closure.py new file mode 100644 index 0000000000..93003d8496 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/apply_closure/apply_closure.py @@ -0,0 +1,285 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ApplyClosure: warp a DynamicCloud global map by a pose-graph correction. + +Inputs: +- ``global_map``: the voxel map to warp (DynamicCloud) +- ``loop_closure_event``: a GraphDelta3D published by PGO when iSAM2 smooths + the pose graph. ``nodes[i]`` is the pre-smooth keyframe; ``transforms[i]`` + is the SE(3) delta to apply (left-multiplied: ``post = T_delta @ T_pre``). + +Each voxel is bound to the pose-graph timeline by its latest event timestamp +(``per_point_latest_timestamp``), and its warp is a two-nearest-neighbor LBS +blend: lerp on translation, slerp on rotation between the bracketing nodes' +deltas. + +The effect: voxels with recent event timestamps follow the latest pose +corrections, older voxels barely move — matching the way pose-graph drift +accumulates along a trajectory. + +Voxels without any event (timestamp 0) clip to the earliest node and get the +smallest correction, which is the conservative choice for "unknown age". +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import numpy as np +from reactivex.disposable import Disposable +from scipy.spatial.transform import Rotation, Slerp + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +# Slerp requires strictly increasing input times. If two pose-graph nodes +# share a timestamp (degenerate input), bump later ones by this epsilon so +# the interpolator stays well-defined. +_TIME_DEDUP_EPS = 1e-9 + + +def transform_to_matrix(transform: GraphDelta3D.Transform) -> np.ndarray: + """Pack a ``GraphDelta3D.Transform`` (translation + quaternion) into 4x4.""" + quat = np.array( + [ + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ], + dtype=np.float64, + ) + norm = float(np.linalg.norm(quat)) + if norm == 0.0: + quat = np.array([0.0, 0.0, 0.0, 1.0]) + else: + quat = quat / norm + out = np.eye(4, dtype=np.float64) + out[:3, :3] = Rotation.from_quat(quat).as_matrix() + out[:3, 3] = [ + transform.translation.x, + transform.translation.y, + transform.translation.z, + ] + return out + + +def graph_delta_to_arrays(graph_delta: GraphDelta3D) -> tuple[np.ndarray, np.ndarray]: + """Return (timestamps[N], deltas[N, 4, 4]) extracted from a GraphDelta3D. + + Node timestamps come from each ``node.pose.ts``; deltas come from each + ``transforms[i]`` (treated as a world-frame correction per the + GraphDelta3D ``post = T_delta @ T_pre`` convention). + """ + n = len(graph_delta.nodes) + timestamps = np.empty(n, dtype=np.float64) + deltas = np.empty((n, 4, 4), dtype=np.float64) + for i, (node, transform) in enumerate( + zip(graph_delta.nodes, graph_delta.transforms, strict=True) + ): + timestamps[i] = float(node.pose.ts) + deltas[i] = transform_to_matrix(transform) + return timestamps, deltas + + +def _dedupe_times(times: np.ndarray) -> np.ndarray: + """Bump any duplicate timestamps so the sequence is strictly increasing.""" + out = times.astype(np.float64).copy() + for i in range(1, out.size): + if out[i] <= out[i - 1]: + out[i] = out[i - 1] + _TIME_DEDUP_EPS + return out # type: ignore[no-any-return] + + +def lbs_warp_positions( + positions: np.ndarray, + position_times: np.ndarray, + node_times: np.ndarray, + node_deltas: np.ndarray, +) -> np.ndarray: + """Apply two-nearest-neighbor LBS to ``positions`` (M, 3) using ``node_deltas``. + + For each point, find the two pose-graph nodes whose timestamps bracket + the point's time, slerp the rotations and lerp the translations between + them by the time-fraction, and apply the blended delta. Points outside + the node-time range clip to the nearest endpoint. + + Args: + positions: (M, 3) world-space positions. + position_times: (M,) per-position timestamps (seconds). + node_times: (N,) strictly increasing node timestamps. + node_deltas: (N, 4, 4) correction transforms per node. + + Returns: + (M, 3) warped positions. + """ + if positions.shape[0] == 0: + return positions.astype(np.float64, copy=True) + if node_times.shape[0] == 0: + return positions.astype(np.float64, copy=True) + if node_times.shape[0] == 1: + delta = node_deltas[0] + homog = np.concatenate( + [positions, np.ones((positions.shape[0], 1), dtype=positions.dtype)], axis=1 + ) + return (homog @ delta.T)[:, :3] # type: ignore[no-any-return] + + node_times_safe = _dedupe_times(node_times) + clipped = np.clip(position_times, node_times_safe[0], node_times_safe[-1]) + + node_R = Rotation.from_matrix(node_deltas[:, :3, :3]) + slerp = Slerp(node_times_safe, node_R) + point_R = slerp(clipped) + + node_t = node_deltas[:, :3, 3] + tx = np.interp(clipped, node_times_safe, node_t[:, 0]) + ty = np.interp(clipped, node_times_safe, node_t[:, 1]) + tz = np.interp(clipped, node_times_safe, node_t[:, 2]) + translation = np.stack([tx, ty, tz], axis=1) + + return point_R.apply(positions.astype(np.float64)) + translation # type: ignore[no-any-return] + + +def merge_duplicate_voxels( + voxels: np.ndarray, + quantity: np.ndarray, + event_indices: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Collapse voxels that share an integer position. + + Two original voxels can warp into the same int32 grid cell. Sum their + ``quantity``, and remap ``event_indices`` so events still point at the + surviving merged voxel. + + Returns ``(unique_voxels, merged_quantity, remapped_event_indices)``. + """ + if voxels.shape[0] == 0: + return voxels, quantity, event_indices + unique, inverse = np.unique(voxels, axis=0, return_inverse=True) + merged_q = np.zeros(unique.shape[0], dtype=np.uint64) + np.add.at(merged_q, inverse, quantity.astype(np.uint64)) + merged_q = np.minimum(merged_q, np.iinfo(np.uint32).max).astype(np.uint32) + if event_indices.size == 0: + new_events = event_indices + else: + new_events = inverse[event_indices.astype(np.intp)].astype(np.uint32) + return unique.astype(np.int32), merged_q, new_events + + +def apply_closure_to_cloud( + cloud: DynamicCloud, + graph_delta: GraphDelta3D, +) -> DynamicCloud: + """Warp ``cloud`` by the per-node deltas carried in ``graph_delta``. + + A pass-through if ``graph_delta`` has no nodes. + """ + if len(graph_delta.nodes) == 0: + return cloud + + node_times, deltas = graph_delta_to_arrays(graph_delta) + order = np.argsort(node_times, kind="stable") + node_times = node_times[order] + deltas = deltas[order] + + world = cloud.world_positions().astype(np.float64) + latest_ns = cloud.per_point_latest_timestamp() + point_times = latest_ns.astype(np.float64) / 1_000_000_000.0 + + new_world = lbs_warp_positions(world, point_times, node_times, deltas) + new_voxels = np.rint(new_world / cloud.voxel_size).astype(np.int32) + + voxels, quantity, event_indices = merge_duplicate_voxels( + new_voxels, cloud.quantity, cloud.event_indices + ) + + # event_timestamps is unchanged (the events still refer to the same physical + # observations, just at remapped voxel indices). DynamicCloud copies/normalizes + # the array internally so sharing the reference is safe. + return DynamicCloud( + voxels=voxels, + quantity=quantity, + event_indices=event_indices, + event_timestamps=cloud.event_timestamps, + voxel_size=cloud.voxel_size, + frame_id=cloud.frame_id, + ts=cloud.ts, + ) + + +class ApplyClosureConfig(ModuleConfig): + world_frame: str = "map" + # Log a one-line summary every time a correction is applied. + log_each_apply: bool = True + + +class ApplyClosure(Module): + """Warp the global voxel map by a pose-graph loop-closure correction.""" + + config: ApplyClosureConfig + + global_map: In[DynamicCloud] + loop_closure_event: In[GraphDelta3D] + map_override: Out[DynamicCloud] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lock = threading.Lock() + self._latest_map: DynamicCloud | None = None + + @rpc + def start(self) -> None: + super().start() + self.register_disposable(Disposable(self.global_map.subscribe(self._on_global_map))) + self.register_disposable( + Disposable(self.loop_closure_event.subscribe(self._on_loop_closure)) + ) + logger.info("ApplyClosure started") + + @rpc + def stop(self) -> None: + super().stop() + + def _on_global_map(self, msg: DynamicCloud) -> None: + with self._lock: + self._latest_map = msg + + def _on_loop_closure(self, msg: GraphDelta3D) -> None: + """Loop-closure trigger: apply correction to the latched map.""" + with self._lock: + cloud = self._latest_map + if cloud is None: + return + t0 = time.monotonic() + corrected = apply_closure_to_cloud(cloud, msg) + corrected.ts = time.time() + self.map_override.publish(corrected) + if self.config.log_each_apply: + logger.info( + "ApplyClosure applied", + num_nodes=len(msg.nodes), + num_points_in=len(cloud), + num_points_out=len(corrected), + elapsed_ms=round((time.monotonic() - t0) * 1000.0, 2), + ) diff --git a/dimos/navigation/nav_stack/modules/apply_closure/demo_closure_scene.py b/dimos/navigation/nav_stack/modules/apply_closure/demo_closure_scene.py new file mode 100644 index 0000000000..f3d5211f13 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/apply_closure/demo_closure_scene.py @@ -0,0 +1,481 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end visual stress test for ApplyClosure. + +A robot walks a closed rectangular loop inside a known room. Per-step odometry +adds a systematic yaw bias plus translational noise, so by the time it returns +to the start its drifted trajectory has accumulated several meters of position +error and ~10 degrees of yaw error. At each keyframe a synthetic lidar +"observes" visible ground-truth landmark points; observations are projected +into world space using the *drifted* pose (so the resulting global map is +smeared along the drift). + +The whole sequence runs inside a streaming ``while`` loop so each keyframe's +pose, observations, and accumulating voxels are rr.log'd as they're computed +— you can scrub the rerun timeline to verify the math step by step. + +After the loop, the demo synthesizes a pose-graph correction by linearly +blending each drifted pose toward its known ground-truth pose (translation +lerp + rotation slerp; alpha = i / (N-1)). This mimics the redistribution +that GTSAM iSAM2 would produce when a loop-closing edge nails the endpoint +back to the start. Then ApplyClosure warps the accumulated DynamicCloud and +the corrected map is logged. + +Run: + uv run python -m dimos.navigation.nav_stack.modules.apply_closure.demo_closure_scene --step-ms 200 + +Flags: + --no-spawn Do not auto-launch the rerun viewer. + --step-ms N Sleep N milliseconds between keyframes (default 80). +""" + +from __future__ import annotations + +import argparse +import math +import time + +import numpy as np +import rerun as rr +from scipy.spatial.transform import Rotation, Slerp + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.navigation.nav_stack.modules.apply_closure.apply_closure import ( + apply_closure_to_cloud, +) + +ROOM_SIZE = 20.0 +WALL_HEIGHT = 3.0 +COLUMN_RADIUS = 0.4 +COLUMN_HEIGHT = 3.0 +COLUMN_CENTERS = [ + (5.0, 5.0), + (15.0, 5.0), + (5.0, 15.0), + (15.0, 15.0), +] + + +def sample_wall(p0: tuple[float, float], p1: tuple[float, float]) -> np.ndarray: + n_along = max(2, int(np.linalg.norm(np.array(p1) - np.array(p0)) * 8)) + n_vert = 12 + us = np.linspace(0.0, 1.0, n_along) + vs = np.linspace(0.0, WALL_HEIGHT, n_vert) + p0a = np.array(p0, dtype=np.float64) + p1a = np.array(p1, dtype=np.float64) + seg = p0a[None, :] + (p1a - p0a)[None, :] * us[:, None] + pts = np.zeros((n_along, n_vert, 3)) + pts[..., 0] = seg[:, 0:1] + pts[..., 1] = seg[:, 1:2] + pts[..., 2] = vs[None, :] + return pts.reshape(-1, 3) + + +def sample_cylinder(center: tuple[float, float]) -> np.ndarray: + n_around = 32 + n_vert = 12 + angles = np.linspace(0.0, 2.0 * math.pi, n_around, endpoint=False) + zs = np.linspace(0.0, COLUMN_HEIGHT, n_vert) + xs = center[0] + COLUMN_RADIUS * np.cos(angles) + ys = center[1] + COLUMN_RADIUS * np.sin(angles) + pts = np.zeros((n_around, n_vert, 3)) + pts[..., 0] = xs[:, None] + pts[..., 1] = ys[:, None] + pts[..., 2] = zs[None, :] + return pts.reshape(-1, 3) + + +def build_ground_truth() -> np.ndarray: + walls = np.concatenate( + [ + sample_wall((0.0, 0.0), (ROOM_SIZE, 0.0)), + sample_wall((ROOM_SIZE, 0.0), (ROOM_SIZE, ROOM_SIZE)), + sample_wall((ROOM_SIZE, ROOM_SIZE), (0.0, ROOM_SIZE)), + sample_wall((0.0, ROOM_SIZE), (0.0, 0.0)), + ], + axis=0, + ) + columns = np.concatenate([sample_cylinder(c) for c in COLUMN_CENTERS], axis=0) + return np.concatenate([walls, columns], axis=0) + + +PATH_INSET = 2.0 +KEYFRAMES_PER_SIDE = 12 # 48 keyframes total around the perimeter + + +def build_true_poses() -> tuple[np.ndarray, np.ndarray]: + """Return (times[N], poses[N, 4, 4]) tracing a closed rectangular loop.""" + corners = np.array( + [ + [PATH_INSET, PATH_INSET], + [ROOM_SIZE - PATH_INSET, PATH_INSET], + [ROOM_SIZE - PATH_INSET, ROOM_SIZE - PATH_INSET], + [PATH_INSET, ROOM_SIZE - PATH_INSET], + [PATH_INSET, PATH_INSET], + ], + dtype=np.float64, + ) + + positions: list[np.ndarray] = [] + yaws: list[float] = [] + for i in range(4): + a, b = corners[i], corners[i + 1] + for k in range(KEYFRAMES_PER_SIDE): + t = k / KEYFRAMES_PER_SIDE + positions.append(a + (b - a) * t) + yaws.append(math.atan2(b[1] - a[1], b[0] - a[0])) + n = len(positions) + poses = np.zeros((n, 4, 4)) + times = np.linspace(1.0, 1.0 + n * 0.5, n) # 0.5s per keyframe + for i, (xy, yaw) in enumerate(zip(positions, yaws, strict=True)): + T = np.eye(4) + T[:3, :3] = Rotation.from_euler("z", yaw).as_matrix() + T[:2, 3] = xy + T[2, 3] = 0.5 # robot sensor at 0.5m above the floor + poses[i] = T + return times, poses + + +YAW_BIAS_PER_STEP_DEG = 0.35 +TRANSLATION_NOISE_STD = 0.02 # m per step +LIDAR_MAX_RANGE = 6.0 +VOXEL_SIZE = 0.25 + + +def step_drift( + prev_drifted: np.ndarray, body_step: np.ndarray, rng: np.random.Generator +) -> np.ndarray: + """Compose the prev drifted pose with a noisy version of the true body step.""" + yaw_bias = math.radians(YAW_BIAS_PER_STEP_DEG) + R_bias = Rotation.from_euler("z", yaw_bias).as_matrix() + noisy = body_step.copy() + noisy[:3, :3] = R_bias @ noisy[:3, :3] + noisy[:3, 3] += rng.normal(0.0, TRANSLATION_NOISE_STD, 3) + return prev_drifted @ noisy # type: ignore[no-any-return] + + +def visible_points(true_pose: np.ndarray, gt_points: np.ndarray) -> np.ndarray: + """Return GT points within ``LIDAR_MAX_RANGE`` of the pose origin.""" + origin = true_pose[:3, 3] + d = np.linalg.norm(gt_points - origin, axis=1) + return gt_points[d <= LIDAR_MAX_RANGE] # type: ignore[no-any-return] + + +def apply_delta(delta: np.ndarray, points: np.ndarray) -> np.ndarray: + """Apply a 4x4 transform to (M, 3) points.""" + homog = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) + return (homog @ delta.T)[:, :3] # type: ignore[no-any-return] + + +def synthesize_closure_correction(drifted_poses: np.ndarray, true_poses: np.ndarray) -> np.ndarray: + """Blend drifted toward true linearly along the trajectory. + + Mimics what GTSAM + ICP produces after the closing edge nails the endpoint + back to the start: alpha = i / (N-1), slerp on rotation, lerp on translation. + """ + n = drifted_poses.shape[0] + corrected = np.empty_like(drifted_poses) + alphas = np.linspace(0.0, 1.0, n) + + drifted_R = Rotation.from_matrix(drifted_poses[:, :3, :3]) + true_R = Rotation.from_matrix(true_poses[:, :3, :3]) + drifted_t = drifted_poses[:, :3, 3] + true_t = true_poses[:, :3, 3] + + for i in range(n): + a = alphas[i] + key_R = Rotation.concatenate([drifted_R[i], true_R[i]]) + slerp = Slerp([0.0, 1.0], key_R) + R_blend = slerp([a])[0] + t_blend = (1.0 - a) * drifted_t[i] + a * true_t[i] + T = np.eye(4) + T[:3, :3] = R_blend.as_matrix() + T[:3, 3] = t_blend + corrected[i] = T + return corrected + + +def lerp_pose_arrays(A: np.ndarray, B: np.ndarray, alpha: float) -> np.ndarray: + """Per-node lerp/slerp between two pose arrays at fraction ``alpha``. + + Translations are linearly interpolated; rotations use scipy's Slerp on + each pair of quaternions independently. Used to animate the closure + correction so we can watch the cloud snap from drifted to corrected. + """ + n = A.shape[0] + R_A = Rotation.from_matrix(A[:, :3, :3]) + R_B = Rotation.from_matrix(B[:, :3, :3]) + out = np.empty_like(A) + for i in range(n): + key_R = Rotation.concatenate([R_A[i], R_B[i]]) + slerp = Slerp([0.0, 1.0], key_R) + R_blend = slerp([alpha])[0] + t_blend = (1.0 - alpha) * A[i, :3, 3] + alpha * B[i, :3, 3] + T = np.eye(4) + T[:3, :3] = R_blend.as_matrix() + T[:3, 3] = t_blend + out[i] = T + return out + + +def _matrix_to_translation_quaternion(mat: np.ndarray) -> tuple[Vector3, Quaternion]: + quat = Rotation.from_matrix(mat[:3, :3]).as_quat() # [x, y, z, w] + return ( + Vector3(float(mat[0, 3]), float(mat[1, 3]), float(mat[2, 3])), + Quaternion(float(quat[0]), float(quat[1]), float(quat[2]), float(quat[3])), + ) + + +def make_graph_delta( + times: np.ndarray, prev_poses: np.ndarray, target_poses: np.ndarray +) -> GraphDelta3D: + """Build a GraphDelta3D carrying the per-node correction from prev → target. + + ``nodes[i].pose`` snapshots ``prev_poses[i]``; ``transforms[i]`` is the + world-frame delta s.t. ``transforms[i] @ prev_poses[i] = target_poses[i]``. + This is the message PGO would publish on a real loop-closure event. + """ + nodes: list[Graph3D.Node3D] = [] + transforms: list[GraphDelta3D.Transform] = [] + for i, ts in enumerate(times): + prev_mat = prev_poses[i] + delta_mat = target_poses[i] @ np.linalg.inv(prev_mat) + + prev_t, prev_q = _matrix_to_translation_quaternion(prev_mat) + delta_t, delta_q = _matrix_to_translation_quaternion(delta_mat) + + pose = PoseStamped( + ts=float(ts), + frame_id="map", + position=[prev_t.x, prev_t.y, prev_t.z], + orientation=[prev_q.x, prev_q.y, prev_q.z, prev_q.w], + ) + nodes.append(Graph3D.Node3D(pose=pose, id=i, metadata_id=0)) + transforms.append(GraphDelta3D.Transform(translation=delta_t, rotation=delta_q)) + return GraphDelta3D(ts=float(times[-1]), nodes=nodes, transforms=transforms) + + +def log_pose_arrow(name: str, T: np.ndarray, color: tuple[int, int, int]) -> None: + origin = T[:3, 3] + forward = T[:3, :3] @ np.array([0.6, 0.0, 0.0]) + rr.log(name, rr.Arrows3D(origins=[origin], vectors=[forward], colors=[color])) + + +def log_voxels( + name: str, + cloud: DynamicCloud, + color: tuple[int, int, int], + radii: float | None = None, +) -> None: + pts = cloud.world_positions() + if pts.shape[0] == 0: + return + r = cloud.voxel_size / 2 if radii is None else radii + rr.log(name, rr.Points3D(pts, colors=[color], radii=r)) + + +def mean_nearest_distance(cloud_points: np.ndarray, target_points: np.ndarray) -> float: + """Mean nearest-neighbor distance from cloud_points to target_points.""" + if cloud_points.shape[0] == 0 or target_points.shape[0] == 0: + return float("nan") + chunk = 2048 + total = 0.0 + for i in range(0, cloud_points.shape[0], chunk): + block = cloud_points[i : i + chunk] + d2 = ((block[:, None, :] - target_points[None, :, :]) ** 2).sum(axis=2) + total += float(np.sqrt(d2.min(axis=1)).sum()) + return total / cloud_points.shape[0] # type: ignore[no-any-return] + + +def run_demo(spawn: bool, step_ms: int) -> None: + rr.init("apply_closure_demo", spawn=spawn) + + gt_points = build_ground_truth() + rr.log( + "world/ground_truth", + rr.Points3D(gt_points, colors=[150, 150, 150], radii=0.04), + static=True, + ) + + times, true_poses = build_true_poses() + n = len(times) + rr.log( + "world/trajectory/true", + rr.LineStrips3D([true_poses[:, :3, 3]], colors=[(60, 200, 80)], radii=0.06), + static=True, + ) + + # Streaming state built up step by step inside the while loop below. + rng = np.random.default_rng(7) + drifted_poses = np.empty_like(true_poses) + drifted_poses[0] = true_poses[0] + + voxel_to_idx: dict[tuple[int, int, int], int] = {} + quantity: list[int] = [] + event_idx: list[int] = [] + event_ts: list[int] = [] + accumulating: list[np.ndarray] = [] + + i = 0 + while i < n: + rr.set_time("step", sequence=i) + rr.set_time("sim_time", duration=float(times[i] - times[0])) + + # Drifted pose: identity at i=0, accumulate noisy body steps otherwise. + if i > 0: + body_step = np.linalg.inv(true_poses[i - 1]) @ true_poses[i] + drifted_poses[i] = step_drift(drifted_poses[i - 1], body_step, rng) + + true_T = true_poses[i] + drifted_T = drifted_poses[i] + + # Visible GT points from the TRUE pose (what the robot actually sees); + # project into world using the DRIFTED pose (what the robot thinks); + # voxelize and accumulate. + seen = visible_points(true_T, gt_points) + log_pose_arrow("world/pose/true", true_T, (60, 200, 80)) + log_pose_arrow("world/pose/drifted", drifted_T, (220, 70, 70)) + rr.log( + "world/trajectory/drifted_so_far", + rr.LineStrips3D([drifted_poses[: i + 1, :3, 3]], colors=[(220, 70, 70)], radii=0.06), + ) + + if seen.shape[0] > 0: + delta = drifted_T @ np.linalg.inv(true_T) + observed = apply_delta(delta, seen) + accumulating.append(observed) + + voxels = np.rint(observed / VOXEL_SIZE).astype(np.int32) + ts_ns = int(times[i] * 1_000_000_000) + for v in voxels: + key = (int(v[0]), int(v[1]), int(v[2])) + idx = voxel_to_idx.get(key) + if idx is None: + idx = len(voxel_to_idx) + voxel_to_idx[key] = idx + quantity.append(0) + quantity[idx] += 1 + event_idx.append(idx) + event_ts.append(ts_ns) + + rr.log( + "world/observations/this_frame", + rr.Points3D(observed, colors=[255, 200, 60], radii=0.06), + ) + cumulative = np.concatenate(accumulating, axis=0) + rr.log( + "world/observations/drifted_accum", + rr.Points3D(cumulative, colors=[220, 70, 70], radii=0.05), + ) + + if step_ms > 0: + time.sleep(step_ms / 1000.0) + i += 1 + + # Closure event: synthesize the target correction and apply it. + rr.set_time("step", sequence=n) + rr.set_time("sim_time", duration=float(times[-1] - times[0] + 1.0)) + + # The per-frame yellow points were temporary; clear them so the voxel + # global map is the dominant thing visible after the loop. + rr.log("world/observations/this_frame", rr.Clear(recursive=False)) + + corrected_poses = synthesize_closure_correction(drifted_poses, true_poses) + rr.log( + "world/closure/correction_arrows", + rr.Arrows3D( + origins=drifted_poses[:, :3, 3], + vectors=corrected_poses[:, :3, 3] - drifted_poses[:, :3, 3], + colors=[90, 140, 255], + ), + static=True, + ) + + # Materialize the accumulated DynamicCloud — this is what the running + # system has produced just before the closure event fires. + unique = np.array(sorted(voxel_to_idx, key=lambda k: voxel_to_idx[k]), dtype=np.int32) + drifted_cloud = DynamicCloud( + voxels=unique, + quantity=np.array(quantity, dtype=np.uint32), + event_indices=np.array(event_idx, dtype=np.uint32), + event_timestamps=np.array(event_ts, dtype=np.uint64), + voxel_size=VOXEL_SIZE, + frame_id="map", + ts=float(times[-1]), + ) + # Snapshot the "before" state on the closure step so it's still visible + # if you scrub back here. + log_voxels("world/global_map/drifted", drifted_cloud, (220, 70, 70), radii=0.10) + + # Animate the closure: ramp alpha 0→1 across n_anim frames, applying + # ApplyClosure each frame so the voxel map visibly snaps into place. Each + # frame builds a fresh GraphDelta3D whose transforms[i] is the partial + # correction needed at fraction alpha. + n_anim = 24 + for j in range(n_anim + 1): + alpha = j / n_anim + rr.set_time("step", sequence=n + 1 + j) + rr.set_time("sim_time", duration=float(times[-1] - times[0] + 1.0 + alpha)) + + interp_poses = lerp_pose_arrays(drifted_poses, corrected_poses, alpha) + closure_event = make_graph_delta(times, drifted_poses, interp_poses) + corrected_at_alpha = apply_closure_to_cloud(drifted_cloud, closure_event) + + log_voxels("world/global_map/corrected", corrected_at_alpha, (60, 200, 80), radii=0.10) + rr.log( + "world/trajectory/corrected", + rr.LineStrips3D([interp_poses[:, :3, 3]], colors=[(90, 140, 255)], radii=0.06), + ) + if step_ms > 0: + time.sleep(step_ms / 1000.0) + + # Final corrected cloud is whatever the full correction produces. + final_closure_event = make_graph_delta(times, drifted_poses, corrected_poses) + corrected_cloud = apply_closure_to_cloud(drifted_cloud, final_closure_event) + + err_before = mean_nearest_distance(drifted_cloud.world_positions(), gt_points) + err_after = mean_nearest_distance(corrected_cloud.world_positions(), gt_points) + endpoint_drift = float(np.linalg.norm(drifted_poses[-1, :3, 3] - true_poses[0, :3, 3])) + print( + f"keyframes : {n}\n" + f"endpoint position error : {endpoint_drift:.2f} m\n" + f"drifted cloud → GT mean nn dist: {err_before:.3f} m\n" + f"corrected cloud → GT mean nn dist: {err_after:.3f} m\n" + f"voxels in cloud: {len(drifted_cloud)} drifted, {len(corrected_cloud)} corrected" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--no-spawn", action="store_true", help="Do not auto-launch the rerun viewer." + ) + parser.add_argument( + "--step-ms", + type=int, + default=80, + help="Sleep this many ms between keyframes so you can watch generation live.", + ) + args = parser.parse_args() + run_demo(spawn=not args.no_spawn, step_ms=args.step_ms) + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/nav_stack/modules/apply_closure/test_apply_closure.py b/dimos/navigation/nav_stack/modules/apply_closure/test_apply_closure.py new file mode 100644 index 0000000000..d63502d4a5 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/apply_closure/test_apply_closure.py @@ -0,0 +1,273 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import numpy as np +from scipy.spatial.transform import Rotation + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.navigation.nav_stack.modules.apply_closure.apply_closure import ( + apply_closure_to_cloud, + graph_delta_to_arrays, + lbs_warp_positions, + merge_duplicate_voxels, + transform_to_matrix, +) + + +def _pose(ts, x=0.0, y=0.0, z=0.0, yaw=0.0): + quat = Rotation.from_euler("z", yaw).as_quat() + return PoseStamped( + ts=ts, + frame_id="map", + position=[x, y, z], + orientation=[quat[0], quat[1], quat[2], quat[3]], + ) + + +def _node(ts, x=0.0, y=0.0, z=0.0, yaw=0.0, node_id=0): + return Graph3D.Node3D(pose=_pose(ts, x, y, z, yaw), id=node_id, metadata_id=0) + + +def _transform(tx=0.0, ty=0.0, tz=0.0, yaw=0.0): + quat = Rotation.from_euler("z", yaw).as_quat() + return GraphDelta3D.Transform( + translation=Vector3(tx, ty, tz), + rotation=Quaternion(quat[0], quat[1], quat[2], quat[3]), + ) + + +def _delta(*pairs): + """Build a GraphDelta3D from an iterable of (node, transform) pairs.""" + nodes = [pair[0] for pair in pairs] + transforms = [pair[1] for pair in pairs] + return GraphDelta3D(ts=1.0, nodes=nodes, transforms=transforms) + + +class TestTransformHelpers: + def test_transform_to_matrix_identity(self): + identity = GraphDelta3D.Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + np.testing.assert_allclose(transform_to_matrix(identity), np.eye(4), atol=1e-12) + + def test_transform_to_matrix_translation_only(self): + t = GraphDelta3D.Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + matrix = transform_to_matrix(t) + np.testing.assert_allclose(matrix[:3, :3], np.eye(3)) + np.testing.assert_allclose(matrix[:3, 3], [1.0, 2.0, 3.0]) + + def test_graph_delta_to_arrays(self): + delta = _delta( + (_node(1.0, 2.0, 3.0, 4.0), _transform(0.5, 0.0, 0.0)), + (_node(5.0, 6.0, 7.0, 8.0), _transform(0.0, 0.7, 0.0)), + ) + timestamps, deltas = graph_delta_to_arrays(delta) + np.testing.assert_array_equal(timestamps, [1.0, 5.0]) + assert deltas.shape == (2, 4, 4) + np.testing.assert_allclose(deltas[0, :3, 3], [0.5, 0.0, 0.0]) + np.testing.assert_allclose(deltas[1, :3, 3], [0.0, 0.7, 0.0]) + + +class TestLBSWarp: + def test_empty_positions_returns_empty(self): + out = lbs_warp_positions( + np.zeros((0, 3)), + np.zeros(0), + np.array([0.0, 1.0]), + np.stack([np.eye(4), np.eye(4)]), + ) + assert out.shape == (0, 3) + + def test_no_nodes_passes_through(self): + positions = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + out = lbs_warp_positions(positions, np.array([1.0, 2.0]), np.zeros(0), np.zeros((0, 4, 4))) + np.testing.assert_allclose(out, positions) + + def test_single_node_applies_rigidly(self): + delta = np.eye(4) + delta[:3, 3] = [10.0, 0.0, 0.0] + positions = np.array([[1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]) + out = lbs_warp_positions( + positions, np.array([0.0, 100.0]), np.array([0.0]), delta[None, :, :] + ) + np.testing.assert_allclose(out, positions + np.array([10.0, 0.0, 0.0])) + + def test_before_range_clips_to_first_node(self): + deltas = np.stack([np.eye(4), np.eye(4)]) + deltas[0, :3, 3] = [1.0, 0.0, 0.0] + deltas[1, :3, 3] = [10.0, 0.0, 0.0] + positions = np.array([[0.0, 0.0, 0.0]]) + # point time well before node[0] should snap to delta[0] + out = lbs_warp_positions(positions, np.array([-100.0]), np.array([0.0, 1.0]), deltas) + np.testing.assert_allclose(out, [[1.0, 0.0, 0.0]]) + + def test_after_range_clips_to_last_node(self): + deltas = np.stack([np.eye(4), np.eye(4)]) + deltas[0, :3, 3] = [1.0, 0.0, 0.0] + deltas[1, :3, 3] = [10.0, 0.0, 0.0] + positions = np.array([[0.0, 0.0, 0.0]]) + out = lbs_warp_positions(positions, np.array([1e9]), np.array([0.0, 1.0]), deltas) + np.testing.assert_allclose(out, [[10.0, 0.0, 0.0]]) + + def test_midpoint_translation_lerps(self): + deltas = np.stack([np.eye(4), np.eye(4)]) + deltas[0, :3, 3] = [0.0, 0.0, 0.0] + deltas[1, :3, 3] = [10.0, 0.0, 0.0] + positions = np.array([[0.0, 0.0, 0.0]]) + out = lbs_warp_positions(positions, np.array([0.5]), np.array([0.0, 1.0]), deltas) + np.testing.assert_allclose(out, [[5.0, 0.0, 0.0]]) + + def test_midpoint_rotation_slerps(self): + deltas = np.stack([np.eye(4), np.eye(4)]) + deltas[1, :3, :3] = Rotation.from_euler("z", math.pi / 2).as_matrix() + # A point at (1, 0, 0) rotated by 45deg should land at (cos45, sin45, 0) + out = lbs_warp_positions( + np.array([[1.0, 0.0, 0.0]]), + np.array([0.5]), + np.array([0.0, 1.0]), + deltas, + ) + np.testing.assert_allclose( + out, [[math.cos(math.pi / 4), math.sin(math.pi / 4), 0.0]], atol=1e-9 + ) + + +class TestMergeDuplicates: + def test_no_duplicates_passes_through(self): + voxels = np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int32) + quantity = np.array([2, 3], dtype=np.uint32) + events = np.array([0, 1, 1], dtype=np.uint32) + unique_voxels, merged_quantity, remapped_events = merge_duplicate_voxels( + voxels, quantity, events + ) + # np.unique sorts lexicographically — order may differ but contents must match + assert unique_voxels.shape == (2, 3) + assert int(merged_quantity.sum()) == 5 + # Build old → new index map and verify events were remapped correctly + old_to_new = {} + for old_i, original in enumerate(voxels): + matches = np.where((unique_voxels == original).all(axis=1))[0] + assert matches.size == 1 + old_to_new[old_i] = int(matches[0]) + expected_events = np.array([old_to_new[int(idx)] for idx in events], dtype=np.uint32) + np.testing.assert_array_equal(remapped_events, expected_events) + + def test_collision_sums_quantity(self): + voxels = np.array([[0, 0, 0], [0, 0, 0], [1, 0, 0]], dtype=np.int32) + quantity = np.array([5, 7, 9], dtype=np.uint32) + events = np.array([0, 1, 2], dtype=np.uint32) + unique_voxels, merged_quantity, remapped_events = merge_duplicate_voxels( + voxels, quantity, events + ) + assert unique_voxels.shape == (2, 3) + zero_row = np.where((unique_voxels == [0, 0, 0]).all(axis=1))[0][0] + one_row = np.where((unique_voxels == [1, 0, 0]).all(axis=1))[0][0] + assert int(merged_quantity[zero_row]) == 12 + assert int(merged_quantity[one_row]) == 9 + # The two events that referenced (0,0,0) should now reference zero_row + assert int(remapped_events[0]) == zero_row + assert int(remapped_events[1]) == zero_row + assert int(remapped_events[2]) == one_row + + def test_empty_inputs(self): + unique_voxels, merged_quantity, remapped_events = merge_duplicate_voxels( + np.zeros((0, 3), dtype=np.int32), + np.zeros(0, dtype=np.uint32), + np.zeros(0, dtype=np.uint32), + ) + assert unique_voxels.shape == (0, 3) + assert merged_quantity.shape == (0,) + assert remapped_events.shape == (0,) + + +class TestApplyClosureToCloud: + def test_empty_graph_delta_returns_input(self): + cloud = DynamicCloud( + voxels=np.array([[1, 2, 3]], dtype=np.int32), + quantity=np.array([1], dtype=np.uint32), + voxel_size=0.5, + ) + out = apply_closure_to_cloud(cloud, _delta()) + assert out is cloud + + def test_identity_deltas_preserve_voxels(self): + cloud = DynamicCloud( + voxels=np.array([[1, 0, 0], [-2, 3, 1]], dtype=np.int32), + quantity=np.array([4, 5], dtype=np.uint32), + voxel_size=0.5, + ) + delta = _delta( + (_node(1.0), _transform()), + (_node(10.0), _transform()), + ) + out = apply_closure_to_cloud(cloud, delta) + # Sort both for comparison since merge_duplicates may reorder + np.testing.assert_array_equal(np.sort(out.voxels, axis=0), np.sort(cloud.voxels, axis=0)) + assert int(out.quantity.sum()) == int(cloud.quantity.sum()) + + def test_uniform_translation_shifts_whole_cloud(self): + """All nodes carry the same translation delta → entire cloud shifts by it.""" + cloud = DynamicCloud( + voxels=np.array([[2, 0, 0], [4, 0, 0]], dtype=np.int32), + quantity=np.array([1, 1], dtype=np.uint32), + voxel_size=0.5, + ) + delta = _delta( + (_node(1.0), _transform(tx=1.0)), + (_node(2.0), _transform(tx=1.0)), + ) + out = apply_closure_to_cloud(cloud, delta) + # World positions: (1.0, 0, 0) and (2.0, 0, 0). +1m → (2,0,0), (3,0,0). + # voxel_size = 0.5, so voxels should be (4, 0, 0) and (6, 0, 0). + sorted_out = np.sort(out.voxels, axis=0) + np.testing.assert_array_equal(sorted_out, np.array([[4, 0, 0], [6, 0, 0]])) + + def test_recent_voxel_follows_latest_node_correction(self): + """A voxel with a recent event timestamp warps by the late-node delta. + + Older voxel (no event, effective ts=0) clips to the early node which has + zero correction; recent voxel clips to the late node which has a +5m shift. + + Note: PoseStamped maps ts=0 to time.time() as a "missing" sentinel, so + we use ts >= 1.0 throughout to keep the pose-graph timeline well-defined. + """ + cloud = DynamicCloud( + voxels=np.array([[0, 0, 0], [10, 0, 0]], dtype=np.int32), + quantity=np.array([1, 1], dtype=np.uint32), + event_indices=np.array([1], dtype=np.uint32), + event_timestamps=np.array([100 * 1_000_000_000], dtype=np.uint64), + voxel_size=1.0, + ) + delta = _delta( + (_node(1.0), _transform()), + (_node(100.0), _transform(tx=5.0)), + ) + out = apply_closure_to_cloud(cloud, delta) + # Voxel 0 (no event → ts=0 → clipped to first node, ts=1 → identity delta): (0,0,0) + # Voxel 1 (event_ts=100s → clipped to last node → +5m): (10,0,0) → (15,0,0) + sorted_out = np.sort(out.voxels, axis=0) + np.testing.assert_array_equal(sorted_out, np.array([[0, 0, 0], [15, 0, 0]])) diff --git a/dimos/navigation/nav_stack/modules/far_planner/far_planner.py b/dimos/navigation/nav_stack/modules/far_planner/far_planner.py index e88cdb707c..348c362152 100644 --- a/dimos/navigation/nav_stack/modules/far_planner/far_planner.py +++ b/dimos/navigation/nav_stack/modules/far_planner/far_planner.py @@ -23,7 +23,7 @@ from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs.PointStamped import PointStamped from dimos.msgs.nav_msgs.ContourPolygons3D import ContourPolygons3D -from dimos.msgs.nav_msgs.GraphNodes3D import GraphNodes3D +from dimos.msgs.nav_msgs.Graph3D import Graph3D from dimos.msgs.nav_msgs.LineSegments3D import LineSegments3D from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.nav_msgs.Path import Path as NavPath @@ -34,12 +34,14 @@ class FarPlannerConfig(NativeModuleConfig): cwd: str | None = str(Path(__file__).resolve().parent) executable: str = "result/bin/far_planner_native" build_command: str | None = ( - "nix build github:dimensionalOS/dimos-module-far-planner/v0.5.0 --no-write-lock-file" + "nix build github:dimensionalOS/dimos-module-far-planner/v0.7.0 --no-write-lock-file" ) - # C++ binary uses snake_case CLI args. + # C++ binary uses snake_case CLI args. frame_id -> --world_frame maps + # the new Python builtin name back to the legacy C++ arg. cli_name_override: dict[str, str] = { "robot_dimension": "robot_dim", + "frame_id": "world_frame", } update_rate: float = 5.0 @@ -54,7 +56,7 @@ class FarPlannerConfig(NativeModuleConfig): is_multi_layer: bool = False is_debug_output: bool = False is_attempt_autoswitch: bool = True - world_frame: str = "map" + frame_id: str = "map" converge_dist: float = 1.5 goal_adjust_radius: float = 10.0 @@ -107,7 +109,6 @@ def stop(self) -> None: stop_movement: In[Bool] way_point: Out[PointStamped] goal_path: Out[NavPath] - graph_nodes: Out[GraphNodes3D] - graph_edges: Out[LineSegments3D] + graph: Out[Graph3D] contour_polygons: Out[ContourPolygons3D] nav_boundary: Out[LineSegments3D] diff --git a/dimos/navigation/nav_stack/modules/far_planner/test_far_planner_rosbag.py b/dimos/navigation/nav_stack/modules/far_planner/test_far_planner_rosbag.py index 7061f59523..ed4eb2c9dd 100644 --- a/dimos/navigation/nav_stack/modules/far_planner/test_far_planner_rosbag.py +++ b/dimos/navigation/nav_stack/modules/far_planner/test_far_planner_rosbag.py @@ -55,8 +55,7 @@ STOP_LCM = "/rbfp_stop#std_msgs.Bool" WAYPOINT_OUT_LCM = "/rbfp_wp#geometry_msgs.PointStamped" GOAL_PATH_LCM = "/rbfp_gp#nav_msgs.Path" -GRAPH_NODES_LCM = "/rbfp_gn#nav_msgs.GraphNodes3D" -GRAPH_EDGES_LCM = "/rbfp_ge#nav_msgs.LineSegments3D" +GRAPH_LCM = "/rbfp_g#nav_msgs.Graph3D" CONTOUR_LCM = "/rbfp_cp#nav_msgs.ContourPolygons3D" NAV_BOUNDARY_LCM = "/rbfp_nb#nav_msgs.LineSegments3D" @@ -79,10 +78,8 @@ def _far_planner_args() -> list[str]: WAYPOINT_OUT_LCM, "--goal_path", GOAL_PATH_LCM, - "--graph_nodes", - GRAPH_NODES_LCM, - "--graph_edges", - GRAPH_EDGES_LCM, + "--graph", + GRAPH_LCM, "--contour_polygons", CONTOUR_LCM, "--nav_boundary", @@ -217,7 +214,7 @@ def test_waypoint_accuracy(self) -> None: assert len(ref_wp) > 0, "No reference waypoints in fixture" lcm = lcmlib.LCM() - wp_collector = LcmCollector(topic=WAYPOINT_OUT_LCM, msg_type=PointStamped) + wp_collector = LcmCollector(topic=WAYPOINT_OUT_LCM, message_type=PointStamped) wp_collector.start(lcm) stop_event = threading.Event() diff --git a/dimos/navigation/nav_stack/modules/local_planner/local_planner.py b/dimos/navigation/nav_stack/modules/local_planner/local_planner.py index 217ca6c15a..d303977495 100644 --- a/dimos/navigation/nav_stack/modules/local_planner/local_planner.py +++ b/dimos/navigation/nav_stack/modules/local_planner/local_planner.py @@ -19,7 +19,10 @@ from pathlib import Path from dimos_lcm.geometry_msgs import PolygonStamped -from dimos_lcm.std_msgs import Float32 +from dimos_lcm.std_msgs import ( + Bool, # type: ignore[import-untyped] + Float32, +) from dimos.core.core import rpc from dimos.core.native_module import NativeModule, NativeModuleConfig @@ -30,7 +33,6 @@ from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.nav_msgs.Path import Path as NavPath from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.msgs.std_msgs.Bool import Bool from dimos.msgs.std_msgs.Int8 import Int8 @@ -38,9 +40,9 @@ class LocalPlannerConfig(NativeModuleConfig): cwd: str | None = str(Path(__file__).resolve().parent) executable: str = "result/bin/local_planner" build_command: str | None = ( - "nix build github:dimensionalOS/dimos-module-local-planner/v0.6.0 --no-write-lock-file" + "nix build github:dimensionalOS/dimos-module-local-planner/feat/configurable-body-frame" + " --no-write-lock-file" ) - # C++ binary uses camelCase CLI args. cli_name_override: dict[str, str] = { "max_speed": "maxSpeed", @@ -95,6 +97,8 @@ class LocalPlannerConfig(NativeModuleConfig): paths_dir: str = "" + body_frame: str = "current_point" + vehicle_length: float = 0.5 # m vehicle_width: float = 0.5 # m sensor_offset_x: float | None = None # m diff --git a/dimos/navigation/nav_stack/modules/local_planner/test_local_planner_rosbag.py b/dimos/navigation/nav_stack/modules/local_planner/test_local_planner_rosbag.py index 5bee984088..e6c901e5e3 100644 --- a/dimos/navigation/nav_stack/modules/local_planner/test_local_planner_rosbag.py +++ b/dimos/navigation/nav_stack/modules/local_planner/test_local_planner_rosbag.py @@ -236,7 +236,7 @@ def test_path_accuracy(self) -> None: assert len(ref_paths) > 0, "No reference path data in fixture" lcm = lcmlib.LCM() - path_collector = LcmCollector(topic=PATH_LCM, msg_type=NavPath) + path_collector = LcmCollector(topic=PATH_LCM, message_type=NavPath) path_collector.start(lcm) stop_event = threading.Event() diff --git a/dimos/navigation/nav_stack/modules/nav_record/nav_record.py b/dimos/navigation/nav_stack/modules/nav_record/nav_record.py index e7a6db4870..4da8227af5 100644 --- a/dimos/navigation/nav_stack/modules/nav_record/nav_record.py +++ b/dimos/navigation/nav_stack/modules/nav_record/nav_record.py @@ -23,6 +23,10 @@ from dimos.memory2.module import Recorder, RecorderConfig from dimos.msgs.geometry_msgs.PointStamped import PointStamped from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.ContourPolygons3D import ContourPolygons3D +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.msgs.nav_msgs.LineSegments3D import LineSegments3D from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.nav_msgs.Path import Path as NavPath from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 @@ -32,6 +36,10 @@ class NavRecordConfig(RecorderConfig): db_path: str = "nav_recording.db" + # Robot body frame, for unstamped messages. + default_frame_id: str = "current_point" + # Generous so PGO iSAM2 stalls (~500ms) don't cause lookup misses. + tf_tolerance: float = 3.0 class NavRecord(Recorder): @@ -47,23 +55,46 @@ def start(self) -> None: def stop(self) -> None: super().stop() - # Core nav outputs + # MovementManager outputs (muxed nav + teleop) cmd_vel: In[Twist] - corrected_odometry: In[Odometry] - path: In[NavPath] - goal_path: In[NavPath] - way_point: In[PointStamped] goal: In[PointStamped] stop_movement: In[LcmBool] - # LocalPlanner details + # PathFollower output (raw nav cmd before MovementManager mux; remapped from "cmd_vel") + nav_cmd_vel: In[Twist] + + # LocalPlanner outputs + path: In[NavPath] effective_cmd_vel: In[Twist] + free_paths: In[PointCloud2] slow_down: In[Int8] goal_reached: In[Bool] - # Point clouds + # SimplePlanner / FarPlanner / TarePlanner outputs + way_point: In[PointStamped] + goal_path: In[NavPath] + costmap_cloud: In[PointCloud2] # SimplePlanner only + # FarPlanner-specific + graph: In[Graph3D] + contour_polygons: In[ContourPolygons3D] + nav_boundary: In[LineSegments3D] + + # TerrainAnalysis / TerrainMapExt outputs terrain_map: In[PointCloud2] + terrain_map_ext: In[PointCloud2] + + # PGO outputs + corrected_odometry: In[Odometry] global_map: In[PointCloud2] + pose_graph: In[Graph3D] + loop_closure_event: In[GraphDelta3D] + # FastLio2 outputs (SLAM source; blueprints typically remap FastLio2's + # "lidar" -> "registered_scan" and "global_map" -> "global_map_fastlio") odometry: In[Odometry] registered_scan: In[PointCloud2] + global_map_fastlio: In[PointCloud2] + + # External inputs to the nav stack (recorded for context) + clicked_point: In[PointStamped] # from rerun click-to-drive + tele_cmd_vel: In[Twist] # from keyboard / quest / phone teleop diff --git a/dimos/navigation/nav_stack/modules/path_follower/path_follower.py b/dimos/navigation/nav_stack/modules/path_follower/path_follower.py index 9179ee4144..09367f965c 100644 --- a/dimos/navigation/nav_stack/modules/path_follower/path_follower.py +++ b/dimos/navigation/nav_stack/modules/path_follower/path_follower.py @@ -34,9 +34,9 @@ class PathFollowerConfig(NativeModuleConfig): cwd: str | None = str(Path(__file__).resolve().parent) executable: str = "result/bin/path_follower" build_command: str | None = ( - "nix build github:dimensionalOS/dimos-module-path-follower/v0.2.0 --no-write-lock-file" + "nix build github:dimensionalOS/dimos-module-path-follower/feat/dimos-native-ready" + " --no-write-lock-file" ) - cli_name_override: dict[str, str] = { "look_ahead_distance": "lookAheadDis", "max_speed": "maxSpeed", diff --git a/dimos/navigation/nav_stack/modules/path_follower/test_path_follower_rosbag.py b/dimos/navigation/nav_stack/modules/path_follower/test_path_follower_rosbag.py index 0a7893aad2..ddb6b206b5 100644 --- a/dimos/navigation/nav_stack/modules/path_follower/test_path_follower_rosbag.py +++ b/dimos/navigation/nav_stack/modules/path_follower/test_path_follower_rosbag.py @@ -116,7 +116,7 @@ def test_cmd_vel_accuracy(self) -> None: assert len(ref_cmd) > 0, "No reference cmd_vel in fixture" lcm = lcmlib.LCM() - cmd_collector = LcmCollector(topic=CMD_VEL_LCM, msg_type=Twist) + cmd_collector = LcmCollector(topic=CMD_VEL_LCM, message_type=Twist) cmd_collector.start(lcm) stop_event = threading.Event() diff --git a/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360.py b/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360.py new file mode 100644 index 0000000000..11b9e33e1e --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360.py @@ -0,0 +1,70 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""evaluate PGO against KITTI-360 + +Usage: + uv run python -m dimos.navigation.nav_stack.modules.pgo.run_kitti360 \\ + --kitti360-root ~/datasets/kitti360 --sequence 2 +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.runner import ( + run_benchmark, +) +from dimos.navigation.nav_stack.modules.pgo.pgo import PGO + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run the generic pose-graph KITTI-360 benchmark against PGO" + ) + parser.add_argument("--kitti360-root", type=Path, required=True) + parser.add_argument("--sequence", type=int, default=2) + parser.add_argument("--max-scans", type=int, default=None) + parser.add_argument("--scan-context-match-threshold", type=float, default=0.4) + parser.add_argument("--loop-score-thresh", type=float, default=0.5) + parser.add_argument("--loop-search-radius-m", type=float, default=1.0) + parser.add_argument("--key-pose-delta-trans", type=float, default=0.5) + parser.add_argument("--publish-interval-sec", type=float, default=0.02) + parser.add_argument("--output-json", type=Path, default=None) + args = parser.parse_args() + + results = run_benchmark( + module_under_test=PGO, + module_kwargs={ + "scan_context_match_threshold": args.scan_context_match_threshold, + "loop_score_thresh": args.loop_score_thresh, + "loop_search_radius": args.loop_search_radius_m, + "key_pose_delta_trans": args.key_pose_delta_trans, + }, + kitti360_root=args.kitti360_root, + sequence_id=args.sequence, + max_scans=args.max_scans, + publish_interval_sec=args.publish_interval_sec, + ) + + print(json.dumps(results, indent=2)) + if args.output_json is not None: + args.output_json.parent.mkdir(parents=True, exist_ok=True) + args.output_json.write_text(json.dumps(results, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360_smoke.py b/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360_smoke.py new file mode 100644 index 0000000000..befc90bb42 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/benchmark_kitti360_smoke.py @@ -0,0 +1,144 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spins up a blueprint with PGO, the KITTI-360 playback module, and a +TopicCounter module that subscribes to every PGO output. Reports per-topic +message counts and a one-line verdict. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import time +from typing import Any + +from reactivex.disposable import Disposable + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.playback import ( + Kitti360PlaybackModule, +) +from dimos.navigation.nav_stack.modules.pgo.pgo import PGO + + +class TopicCounterModule(Module): + """Subscribes to every PGO output stream and counts arrivals per topic.""" + + corrected_odometry: In[Odometry] + global_map: In[PointCloud2] + pose_graph: In[Graph3D] + loop_closure_event: In[GraphDelta3D] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._counts: dict[str, int] = { + "corrected_odometry": 0, + "global_map": 0, + "pose_graph": 0, + "loop_closure_event": 0, + } + + @rpc + def start(self) -> None: + super().start() + for stream_name in self._counts: + stream = getattr(self, stream_name) + self.register_disposable(Disposable(stream.subscribe(self._make_counter(stream_name)))) + + def _make_counter(self, name: str) -> Any: + def _on_message(_message: Any) -> None: + self._counts[name] += 1 + + return _on_message + + @rpc + def counts(self) -> dict[str, int]: + return dict(self._counts) + + +def main() -> None: + parser = argparse.ArgumentParser(description="PGO liveness probe via DimOS modules") + parser.add_argument("--kitti360-root", type=Path, required=True) + parser.add_argument("--sequence", type=int, default=2) + parser.add_argument("--num-scans", type=int, default=200) + parser.add_argument( + "--loop-search-radius-m", + type=float, + default=4.0, + help="m; default 4.0 matches groundtruth radius", + ) + parser.add_argument("--publish-interval-sec", type=float, default=0.02) + parser.add_argument("--drain-sec", type=float, default=5.0) + parser.add_argument("--poll-interval-sec", type=float, default=0.5) + args = parser.parse_args() + + playback_blueprint = Kitti360PlaybackModule.blueprint( + kitti360_root=str(args.kitti360_root), + sequence_id=args.sequence, + max_scans=args.num_scans, + publish_interval_sec=args.publish_interval_sec, + ) + pgo_blueprint = PGO.blueprint( + loop_search_radius=args.loop_search_radius_m, + ) + counter_blueprint = TopicCounterModule.blueprint() + + blueprint = autoconnect(playback_blueprint, pgo_blueprint, counter_blueprint) + coordinator = ModuleCoordinator.build(blueprint) + try: + playback = coordinator.get_instance(Kitti360PlaybackModule) + counter = coordinator.get_instance(TopicCounterModule) + while not playback.is_finished(): + time.sleep(args.poll_interval_sec) + playback_error = playback.playback_error() + if playback_error is not None: + raise RuntimeError(f"Kitti360 playback aborted: {playback_error}") + time.sleep(args.drain_sec) + counts = counter.counts() + finally: + coordinator.stop() + + print("\n=== PGO topic message counts ===") + for name in ( + "corrected_odometry", + "global_map", + "pose_graph", + "loop_closure_event", + ): + print(f" {name:<24} {counts.get(name, 0):>6}") + + print("\nverdict:") + if counts.get("pose_graph", 0) == 0: + print(" ⚠ no pose graph — PGO never promoted a keyframe. Check --key_pose_delta_*.") + elif counts.get("loop_closure_event", 0) == 0: + print( + " ⚠ graph builds, no loop closure events — try wider --loop-search-radius " + "or lower --scan-context-match-threshold." + ) + else: + print(" ✓ all topics firing — PGO is alive end-to-end.") + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/nav_stack/modules/pgo/benchmark_place_recognition.py b/dimos/navigation/nav_stack/modules/pgo/benchmark_place_recognition.py new file mode 100644 index 0000000000..7ad5dac925 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/benchmark_place_recognition.py @@ -0,0 +1,304 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""How well can Scan Context tell "I've been here before" on KITTI-360? + +Score: Average Precision (AP): a single 0-1 number, Higher = better +Replays a real driven trajectory, and for every frame asks the +descriptor: of all the places I saw a while ago, which one looks most +like where I am now? If that "most similar" old place is actually +within a few metres of where I am, that's a correct revisit detection. + +The published Scan Context paper gets 0.65-0.78 on this sequence, so that's the bar. + +Usage: + uv run python -m dimos.navigation.nav_stack.modules.pgo.benchmark_place_recognition \\ + --kitti360-root ~/datasets/kitti360 --sequence 2 +""" + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from pathlib import Path +import time + +import numpy as np +from scipy.spatial import cKDTree +from sklearn.metrics import average_precision_score # type: ignore[import-untyped] + +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.kitti360_loader import ( + load_kitti360_sequence, +) +from dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.loop_groundtruth import ( + DEFAULT_MAX_LOOP_DISTANCE_M, + DEFAULT_MIN_FRAME_GAP, + compute_loop_groundtruth, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class ScanContextConfig: + """Mirror of cpp/scan_context.h scan_context::Config defaults.""" + + num_rings: int = 20 + num_sectors: int = 60 + max_range_m: float = 80.0 + lidar_height_m: float = 2.0 + + +def make_descriptor(points_body: np.ndarray, config: ScanContextConfig) -> np.ndarray: + """Polar max-z descriptor — matches cpp/scan_context.cpp::make_descriptor. + + ``points_body``: (N, 3+) body-frame point cloud. + Returns: (num_rings, num_sectors) float32 with cell value = max(z + lidar_height, 0) + for points falling in that (range, azimuth) bin. + """ + descriptor = np.zeros((config.num_rings, config.num_sectors), dtype=np.float32) + if len(points_body) == 0: + return descriptor + + x = points_body[:, 0] + y = points_body[:, 1] + z = points_body[:, 2] + + range_xy = np.sqrt(x * x + y * y) + valid = (range_xy < config.max_range_m) & (range_xy > 1e-6) + if not valid.any(): + return descriptor + + range_valid = range_xy[valid] + azimuth = np.arctan2(y[valid], x[valid]) + azimuth = np.where(azimuth < 0, azimuth + 2 * np.pi, azimuth) + z_shifted = np.maximum(z[valid] + config.lidar_height_m, 0.0) + + ring_step = config.max_range_m / config.num_rings + sector_step = 2 * np.pi / config.num_sectors + rings = np.clip(np.floor(range_valid / ring_step).astype(np.int32), 0, config.num_rings - 1) + sectors = np.clip(np.floor(azimuth / sector_step).astype(np.int32), 0, config.num_sectors - 1) + + flat_idx = rings * config.num_sectors + sectors + np.maximum.at(descriptor.ravel(), flat_idx, z_shifted.astype(np.float32)) + return descriptor + + +def best_scan_context_distance(query: np.ndarray, candidate: np.ndarray) -> tuple[float, int]: + """Min cosine distance over all column shifts — matches cpp::best_distance. + + Returns (min_distance, best_shift). 0 = identical, 2 = opposite. + Each shift's score is mean(1 - cosine_sim) across columns whose + norms are both non-zero (matches reference's "skip empty sector" logic). + """ + num_sectors = query.shape[1] + query_norms = np.linalg.norm(query, axis=0) + candidate_norms = np.linalg.norm(candidate, axis=0) + + best_distance = 2.0 + best_shift = 0 + for shift in range(num_sectors): + # roll candidate so candidate_shifted[:, j] = candidate[:, (j + shift) % num_sectors] + shifted_norms = np.roll(candidate_norms, -shift) + valid = (query_norms > 1e-6) & (shifted_norms > 1e-6) + if not valid.any(): + continue + candidate_shifted = np.roll(candidate, -shift, axis=1) + dot_products = (query * candidate_shifted).sum(axis=0) + similarities = dot_products[valid] / (query_norms[valid] * shifted_norms[valid]) + distance = float(1.0 - similarities.mean()) + if distance < best_distance: + best_distance = distance + best_shift = shift + return best_distance, best_shift + + +def main() -> None: + parser = argparse.ArgumentParser(description="Place-recognition AP eval (KITTI-360)") + parser.add_argument("--kitti360-root", type=Path, required=True) + parser.add_argument("--sequence", type=int, default=2) + parser.add_argument( + "--max-scans", + type=int, + default=None, + help="cap total frames evaluated (default: full sequence)", + ) + parser.add_argument("--min-frame-gap", type=int, default=DEFAULT_MIN_FRAME_GAP) + parser.add_argument("--max-loop-distance-m", type=float, default=DEFAULT_MAX_LOOP_DISTANCE_M) + parser.add_argument( + "--candidate-top-k", + type=int, + default=10, + help="ring-key kd-tree prefilter size (Kim & Kim default)", + ) + parser.add_argument( + "--brute-force", + action="store_true", + help="skip ring-key prefilter; score every past candidate (slow)", + ) + args = parser.parse_args() + + config = ScanContextConfig() + + logger.info(f"Loading KITTI-360 sequence {args.sequence} from {args.kitti360_root}") + sequence = load_kitti360_sequence(args.kitti360_root, args.sequence) + frame_ids = sequence.frame_ids + if args.max_scans: + frame_ids = frame_ids[: args.max_scans] + num_frames = len(frame_ids) + logger.info(f"{num_frames} frames") + + positions = np.array([sequence.lidar_pose(frame_id)[:3, 3] for frame_id in frame_ids]) + travelled = float(np.linalg.norm(positions[-1] - positions[0])) + logger.info(f"trajectory ~{travelled:.1f}m end-to-end") + + groundtruth = compute_loop_groundtruth( + frame_ids, + positions, + min_frame_gap=args.min_frame_gap, + max_distance_m=args.max_loop_distance_m, + ) + queries_with_gt = sum(1 for v in groundtruth.valid_loops_per_query.values() if v) + total_pairs = sum(len(v) for v in groundtruth.valid_loops_per_query.values()) + logger.info( + f"GT: {queries_with_gt} queries have a valid loop " + f"(min_gap={args.min_frame_gap}, radius={args.max_loop_distance_m}m), " + f"{total_pairs} total valid pairs" + ) + + logger.info("Building SC descriptors...") + build_start = time.time() + descriptors = np.zeros((num_frames, config.num_rings, config.num_sectors), dtype=np.float32) + ring_keys = np.zeros((num_frames, config.num_rings), dtype=np.float32) + for i, frame_id in enumerate(frame_ids): + scan = sequence.scan_xyz(frame_id) + descriptors[i] = make_descriptor(scan, config) + ring_keys[i] = descriptors[i].mean(axis=1) + if (i + 1) % 500 == 0: + rate = (i + 1) / (time.time() - build_start) + logger.info(f" {i + 1}/{num_frames} ({rate:.0f} scans/s)") + logger.info(f"Built {num_frames} descriptors in {time.time() - build_start:.1f}s") + + logger.info("Computing top-1 SC matches per query...") + score_start = time.time() + top_match_distances = np.full(num_frames, 2.0, dtype=np.float64) + is_true_positive = np.zeros(num_frames, dtype=bool) + has_any_groundtruth = np.zeros(num_frames, dtype=bool) + + eval_count = 0 + for query_index, frame_id in enumerate(frame_ids): + max_candidate_index = query_index - args.min_frame_gap + if max_candidate_index < 0: + continue + eval_count += 1 + valid_set = groundtruth.valid_loops_per_query.get(frame_id, set()) + has_any_groundtruth[query_index] = bool(valid_set) + + if args.brute_force: + candidate_indices: list[int] = list(range(max_candidate_index + 1)) + else: + past_keys = ring_keys[: max_candidate_index + 1] + tree = cKDTree(past_keys) + top_k = min(args.candidate_top_k, max_candidate_index + 1) + _, neighbor_indices = tree.query(ring_keys[query_index], k=top_k) + candidate_indices = ( + [int(neighbor_indices)] + if top_k == 1 + else [int(index) for index in neighbor_indices] + ) + + best_distance = 2.0 + best_candidate_index = -1 + for candidate_index in candidate_indices: + distance, _shift = best_scan_context_distance( + descriptors[query_index], descriptors[candidate_index] + ) + if distance < best_distance: + best_distance = distance + best_candidate_index = candidate_index + + top_match_distances[query_index] = best_distance + if best_candidate_index >= 0 and frame_ids[best_candidate_index] in valid_set: + is_true_positive[query_index] = True + + if eval_count % 200 == 0: + elapsed = time.time() - score_start + logger.info( + f" scored {eval_count} queries ({eval_count / elapsed:.1f} q/s, " + f"running TP={is_true_positive.sum()}, has_gt={has_any_groundtruth.sum()})" + ) + + logger.info(f"Scoring done in {time.time() - score_start:.1f}s") + + # AP: rank queries by score = -top_match_distances (high = more confident "this is a loop") + eval_mask = np.arange(num_frames) >= args.min_frame_gap + y_true = is_true_positive[eval_mask].astype(np.int32) + y_score = -top_match_distances[eval_mask] + num_evaluated = int(eval_mask.sum()) + num_with_groundtruth = int(has_any_groundtruth[eval_mask].sum()) + num_true_positives = int(y_true.sum()) + + average_precision = ( + float(average_precision_score(y_true, y_score)) if y_true.any() else float("nan") + ) + + # Manual P/R sweep at representative SC-distance thresholds. + # At threshold T: a query is "predicted loop" iff its top_match_distances <= T. + # precision = (#predicted ∧ is_true_positive) / #predicted + # recall = (#predicted ∧ is_true_positive) / #queries_with_any_gt + eval_distances = top_match_distances[eval_mask] + pr_rows = [] + for threshold in (0.13, 0.20, 0.30, 0.40, 0.50, 0.60, 0.80, 1.00): + predicted = eval_distances <= threshold + true_positives_at_threshold = int(np.logical_and(predicted, y_true).sum()) + num_predicted = int(predicted.sum()) + precision = ( + true_positives_at_threshold / num_predicted if num_predicted > 0 else float("nan") + ) + recall = ( + true_positives_at_threshold / num_with_groundtruth + if num_with_groundtruth > 0 + else float("nan") + ) + if (precision + recall) and not np.isnan(precision + recall): + f1 = 2 * precision * recall / (precision + recall) + else: + f1 = 0.0 + pr_rows.append( + (threshold, num_predicted, true_positives_at_threshold, precision, recall, f1) + ) + + print("") + print(f"=== KITTI-360 seq {args.sequence} — Place Recognition (Scan Context) ===") + print(f"frames evaluated: {num_evaluated}") + print(f"queries with any valid GT: {num_with_groundtruth}") + print(f"top-1 matches that are TP: {num_true_positives}") + print("") + print(f"Average Precision (AP): {average_precision:.4f}") + print("") + print("PR points (SC distance threshold):") + print( + f" {'threshold':>9s} {'predicted':>9s} {'true_positives':>14s} " + f"{'precision':>9s} {'recall':>8s} {'F1':>6s}" + ) + for threshold, num_predicted, true_positives, precision, recall, f1 in pr_rows: + print( + f" {threshold:>9.2f} {num_predicted:>9d} {true_positives:>14d} " + f"{precision:>9.4f} {recall:>8.4f} {f1:>6.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/CMakeLists.txt b/dimos/navigation/nav_stack/modules/pgo/cpp/CMakeLists.txt index 8c7b6d5b94..b302bcc338 100644 --- a/dimos/navigation/nav_stack/modules/pgo/cpp/CMakeLists.txt +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/CMakeLists.txt @@ -24,6 +24,7 @@ add_definitions(-DUSE_PCL) add_executable(pgo main.cpp simple_pgo.cpp + scan_context.cpp commons.cpp ) diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/main.cpp b/dimos/navigation/nav_stack/modules/pgo/cpp/main.cpp index aa08e29048..9833c67504 100644 --- a/dimos/navigation/nav_stack/modules/pgo/cpp/main.cpp +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/main.cpp @@ -17,13 +17,20 @@ #include "commons.h" #include "simple_pgo.h" #include "dimos_native_module.hpp" +#include "msgs/Graph3D.hpp" +#include "msgs/GraphDelta3D.hpp" #include "point_cloud_utils.hpp" #include "nav_msgs/Odometry.hpp" +#include "nav_msgs/Path.hpp" #include "sensor_msgs/PointCloud2.hpp" #include "geometry_msgs/Pose.hpp" +#include "geometry_msgs/PoseStamped.hpp" #include "geometry_msgs/Quaternion.hpp" #include "geometry_msgs/Point.hpp" +#include "geometry_msgs/Transform.hpp" +#include "geometry_msgs/TransformStamped.hpp" +#include "tf2_msgs/TFMessage.hpp" static std::atomic g_running{true}; static void signal_handler(int) { g_running.store(false); } @@ -77,21 +84,55 @@ class Handlers { return; g_last_message_time = ts; - CloudWithPose cp; - cp.pose.r = g_latest_r; - cp.pose.t = g_latest_t; - cp.pose.setTime(static_cast(ts), + CloudWithPose cloud_with_pose; + cloud_with_pose.pose.r = g_latest_r; + cloud_with_pose.pose.t = g_latest_t; + cloud_with_pose.pose.setTime(static_cast(ts), static_cast((ts - static_cast(ts)) * 1e9)); // Parse PointCloud2 to PCL - cp.cloud = CloudType::Ptr(new CloudType); - smartnav::to_pcl(*msg, *cp.cloud); + cloud_with_pose.cloud = CloudType::Ptr(new CloudType); + smartnav::to_pcl(*msg, *cloud_with_pose.cloud); std::lock_guard buf_lock(g_buffer_mutex); - g_cloud_buffer.push(cp); + g_cloud_buffer.push(cloud_with_pose); } }; +static geometry_msgs::TransformStamped build_tf(const M3D& r, const V3D& t, double ts, + const std::string& frame_id, + const std::string& child_frame_id) { + geometry_msgs::TransformStamped ts_msg; + ts_msg.header = dimos::make_header(frame_id, ts); + ts_msg.child_frame_id = child_frame_id; + Eigen::Quaterniond q(r); + ts_msg.transform.translation.x = t.x(); + ts_msg.transform.translation.y = t.y(); + ts_msg.transform.translation.z = t.z(); + ts_msg.transform.rotation.x = q.x(); + ts_msg.transform.rotation.y = q.y(); + ts_msg.transform.rotation.z = q.z(); + ts_msg.transform.rotation.w = q.w(); + return ts_msg; +} + +static tf2_msgs::TFMessage build_tf_message(const M3D& correction_r, + const V3D& correction_t, + double ts, + const std::string& parent_frame, + const std::string& world_frame, + const std::string& local_frame) { + tf2_msgs::TFMessage msg; + // Identity anchor parent_frame -> world_frame. + msg.transforms.push_back( + build_tf(M3D::Identity(), V3D::Zero(), ts, parent_frame, world_frame)); + // SLAM correction world_frame -> local_frame. + msg.transforms.push_back( + build_tf(correction_r, correction_t, ts, world_frame, local_frame)); + msg.transforms_length = static_cast(msg.transforms.size()); + return msg; +} + static nav_msgs::Odometry build_odometry(const M3D& r, const V3D& t, double ts, const std::string& frame_id, const std::string& child_frame_id) { @@ -111,43 +152,136 @@ static nav_msgs::Odometry build_odometry(const M3D& r, const V3D& t, double ts, return odom; } +// Pose-graph snapshot encoded as a Graph3D: +// - one node per keyframe +static constexpr uint64_t NODE_KEYFRAME = 0; +static constexpr uint64_t EDGE_ODOMETRY = 0; +static constexpr uint64_t EDGE_LOOP_CLOSURE = 1; + +static dimos::Graph3D build_pose_graph( + const std::vector& key_poses, + const std::vector>& loop_pairs, + double ts, + const std::string& frame_id) { + dimos::Graph3D msg(frame_id, ts); + msg.reserve_nodes(key_poses.size()); + msg.reserve_edges(key_poses.size() + loop_pairs.size()); + for (size_t i = 0; i < key_poses.size(); i++) { + const auto& kp = key_poses[i]; + Eigen::Quaterniond q(kp.r_global); + msg.add_node( + static_cast(i), + NODE_KEYFRAME, + kp.time, + kp.t_global.x(), kp.t_global.y(), kp.t_global.z(), + q.x(), q.y(), q.z(), q.w()); + } + for (size_t i = 1; i < key_poses.size(); i++) { + msg.add_edge( + static_cast(i - 1), + static_cast(i), + key_poses[i].time, + EDGE_ODOMETRY); + } + for (const auto& pair : loop_pairs) { + if (pair.first >= key_poses.size() || pair.second >= key_poses.size()) continue; + msg.add_edge( + static_cast(pair.first), + static_cast(pair.second), + ts, + EDGE_LOOP_CLOSURE); + } + return msg; +} + +// Build a GraphDelta3D from paired pre/post keyframe lists. Each +// (node, transform) pair has: +// - node = the keyframe BEFORE iSAM2's smoothAndUpdate, with id = +// keyframe index and metadata_id = NODE_KEYFRAME. +// - transform = SE(3) delta such that post = transform * pre. +// Convention matches Python's GraphDelta3D.lcm_decode. +static constexpr uint64_t NODE_KEYFRAME_DELTA = 0; + +static dimos::GraphDelta3D build_loop_closure_event( + const std::vector>& pre_poses, + const std::vector& post_poses, + double ts, + const std::string& frame_id) { + dimos::GraphDelta3D msg(frame_id, ts); + size_t count = std::min(pre_poses.size(), post_poses.size()); + msg.reserve(count); + for (size_t i = 0; i < count; i++) { + const M3D& pre_r = pre_poses[i].first; + const V3D& pre_t = pre_poses[i].second; + const M3D& post_r = post_poses[i].r_global; + const V3D& post_t = post_poses[i].t_global; + + // SE(3) delta such that post = delta * pre. + M3D r_delta = post_r * pre_r.transpose(); + V3D t_delta = post_t - r_delta * pre_t; + Eigen::Quaterniond q_pre(pre_r); + Eigen::Quaterniond q_delta(r_delta); + + msg.add( + /* id */ static_cast(i), + /* metadata_id */ NODE_KEYFRAME_DELTA, + /* pose_ts */ post_poses[i].time, + /* pos_x,y,z */ pre_t.x(), pre_t.y(), pre_t.z(), + /* quat_x,y,z,w */ q_pre.x(), q_pre.y(), q_pre.z(), q_pre.w(), + /* translation_x,y,z */ t_delta.x(), t_delta.y(), t_delta.z(), + /* rotation_x,y,z,w */ q_delta.x(), q_delta.y(), q_delta.z(), q_delta.w()); + } + return msg; +} + int main(int argc, char** argv) { signal(SIGTERM, signal_handler); signal(SIGINT, signal_handler); - dimos::NativeModule mod(argc, argv); + dimos::NativeModule native_module(argc, argv); // Port topics - std::string scan_topic = mod.topic("registered_scan"); - std::string odom_topic = mod.topic("odometry"); - std::string corrected_odom_topic = mod.topic("corrected_odometry"); - std::string global_map_topic = mod.topic("global_map"); - std::string tf_topic = mod.topic("pgo_tf"); + std::string scan_topic = native_module.topic("registered_scan"); + std::string odom_topic = native_module.topic("odometry"); + std::string corrected_odom_topic = native_module.topic("corrected_odometry"); + std::string global_map_topic = native_module.topic("global_map"); + std::string tf_channel = native_module.arg("tf_channel", "/tf#tf2_msgs.TFMessage"); + std::string pose_graph_topic = native_module.topic("pose_graph"); + std::string loop_closure_event_topic = native_module.topic("loop_closure_event"); // Config parameters Config config; - config.key_pose_delta_deg = mod.arg_float("key_pose_delta_deg", 10.0f); - config.key_pose_delta_trans = mod.arg_float("key_pose_delta_trans", 0.5f); - config.loop_search_radius = mod.arg_float("loop_search_radius", 1.0f); - config.loop_time_tresh = mod.arg_float("loop_time_thresh", 60.0f); - config.loop_score_tresh = mod.arg_float("loop_score_thresh", 0.15f); - config.loop_submap_half_range = mod.arg_int("loop_submap_half_range", 5); - config.submap_resolution = mod.arg_float("submap_resolution", 0.1f); - config.min_loop_detect_duration = mod.arg_float("min_loop_detect_duration", 5.0f); + config.key_pose_delta_deg = native_module.arg_float("key_pose_delta_deg", 10.0f); + config.key_pose_delta_trans = native_module.arg_float("key_pose_delta_trans", 0.5f); + config.loop_search_radius = native_module.arg_float("loop_search_radius", 1.0f); + config.loop_time_thresh = native_module.arg_float("loop_time_thresh", 60.0f); + config.loop_score_thresh = native_module.arg_float("loop_score_thresh", 0.15f); + config.loop_submap_half_range = native_module.arg_int("loop_submap_half_range", 5); + config.submap_resolution = native_module.arg_float("submap_resolution", 0.1f); + config.min_loop_detect_duration = native_module.arg_float("min_loop_detect_duration", 5.0f); + config.use_scan_context = native_module.arg_bool("use_scan_context", true); + config.scan_context_num_rings = native_module.arg_int("scan_context_num_rings", 20); + config.scan_context_num_sectors = native_module.arg_int("scan_context_num_sectors", 60); + config.scan_context_max_range_m = native_module.arg_float("scan_context_max_range_m", 80.0f); + config.scan_context_top_k = native_module.arg_int("scan_context_top_k", 10); + config.scan_context_match_threshold = native_module.arg_float("scan_context_match_threshold", 0.4f); + config.scan_context_lidar_height_m = native_module.arg_float("scan_context_lidar_height_m", 2.0f); // Node-level config - std::string world_frame = mod.arg("world_frame", "map"); - std::string local_frame = mod.arg("local_frame", "odom"); - float global_map_voxel_size = mod.arg_float("global_map_voxel_size", 0.1f); - float global_map_publish_rate = mod.arg_float("global_map_publish_rate", 1.0f); + std::string parent_frame = native_module.arg("parent_frame", "world"); + std::string world_frame = native_module.arg("world_frame", "map"); + std::string local_frame = native_module.arg("local_frame", "odom"); + std::string body_frame = native_module.arg("body_frame", "base_link"); + float global_map_voxel_size = native_module.arg_float("global_map_voxel_size", 0.1f); + float global_map_publish_rate = native_module.arg_float("global_map_publish_rate", 1.0f); double global_map_interval = global_map_publish_rate > 0 ? 1.0 / global_map_publish_rate : 2.0; // Unregister mode: transform world-frame scans to body-frame - bool unregister_input = mod.arg_bool("unregister_input", true); + bool unregister_input = native_module.arg_bool("unregister_input", true); - bool debug = mod.arg_bool("debug", false); + bool debug = native_module.arg_bool("debug", false); pcl::console::setVerbosityLevel( debug ? pcl::console::L_INFO : pcl::console::L_ERROR); @@ -164,13 +298,32 @@ int main(int argc, char** argv) lcm.subscribe(odom_topic, &Handlers::on_odometry, &handlers); lcm.subscribe(scan_topic, &Handlers::on_registered_scan, &handlers); + // NativeModule.start() in Python reads stderr for this marker and only + // returns once it sees it. Without this, upstream publishers can race + // ahead and emit messages before our LCM subscriptions are live. + fprintf(stderr, "[DIMOS_NATIVE_READY]\n"); + fflush(stderr); + if (debug) { fprintf(stderr, "PGO native module started\n"); fprintf(stderr, " registered_scan: %s\n", scan_topic.c_str()); fprintf(stderr, " odometry: %s\n", odom_topic.c_str()); fprintf(stderr, " corrected_odometry: %s\n", corrected_odom_topic.c_str()); fprintf(stderr, " global_map: %s\n", global_map_topic.c_str()); - fprintf(stderr, " pgo_tf: %s\n", tf_topic.c_str()); + fprintf(stderr, " tf_channel: %s\n", tf_channel.c_str()); + fprintf(stderr, " pose_graph: %s\n", pose_graph_topic.c_str()); + fprintf(stderr, " loop_closure_event: %s\n", loop_closure_event_topic.c_str()); + } + // Seed identity TF so consumers can query the chain before the first + // odom message arrives. + { + double seed_ts = + std::chrono::duration( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + auto seed = build_tf_message(M3D::Identity(), V3D::Zero(), seed_ts, + parent_frame, world_frame, local_frame); + lcm.publish(tf_channel, &seed); } double last_global_map_time = 0.0; @@ -181,12 +334,12 @@ int main(int argc, char** argv) while (lcm.handleTimeout(0) > 0) {} // Check buffer - CloudWithPose cp; + CloudWithPose cloud_with_pose; bool has_data = false; { std::lock_guard lock(g_buffer_mutex); if (!g_cloud_buffer.empty()) { - cp = g_cloud_buffer.front(); + cloud_with_pose = g_cloud_buffer.front(); // Drain entire queue (matching original: process oldest, discard rest) while (!g_cloud_buffer.empty()) { g_cloud_buffer.pop(); @@ -201,13 +354,13 @@ int main(int argc, char** argv) } // Optionally transform world-frame scan to body-frame - if (unregister_input && cp.cloud && cp.cloud->size() > 0) { + if (unregister_input && cloud_with_pose.cloud && cloud_with_pose.cloud->size() > 0) { CloudType::Ptr body_cloud(new CloudType); // body = R_odom^T * (world_pts - t_odom) - M3D r_inv = cp.pose.r.transpose(); - for (const auto& pt : *cp.cloud) { + M3D r_inv = cloud_with_pose.pose.r.transpose(); + for (const auto& pt : *cloud_with_pose.cloud) { V3D world_pt(pt.x, pt.y, pt.z); - V3D body_pt = r_inv * (world_pt - cp.pose.t); + V3D body_pt = r_inv * (world_pt - cloud_with_pose.pose.t); PointType bp; bp.x = static_cast(body_pt.x()); bp.y = static_cast(body_pt.y()); @@ -215,49 +368,77 @@ int main(int argc, char** argv) bp.intensity = pt.intensity; body_cloud->push_back(bp); } - cp.cloud = body_cloud; + cloud_with_pose.cloud = body_cloud; } - double cur_time = cp.pose.second; + double cur_time = cloud_with_pose.pose.second; - if (!pgo.addKeyPose(cp)) { + if (!pgo.addKeyPose(cloud_with_pose)) { // Not a keyframe — still broadcast TF and corrected odom - M3D corr_r = pgo.offsetR() * cp.pose.r; - V3D corr_t = pgo.offsetR() * cp.pose.t + pgo.offsetT(); + M3D corr_r = pgo.offsetR() * cloud_with_pose.pose.r; + V3D corr_t = pgo.offsetR() * cloud_with_pose.pose.t + pgo.offsetT(); nav_msgs::Odometry corrected = build_odometry( - corr_r, corr_t, cur_time, world_frame, "base_link"); + corr_r, corr_t, cur_time, world_frame, body_frame); lcm.publish(corrected_odom_topic, &corrected); - nav_msgs::Odometry tf_msg = build_odometry( - pgo.offsetR(), pgo.offsetT(), cur_time, world_frame, local_frame); - lcm.publish(tf_topic, &tf_msg); + auto tf_msg = build_tf_message( + pgo.offsetR(), pgo.offsetT(), cur_time, parent_frame, world_frame, local_frame); + lcm.publish(tf_channel, &tf_msg); std::this_thread::sleep_for(std::chrono::milliseconds(timer_period_ms)); continue; } - // Keyframe added + // Keyframe added. Snapshot keyframe global poses BEFORE search + + // smooth so we can publish the delta applied by iSAM2 if a loop + // closure actually fires. pgo.searchForLoopPairs(); + bool had_loop = pgo.hasLoop(); + + std::vector> pre_poses; + if (had_loop) { + pre_poses.reserve(pgo.keyPoses().size()); + for (const auto& kp : pgo.keyPoses()) { + pre_poses.emplace_back(kp.r_global, kp.t_global); + } + } + pgo.smoothAndUpdate(); + if (had_loop) { + dimos::GraphDelta3D loop_closure_event_msg = build_loop_closure_event( + pre_poses, pgo.keyPoses(), cur_time, world_frame); + loop_closure_event_msg.publish(lcm, loop_closure_event_topic); + if (debug) { + fprintf(stderr, + "PGO: loop_closure_event published — %zu keyframe deltas\n", + pre_poses.size()); + } + } + if (debug) { fprintf(stderr, "PGO: keyframe %zu at (%.1f, %.1f, %.1f)\n", pgo.keyPoses().size(), - cp.pose.t.x(), cp.pose.t.y(), cp.pose.t.z()); + cloud_with_pose.pose.t.x(), cloud_with_pose.pose.t.y(), cloud_with_pose.pose.t.z()); } // Publish corrected odometry - M3D corr_r = pgo.offsetR() * cp.pose.r; - V3D corr_t = pgo.offsetR() * cp.pose.t + pgo.offsetT(); + M3D corr_r = pgo.offsetR() * cloud_with_pose.pose.r; + V3D corr_t = pgo.offsetR() * cloud_with_pose.pose.t + pgo.offsetT(); nav_msgs::Odometry corrected = build_odometry( - corr_r, corr_t, cur_time, world_frame, "base_link"); + corr_r, corr_t, cur_time, world_frame, body_frame); lcm.publish(corrected_odom_topic, &corrected); - // Publish TF correction (map -> odom offset) - nav_msgs::Odometry tf_msg = build_odometry( - pgo.offsetR(), pgo.offsetT(), cur_time, world_frame, local_frame); - lcm.publish(tf_topic, &tf_msg); + auto tf_msg = build_tf_message( + pgo.offsetR(), pgo.offsetT(), cur_time, parent_frame, world_frame, local_frame); + lcm.publish(tf_channel, &tf_msg); + + // Publish pose graph (on every keyframe — iSAM2 may have + // re-optimized prior poses on loop closure). + dimos::Graph3D pose_graph_msg = build_pose_graph( + pgo.keyPoses(), pgo.historyPairs(), cur_time, world_frame); + pose_graph_msg.publish(lcm, pose_graph_topic); // Publish global map (throttled) double now = cur_time; diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/Graph3D.hpp b/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/Graph3D.hpp new file mode 100644 index 0000000000..a440127ae2 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/Graph3D.hpp @@ -0,0 +1,175 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Typed C++ helper mirroring the Python `dimos.msgs.nav_msgs.Graph3D`. +// Canonical schema lives in `dimos/msgs/nav_msgs/Graph3D.ksy` — keep +// encode() in sync with that file (and with Graph3D.py.lcm_decode). +// +// Wire format (big-endian): +// +// uint64 edge_count +// uint64 node_count +// double timestamp // seconds since epoch +// per node (node_count): +// pose_stamped: +// double ts +// uint32 frame_id_len +// bytes frame_id (utf-8, no terminator) +// 7×double pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w +// uint64 id +// uint64 metadata_id +// per edge (edge_count): +// uint64 start_id +// uint64 end_id +// double timestamp +// uint64 metadata_id +// +// Edges reference nodes by `id`, not by index. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace dimos { + +namespace graph3d_detail { + +// Host-order → big-endian byte writers. Avoid for portability +// (macOS uses different names) — write byte-by-byte from the top. + +inline void write_u32_be(std::vector& out, uint32_t v) { + out.push_back(static_cast((v >> 24) & 0xFF)); + out.push_back(static_cast((v >> 16) & 0xFF)); + out.push_back(static_cast((v >> 8) & 0xFF)); + out.push_back(static_cast( v & 0xFF)); +} + +inline void write_u64_be(std::vector& out, uint64_t v) { + for (int shift = 56; shift >= 0; shift -= 8) { + out.push_back(static_cast((v >> shift) & 0xFF)); + } +} + +inline void write_double_be(std::vector& out, double v) { + uint64_t bits; + std::memcpy(&bits, &v, sizeof(bits)); + write_u64_be(out, bits); +} + +inline void write_bytes(std::vector& out, const std::string& s) { + out.insert(out.end(), s.begin(), s.end()); +} + +} // namespace graph3d_detail + +class Graph3D { +public: + struct PoseStamped { + double ts = 0.0; + std::string frame_id; + double pos_x = 0.0, pos_y = 0.0, pos_z = 0.0; + double quat_x = 0.0, quat_y = 0.0, quat_z = 0.0, quat_w = 1.0; + }; + + struct Node3D { + PoseStamped pose; + uint64_t id = 0; + uint64_t metadata_id = 0; + }; + + struct Edge { + uint64_t start_id = 0; + uint64_t end_id = 0; + double timestamp = 0.0; + uint64_t metadata_id = 0; + }; + + Graph3D(std::string frame_id, double timestamp) + : frame_id_(std::move(frame_id)), timestamp_(timestamp) {} + + void reserve_nodes(size_t capacity) { nodes_.reserve(capacity); } + void reserve_edges(size_t capacity) { edges_.reserve(capacity); } + + // Add a node. The pose's frame_id defaults to the graph's frame_id — + // override per-node only if a node lives in a different frame. + void add_node(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z, + double quat_x, double quat_y, double quat_z, double quat_w, + std::string node_frame_id = "") { + PoseStamped pose; + pose.ts = pose_ts; + pose.frame_id = node_frame_id.empty() ? frame_id_ : std::move(node_frame_id); + pose.pos_x = pos_x; pose.pos_y = pos_y; pose.pos_z = pos_z; + pose.quat_x = quat_x; pose.quat_y = quat_y; pose.quat_z = quat_z; pose.quat_w = quat_w; + nodes_.push_back({pose, id, metadata_id}); + } + + // Position-only convenience (orientation defaults to identity). + void add_node_xyz(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z) { + add_node(id, metadata_id, pose_ts, pos_x, pos_y, pos_z, 0.0, 0.0, 0.0, 1.0); + } + + void add_edge(uint64_t start_id, uint64_t end_id, double edge_ts, + uint64_t metadata_id = 0) { + edges_.push_back({start_id, end_id, edge_ts, metadata_id}); + } + + size_t node_count() const { return nodes_.size(); } + size_t edge_count() const { return edges_.size(); } + const std::string& frame_id() const { return frame_id_; } + + std::vector encode() const { + using namespace graph3d_detail; + std::vector out; + // Conservative reservation: header + per-node fixed bytes + per-edge. + // frame_id strings add variable length on top — that just causes a + // realloc, not correctness issues. + out.reserve(24 + nodes_.size() * 84 + edges_.size() * 32); + write_u64_be(out, static_cast(edges_.size())); + write_u64_be(out, static_cast(nodes_.size())); + write_double_be(out, timestamp_); + for (const auto& n : nodes_) { + // pose_stamped first (per Graph3D.ksy) + write_double_be(out, n.pose.ts); + write_u32_be(out, static_cast(n.pose.frame_id.size())); + write_bytes(out, n.pose.frame_id); + write_double_be(out, n.pose.pos_x); + write_double_be(out, n.pose.pos_y); + write_double_be(out, n.pose.pos_z); + write_double_be(out, n.pose.quat_x); + write_double_be(out, n.pose.quat_y); + write_double_be(out, n.pose.quat_z); + write_double_be(out, n.pose.quat_w); + // then id, metadata_id + write_u64_be(out, n.id); + write_u64_be(out, n.metadata_id); + } + for (const auto& e : edges_) { + write_u64_be(out, e.start_id); + write_u64_be(out, e.end_id); + write_double_be(out, e.timestamp); + write_u64_be(out, e.metadata_id); + } + return out; + } + + int publish(lcm::LCM& lcm, const std::string& channel) const { + std::vector bytes = encode(); + return lcm.publish(channel, bytes.data(), static_cast(bytes.size())); + } + +private: + std::string frame_id_; + double timestamp_; + std::vector nodes_; + std::vector edges_; +}; + +} // namespace dimos diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/GraphDelta3D.hpp b/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/GraphDelta3D.hpp new file mode 100644 index 0000000000..4db42eb71c --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/msgs/GraphDelta3D.hpp @@ -0,0 +1,168 @@ +// Copyright 2026 Dimensional Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Typed C++ helper mirroring the Python `dimos.msgs.nav_msgs.GraphDelta3D`. +// +// Wire format (big-endian): +// +// uint64 node_count +// double timestamp // seconds since epoch +// per node (node_count): +// pose_stamped: // (same as Graph3D's node3d pose) +// double ts +// uint32 frame_id_len +// bytes frame_id (utf-8, no terminator) +// 7×double pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w +// uint64 id +// uint64 metadata_id +// per transform (node_count): +// 7×double translation_x, translation_y, translation_z, +// rotation_x, rotation_y, rotation_z, rotation_w +// +// Two aligned arrays: ``transforms[i]`` is the SE(3) delta about to +// be applied to ``nodes[i]``. ``post_pose = transforms[i] * nodes[i].pose`` +// is the convention (left-multiply). +// +// `GraphDelta3D.py.lcm_decode` reads exactly this layout — keep in sync. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace dimos { + +namespace graph_delta3d_detail { + +inline void write_u32_be(std::vector& out, uint32_t v) { + out.push_back(static_cast((v >> 24) & 0xFF)); + out.push_back(static_cast((v >> 16) & 0xFF)); + out.push_back(static_cast((v >> 8) & 0xFF)); + out.push_back(static_cast( v & 0xFF)); +} + +inline void write_u64_be(std::vector& out, uint64_t v) { + for (int shift = 56; shift >= 0; shift -= 8) { + out.push_back(static_cast((v >> shift) & 0xFF)); + } +} + +inline void write_double_be(std::vector& out, double v) { + uint64_t bits; + std::memcpy(&bits, &v, sizeof(bits)); + write_u64_be(out, bits); +} + +inline void write_bytes(std::vector& out, const std::string& s) { + out.insert(out.end(), s.begin(), s.end()); +} + +} // namespace graph_delta3d_detail + +class GraphDelta3D { +public: + struct PoseStamped { + double ts = 0.0; + std::string frame_id; + double pos_x = 0.0, pos_y = 0.0, pos_z = 0.0; + double quat_x = 0.0, quat_y = 0.0, quat_z = 0.0, quat_w = 1.0; + }; + + struct Node3D { + PoseStamped pose; + uint64_t id = 0; + uint64_t metadata_id = 0; + }; + + struct Transform { + double translation_x = 0.0, translation_y = 0.0, translation_z = 0.0; + double rotation_x = 0.0, rotation_y = 0.0, rotation_z = 0.0, rotation_w = 1.0; + }; + + GraphDelta3D(std::string frame_id, double timestamp) + : frame_id_(std::move(frame_id)), timestamp_(timestamp) {} + + void reserve(size_t capacity) { + nodes_.reserve(capacity); + transforms_.reserve(capacity); + } + + // Add a node + its SE(3) delta. Pass empty `node_frame_id` to inherit + // the graph's frame_id. + void add(uint64_t id, uint64_t metadata_id, double pose_ts, + double pos_x, double pos_y, double pos_z, + double quat_x, double quat_y, double quat_z, double quat_w, + double translation_x, double translation_y, double translation_z, + double rotation_x, double rotation_y, double rotation_z, double rotation_w, + std::string node_frame_id = "") { + Node3D node; + node.id = id; + node.metadata_id = metadata_id; + node.pose.ts = pose_ts; + node.pose.frame_id = node_frame_id.empty() ? frame_id_ : std::move(node_frame_id); + node.pose.pos_x = pos_x; node.pose.pos_y = pos_y; node.pose.pos_z = pos_z; + node.pose.quat_x = quat_x; node.pose.quat_y = quat_y; + node.pose.quat_z = quat_z; node.pose.quat_w = quat_w; + nodes_.push_back(node); + + Transform tf; + tf.translation_x = translation_x; tf.translation_y = translation_y; tf.translation_z = translation_z; + tf.rotation_x = rotation_x; tf.rotation_y = rotation_y; + tf.rotation_z = rotation_z; tf.rotation_w = rotation_w; + transforms_.push_back(tf); + } + + size_t size() const { return nodes_.size(); } + bool empty() const { return nodes_.empty(); } + const std::string& frame_id() const { return frame_id_; } + + std::vector encode() const { + using namespace graph_delta3d_detail; + std::vector out; + out.reserve(16 + nodes_.size() * (84 + 56)); + write_u64_be(out, static_cast(nodes_.size())); + write_double_be(out, timestamp_); + for (const auto& n : nodes_) { + write_double_be(out, n.pose.ts); + write_u32_be(out, static_cast(n.pose.frame_id.size())); + write_bytes(out, n.pose.frame_id); + write_double_be(out, n.pose.pos_x); + write_double_be(out, n.pose.pos_y); + write_double_be(out, n.pose.pos_z); + write_double_be(out, n.pose.quat_x); + write_double_be(out, n.pose.quat_y); + write_double_be(out, n.pose.quat_z); + write_double_be(out, n.pose.quat_w); + write_u64_be(out, n.id); + write_u64_be(out, n.metadata_id); + } + for (const auto& t : transforms_) { + write_double_be(out, t.translation_x); + write_double_be(out, t.translation_y); + write_double_be(out, t.translation_z); + write_double_be(out, t.rotation_x); + write_double_be(out, t.rotation_y); + write_double_be(out, t.rotation_z); + write_double_be(out, t.rotation_w); + } + return out; + } + + int publish(lcm::LCM& lcm, const std::string& channel) const { + std::vector bytes = encode(); + return lcm.publish(channel, bytes.data(), static_cast(bytes.size())); + } + +private: + std::string frame_id_; + double timestamp_; + std::vector nodes_; + std::vector transforms_; +}; + +} // namespace dimos diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.cpp b/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.cpp new file mode 100644 index 0000000000..370cd20b64 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.cpp @@ -0,0 +1,119 @@ +#include "scan_context.h" + +#include +#include + +namespace scan_context { + +Descriptor make_descriptor(const CloudType& cloud, const Config& config) { + // Empty cells stay at 0; we shift z by lidar_height so real points + // are strictly positive and "no point here" is distinguishable from + // ground level. Matches irapkaist/scancontext's NO_POINT convention + // closely enough that the column-wise cosine distance behaves. + Descriptor descriptor = Descriptor::Constant(config.n_rings, config.n_sectors, 0.0f); + if (config.n_rings <= 0 || config.n_sectors <= 0 || config.max_range_m <= 0.0) { + return descriptor; + } + + const double ring_step = config.max_range_m / static_cast(config.n_rings); + const double sector_step = 2.0 * M_PI / static_cast(config.n_sectors); + const float height_offset = static_cast(config.lidar_height_m); + + for (const auto& point : cloud.points) { + const double x = point.x; + const double y = point.y; + const double z = point.z; + + const double range = std::sqrt(x * x + y * y); + if (range >= config.max_range_m || range <= 1e-6) { + continue; + } + + int ring = static_cast(std::floor(range / ring_step)); + if (ring < 0 || ring >= config.n_rings) { + continue; + } + + double azimuth = std::atan2(y, x); + if (azimuth < 0.0) { + azimuth += 2.0 * M_PI; + } + int sector = static_cast(std::floor(azimuth / sector_step)); + if (sector < 0) sector = 0; + if (sector >= config.n_sectors) sector = config.n_sectors - 1; + + const float shifted_z = static_cast(z) + height_offset; + // Clip to >= 0 — points slightly below the sensor frame (rare in + // properly-mounted lidars) shouldn't pull the cell negative. + const float cell_value = shifted_z > 0.0f ? shifted_z : 0.0f; + float& cell = descriptor(ring, sector); + if (cell_value > cell) { + cell = cell_value; + } + } + return descriptor; +} + +RingKey make_ring_key(const Descriptor& descriptor) { + RingKey key = RingKey::Zero(descriptor.rows()); + if (descriptor.cols() == 0) return key; + for (int i = 0; i < descriptor.rows(); i++) { + key(i) = descriptor.row(i).mean(); + } + return key; +} + +SectorKey make_sector_key(const Descriptor& descriptor) { + SectorKey key = SectorKey::Zero(descriptor.cols()); + if (descriptor.rows() == 0) return key; + for (int j = 0; j < descriptor.cols(); j++) { + key(j) = descriptor.col(j).mean(); + } + return key; +} + +float column_cosine_distance(const Descriptor& query, + const Descriptor& candidate, + int shift) { + if (query.rows() != candidate.rows() || query.cols() != candidate.cols()) { + return 2.0f; + } + const int cols = static_cast(query.cols()); + if (cols == 0) return 2.0f; + + float total = 0.0f; + int valid_cols = 0; + for (int j = 0; j < cols; j++) { + const int shifted_j = ((j + shift) % cols + cols) % cols; + const auto query_column = query.col(j); + const auto candidate_column = candidate.col(shifted_j); + const float query_norm = query_column.norm(); + const float candidate_norm = candidate_column.norm(); + if (query_norm <= 1e-6f || candidate_norm <= 1e-6f) { + continue; + } + const float cos_sim = query_column.dot(candidate_column) / + (query_norm * candidate_norm); + total += (1.0f - cos_sim); + valid_cols++; + } + if (valid_cols == 0) return 2.0f; + return total / static_cast(valid_cols); +} + +std::pair best_distance(const Descriptor& query, + const Descriptor& candidate) { + const int cols = static_cast(query.cols()); + float min_distance = 2.0f; + int best_shift = 0; + for (int shift = 0; shift < cols; shift++) { + const float distance = column_cosine_distance(query, candidate, shift); + if (distance < min_distance) { + min_distance = distance; + best_shift = shift; + } + } + return {min_distance, best_shift}; +} + +} // namespace scan_context diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.h b/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.h new file mode 100644 index 0000000000..803626ba8b --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/scan_context.h @@ -0,0 +1,74 @@ +// Scan Context — polar-binned lidar place-recognition descriptor. +// +// Each scan becomes an (N_rings × N_sectors) matrix where cell [i, j] +// holds the max z value among points falling in the (range, azimuth) +// bin. The "ring key" — the per-row mean — is the coarse feature used +// for fast kd-tree retrieval; the full matrix is then column-shifted +// against the candidate to measure rotation-invariant cosine distance. +// +// Inspired by Kim & Kim 2018 "Scan Context: Egocentric Spatial +// Descriptor for Place Recognition within 3D Point Cloud Map" and the +// reference implementation at github.com/irapkaist/scancontext (MIT). +// Reimplemented locally to keep the PGO module self-contained and to +// avoid the OpenCV/external-yaml deps the upstream version carries. + +#pragma once + +#include "commons.h" + +#include + +#include + +namespace scan_context { + +struct Config { + int n_rings = 20; // radial bins + int n_sectors = 60; // azimuth bins + double max_range_m = 80.0; // ignore points beyond this + int candidate_top_k = 10; // kd-tree neighbours to score + double match_threshold = 0.4; // accepted cosine distance (0..2) + // Shifts body-frame z so all cells are positive before cosine distance, + // matching irapkaist/scancontext's LIDAR_HEIGHT convention. Ground points + // sit near -lidar_height_m in the body frame; without this shift, negative + // cells make cosine similarity meaningless for revisits. + double lidar_height_m = 2.0; +}; + +using Descriptor = Eigen::MatrixXf; // (n_rings × n_sectors) +using RingKey = Eigen::VectorXf; // length n_rings +using SectorKey = Eigen::VectorXf; // length n_sectors + +// Build the polar-max-z descriptor for a body-frame scan. Points +// outside ``max_range_m`` or with negative ring index are ignored. +Descriptor make_descriptor(const CloudType& cloud, const Config& config); + +// Mean per row — the coarse feature used for kd-tree retrieval. +RingKey make_ring_key(const Descriptor& descriptor); + +// Mean per column — only used for the optional sector-key alignment. +SectorKey make_sector_key(const Descriptor& descriptor); + +// Cosine distance between two descriptors after column-shifting +// ``candidate`` by ``shift`` columns. 0 = identical, 2 = opposite. +float column_cosine_distance(const Descriptor& query, + const Descriptor& candidate, + int shift); + +// Best (min-distance, best-shift) pair across all column shifts. +// Returns {distance, shift_columns}. To recover yaw rotation from the +// shift: yaw_rad = -2*M_PI * shift / n_sectors. +std::pair best_distance(const Descriptor& query, + const Descriptor& candidate); + +// Convert sector shift to yaw rotation (radians). +// shift comes from best_distance, which scans [0, n_sectors-1], so +// the raw yaw lies in (-2pi, 0]; wrap into [-pi, pi]. +inline double yaw_from_shift(int shift, int n_sectors) { + double yaw = -2.0 * M_PI * static_cast(shift) / + static_cast(n_sectors); + if (yaw < -M_PI) yaw += 2.0 * M_PI; + return yaw; +} + +} // namespace scan_context diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.cpp b/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.cpp index 5fc18bf0e7..d257a58437 100644 --- a/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.cpp +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.cpp @@ -1,5 +1,8 @@ #include "simple_pgo.h" +#include +#include + SimplePGO::SimplePGO(const Config &config) : m_config(config) { gtsam::ISAM2Params isam2_params; @@ -16,6 +19,13 @@ SimplePGO::SimplePGO(const Config &config) : m_config(config) m_icp.setTransformationEpsilon(1e-6); m_icp.setEuclideanFitnessEpsilon(1e-6); m_icp.setRANSACIterations(0); + + m_scan_context_config.n_rings = m_config.scan_context_num_rings; + m_scan_context_config.n_sectors = m_config.scan_context_num_sectors; + m_scan_context_config.max_range_m = m_config.scan_context_max_range_m; + m_scan_context_config.candidate_top_k = m_config.scan_context_top_k; + m_scan_context_config.match_threshold = m_config.scan_context_match_threshold; + m_scan_context_config.lidar_height_m = m_config.scan_context_lidar_height_m; } bool SimplePGO::isKeyPose(const PoseWithTime &pose) @@ -62,6 +72,18 @@ bool SimplePGO::addKeyPose(const CloudWithPose &cloud_with_pose) item.r_global = init_r; item.t_global = init_t; m_key_poses.push_back(item); + + // Cache the Scan Context descriptor + ring-key for this keyframe. + if (cloud_with_pose.cloud) { + scan_context::Descriptor descriptor = + scan_context::make_descriptor(*cloud_with_pose.cloud, m_scan_context_config); + m_scan_context_ring_keys.push_back(scan_context::make_ring_key(descriptor)); + m_scan_context_descriptors.push_back(std::move(descriptor)); + } else { + m_scan_context_descriptors.emplace_back(); + m_scan_context_ring_keys.emplace_back(); + } + return true; } @@ -90,21 +112,8 @@ CloudType::Ptr SimplePGO::getSubMap(int idx, int half_range, double resolution) return ret; } -void SimplePGO::searchForLoopPairs() +int SimplePGO::searchByPosition() const { - if (m_key_poses.size() < 10) - return; - if (m_config.min_loop_detect_duration > 0.0) - { - if (m_history_pairs.size() > 0) - { - double current_time = m_key_poses.back().time; - double last_time = m_key_poses[m_history_pairs.back().second].time; - if (current_time - last_time < m_config.min_loop_detect_duration) - return; - } - } - size_t cur_idx = m_key_poses.size() - 1; const KeyPoseWithCloud &last_item = m_key_poses.back(); pcl::PointXYZ last_pose_pt; @@ -113,7 +122,7 @@ void SimplePGO::searchForLoopPairs() last_pose_pt.z = last_item.t_global(2); pcl::PointCloud::Ptr key_poses_cloud(new pcl::PointCloud); - for (size_t i = 0; i < m_key_poses.size() - 1; i++) + for (size_t i = 0; i < cur_idx; i++) { pcl::PointXYZ pt; pt.x = m_key_poses[i].t_global(0); @@ -127,31 +136,136 @@ void SimplePGO::searchForLoopPairs() std::vector sqdists; int neighbors = kdtree.radiusSearch(last_pose_pt, m_config.loop_search_radius, ids, sqdists); if (neighbors == 0) - return; + return -1; - int loop_idx = -1; for (size_t i = 0; i < ids.size(); i++) { int idx = ids[i]; - if (std::abs(last_item.time - m_key_poses[idx].time) > m_config.loop_time_tresh) + if (std::abs(last_item.time - m_key_poses[idx].time) > m_config.loop_time_thresh) + { + return idx; + } + } + return -1; +} + +int SimplePGO::searchByScanContext(int& out_sector_shift) const +{ + out_sector_shift = 0; + if (m_scan_context_descriptors.empty() || m_scan_context_descriptors.back().size() == 0) { + return -1; + } + const auto& query = m_scan_context_descriptors.back(); + const auto& query_key = m_scan_context_ring_keys.back(); + const double current_time = m_key_poses.back().time; + + // Two-stage retrieval: first rank candidates by ring-key L2 distance + // (fast coarse filter), then score the top-K via column-shifted cosine + // distance on the full descriptor. + std::vector> ranked; // (ring-key dist, idx) + ranked.reserve(m_scan_context_descriptors.size()); + const size_t cur_idx = m_key_poses.size() - 1; + for (size_t i = 0; i < cur_idx; i++) { + if (m_scan_context_descriptors[i].size() == 0) continue; + if (std::abs(current_time - m_key_poses[i].time) <= m_config.loop_time_thresh) { + continue; // too recent — not a true loop candidate + } + const float key_dist = (m_scan_context_ring_keys[i] - query_key).norm(); + ranked.emplace_back(key_dist, static_cast(i)); + } + if (ranked.empty()) return -1; + + const int top_k_count = std::min( + static_cast(ranked.size()), m_scan_context_config.candidate_top_k); + std::partial_sort( + ranked.begin(), ranked.begin() + top_k_count, ranked.end(), + [](const std::pair& a, const std::pair& b) { + return a.first < b.first; + }); + + float best_dist = std::numeric_limits::max(); + int best_idx_unfiltered = -1; + float best_dist_filtered = static_cast(m_scan_context_config.match_threshold); + int best_idx = -1; + int best_shift = 0; + for (int rank = 0; rank < top_k_count; rank++) { + const int idx = ranked[rank].second; + const auto [distance, shift] = scan_context::best_distance( + query, m_scan_context_descriptors[idx]); + if (distance < best_dist) { + best_dist = distance; + best_idx_unfiltered = idx; + } + if (distance < best_dist_filtered) { + best_dist_filtered = distance; + best_idx = idx; + best_shift = shift; + } + } + + out_sector_shift = best_shift; + return best_idx; +} + +void SimplePGO::searchForLoopPairs() +{ + if (m_key_poses.size() < 10) + return; + if (m_config.min_loop_detect_duration > 0.0) + { + if (m_history_pairs.size() > 0) { - loop_idx = idx; - break; + double current_time = m_key_poses.back().time; + double last_time = m_key_poses[m_history_pairs.back().second].time; + if (current_time - last_time < m_config.min_loop_detect_duration) + return; } } - if (loop_idx == -1) + size_t cur_idx = m_key_poses.size() - 1; + + int loop_idx = -1; + int sector_shift = 0; + if (m_config.use_scan_context) { + loop_idx = searchByScanContext(sector_shift); + } + if (loop_idx < 0) { + // Fallback (or sole path if SC disabled): kdtree on past positions. + loop_idx = searchByPosition(); + } + + if (loop_idx < 0) return; + // Use Scan Context's column shift to seed ICP with a yaw-aligned initial + // guess, which dramatically improves convergence on revisits at + // different headings. Both submaps are in *global* frame, so a naive + // rotation about the world origin would translate the source cloud + // kilometers off (e.g. at world position (3500, 350), rotating by 90° + // sends it to (-350, 3500)). Build a transform that rotates about the + // source keyframe's own global position instead: + // init = T(source_position) · Rz(yaw) · T(-source_position) + // → init · p = rotation · (p - source_position) + source_position + Eigen::Matrix4f init_guess = Eigen::Matrix4f::Identity(); + if (m_config.use_scan_context && sector_shift != 0) { + const double yaw = scan_context::yaw_from_shift(sector_shift, m_scan_context_config.n_sectors); + Eigen::AngleAxisf yaw_axis_angle( + static_cast(yaw), Eigen::Vector3f::UnitZ()); + Eigen::Matrix3f rotation = yaw_axis_angle.toRotationMatrix(); + Eigen::Vector3f source_position = m_key_poses[cur_idx].t_global.cast(); + init_guess.block<3, 3>(0, 0) = rotation; + init_guess.block<3, 1>(0, 3) = source_position - rotation * source_position; + } + CloudType::Ptr target_cloud = getSubMap(loop_idx, m_config.loop_submap_half_range, m_config.submap_resolution); - CloudType::Ptr source_cloud = getSubMap(m_key_poses.size() - 1, 0, m_config.submap_resolution); + CloudType::Ptr source_cloud = getSubMap(cur_idx, 0, m_config.submap_resolution); CloudType::Ptr align_cloud(new CloudType); m_icp.setInputSource(source_cloud); m_icp.setInputTarget(target_cloud); - m_icp.align(*align_cloud); + m_icp.align(*align_cloud, init_guess); - if (!m_icp.hasConverged() || m_icp.getFitnessScore() > m_config.loop_score_tresh) + if (!m_icp.hasConverged() || m_icp.getFitnessScore() > m_config.loop_score_thresh) return; M4F loop_transform = m_icp.getFinalTransformation(); diff --git a/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.h b/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.h index 7f80c5b09a..7c9f08bf94 100644 --- a/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.h +++ b/dimos/navigation/nav_stack/modules/pgo/cpp/simple_pgo.h @@ -1,5 +1,6 @@ #pragma once #include "commons.h" +#include "scan_context.h" #include #include #include @@ -35,11 +36,20 @@ struct Config double key_pose_delta_deg = 10; double key_pose_delta_trans = 1.0; double loop_search_radius = 1.0; - double loop_time_tresh = 60.0; - double loop_score_tresh = 0.15; + double loop_time_thresh = 60.0; + double loop_score_thresh = 0.15; int loop_submap_half_range = 5; double submap_resolution = 0.1; double min_loop_detect_duration = 10.0; + + // Scan Context settings + bool use_scan_context = true; + int scan_context_num_rings = 20; + int scan_context_num_sectors = 60; + double scan_context_max_range_m = 80.0; + int scan_context_top_k = 10; + double scan_context_match_threshold = 0.4; + double scan_context_lidar_height_m = 2.0; }; class SimplePGO @@ -64,11 +74,24 @@ class SimplePGO M3D offsetR() { return m_r_offset; } V3D offsetT() { return m_t_offset; } + // Place recognition exposed for diagnostics / persistence. + const std::vector& descriptors() const { return m_scan_context_descriptors; } + const std::vector& ringKeys() const { return m_scan_context_ring_keys; } + private: + // Scan-context-based candidate search; returns -1 if no acceptable match. + int searchByScanContext(int& out_sector_shift) const; + // Original position-based fallback (radius search on past key-pose + // positions). Kept for ablation + when scan context is disabled. + int searchByPosition() const; + Config m_config; + scan_context::Config m_scan_context_config; std::vector m_key_poses; std::vector> m_history_pairs; std::vector m_cache_pairs; + std::vector m_scan_context_descriptors; + std::vector m_scan_context_ring_keys; M3D m_r_offset; V3D m_t_offset; std::shared_ptr m_isam2; diff --git a/dimos/navigation/nav_stack/modules/pgo/pgo.py b/dimos/navigation/nav_stack/modules/pgo/pgo.py index becf3f5b7a..79d05d3488 100644 --- a/dimos/navigation/nav_stack/modules/pgo/pgo.py +++ b/dimos/navigation/nav_stack/modules/pgo/pgo.py @@ -20,19 +20,15 @@ from __future__ import annotations from pathlib import Path -import time - -from reactivex.disposable import Disposable from dimos.core.core import rpc from dimos.core.native_module import NativeModule, NativeModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.nav_stack.frames import FRAME_MAP, FRAME_ODOM +from dimos.navigation.nav_stack.specs import LoopClosure from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -43,9 +39,17 @@ class PGOConfig(NativeModuleConfig): executable: str = "result/bin/pgo" build_command: str | None = "nix build .#default --no-write-lock-file" - # Frame names - world_frame: str = FRAME_MAP - local_frame: str = FRAME_ODOM + frame_id: str = "map" + child_frame_id: str = "start_point" + parent_frame: str = "world" + body_frame: str = "current_point" + tf_channel: str = "/tf#tf2_msgs.TFMessage" + + # The C++ binary's CLI args use the legacy frame names. + cli_name_override: dict[str, str] = { + "frame_id": "world_frame", + "child_frame_id": "local_frame", + } # Keyframe detection key_pose_delta_deg: float = 10.0 @@ -66,10 +70,19 @@ class PGOConfig(NativeModuleConfig): global_map_voxel_size: float = 0.1 global_map_publish_rate: float = 1.0 + # Scan Context place recognition (used by loop closure search) + use_scan_context: bool = True + scan_context_num_rings: int = 20 + scan_context_num_sectors: int = 60 + scan_context_max_range_m: float = 80.0 + scan_context_top_k: int = 10 + scan_context_match_threshold: float = 0.4 + scan_context_lidar_height_m: float = 2.0 + debug: bool = False -class PGO(NativeModule): +class PGO(NativeModule, LoopClosure): """Pose graph optimization with loop closure using GTSAM iSAM2 + PCL ICP.""" config: PGOConfig @@ -78,55 +91,15 @@ class PGO(NativeModule): odometry: In[Odometry] corrected_odometry: Out[Odometry] global_map: Out[PointCloud2] - pgo_tf: Out[Odometry] + pose_graph: Out[Graph3D] + loop_closure_event: Out[GraphDelta3D] @rpc def start(self) -> None: super().start() - self.register_disposable( - Disposable(self.pgo_tf.transport.subscribe(self._on_tf_correction, self.pgo_tf)) - ) - # Seed identity TF so consumers can query map->body immediately. - self._publish_tf( - translation=(0.0, 0.0, 0.0), - rotation=(0.0, 0.0, 0.0, 1.0), - ts=time.time(), - ) if self.config.debug: logger.info("PGO native module started (C++ iSAM2 + PCL ICP)") @rpc def stop(self) -> None: super().stop() - - def _on_tf_correction(self, msg: Odometry) -> None: - self._publish_tf( - translation=( - msg.pose.position.x, - msg.pose.position.y, - msg.pose.position.z, - ), - rotation=( - msg.pose.orientation.x, - msg.pose.orientation.y, - msg.pose.orientation.z, - msg.pose.orientation.w, - ), - ts=msg.ts or time.time(), - ) - - def _publish_tf( - self, - translation: tuple[float, float, float], - rotation: tuple[float, float, float, float], - ts: float, - ) -> None: - self.tf.publish( - Transform( - frame_id=self.config.world_frame, - child_frame_id=self.config.local_frame, - translation=Vector3(*translation), - rotation=Quaternion(*rotation), - ts=ts, - ) - ) diff --git a/dimos/navigation/nav_stack/modules/pgo/test_pgo_loop_closure.py b/dimos/navigation/nav_stack/modules/pgo/test_pgo_loop_closure.py new file mode 100644 index 0000000000..299c40a743 --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/test_pgo_loop_closure.py @@ -0,0 +1,228 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end check that PGO publishes valid loop_closure_event messages. + +Replays the ``og_nav_60s`` rosbag through PGO with aggressive +loop-closure thresholds and asserts each emitted ``loop_closure_event`` +(a ``GraphDelta3D``) has positive-shape per-node SE(3) deltas with +unit-norm quaternions and finite translations. Wired with the DimOS +Module + Blueprint pipeline so no LCM topic strings live here. + +If the bag doesn't trigger any loop the test skips — the rosbag +trajectory is data-dependent, not a code defect. +""" + +from __future__ import annotations + +import math +import time +from typing import Any + +import pytest +from reactivex.disposable import Disposable + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.navigation.nav_stack.modules.pgo.pgo import PGO +from dimos.navigation.nav_stack.tests.rosbag_fixtures import ( + RosbagScanOdomPlaybackModule, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +pytestmark = [pytest.mark.self_hosted, pytest.mark.skipif_no_nix, pytest.mark.skipif_macos_bug] + +POST_FEED_DRAIN_SEC = 5.0 +POLL_INTERVAL_SEC = 0.25 + +QUATERNION_UNIT_TOL = 0.05 +TRANSLATION_MAX_M = 100.0 + + +class LoopClosureEventRecorderModule(Module): + """Accumulates every loop_closure_event so the test can validate the shape.""" + + loop_closure_event: In[GraphDelta3D] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._events: list[dict[str, Any]] = [] + + @rpc + def start(self) -> None: + super().start() + self.register_disposable( + Disposable(self.loop_closure_event.subscribe(self._on_loop_closure_event)) + ) + + def _on_loop_closure_event(self, message: GraphDelta3D) -> None: + # JSON-friendly snapshot — Pydantic-friendly RPC return. + self._events.append( + { + "ts": message.ts, + "transforms": [ + { + "translation": ( + transform.translation.x, + transform.translation.y, + transform.translation.z, + ), + "rotation": ( + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ), + } + for transform in message.transforms + ], + } + ) + first_transform = message.transforms[0] if message.transforms else None + first_summary = ( + f"first=t=({first_transform.translation.x:.3f},{first_transform.translation.y:.3f}," + f"{first_transform.translation.z:.3f}) " + f"q=({first_transform.rotation.x:.3f},{first_transform.rotation.y:.3f}," + f"{first_transform.rotation.z:.3f},{first_transform.rotation.w:.3f})" + if first_transform + else "" + ) + logger.info( + f"[loop_closure_event] event #{len(self._events) - 1} received: " + f"node_count={len(message.nodes)}, ts={message.ts:.3f}, {first_summary}" + ) + + @rpc + def events(self) -> list[dict[str, Any]]: + return list(self._events) + + +def _validate_loop_closure_event(event: dict[str, Any], event_index: int) -> tuple[float, float]: + """Assert each transform has unit-norm rotation + finite translation. + + Returns aggregate ``(max_translation_norm, max_quaternion_drift)`` stats. + """ + transforms = event["transforms"] + assert len(transforms) > 0, f"event {event_index}: loop-closure event has no transforms" + + max_translation_norm = 0.0 + max_quaternion_drift = 0.0 + for transform_index, transform in enumerate(transforms): + translation_x, translation_y, translation_z = transform["translation"] + rotation_x, rotation_y, rotation_z, rotation_w = transform["rotation"] + for value, name in [ + (translation_x, "translation_x"), + (translation_y, "translation_y"), + (translation_z, "translation_z"), + ]: + assert math.isfinite(value), ( + f"event {event_index} transform {transform_index}: {name}={value} not finite" + ) + for value, name in [ + (rotation_x, "rotation_x"), + (rotation_y, "rotation_y"), + (rotation_z, "rotation_z"), + (rotation_w, "rotation_w"), + ]: + assert math.isfinite(value), ( + f"event {event_index} transform {transform_index}: {name}={value} not finite" + ) + translation_norm = math.sqrt( + translation_x * translation_x + + translation_y * translation_y + + translation_z * translation_z + ) + assert translation_norm < TRANSLATION_MAX_M, ( + f"event {event_index} transform {transform_index}: " + f"|t|={translation_norm:.3f}m exceeds sanity cap {TRANSLATION_MAX_M}m" + ) + quaternion_norm = math.sqrt( + rotation_x * rotation_x + + rotation_y * rotation_y + + rotation_z * rotation_z + + rotation_w * rotation_w + ) + quaternion_drift = abs(quaternion_norm - 1.0) + assert quaternion_drift < QUATERNION_UNIT_TOL, ( + f"event {event_index} transform {transform_index}: " + f"|q|={quaternion_norm:.6f} drifts from unit by {quaternion_drift:.6f} " + f"(tol {QUATERNION_UNIT_TOL})" + ) + max_translation_norm = max(max_translation_norm, translation_norm) + max_quaternion_drift = max(max_quaternion_drift, quaternion_drift) + + return max_translation_norm, max_quaternion_drift + + +class TestPGOLoopClosure: + """End-to-end: PGO publishes loop_closure_event with valid SE(3) deltas.""" + + def test_loop_closure_events_published(self) -> None: + playback_blueprint = RosbagScanOdomPlaybackModule.blueprint() + # Aggressive loop-closure thresholds — bag is 60s, so we need short + # re-visit windows to actually fire events. + pgo_blueprint = PGO.blueprint( + key_pose_delta_trans=0.5, + loop_search_radius=2.0, + loop_time_thresh=5.0, + loop_score_thresh=0.5, + loop_submap_half_range=5, + submap_resolution=0.1, + min_loop_detect_duration=1.0, + global_map_voxel_size=0.1, + global_map_publish_rate=1.0, + unregister_input=True, + ) + recorder_blueprint = LoopClosureEventRecorderModule.blueprint() + + blueprint = autoconnect(playback_blueprint, pgo_blueprint, recorder_blueprint) + coordinator = ModuleCoordinator.build(blueprint) + try: + playback = coordinator.get_instance(RosbagScanOdomPlaybackModule) + recorder = coordinator.get_instance(LoopClosureEventRecorderModule) + while not playback.is_finished(): + time.sleep(POLL_INTERVAL_SEC) + time.sleep(POST_FEED_DRAIN_SEC) + events = recorder.events() + finally: + coordinator.stop() + + logger.info(f"\n[loop_closure_event] total events received: {len(events)}") + + if not events: + pytest.skip( + "rosbag trajectory didn't trigger any PGO loop closures " + "even with aggressive thresholds — this validates only the " + "publishing path's existence (verified via native log " + "lines), not the on-wire payload." + ) + + for event_index, event in enumerate(events): + max_translation_norm, max_quaternion_drift = _validate_loop_closure_event( + event, event_index + ) + logger.info( + f"[loop_closure_event] event #{event_index} VALID: " + f"transform_count={len(event['transforms'])}, " + f"max|t|={max_translation_norm:.4f}m, " + f"max|q|-1|={max_quaternion_drift:.6f}" + ) + + assert all(len(event["transforms"]) > 0 for event in events) diff --git a/dimos/navigation/nav_stack/modules/pgo/test_pgo_rosbag.py b/dimos/navigation/nav_stack/modules/pgo/test_pgo_rosbag.py index b4f04bdcb2..5baec4d2c2 100644 --- a/dimos/navigation/nav_stack/modules/pgo/test_pgo_rosbag.py +++ b/dimos/navigation/nav_stack/modules/pgo/test_pgo_rosbag.py @@ -16,41 +16,69 @@ from __future__ import annotations -from pathlib import Path -import threading import time +from typing import Any -import lcm as lcmlib import numpy as np import pytest +from reactivex.disposable import Disposable -from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.nav_stack.modules.pgo.pgo import PGO from dimos.navigation.nav_stack.tests.rosbag_fixtures import ( - LcmCollector, - NativeProcessRunner, - feed_at_original_timing, - lcm_handle_loop, + RosbagScanOdomPlaybackModule, load_rosbag_window, ) from dimos.utils.logging_config import setup_logger logger = setup_logger() -pytestmark = [pytest.mark.self_hosted] +pytestmark = [pytest.mark.self_hosted, pytest.mark.skipif_no_nix, pytest.mark.skipif_macos_bug] -_PROCESS_STARTUP_SEC = 2.0 -_POST_FEED_DRAIN_SEC = 3.0 +POST_FEED_DRAIN_SEC = 3.0 +POLL_INTERVAL_SEC = 0.25 -PGO_BIN = Path(__file__).parent / "cpp" / "result" / "bin" / "pgo" -# LCM topic names for this test (prefixed to avoid collision) -SCAN_LCM = "/rbpgo_scan#sensor_msgs.PointCloud2" -ODOM_LCM = "/rbpgo_odom#nav_msgs.Odometry" -CORRECTED_ODOM_LCM = "/rbpgo_corr_odom#nav_msgs.Odometry" -GLOBAL_MAP_LCM = "/rbpgo_global_map#sensor_msgs.PointCloud2" -TF_LCM = "/rbpgo_tf#nav_msgs.Odometry" +class PgoOutputCollectorModule(Module): + corrected_odometry: In[Odometry] + global_map: In[PointCloud2] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._corrected_positions: list[list[float]] = [] + self._global_map_point_counts: list[int] = [] + + @rpc + def start(self) -> None: + super().start() + self.register_disposable( + Disposable(self.corrected_odometry.subscribe(self._on_corrected_odometry)) + ) + self.register_disposable(Disposable(self.global_map.subscribe(self._on_global_map))) + + def _on_corrected_odometry(self, message: Odometry) -> None: + self._corrected_positions.append( + [message.pose.position.x, message.pose.position.y, message.pose.position.z] + ) + + def _on_global_map(self, message: PointCloud2) -> None: + points, _ = message.as_numpy() + if points is not None: + self._global_map_point_counts.append(len(points)) + + @rpc + def corrected_positions(self) -> list[list[float]]: + return list(self._corrected_positions) + + @rpc + def global_map_point_counts(self) -> list[int]: + return list(self._global_map_point_counts) class TestPGORosbag: @@ -63,102 +91,42 @@ def test_pgo_corrected_odometry(self) -> None: - PGO produces corrected odometry messages - Corrected odometry tracks the input trajectory (no wild divergence) - Global map is published with non-zero points - - TF corrections are published """ - if not PGO_BIN.exists(): - pytest.skip(f"PGO binary not found: {PGO_BIN}") - window = load_rosbag_window() assert len(window.scans) > 0, "No scans in rosbag fixture" assert len(window.odom) > 0, "No odometry in rosbag fixture" - lcm_instance = lcmlib.LCM() - - corrected_odom_collector = LcmCollector(topic=CORRECTED_ODOM_LCM, msg_type=Odometry) - global_map_collector = LcmCollector(topic=GLOBAL_MAP_LCM, msg_type=PointCloud2) - tf_collector = LcmCollector(topic=TF_LCM, msg_type=Odometry) - - corrected_odom_collector.start(lcm_instance) - global_map_collector.start(lcm_instance) - tf_collector.start(lcm_instance) - - stop_event = threading.Event() - handle_thread = threading.Thread( - target=lcm_handle_loop, args=(lcm_instance, stop_event), daemon=True - ) - handle_thread.start() - - runner = NativeProcessRunner( - binary_path=str(PGO_BIN), - args=[ - "--registered_scan", - SCAN_LCM, - "--odometry", - ODOM_LCM, - "--corrected_odometry", - CORRECTED_ODOM_LCM, - "--global_map", - GLOBAL_MAP_LCM, - "--pgo_tf", - TF_LCM, - # Config params matching pgo_unity_sim.yaml - "--key_pose_delta_deg", - "10.0", - "--key_pose_delta_trans", - "0.5", - "--loop_search_radius", - "1.0", - "--loop_time_thresh", - "60.0", - "--loop_score_thresh", - "0.15", - "--loop_submap_half_range", - "5", - "--submap_resolution", - "0.1", - "--min_loop_detect_duration", - "5.0", - "--global_map_voxel_size", - "0.1", - "--global_map_publish_rate", - "1.0", - "--unregister_input", - "true", - "--world_frame", - "map", - "--local_frame", - "odom", - ], + playback_blueprint = RosbagScanOdomPlaybackModule.blueprint() + # Config params matching pgo_unity_sim.yaml. + pgo_blueprint = PGO.blueprint( + key_pose_delta_trans=0.5, + loop_search_radius=1.0, + loop_time_thresh=60.0, + loop_score_thresh=0.15, + loop_submap_half_range=5, + submap_resolution=0.1, + min_loop_detect_duration=5.0, + global_map_voxel_size=0.1, + global_map_publish_rate=1.0, + unregister_input=True, ) + collector_blueprint = PgoOutputCollectorModule.blueprint() + blueprint = autoconnect(playback_blueprint, pgo_blueprint, collector_blueprint) + coordinator = ModuleCoordinator.build(blueprint) try: - runner.start(capture_stderr=True) - assert runner.is_running, "PGO binary failed to start" - time.sleep(_PROCESS_STARTUP_SEC) - - feed_at_original_timing( - lcm_instance, - window, - topic_map={ - "odom": ODOM_LCM, - "scan": SCAN_LCM, - }, - ) - - time.sleep(_POST_FEED_DRAIN_SEC) - + playback = coordinator.get_instance(RosbagScanOdomPlaybackModule) + collector = coordinator.get_instance(PgoOutputCollectorModule) + while not playback.is_finished(): + time.sleep(POLL_INTERVAL_SEC) + time.sleep(POST_FEED_DRAIN_SEC) + corrected_positions = np.array(collector.corrected_positions()) + global_map_point_counts = collector.global_map_point_counts() finally: - runner.stop() - stop_event.set() - handle_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - corrected_odom_collector.stop(lcm_instance) - global_map_collector.stop(lcm_instance) - tf_collector.stop(lcm_instance) - - # -- Analysis -- - corrected_count = len(corrected_odom_collector.messages) - global_map_count = len(global_map_collector.messages) - tf_count = len(tf_collector.messages) + coordinator.stop() + + corrected_count = len(corrected_positions) + global_map_count = len(global_map_point_counts) logger.info(f"\n{'=' * 60}") logger.info("PGO NATIVE ROSBAG DEVIATION SCORE") @@ -166,69 +134,36 @@ def test_pgo_corrected_odometry(self) -> None: logger.info(f" Input odom messages: {len(window.odom)}") logger.info(f" Corrected odom outputs: {corrected_count}") logger.info(f" Global map outputs: {global_map_count}") - logger.info(f" TF outputs: {tf_count}") - # Basic output checks assert corrected_count > 0, "PGO produced no corrected odometry" assert global_map_count > 0, "PGO produced no global map messages" - assert tf_count > 0, "PGO produced no TF messages" - - # Extract corrected trajectory - corrected_positions = np.array( - [ - [msg.pose.position.x, msg.pose.position.y, msg.pose.position.z] - for msg in corrected_odom_collector.messages - ] - ) - # Extract input trajectory (subsample to match) input_positions = window.odom[:, 1:4] # Corrected trajectory should be spatially close to input (no loop closures - # expected in 60s recording, so correction should be near-identity) + # expected in 60s recording, so correction should be near-identity). corrected_centroid = corrected_positions.mean(axis=0) input_centroid = input_positions.mean(axis=0) centroid_error = float(np.linalg.norm(corrected_centroid - input_centroid)) - # Check trajectory extent (PGO shouldn't collapse trajectory to a point) + # PGO shouldn't collapse the trajectory to a point or explode it. corrected_extent = corrected_positions.max(axis=0) - corrected_positions.min(axis=0) input_extent = input_positions.max(axis=0) - input_positions.min(axis=0) extent_ratio_xy = float( np.linalg.norm(corrected_extent[:2]) / max(np.linalg.norm(input_extent[:2]), 1e-6) ) - # Check global map point count - global_map_point_counts = [] - for msg in global_map_collector.messages: - points, _ = msg.as_numpy() - if points is not None: - global_map_point_counts.append(len(points)) - mean_map_points = ( float(np.mean(global_map_point_counts)) if global_map_point_counts else 0.0 ) last_map_points = global_map_point_counts[-1] if global_map_point_counts else 0 - # TF should be near-identity for a short recording without loop closures - last_tf = tf_collector.messages[-1] - tf_translation_norm = float( - np.linalg.norm( - [ - last_tf.pose.position.x, - last_tf.pose.position.y, - last_tf.pose.position.z, - ] - ) - ) - logger.info(f" Centroid error: {centroid_error:.3f} m") logger.info(f" Extent ratio (XY): {extent_ratio_xy:.3f}") logger.info(f" Mean global map points: {mean_map_points:.0f}") logger.info(f" Last global map points: {last_map_points}") - logger.info(f" Final TF translation: {tf_translation_norm:.4f} m") logger.info(f"{'=' * 60}\n") - # Assertions assert centroid_error < 5.0, ( f"Corrected trajectory centroid too far from input: {centroid_error:.3f} m" ) diff --git a/dimos/navigation/nav_stack/modules/pgo/test_pgo_synthetic_drift.py b/dimos/navigation/nav_stack/modules/pgo/test_pgo_synthetic_drift.py new file mode 100644 index 0000000000..328c79bf8c --- /dev/null +++ b/dimos/navigation/nav_stack/modules/pgo/test_pgo_synthetic_drift.py @@ -0,0 +1,449 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demonstrate that Scan Context catches loop closures that the +position-based search would miss. + +Setup: build a synthetic point-cloud "room", drive a virtual robot +out-and-back along a corridor, and inject a linear drift into the +reported odometry. On the return leg the robot is *physically* back at +the start (so the body-frame scan is byte-identical to the first +scan), but the reported odom pose is offset by several metres. With +``loop_search_radius=1.0m`` the position-based search cannot match +the two visits; Scan Context, which works on the appearance of the +scan rather than the pose, can. + +This test runs PGO twice with the same input via the DimOS Module + +Blueprint pipeline (no direct LCM topic strings here): + +1. ``use_scan_context=true`` → expect ≥1 loop_closure_event message. +2. ``use_scan_context=false`` → expect 0 loop_closure_event messages. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +import math +import time +from typing import Any + +import numpy as np +import pytest +from reactivex.disposable import Disposable + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.nav_stack.modules.pgo.pgo import PGO +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +pytestmark = [pytest.mark.self_hosted, pytest.mark.skipif_no_nix, pytest.mark.skipif_macos_bug] + +# Cross-trajectory drift injected at the revisit. Must be >> loop_search_radius +# so position-based search cannot accidentally find the loop. +DRIFT_AT_REVISIT_M = 5.0 + +# Loop closure thresholds passed to the binary. +LOOP_SEARCH_RADIUS_M = 1.0 +LOOP_TIME_THRESH_S = 5.0 +MIN_LOOP_DETECT_DURATION_S = 1.0 + +# Per-frame publish interval driving the synthetic playback module. +INTER_FRAME_SLEEP_SEC = 0.15 +# Drain after the playback module reports finished, so PGO can flush +# any pending loop closure events before the coordinator stops. +POST_FEED_DRAIN_SEC = 3.0 +# Poll period when waiting for the playback module to drain. +POLL_INTERVAL_SEC = 0.25 +# After the first scan goes out, wait this long for PGO to emit anything +PGO_FIRST_RESPONSE_TIMEOUT_SEC = 20.0 + + +def _make_room_points(half_size: float = 20.0, density: float = 0.15) -> np.ndarray: + """Sample points on the inside of a 4-wall square room.""" + points: list[np.ndarray] = [] + z_levels = np.arange(0.0, 3.0, density) + wall_axis = np.arange(-half_size, half_size, density) + + for wall_y in (half_size, -half_size): + grid_x, grid_z = np.meshgrid(wall_axis, z_levels) + block = np.column_stack([grid_x.ravel(), np.full(grid_x.size, wall_y), grid_z.ravel()]) + points.append(block) + for wall_x in (half_size, -half_size): + grid_y, grid_z = np.meshgrid(wall_axis, z_levels) + block = np.column_stack([np.full(grid_y.size, wall_x), grid_y.ravel(), grid_z.ravel()]) + points.append(block) + + # Distinctive interior columns so the scene isn't rotationally symmetric. + column_radius = 0.5 + for column_center_x, column_center_y in [(5.0, 0.0), (-5.0, 8.0)]: + angles = np.arange(0.0, 2.0 * math.pi, 0.2) + column_z_levels = np.arange(0.0, 3.0, density) + grid_angle, grid_z = np.meshgrid(angles, column_z_levels) + column_x = column_center_x + column_radius * np.cos(grid_angle.ravel()) + column_y = column_center_y + column_radius * np.sin(grid_angle.ravel()) + points.append(np.column_stack([column_x, column_y, grid_z.ravel()])) + + return np.concatenate(points).astype(np.float32) + + +def _make_pose(x: float, y: float, z: float, yaw: float) -> Pose: + pose = Pose() + pose.position = Vector3(x, y, z) + half_yaw = yaw * 0.5 + pose.orientation = Quaternion(0.0, 0.0, math.sin(half_yaw), math.cos(half_yaw)) + return pose + + +def _yaw_rotation(yaw: float) -> np.ndarray: + cos_yaw, sin_yaw = math.cos(yaw), math.sin(yaw) + return np.array( + [[cos_yaw, -sin_yaw, 0.0], [sin_yaw, cos_yaw, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float64, + ) + + +def _world_to_body(points_world: np.ndarray, position: np.ndarray, yaw: float) -> np.ndarray: + rotation = _yaw_rotation(yaw).T + return (points_world - position) @ rotation.T + + +def _body_to_world(points_body: np.ndarray, position: np.ndarray, yaw: float) -> np.ndarray: + rotation = _yaw_rotation(yaw) + return points_body @ rotation.T + position + + +def _trajectory_with_drift( + num_outbound: int = 20, num_inbound: int = 20, leg_length: float = 8.0 +) -> list[tuple[float, np.ndarray, float, np.ndarray, float]]: + """``(t, true_position, true_yaw, drifted_position, drifted_yaw)`` waypoints + for an out-and-back trajectory that physically returns to the start. + + The drift is purely additive in (x, y) and ramps linearly with the total + travelled distance, so by the time the robot returns to (0, 0) the reported + odom pose is offset by ``DRIFT_AT_REVISIT_M``. + """ + samples: list[tuple[float, np.ndarray, float, np.ndarray, float]] = [] + # Start at timestamp=1.0 because Odometry(ts=0.0) is treated as "now" by + # the constructor — using 0.0 would inject wall-clock time and break the + # monotonic-ts assumption in PGO's on_registered_scan. + timestamp = 1.0 + time_step = 0.5 + total_steps = num_outbound + num_inbound + for step in range(num_outbound + 1): + progress = step / max(num_outbound, 1) + x = progress * leg_length + true_position = np.array([x, 0.0, 0.5]) + yaw = 0.0 + drift_amount = (step / total_steps) * DRIFT_AT_REVISIT_M + drifted_position = true_position + np.array([0.0, drift_amount, 0.0]) + samples.append((timestamp, true_position, yaw, drifted_position, yaw)) + timestamp += time_step + for step in range(1, num_inbound + 1): + progress = step / max(num_inbound, 1) + x = leg_length * (1.0 - progress) + true_position = np.array([x, 0.0, 0.5]) + yaw = 0.0 # keep heading the same so descriptors are directly comparable + drift_amount = ((num_outbound + step) / total_steps) * DRIFT_AT_REVISIT_M + drifted_position = true_position + np.array([0.0, drift_amount, 0.0]) + samples.append((timestamp, true_position, yaw, drifted_position, yaw)) + timestamp += time_step + return samples + + +def _trajectory_reverse_loop( + num_outbound: int = 20, num_inbound: int = 20, leg_length: float = 8.0 +) -> list[tuple[float, np.ndarray, float, np.ndarray, float]]: + """Out-and-back where the robot turns 180° at the far end. + + Exercises ICP's yaw-around-source-keyframe init_guess fix in + ``simple_pgo.cpp::searchForLoopPairs``. + """ + samples: list[tuple[float, np.ndarray, float, np.ndarray, float]] = [] + timestamp = 1.0 + time_step = 0.5 + for step in range(num_outbound + 1): + progress = step / max(num_outbound, 1) + x = progress * leg_length + position = np.array([x, 0.0, 0.5]) + yaw = 0.0 + samples.append((timestamp, position, yaw, position.copy(), yaw)) + timestamp += time_step + for step in range(1, num_inbound + 1): + progress = step / max(num_inbound, 1) + x = leg_length * (1.0 - progress) + position = np.array([x, 0.0, 0.5]) + yaw = math.pi + samples.append((timestamp, position, yaw, position.copy(), yaw)) + timestamp += time_step + return samples + + +def _trajectory_payload( + trajectory: list[tuple[float, np.ndarray, float, np.ndarray, float]], +) -> list[list[float]]: + """Flatten the trajectory into a JSON-serializable matrix for ModuleConfig. + + Each row is ``[timestamp, true_x, true_y, true_z, true_yaw, + drifted_x, drifted_y, drifted_z, drifted_yaw]``. + """ + rows: list[list[float]] = [] + for timestamp, true_position, true_yaw, drifted_position, drifted_yaw in trajectory: + rows.append( + [ + float(timestamp), + float(true_position[0]), + float(true_position[1]), + float(true_position[2]), + float(true_yaw), + float(drifted_position[0]), + float(drifted_position[1]), + float(drifted_position[2]), + float(drifted_yaw), + ] + ) + return rows + + +class SyntheticDriftPlaybackConfig(ModuleConfig): + trajectory: list[list[float]] + inter_frame_sleep_sec: float = INTER_FRAME_SLEEP_SEC + pgo_first_response_timeout_sec: float = PGO_FIRST_RESPONSE_TIMEOUT_SEC + room_half_size: float = 20.0 + room_density: float = 0.15 + + +class SyntheticDriftPlaybackModule(Module): + """Publishes synthetic scans + drifted odometry from a precomputed trajectory.""" + + config: SyntheticDriftPlaybackConfig + + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + # Subscribed only so we can detect when PGO has come up and processed the + # first scan — see _run_playback's "wait for PGO ack" gate. + corrected_odometry: In[Odometry] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._frames_published: int = 0 + self._playback_finished: bool = False + self._playback_error: str | None = None + self._pgo_first_response: asyncio.Event | None = None + + async def handle_corrected_odometry(self, value: Odometry) -> None: + if self._pgo_first_response is not None: + self._pgo_first_response.set() + + async def main(self) -> AsyncIterator[None]: + self._room_points = _make_room_points(self.config.room_half_size, self.config.room_density) + # Event lives on self._loop, the same loop _run_playback and + # handle_corrected_odometry run on. + self._pgo_first_response = asyncio.Event() + self._playback_task = asyncio.create_task(self._run_playback()) + yield + self._playback_task.cancel() + + async def _run_playback(self) -> None: + # finally guarantees is_finished() flips to True even if a + # publish raises. Without it, _run_pgo's poll loop hangs and + # the coordinator leaks. + try: + assert self._pgo_first_response is not None + for frame_index, row in enumerate(self.config.trajectory): + ( + timestamp, + true_x, + true_y, + true_z, + true_yaw, + drifted_x, + drifted_y, + drifted_z, + drifted_yaw, + ) = row + true_position = np.array([true_x, true_y, true_z]) + drifted_position = np.array([drifted_x, drifted_y, drifted_z]) + body_points = _world_to_body(self._room_points, true_position, true_yaw) + world_points = _body_to_world(body_points, drifted_position, drifted_yaw) + scan_message = PointCloud2.from_numpy( + world_points.astype(np.float32), + frame_id="map", # FIXME: this should be derived from something + timestamp=timestamp, + ) + odometry_message = Odometry( + ts=timestamp, + frame_id="odom", # FIXME: this should be derived from something + child_frame_id="base_link", # FIXME: this should be derived from something + pose=_make_pose( + float(drifted_position[0]), + float(drifted_position[1]), + float(drifted_position[2]), + float(drifted_yaw), + ), + ) + self.odometry.publish(odometry_message) + self.registered_scan.publish(scan_message) + self._frames_published += 1 + if frame_index == 0: + # Wait for PGO to publish anything (corrected_odometry) + # before sending the rest of the trajectory, so we don't + # race PGO's startup. + try: + await asyncio.wait_for( + self._pgo_first_response.wait(), + timeout=self.config.pgo_first_response_timeout_sec, + ) + except asyncio.TimeoutError: + raise RuntimeError( + "PGO didn't start in time: no corrected_odometry " + f"received within {self.config.pgo_first_response_timeout_sec:.1f}s " + "of the first scan. Bump PGO_FIRST_RESPONSE_TIMEOUT_SEC " + "(top of test_pgo_synthetic_drift.py) if PGO needs longer to " + "start on this host." + ) from None + if self.config.inter_frame_sleep_sec > 0: + await asyncio.sleep(self.config.inter_frame_sleep_sec) + except Exception as exc: + self._playback_error = f"{type(exc).__name__}: {exc}" + raise + finally: + self._playback_finished = True + + @rpc + def is_finished(self) -> bool: + return self._playback_finished + + @rpc + def frames_published(self) -> int: + return self._frames_published + + +class LoopClosureEventCounterModule(Module): + """Counts loop_closure_event messages from any pose-graph SLAM module.""" + + loop_closure_event: In[GraphDelta3D] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._count: int = 0 + + @rpc + def start(self) -> None: + super().start() + self.register_disposable( + Disposable(self.loop_closure_event.subscribe(self._on_loop_closure_event)) + ) + + def _on_loop_closure_event(self, message: GraphDelta3D) -> None: + self._count += 1 + logger.info( + f"[loop_closure_event_counter] event #{self._count - 1}: " + f"node_count={len(message.nodes)}, ts={message.ts:.3f}" + ) + + @rpc + def count(self) -> int: + return self._count + + +def _run_pgo( + use_scan_context: bool, + trajectory: list[tuple[float, np.ndarray, float, np.ndarray, float]] | None = None, +) -> int: + """Build the blueprint, run the synthetic trajectory through PGO, return loop count.""" + if trajectory is None: + trajectory = _trajectory_with_drift() + + playback_blueprint = SyntheticDriftPlaybackModule.blueprint( + trajectory=_trajectory_payload(trajectory), + ) + pgo_blueprint = PGO.blueprint( + debug=True, + use_scan_context=use_scan_context, + key_pose_delta_trans=0.4, + loop_search_radius=LOOP_SEARCH_RADIUS_M, + loop_time_thresh=LOOP_TIME_THRESH_S, + loop_score_thresh=1.0, + loop_submap_half_range=5, + submap_resolution=0.1, + min_loop_detect_duration=MIN_LOOP_DETECT_DURATION_S, + global_map_voxel_size=0.1, + global_map_publish_rate=1.0, + unregister_input=True, + scan_context_max_range_m=30.0, + scan_context_match_threshold=0.6, + ) + counter_blueprint = LoopClosureEventCounterModule.blueprint() + + blueprint = autoconnect(playback_blueprint, pgo_blueprint, counter_blueprint) + coordinator = ModuleCoordinator.build(blueprint) + try: + playback = coordinator.get_instance(SyntheticDriftPlaybackModule) + counter = coordinator.get_instance(LoopClosureEventCounterModule) + while not playback.is_finished(): + time.sleep(POLL_INTERVAL_SEC) + time.sleep(POST_FEED_DRAIN_SEC) + return counter.count() + finally: + coordinator.stop() + + +class TestPGOSyntheticDrift: + """Scan Context catches the loop; position search misses it.""" + + def test_scan_context_catches_drifted_loop(self) -> None: + scan_context_events = _run_pgo(use_scan_context=True) + logger.info(f"[synthetic_drift] scan_context=true → {scan_context_events} loop events") + assert scan_context_events >= 1, ( + f"Scan Context should catch the loop at the revisit point " + f"(drift={DRIFT_AT_REVISIT_M}m). Got {scan_context_events} events." + ) + + def test_position_search_misses_drifted_loop(self) -> None: + position_search_events = _run_pgo(use_scan_context=False) + logger.info(f"[synthetic_drift] scan_context=false → {position_search_events} loop events") + assert position_search_events == 0, ( + f"Position-based search shouldn't fire when drift " + f"({DRIFT_AT_REVISIT_M}m) >> loop_search_radius " + f"({LOOP_SEARCH_RADIUS_M}m). Got {position_search_events} events." + ) + + def test_scan_context_catches_reverse_loop(self) -> None: + """Robot drives 8m east facing east, turns 180°, drives back facing west. + + Regression test for the init_guess fix in + ``simple_pgo.cpp::searchForLoopPairs``: ICP must seed the yaw rotation + about the source keyframe (not the world origin) for the rotated source + cloud to stay co-located with the target. + """ + events = _run_pgo(use_scan_context=True, trajectory=_trajectory_reverse_loop()) + logger.info(f"[reverse_loop] → {events} loop events") + assert events >= 1, ( + "Scan Context + ICP should catch the 180° reverse-heading loop. " + f"Got {events} events. This regresses the init_guess fix in " + "simple_pgo.cpp (rotation must be about the source keyframe, " + "not the world origin)." + ) diff --git a/dimos/navigation/nav_stack/modules/simple_planner/simple_planner.py b/dimos/navigation/nav_stack/modules/simple_planner/simple_planner.py index b6bb5966f9..906b2cce6a 100644 --- a/dimos/navigation/nav_stack/modules/simple_planner/simple_planner.py +++ b/dimos/navigation/nav_stack/modules/simple_planner/simple_planner.py @@ -34,9 +34,9 @@ from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs.PointStamped import PointStamped from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.DynamicCloud import DynamicCloud from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.nav_stack.frames import FRAME_BODY, FRAME_MAP, FRAME_SENSOR from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -159,14 +159,6 @@ def progress_tick( return (state, False) -def resolve_tf_chain(tf_buffer: Any, queries: list[tuple[str, str]]) -> Any: - for parent, child in queries: - tf = tf_buffer.get(parent, child) - if tf is not None: - return tf - return None - - def plan_on_costmap( costmap: Costmap, rx: float, @@ -283,9 +275,8 @@ def heuristic(c: tuple[int, int]) -> float: class SimplePlannerConfig(ModuleConfig): - world_frame: str = FRAME_MAP - body_frame: str = FRAME_BODY - sensor_frame: str = FRAME_SENSOR + frame_id: str = "map" + body_frame: str = "current_point" cell_size: float = 0.3 # m per cell obstacle_height_threshold: float = 0.15 # m above ground @@ -322,7 +313,7 @@ class SimplePlanner(Module): config: SimplePlannerConfig - terrain_map_ext: In[PointCloud2] + global_map: In[DynamicCloud] terrain_map: In[PointCloud2] goal: In[PointStamped] stop_movement: In[Bool] @@ -375,9 +366,7 @@ def start(self) -> None: super().start() self.register_disposable(Disposable(self.goal.subscribe(self._on_goal))) self.register_disposable(Disposable(self.stop_movement.subscribe(self._on_stop_movement))) - self.register_disposable( - Disposable(self.terrain_map_ext.subscribe(self._on_terrain_map_ext)) - ) + self.register_disposable(Disposable(self.global_map.subscribe(self._on_global_map))) self.register_disposable(Disposable(self.terrain_map.subscribe(self._on_terrain_map))) self._running = True self._thread = threading.Thread(target=self._planning_loop, daemon=True) @@ -397,26 +386,9 @@ def stop(self) -> None: self._thread = None super().stop() - @property - def _tf_pose_queries(self) -> list[tuple[str, str]]: - """Ordered (parent, child) TF lookups for the robot pose. - The first successful lookup wins. ``sensor`` is used by the Unity sim bridge.""" - return [ - (self.config.world_frame, self.config.body_frame), - (self.config.world_frame, self.config.sensor_frame), - ] - def _query_pose(self) -> bool: - """Update cached robot position from the TF tree. - - Tries several ``(parent, child)`` pairs in priority order so the - planner works both on real hardware (``map → body`` via PGO + - FastLio2) and in simulation (``map → sensor`` from the Unity - bridge). - - Returns True if a pose was obtained from any chain. - """ - tf = resolve_tf_chain(self.tf, list(self._tf_pose_queries)) + """Update cached robot position from the TF tree.""" + tf = self.tf.get(self.config.frame_id, self.config.body_frame) if tf is None: now = time.monotonic() if now - self._last_tf_warn > _TF_WARN_THROTTLE: @@ -424,7 +396,7 @@ def _query_pose(self) -> bool: buffers = list(self.tf.buffers.keys()) if hasattr(self.tf, "buffers") else [] logger.warning( "TF lookup failed — no robot pose available", - tried=[(p, c) for p, c in self._tf_pose_queries], + tried=(self.config.frame_id, self.config.body_frame), available_frames=buffers, ) return False @@ -454,18 +426,18 @@ def _cancel_navigation(self, source: str) -> None: rx, ry, rz = self._robot_x, self._robot_y, self._robot_z now = time.time() self.way_point.publish( - PointStamped(ts=now, frame_id=self.config.world_frame, x=rx, y=ry, z=rz) + PointStamped(ts=now, frame_id=self.config.frame_id, x=rx, y=ry, z=rz) ) # Single-pose path at the robot — explicitly distinguishes "cancelled, # holding position" from "no goal_path message yet" in the viewer. self.goal_path.publish( Path( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, poses=[ PoseStamped( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, position=[rx, ry, rz], orientation=[0.0, 0.0, 0.0, 1.0], ) @@ -547,16 +519,17 @@ def _fresh_costmap(self) -> Costmap: inflation_radius=self.config.inflation_radius, ) - def _on_terrain_map_ext(self, msg: PointCloud2) -> None: - """Rebuild the costmap from scratch using the persistent world view. + def _on_global_map(self, msg: DynamicCloud) -> None: + """Rebuild the costmap from scratch using the persistent voxel world. - ``terrain_map_ext`` applies a decay window (8 s by default) on - the producer side, so each message represents the current world - state. Resetting here prevents stale obstacles from piling up - forever. + ``global_map`` is the DynamicCloud produced by the ray-tracing voxel + map (post-closure-correction via ApplyClosure). Each message is + authoritative for the current world state, so the costmap is rebuilt + from scratch rather than accumulated — that way obstacles cleared by + raycasting on the producer side don't linger here. """ - points, _ = msg.as_numpy() - if points is None or len(points) == 0: + points = msg.world_positions() + if len(points) == 0: return new_cm = self._fresh_costmap() self._classify_points(points, new_cm) @@ -566,8 +539,8 @@ def _on_terrain_map_ext(self, msg: PointCloud2) -> None: def _on_terrain_map(self, msg: PointCloud2) -> None: """Layer fresh local terrain on top of the current costmap. - ``terrain_map`` is faster than ``terrain_map_ext`` so dynamic obstacles - appear here first; additions are wiped on the next ``terrain_map_ext`` rebuild. + ``terrain_map`` is faster than ``global_map`` so dynamic obstacles + appear here first; additions are wiped on the next ``global_map`` rebuild. """ points, _ = msg.as_numpy() if points is None or len(points) == 0: @@ -623,7 +596,7 @@ def _update_waypoint(self) -> None: self._current_wp_is_goal = is_goal now = time.time() self.way_point.publish( - PointStamped(ts=now, frame_id=self.config.world_frame, x=wx, y=wy, z=gz) + PointStamped(ts=now, frame_id=self.config.frame_id, x=wx, y=wy, z=gz) ) def _publish_costmap_cloud(self, rz: float, now: float) -> None: @@ -662,7 +635,7 @@ def _publish_costmap_cloud(self, rz: float, now: float) -> None: pcd_t.point["positions"] = o3c.Tensor(pts, dtype=o3c.float32) pcd_t.point["colors"] = o3c.Tensor(colors, dtype=o3c.float32) self.costmap_cloud.publish( - PointCloud2(pointcloud=pcd_t, ts=now, frame_id=self.config.world_frame) + PointCloud2(pointcloud=pcd_t, ts=now, frame_id=self.config.frame_id) ) def _replan_once(self) -> None: @@ -744,22 +717,22 @@ def _replan_once(self) -> None: self._current_wp = None self._current_wp_is_goal = False self.way_point.publish( - PointStamped(ts=now, frame_id=self.config.world_frame, x=rx, y=ry, z=rz) + PointStamped(ts=now, frame_id=self.config.frame_id, x=rx, y=ry, z=rz) ) self.goal_path.publish( Path( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, poses=[ PoseStamped( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, position=[rx, ry, rz], orientation=[0.0, 0.0, 0.0, 1.0], ), PoseStamped( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, position=[gx, gy, gz], orientation=[0.0, 0.0, 0.0, 1.0], ), @@ -778,12 +751,12 @@ def _replan_once(self) -> None: poses.append( PoseStamped( ts=now, - frame_id=self.config.world_frame, + frame_id=self.config.frame_id, position=[wx, wy, rz], orientation=[0.0, 0.0, 0.0, 1.0], ) ) - self.goal_path.publish(Path(ts=now, frame_id=self.config.world_frame, poses=poses)) + self.goal_path.publish(Path(ts=now, frame_id=self.config.frame_id, poses=poses)) # 1 Hz diagnostic: cells in costmap, path length if now - self._last_diag_print >= 1.0: diff --git a/dimos/navigation/nav_stack/modules/tare_planner/tare_planner.py b/dimos/navigation/nav_stack/modules/tare_planner/tare_planner.py index 40e2f2db58..b090e8bfef 100644 --- a/dimos/navigation/nav_stack/modules/tare_planner/tare_planner.py +++ b/dimos/navigation/nav_stack/modules/tare_planner/tare_planner.py @@ -28,9 +28,9 @@ class TarePlannerConfig(NativeModuleConfig): cwd: str | None = "." executable: str = "result/bin/tare_planner" build_command: str | None = ( - "nix build github:dimensionalOS/dimos-module-tare-planner/v0.1.0 --no-write-lock-file" + "nix build github:dimensionalOS/dimos-module-tare-planner/feat/dimos-native-ready" + " --no-write-lock-file" ) - # Exploration parameters exploration_range: float = 20.0 update_rate: float = 1.0 diff --git a/dimos/navigation/nav_stack/modules/terrain_analysis/terrain_analysis.py b/dimos/navigation/nav_stack/modules/terrain_analysis/terrain_analysis.py index ca27a6f309..c99f4730a0 100644 --- a/dimos/navigation/nav_stack/modules/terrain_analysis/terrain_analysis.py +++ b/dimos/navigation/nav_stack/modules/terrain_analysis/terrain_analysis.py @@ -25,7 +25,8 @@ class TerrainAnalysisConfig(NativeModuleConfig): cwd: str | None = "." executable: str = "result/bin/terrain_analysis" build_command: str | None = ( - "nix build github:dimensionalOS/dimos-module-terrain-analysis/v0.1.1 --no-write-lock-file" + "nix build github:dimensionalOS/dimos-module-terrain-analysis/feat/dimos-native-ready" + " --no-write-lock-file" ) cli_name_override: dict[str, str] = { "sensor_range": "sensorRange", diff --git a/dimos/navigation/nav_stack/modules/terrain_analysis/test_terrain_analysis_rosbag.py b/dimos/navigation/nav_stack/modules/terrain_analysis/test_terrain_analysis_rosbag.py index 223ebf8ee9..5d89b4c1e5 100644 --- a/dimos/navigation/nav_stack/modules/terrain_analysis/test_terrain_analysis_rosbag.py +++ b/dimos/navigation/nav_stack/modules/terrain_analysis/test_terrain_analysis_rosbag.py @@ -62,7 +62,7 @@ def test_terrain_map_accuracy(self) -> None: assert len(ref_tmaps) > 0, "No reference terrain maps in fixture" lcm = lcmlib.LCM() - terrain_collector = LcmCollector(topic=TERRAIN_OUT_LCM, msg_type=PointCloud2) + terrain_collector = LcmCollector(topic=TERRAIN_OUT_LCM, message_type=PointCloud2) terrain_collector.start(lcm) stop_event = threading.Event() diff --git a/dimos/navigation/nav_stack/modules/terrain_map_ext/terrain_map_ext.py b/dimos/navigation/nav_stack/modules/terrain_map_ext/terrain_map_ext.py index 40816dbedb..29a4e37a85 100644 --- a/dimos/navigation/nav_stack/modules/terrain_map_ext/terrain_map_ext.py +++ b/dimos/navigation/nav_stack/modules/terrain_map_ext/terrain_map_ext.py @@ -29,7 +29,6 @@ from dimos.core.stream import In, Out from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.nav_stack.frames import FRAME_MAP from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -45,7 +44,7 @@ class TerrainMapExtConfig(ModuleConfig): - world_frame: str = FRAME_MAP + frame_id: str = "map" # Scan voxel size for downsampling (PCL VoxelGrid leaf size equivalent) scan_voxel_size: float = 0.1 @@ -458,7 +457,7 @@ def _process_loop(self) -> None: self.terrain_map_ext.publish( PointCloud2.from_numpy( output_array[:, :3], - frame_id=config.world_frame, + frame_id=config.frame_id, timestamp=laser_cloud_time, intensities=output_array[:, 3], ) diff --git a/dimos/navigation/nav_stack/specs.py b/dimos/navigation/nav_stack/specs.py new file mode 100644 index 0000000000..e76e354809 --- /dev/null +++ b/dimos/navigation/nav_stack/specs.py @@ -0,0 +1,36 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spec protocols for nav-stack producer/consumer pairs.""" + +from typing import TYPE_CHECKING, Any, Protocol + +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Graph3D import Graph3D +from dimos.msgs.nav_msgs.GraphDelta3D import GraphDelta3D +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +if TYPE_CHECKING: + from dimos.core.coordination.blueprints import Blueprint + + +class LoopClosure(Protocol): + registered_scan: In[PointCloud2] + odometry: In[Odometry] + loop_closure_event: Out[GraphDelta3D] + pose_graph: Out[Graph3D] + + @classmethod + def blueprint(cls, **kwargs: Any) -> "Blueprint": ... diff --git a/dimos/navigation/nav_stack/tests/rosbag_fixtures.py b/dimos/navigation/nav_stack/tests/rosbag_fixtures.py index e231196922..39c3c0e7f4 100644 --- a/dimos/navigation/nav_stack/tests/rosbag_fixtures.py +++ b/dimos/navigation/nav_stack/tests/rosbag_fixtures.py @@ -21,7 +21,11 @@ from __future__ import annotations +import asyncio +from collections.abc import AsyncIterator from dataclasses import dataclass, field +import itertools +import os from pathlib import Path import subprocess import threading @@ -32,6 +36,9 @@ import numpy as np import pytest +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out from dimos.msgs.geometry_msgs.PointStamped import PointStamped from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped @@ -135,28 +142,28 @@ class LcmCollector: """Subscribes to an LCM topic and collects decoded messages with timestamps.""" topic: str - msg_type: type + message_type: type messages: list[Any] = field(default_factory=list) timestamps: list[float] = field(default_factory=list) - _sub: Any = field(default=None, repr=False) + _subscription: Any = field(default=None, repr=False) def start(self, lcm: lcmlib.LCM) -> None: - msg_cls = self.msg_type + message_class = self.message_type def handler(_channel: str, data: bytes) -> None: try: - msg = msg_cls.lcm_decode(data) # type: ignore[attr-defined] - self.messages.append(msg) + message = message_class.lcm_decode(data) # type: ignore[attr-defined] + self.messages.append(message) self.timestamps.append(time.monotonic()) except Exception as exc: logger.error(f"LcmCollector decode error on {self.topic}: {exc}") - self._sub = lcm.subscribe(self.topic, handler) + self._subscription = lcm.subscribe(self.topic, handler) def stop(self, lcm: lcmlib.LCM) -> None: - if self._sub is not None: - lcm.unsubscribe(self._sub) - self._sub = None + if self._subscription is not None: + lcm.unsubscribe(self._subscription) + self._subscription = None def lcm_handle_loop(lcm: lcmlib.LCM, stop_event: threading.Event, timeout_ms: int = 50) -> None: @@ -165,6 +172,25 @@ def lcm_handle_loop(lcm: lcmlib.LCM, stop_event: threading.Event, timeout_ms: in lcm.handle_timeout(timeout_ms) +_isolated_lcm_url_counter = itertools.count() + + +def make_isolated_lcm_url() -> str: + """Return an LCM URL that should not collide with concurrent runs. + + Uses the standard multicast group with TTL=0 (so traffic never escapes + the local host) and a port picked from the high ephemeral range. The + monotonic counter guarantees back-to-back calls in the same process get + distinct ports; mixing in the PID disjoint-ifies concurrent CI workers. + """ + # ports 49152..65535 are the IANA "dynamic/private" range + span = 16000 + pid_offset = os.getpid() % span + call_offset = next(_isolated_lcm_url_counter) + port = 49152 + ((pid_offset + call_offset) % span) + return f"udpm://239.255.76.67:{port}?ttl=0" + + @dataclass class NativeProcessRunner: """Start and manage a native module C++ process for testing.""" @@ -173,12 +199,22 @@ class NativeProcessRunner: args: list[str] process: subprocess.Popen[bytes] | None = field(default=None, repr=False) - def start(self, capture_stderr: bool = False) -> None: + def start( + self, + capture_stderr: bool = False, + env: dict[str, str] | None = None, + ) -> None: + process_env: dict[str, str] | None + if env is None: + process_env = None + else: + process_env = {**os.environ, **env} self.process = subprocess.Popen( [self.binary_path, *self.args], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE if capture_stderr else subprocess.DEVNULL, start_new_session=True, + env=process_env, ) def stop(self, timeout: float = 3.0) -> None: @@ -277,3 +313,85 @@ def feed_at_original_timing( if target_offset > elapsed: time.sleep(target_offset - elapsed) lcm.publish(topic, msg.lcm_encode()) + + +PLAYBACK_STARTUP_DELAY_SEC = 2.0 + + +class RosbagScanOdomPlaybackConfig(ModuleConfig): + rosbag_path: str | None = None + odom_subsample: int = 4 + startup_delay_sec: float = PLAYBACK_STARTUP_DELAY_SEC + + +class RosbagScanOdomPlaybackModule(Module): + """Replays scan + odom from a rosbag fixture at original inter-message timing.""" + + config: RosbagScanOdomPlaybackConfig + + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._frames_published: int = 0 + self._playback_finished: bool = False + self._playback_error: str | None = None + + async def main(self) -> AsyncIterator[None]: + rosbag_path = Path(self.config.rosbag_path) if self.config.rosbag_path else None + self._window = load_rosbag_window(rosbag_path) + self._playback_task = asyncio.create_task(self._run_playback()) + yield + self._playback_task.cancel() + + async def _run_playback(self) -> None: + # finally guarantees is_finished() flips to True even if a frame + # raises. Without it, callers polling is_finished() would spin + # forever and the coordinator would never tear down. + try: + if self.config.startup_delay_sec > 0: + await asyncio.sleep(self.config.startup_delay_sec) + timeline: list[tuple[str, float, Any]] = [] + for odom_index in range(0, len(self._window.odom), self.config.odom_subsample): + row = self._window.odom[odom_index] + timeline.append( + ("odom", float(row[0]), make_odometry_msg(row[1:4], row[4:8], ts=row[0])) + ) + for timestamp, points in self._window.scans: + timeline.append( + ("scan", float(timestamp), make_pointcloud_msg(points, ts=timestamp)) + ) + timeline.sort(key=lambda entry: entry[1]) + if not timeline: + return + + bag_start_time = timeline[0][1] + wallclock_start = time.monotonic() + for kind, bag_timestamp, message in timeline: + target_wallclock = wallclock_start + (bag_timestamp - bag_start_time) + now = time.monotonic() + if target_wallclock > now: + await asyncio.sleep(target_wallclock - now) + if kind == "odom": + self.odometry.publish(message) + else: + self.registered_scan.publish(message) + self._frames_published += 1 + except Exception as exc: + self._playback_error = f"{type(exc).__name__}: {exc}" + raise + finally: + self._playback_finished = True + + @rpc + def is_finished(self) -> bool: + return self._playback_finished + + @rpc + def playback_error(self) -> str | None: + return self._playback_error + + @rpc + def frames_published(self) -> int: + return self._frames_published diff --git a/dimos/navigation/replanning_a_star/min_cost_astar.py b/dimos/navigation/replanning_a_star/min_cost_astar.py index 025045c2c9..fb240eff68 100644 --- a/dimos/navigation/replanning_a_star/min_cost_astar.py +++ b/dimos/navigation/replanning_a_star/min_cost_astar.py @@ -67,12 +67,13 @@ def _reconstruct_path( costmap: OccupancyGrid, start_tuple: tuple[int, int], goal_tuple: tuple[int, int], + frame_id: str, ) -> Path: waypoints: list[PoseStamped] = [] while current in parents: world_point = costmap.grid_to_world(current) pose = PoseStamped( - frame_id="world", + frame_id=frame_id, position=[world_point.x, world_point.y, 0.0], orientation=Quaternion(0, 0, 0, 1), # Identity quaternion ) @@ -81,7 +82,7 @@ def _reconstruct_path( start_world_point = costmap.grid_to_world(start_tuple) start_pose = PoseStamped( - frame_id="world", + frame_id=frame_id, position=[start_world_point.x, start_world_point.y, 0.0], orientation=Quaternion(0, 0, 0, 1), ) @@ -97,31 +98,32 @@ def _reconstruct_path( or (waypoints[-1].x - goal_point.x) ** 2 + (waypoints[-1].y - goal_point.y) ** 2 > 1e-10 ): goal_pose = PoseStamped( - frame_id="world", + frame_id=frame_id, position=[goal_point.x, goal_point.y, 0.0], orientation=Quaternion(0, 0, 0, 1), ) waypoints.append(goal_pose) - return Path(frame_id="world", poses=waypoints) + return Path(frame_id=frame_id, poses=waypoints) def _reconstruct_path_from_coords( path_coords: list[tuple[int, int]], costmap: OccupancyGrid, + frame_id: str, ) -> Path: waypoints: list[PoseStamped] = [] for gx, gy in path_coords: world_point = costmap.grid_to_world((gx, gy)) pose = PoseStamped( - frame_id="world", + frame_id=frame_id, position=[world_point.x, world_point.y, 0.0], orientation=Quaternion(0, 0, 0, 1), ) waypoints.append(pose) - return Path(frame_id="world", poses=waypoints) + return Path(frame_id=frame_id, poses=waypoints) def min_cost_astar( @@ -131,6 +133,7 @@ def min_cost_astar( cost_threshold: int = 100, unknown_penalty: float = 0.8, use_cpp: bool = True, + frame_id: str = "world", ) -> Path | None: start_vector = costmap.world_to_grid(start) goal_vector = costmap.world_to_grid(goal) @@ -154,7 +157,7 @@ def min_cost_astar( ) if not path_coords: return None - return _reconstruct_path_from_coords(path_coords, costmap) + return _reconstruct_path_from_coords(path_coords, costmap, frame_id) else: logger.warning( "C++ A* module could not be imported (%s). Using Python.", @@ -183,7 +186,7 @@ def min_cost_astar( continue if current == goal_tuple: - return _reconstruct_path(parents, current, costmap, start_tuple, goal_tuple) + return _reconstruct_path(parents, current, costmap, start_tuple, goal_tuple, frame_id) closed_set.add(current) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 6e25af7704..c9110c3b07 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -155,6 +155,14 @@ def get_transform( return None def get(self, *args, **kwargs) -> Transform | None: # type: ignore[no-untyped-def] + parent_frame = args[0] if args else kwargs.get("parent_frame") + child_frame = args[1] if len(args) > 1 else kwargs.get("child_frame") + if parent_frame is not None and parent_frame == child_frame: + raise ValueError( + f"tf.get() called with same parent and child frame {parent_frame!r}; " + "this is almost always a caller bug — the data is already in that frame" + ) + simple = self.get_transform(*args, **kwargs) if simple is not None: return simple diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 73c4930e9e..8da397ae2d 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -65,6 +65,11 @@ "keyboard-teleop-xarm7": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm7", "mid360": "dimos.hardware.sensors.lidar.livox.livox_blueprints:mid360", "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", + "mid360-fastlio-memory": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_memory", + "mid360-fastlio-ray-trace": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_ray_trace", + "mid360-fastlio-ray-trace-replay": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_ray_trace_replay", + "mid360-fastlio-replay": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_replay", + "mid360-fastlio-replay-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_replay_voxels", "mid360-fastlio-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels", "mid360-fastlio-voxels-native": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels_native", "openarm-mock-planner-coordinator": "dimos.robot.manipulators.openarm.blueprints:openarm_mock_planner_coordinator", @@ -101,6 +106,7 @@ "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", "unitree-go2-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_keyboard_teleop:unitree_go2_keyboard_teleop", "unitree-go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2:unitree_go2_memory", + "unitree-go2-nav": "dimos.robot.unitree.go2.blueprints.navigation.unitree_go2_nav:unitree_go2_nav", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-security": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_security:unitree_go2_security", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", @@ -121,6 +127,7 @@ all_modules = { "alfred-high-level": "dimos.robot.diy.alfred.effector_high_level.AlfredHighLevel", + "apply-closure": "dimos.navigation.nav_stack.modules.apply_closure.apply_closure.ApplyClosure", "arm-teleop-module": "dimos.teleop.quest.quest_extensions.ArmTeleopModule", "b-box-navigation-module": "dimos.navigation.bbox_navigation.BBoxNavigationModule", "b1-connection-module": "dimos.robot.unitree.b1.connection.B1ConnectionModule", @@ -140,6 +147,8 @@ "emitter-module": "dimos.utils.demo_image_encoding.EmitterModule", "far-planner": "dimos.navigation.nav_stack.modules.far_planner.far_planner.FarPlanner", "fast-lio2": "dimos.hardware.sensors.lidar.fastlio2.module.FastLio2", + "fastlio-memory": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints.FastlioMemory", + "fastlio-replay": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints.FastlioReplay", "foxglove-bridge": "dimos.robot.foxglove_bridge.FoxgloveBridge", "g1-connection": "dimos.robot.unitree.g1.connection.G1Connection", "g1-connection-base": "dimos.robot.unitree.g1.connection.G1ConnectionBase", @@ -159,6 +168,7 @@ "joystick-module": "dimos.robot.unitree.b1.joystick_module.JoystickModule", "keyboard-teleop": "dimos.robot.unitree.keyboard_teleop.KeyboardTeleop", "keyboard-teleop-module": "dimos.teleop.keyboard.keyboard_teleop_module.KeyboardTeleopModule", + "kitti360-playback-module": "dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.playback.Kitti360PlaybackModule", "local-planner": "dimos.navigation.nav_stack.modules.local_planner.local_planner.LocalPlanner", "manipulation-module": "dimos.manipulation.manipulation_module.ManipulationModule", "map": "dimos.robot.unitree.type.map.Map", @@ -187,7 +197,9 @@ "pgo": "dimos.navigation.nav_stack.modules.pgo.pgo.PGO", "phone-teleop-module": "dimos.teleop.phone.phone_teleop_module.PhoneTeleopModule", "pick-and-place-module": "dimos.manipulation.pick_and_place_module.PickAndPlaceModule", + "pose-graph-scoring-module": "dimos.navigation.nav_stack.benchmarks.pose_graph_kitti360.scoring.PoseGraphScoringModule", "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module.QuestTeleopModule", + "ray-tracing-voxel-map": "dimos.mapping.ray_tracing.module.RayTracingVoxelMap", "real-sense-camera": "dimos.hardware.sensors.camera.realsense.camera.RealSenseCamera", "receiver-module": "dimos.utils.demo_image_encoding.ReceiverModule", "recorder": "dimos.memory2.module.Recorder", @@ -195,6 +207,7 @@ "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", + "rosbag-scan-odom-playback-module": "dimos.navigation.nav_stack.tests.rosbag_fixtures.RosbagScanOdomPlaybackModule", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", "semantic-search": "dimos.memory2.module.SemanticSearch", "simple-phone-teleop": "dimos.teleop.phone.phone_extensions.SimplePhoneTeleop", @@ -205,6 +218,7 @@ "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory.TemporalMemory", "terrain-analysis": "dimos.navigation.nav_stack.modules.terrain_analysis.terrain_analysis.TerrainAnalysis", "terrain-map-ext": "dimos.navigation.nav_stack.modules.terrain_map_ext.terrain_map_ext.TerrainMapExt", + "topic-counter-module": "dimos.navigation.nav_stack.modules.pgo.benchmark_kitti360_smoke.TopicCounterModule", "twist-teleop-module": "dimos.teleop.quest.quest_extensions.TwistTeleopModule", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container.UnitreeG1SkillContainer", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container.UnitreeSkillContainer", diff --git a/dimos/robot/diy/alfred/blueprints/alfred_nav.py b/dimos/robot/diy/alfred/blueprints/alfred_nav.py index 91a8db9bec..2a474b5bc8 100644 --- a/dimos/robot/diy/alfred/blueprints/alfred_nav.py +++ b/dimos/robot/diy/alfred/blueprints/alfred_nav.py @@ -71,9 +71,7 @@ ) .remappings( [ - # nav stack needs "registered_scan" - (FastLio2, "lidar", "registered_scan"), - (FastLio2, "global_map", "global_map_fastlio"), + (FastLio2, "global_map", "_fastlio_global_map"), # SimplePlanner / FarPlanner owns way_point — disconnect MovementManager's (MovementManager, "way_point", "_mgr_way_point_unused"), ] diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py index d64b2c8aa0..9c13419240 100644 --- a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py @@ -80,9 +80,7 @@ ) .remappings( [ - # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" - (FastLio2, "lidar", "registered_scan"), - (FastLio2, "global_map", "global_map_fastlio"), + (FastLio2, "global_map", "_fastlio_global_map"), # Planner owns way_point — disconnect MovementManager's click relay (MovementManager, "way_point", "_mgr_way_point_unused"), ] diff --git a/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav.py b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav.py new file mode 100644 index 0000000000..d11ff3ffa5 --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav.py @@ -0,0 +1,87 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.movement_manager.movement_manager import MovementManager +from dimos.navigation.nav_stack.main import create_nav_stack, nav_stack_rerun_config +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator +from dimos.robot.unitree.go2.config import GO2, GO2_LOCAL_PLANNER_PRECOMPUTED_PATHS +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.visualization.vis_module import vis_module + +nav_config: dict[str, Any] = dict( + planner="simple", + vehicle_height=GO2.height_clearance, + max_speed=0.8, + terrain_analysis={ + "obstacle_height_threshold": 0.15, + "ground_height_threshold": 0.10, + "sensor_range": 20, + }, + local_planner={ + "paths_dir": str(GO2_LOCAL_PLANNER_PRECOMPUTED_PATHS), + "publish_free_paths": False, + }, + simple_planner={ + "cell_size": 0.2, + "obstacle_height_threshold": 0.15, + "inflation_radius": 0.3, + "lookahead_distance": 2.0, + "replan_rate": 5.0, + "replan_cooldown": 2.0, + }, +) + +unitree_go2_nav = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.18"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + mount=GO2.internal_odom_offsets["mid360_link"], + map_freq=1.0, + config="default.yaml", + ), + create_nav_stack(**nav_config), + MovementManager.blueprint(), + GO2Connection.blueprint(), + vis_module( + global_config.viewer, + rerun_config={ + **nav_stack_rerun_config({"memory_limit": "1GB"}, vis_throttle=0.5), + "rerun_open": "native", + }, + ), + ) + .remappings( + [ + (FastLio2, "global_map", "_fastlio_global_map"), + # disambiguate lidar + (GO2Connection, "lidar", "_go2_onboard_lidar"), + # SimplePlanner / FarPlanner own way_point — disconnect MovementManager's + # click-relay so it doesn't fight the planner. + (MovementManager, "way_point", "_mgr_way_point_unused"), + ] + ) + .configurators(ClockSyncConfigurator()) + .global_config(n_workers=8, robot_model="unitree_go2") +) + +__all__ = ["nav_config", "unitree_go2_nav"] diff --git a/dimos/robot/unitree/go2/config.py b/dimos/robot/unitree/go2/config.py new file mode 100644 index 0000000000..a6e3811d70 --- /dev/null +++ b/dimos/robot/unitree/go2/config.py @@ -0,0 +1,44 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Go2 physical description and sensor odometry offsets.""" + +from __future__ import annotations + +from pathlib import Path + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.robot.config import RobotConfig +from dimos.robot.unitree.g1.config import G1_LOCAL_PLANNER_PRECOMPUTED_PATHS + +# Reuse G1's precomputed local-planner paths until a Go2-specific set is +# generated. The Go2's turning radius is tighter than the G1's, so a future +# regen via CMU's pathGenerator would yield smoother trajectories (especially +# in rage mode), but G1's set is workable as a starting point. +GO2_LOCAL_PLANNER_PRECOMPUTED_PATHS = G1_LOCAL_PLANNER_PRECOMPUTED_PATHS + +GO2 = RobotConfig( + name="unitree_go2", + model_path=Path(__file__).parent / "go2.urdf", + # base_link box from go2.urdf is 0.70 (length) x 0.31 (width) x 0.40 (height). + height_clearance=0.5, + width_clearance=0.4, + internal_odom_offsets={ + # Mid-360 lidar mounted on top of the Go2's back, centered laterally, + # ~0.4 m above the floor. + "mid360_link": Pose(0.0, 0.0, 0.4, *Quaternion.from_euler(Vector3(0, 0, 0))), + }, +) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 80f735ad70..70b7a64c03 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -169,6 +169,10 @@ class UnityBridgeConfig(ModuleConfig): # Extra CLI args to pass to the Unity binary. unity_extra_args: list[str] = Field(default_factory=list) + frame_id: str = "map" + child_frame_id: str = "current_point" + parent_frame: str = "world" + # Vehicle parameters vehicle_height: float = 0.75 @@ -233,7 +237,7 @@ class UnityBridgeModule(Module): cmd_vel (In[Twist]): Velocity commands. terrain_map (In[PointCloud2]): Terrain for Z adjustment. odometry (Out[Odometry]): Vehicle state at sim_rate. - registered_scan (Out[PointCloud2]): Lidar from Unity. + lidar (Out[PointCloud2]): Lidar from Unity. color_image (Out[Image]): RGB camera from Unity (1920x640 panoramic). semantic_image (Out[Image]): Semantic segmentation from Unity. camera_info (Out[CameraInfo]): Camera intrinsics. @@ -244,7 +248,7 @@ class UnityBridgeModule(Module): cmd_vel: In[Twist] terrain_map: In[PointCloud2] odometry: Out[Odometry] - registered_scan: Out[PointCloud2] + lidar: Out[PointCloud2] color_image: Out[Image] semantic_image: Out[Image] camera_info: Out[CameraInfo] @@ -663,12 +667,12 @@ def _handle_syscommand(self, dest: str, data: bytes) -> None: self._send_queue.put(("__raw__", frame)) def _handle_unity_message(self, topic: str, data: bytes) -> None: - if topic == "/registered_scan": + if topic == "/lidar": pc_result = deserialize_pointcloud2(data) if pc_result is not None: points, frame_id, ts = pc_result if len(points) > 0: - self.registered_scan.publish( + self.lidar.publish( PointCloud2.from_numpy(points, frame_id=frame_id, timestamp=ts) ) @@ -764,8 +768,8 @@ def _sim_step(self, dt: float) -> None: self.odometry.publish( Odometry( ts=now, - frame_id="map", - child_frame_id="sensor", + frame_id=self.config.frame_id, + child_frame_id=self.config.child_frame_id, pose=Pose( position=[odom_x, odom_y, z], orientation=[quat.x, quat.y, quat.z, quat.w], @@ -785,15 +789,15 @@ def _sim_step(self, dt: float) -> None: Transform( translation=Vector3(x, y, z), rotation=quat, - frame_id="map", - child_frame_id="sensor", + frame_id=self.config.frame_id, + child_frame_id=self.config.child_frame_id, ts=now, ), Transform( translation=Vector3(0.0, 0.0, 0.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="map", - child_frame_id="world", + frame_id=self.config.parent_frame, + child_frame_id=self.config.frame_id, ts=now, ), ) diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 5bea65afcb..57a4d10c63 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -251,6 +251,13 @@ def final_convert(msg: Any) -> RerunData | None: return msg if is_rerun_multi(msg): return msg + # Prefer to_rerun_multi when available — it lets the message + # render itself across nodes+edges sub-paths (or any other + # compound shape) without forcing every consumer to wire up + # an explicit visual_override. + to_rerun_multi = getattr(msg, "to_rerun_multi", None) + if callable(to_rerun_multi): + return cast("RerunData | None", to_rerun_multi(base_path=entity_path)) if isinstance(msg, RerunConvertible): return msg.to_rerun() return None diff --git a/docs/development/conventions.md b/docs/development/conventions.md index 2b25a7c3c6..577cfdaf52 100644 --- a/docs/development/conventions.md +++ b/docs/development/conventions.md @@ -1,5 +1,6 @@ This mostly to track when conventions change (with regard to codebase updates) because this codebase is under heavy development. Note: this is a non-exhaustive list of conventions. +- Instead of using threading (especially for tests) try to use async in modules - Instead of using `RerunBridge` in blueprints we always use `vis_module` which allows the CLI to control if its foxglove, rerun, or no-vis at all - When global_config.py shouldn't accidentally/indirectly import heavy libraries like rerun. But sometimes global_config needs the type definition or default value from a module. Preferably we import from the module file directly, however when thats not possible, we create a config.py for just that module's config and import that into global_config.py. - When adding visualization tools to a blueprint/autoconnect, instead of using RerunBridge or WebsocketVisModule directly we should always use `vis_module`, which right now should look something like `vis_module(viewer_backend=global_config.viewer, rerun_config={}),`