diff --git a/README.md b/README.md index 467db66..4482b88 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,14 @@ It has simple controlls to move the device around, while also displaying a live The lightweight Python API handles all serial communication and provides convenient command execution and debug message printing. The interface includes functions to home, move, and calibrate the device, as well as to query device information. -Simply copy the [open_micro_stage_api.py](software/PythonAPI/open_micro_stage_api.py) file into your project (also install the dependencies in requirements.txt), and you’re ready to get started. + +This library has not been published to pypi yet but it can be installed by running the following pip command. This can also be added to your projects requirements.txt: + +```bash +pip3 install "git+https://github.com/0x23/MicroManipulatorStepper/#subdirectory=software/PythonAPI" +``` + +If you would like to use the calibration plotter then you will need the additional `plotter` optional dependency. ## Usage Example ```python diff --git a/software/PythonAPI/.gitignore b/software/PythonAPI/.gitignore new file mode 100644 index 0000000..297dd9a --- /dev/null +++ b/software/PythonAPI/.gitignore @@ -0,0 +1,135 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +.ruff_cache/ + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db diff --git a/software/PythonAPI/LICENSE b/software/PythonAPI/LICENSE new file mode 100644 index 0000000..14ebf02 --- /dev/null +++ b/software/PythonAPI/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Github User '0x23' (https://github.com/0x23/) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software, hardware design, documentation, or concept (the "Work"), to deal in the Work +without restriction, including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Work, and to permit persons +to whom the Work is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies +or substantial portions of the Work, including but not limited to products or +derivative works based on the presented concepts, designs, or arrangements, even if +the original design files are not directly used. + +THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE WORK +OR THE USE OR OTHER DEALINGS IN THE WORK. diff --git a/software/PythonAPI/usage_example.py b/software/PythonAPI/examples/usage_example.py similarity index 86% rename from software/PythonAPI/usage_example.py rename to software/PythonAPI/examples/usage_example.py index 212a305..ce3e557 100644 --- a/software/PythonAPI/usage_example.py +++ b/software/PythonAPI/examples/usage_example.py @@ -2,7 +2,7 @@ # create interface and connect oms = OpenMicroStageInterface(show_communication=True, show_log_messages=True) -oms.connect('/dev/ttyACM0') +oms.connect("/dev/ttyACM0") # run this once to calibrate joints # for i in range(3): oms.calibrate_joint(i, save_result=True) @@ -15,4 +15,4 @@ oms.wait_for_stop() # print some info -oms.read_device_state_info() \ No newline at end of file +oms.read_device_state_info() diff --git a/software/PythonAPI/open_micro_stage_api/__init__.py b/software/PythonAPI/open_micro_stage_api/__init__.py new file mode 100644 index 0000000..9c52e66 --- /dev/null +++ b/software/PythonAPI/open_micro_stage_api/__init__.py @@ -0,0 +1,5 @@ +"""Open Micro Stage - Python API for micro-manipulator control.""" + +from .api import OpenMicroStageInterface + +__all__ = ["OpenMicroStageInterface"] diff --git a/software/PythonAPI/open_micro_stage_api/__main__.py b/software/PythonAPI/open_micro_stage_api/__main__.py new file mode 100644 index 0000000..25afde7 --- /dev/null +++ b/software/PythonAPI/open_micro_stage_api/__main__.py @@ -0,0 +1,5 @@ +"""Main file for open micro stage""" + +from open_micro_stage.calibration_plotter import main + +main() diff --git a/software/PythonAPI/open_micro_stage_api.py b/software/PythonAPI/open_micro_stage_api/api.py similarity index 79% rename from software/PythonAPI/open_micro_stage_api.py rename to software/PythonAPI/open_micro_stage_api/api.py index 3dc8835..02b77fb 100644 --- a/software/PythonAPI/open_micro_stage_api.py +++ b/software/PythonAPI/open_micro_stage_api/api.py @@ -1,27 +1,27 @@ +import re import threading import time -import re from enum import Enum -import serial import numpy as np -from colorama import Fore, Style, init +import serial +from colorama import Fore, Style # --- SerialInterface -------------------------------------------------------------------------------------------------- -class SerialInterface: +class SerialInterface: class ReplyStatus(Enum): - OK = 'ok' - ERROR = 'error' - TIMEOUT = 'timeout' - BUSY = 'busy' + OK = "ok" + ERROR = "error" + TIMEOUT = "timeout" + BUSY = "busy" class LogLevel(Enum): - DEBUG = 'debug' - INFO = 'info' - WARNING = 'warning' - ERROR = 'error' + DEBUG = "debug" + INFO = "info" + WARNING = "warning" + ERROR = "error" # Static mapping from prefix to LogLevel log_level_prefix_map = { @@ -31,11 +31,15 @@ class LogLevel(Enum): "E)": LogLevel.ERROR, } - def __init__(self, port: str, baud_rate: int = 115200, - command_msg_callback=None, - log_msg_callback=None, - unsolicited_msg_callback=None, - reconnect_timeout: int = 5): + def __init__( + self, + port: str, + baud_rate: int = 115200, + command_msg_callback=None, + log_msg_callback=None, + unsolicited_msg_callback=None, + reconnect_timeout: int = 5, + ): """ Initializes the serial connection and starts background reader. :param port: Serial port name (e.g., 'COM3' or '/dev/ttyUSB0'). @@ -66,27 +70,26 @@ def __init__(self, port: str, baud_rate: int = 115200, self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True) self._reader_thread.start() - def connect(self, timeout): """ Try to open the serial port. Retry until timeout expires. """ deadline = time.time() + timeout - print(Fore.MAGENTA, end='') - print(f"[SerialInterface] Connecting to port '{self.port}'...", end='') + print(Fore.MAGENTA, end="") + print(f"[SerialInterface] Connecting to port '{self.port}'...", end="") while time.time() < deadline: try: self.serial = serial.Serial(self.port, self.baud_rate, timeout=2) - print(f" [OK]") - print(Style.RESET_ALL, end='') + print(" [OK]") + print(Style.RESET_ALL, end="") return True - except (serial.SerialException, OSError) as e: - print('.', end='') + except (serial.SerialException, OSError): + print(".", end="") time.sleep(0.2) print(f" [FAILED] Timeout after {timeout} seconds.") - print(f"[SerialInterface] Connection is permanently closed") - print(Style.RESET_ALL, end='') + print("[SerialInterface] Connection is permanently closed") + print(Style.RESET_ALL, end="") self.serial = None return False @@ -98,8 +101,8 @@ def _reader_loop(self): while True: try: if self.serial is not None and self.serial.in_waiting: - char = self.serial.read(1).decode('ascii', errors='ignore') - if char in ['\n', '\r']: + char = self.serial.read(1).decode("ascii", errors="ignore") + if char in ["\n", "\r"]: if len(buffer) > 0: self._handle_line(buffer) buffer = "" @@ -108,7 +111,7 @@ def _reader_loop(self): else: time.sleep(0.001) except (serial.SerialException, OSError) as e: - print(Fore.MAGENTA+f"[SerialInterface] Lost connection: {e}"+Style.RESET_ALL) + print(Fore.MAGENTA + f"[SerialInterface] Lost connection: {e}" + Style.RESET_ALL) try: if self.serial is not None and self.serial.is_open: self.serial.close() @@ -128,7 +131,8 @@ def _handle_line(self, line: str): # print(line) # log message if log_level is not None: - if self.log_message_callback: self.log_message_callback(log_level, log_msg) + if self.log_message_callback: + self.log_message_callback(log_level, log_msg) # response elif self._waiting_for_response: line_lower = line.lower() @@ -144,15 +148,16 @@ def _handle_line(self, line: str): if self._response_status is not None: self._condition.notify() else: - self._response_string += line + '\n' + self._response_string += line + "\n" # unsolicited message else: - if self.unsolicited_msg_callback: self.unsolicited_msg_callback(line) + if self.unsolicited_msg_callback: + self.unsolicited_msg_callback(line) def _check_log_msg(self, msg: str): if len(msg) < 2: - return None, '' + return None, "" return self.log_level_prefix_map.get(msg[:2]), msg[2:] def send_command(self, cmd: str, timeout=2) -> tuple[ReplyStatus, str]: @@ -164,7 +169,7 @@ def send_command(self, cmd: str, timeout=2) -> tuple[ReplyStatus, str]: """ with self._lock: if not self.serial or not self.serial.is_open: - return SerialInterface.ReplyStatus.ERROR, 'Serial not open' + return SerialInterface.ReplyStatus.ERROR, "Serial not open" # Reset state self._waiting_for_response = True @@ -172,11 +177,11 @@ def send_command(self, cmd: str, timeout=2) -> tuple[ReplyStatus, str]: self._response_error_msg = "" self._response_status = None - cmd = (cmd.strip() + "\n") - self.command_msg_callback(cmd, None, '') + cmd = cmd.strip() + "\n" + self.command_msg_callback(cmd, None, "") # Send command - self.serial.write(cmd.encode('ascii')) + self.serial.write(cmd.encode("ascii")) self.serial.flush() # Wait for completion @@ -185,7 +190,11 @@ def send_command(self, cmd: str, timeout=2) -> tuple[ReplyStatus, str]: remaining = end_time - time.time() if remaining <= 0: self._waiting_for_response = False - print(Fore.MAGENTA + f"[SerialInterface] Command timeout, device didn't reply in time" + Style.RESET_ALL) + print( + Fore.MAGENTA + + "[SerialInterface] Command timeout, device didn't reply in time" + + Style.RESET_ALL + ) return SerialInterface.ReplyStatus.TIMEOUT, self._response_string self._condition.wait(timeout=remaining) @@ -198,12 +207,14 @@ def close(self): if self.serial and self.serial.is_open: self.serial.close() + # --- OpenMicroStageInterface ------------------------------------------------------------------------------------------ + class OpenMicroStageInterface: # Mapping log levels to colors LOG_COLORS = { - SerialInterface.LogLevel.DEBUG: Fore.WHITE+Style.DIM, + SerialInterface.LogLevel.DEBUG: Fore.WHITE + Style.DIM, SerialInterface.LogLevel.INFO: Style.RESET_ALL, SerialInterface.LogLevel.WARNING: Fore.YELLOW, SerialInterface.LogLevel.ERROR: Fore.RED, @@ -220,21 +231,27 @@ def connect(self, port: str, baud_rate: int = 921600): def version_to_str(v): return f"v{v[0]}.{v[1]}.{v[2]}" - if self.serial is not None: self.disconnect() - self.serial = SerialInterface(port, baud_rate, - log_msg_callback=self.log_msg_callback, - command_msg_callback=self.command_msg_callback, - unsolicited_msg_callback=self.unsolicited_msg_callback) + if self.serial is not None: + self.disconnect() + self.serial = SerialInterface( + port, + baud_rate, + log_msg_callback=self.log_msg_callback, + command_msg_callback=self.command_msg_callback, + unsolicited_msg_callback=self.unsolicited_msg_callback, + ) self.disable_message_callbacks = True fw_version = self.read_firmware_version() min_fw_version = (1, 0, 1) print(Fore.MAGENTA + f"Firmware version: {version_to_str(fw_version)}" + Style.RESET_ALL) if fw_version < min_fw_version: - print(Fore.MAGENTA + f"Firmware version {version_to_str(fw_version)} incompatible. " - f"At least {version_to_str(min_fw_version)} required" + Style.RESET_ALL) + print( + Fore.MAGENTA + f"Firmware version {version_to_str(fw_version)} incompatible. " + f"At least {version_to_str(min_fw_version)} required" + Style.RESET_ALL + ) self.serial = None - print('') + print("") self.disable_message_callbacks = False def disconnect(self): @@ -258,17 +275,17 @@ def command_msg_callback(self, msg, reply_status: SerialInterface.ReplyStatus, e if reply_status is not None: if msg: - msg = '\n'.join('> ' + line for line in msg.splitlines()) + msg = "\n".join("> " + line for line in msg.splitlines()) print(f"{msg.rstrip()}") if error_msg: print(f"{Style.BRIGHT}{str(reply_status.name)}:{Style.RESET_ALL} {error_msg}\n") else: print(f"{Style.BRIGHT}{str(reply_status.name)} {Style.RESET_ALL}\n") else: - print(f"{Fore.GREEN+Style.BRIGHT}{msg.rstrip()}{Style.RESET_ALL}") + print(f"{Fore.GREEN + Style.BRIGHT}{msg.rstrip()}{Style.RESET_ALL}") def unsolicited_msg_callback(self, msg): - print(Fore.CYAN+msg+Style.RESET_ALL) + print(Fore.CYAN + msg + Style.RESET_ALL) pass def set_workspace_transform(self, transform): @@ -282,8 +299,8 @@ def read_firmware_version(self): if ok != SerialInterface.ReplyStatus.OK or len(response) == 0: return 0, 0, 0 - major, minor, patch = map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', response).groups()) - return major,minor,patch + major, minor, patch = map(int, re.match(r"v(\d+)\.(\d+)\.(\d+)", response).groups()) + return major, minor, patch def home(self, axis_list=None): """ @@ -291,15 +308,15 @@ def home(self, axis_list=None): :param axis_list: Optional list of axis indices to home. If None, all axes are homed. :return: The status of the command (e.g. OK, ERROR, TIMEOUT). """ - cmd = 'G28' - axis_chars = ['A', 'B', 'C', 'D', 'E', 'F'] + cmd = "G28" + axis_chars = ["A", "B", "C", "D", "E", "F"] if axis_list is None: - axis_list = [i for i in range(len(axis_chars))] + axis_list = list(range(len(axis_chars))) for axis_idx in axis_list: - if 0 > axis_idx >= len(axis_chars): - raise ValueError('Axis index out of range') - cmd += ' '+axis_chars[axis_idx] + if (0 > axis_idx) or (axis_idx >= len(axis_chars)): + raise ValueError("Axis index out of range") + cmd += " " + axis_chars[axis_idx] res, msg = self.serial.send_command(cmd + "\n", 10) return res @@ -315,7 +332,8 @@ def calibrate_joint(self, joint_index: int, save_result: bool): :return: """ cmd = f"M56 J{joint_index} P" - if save_result: cmd += ' S' + if save_result: + cmd += " S" res, msg = self.serial.send_command(cmd, 30) calibration_data = self._parse_table_data(msg, 3) @@ -364,11 +382,13 @@ def set_max_acceleration(self, linear_accel, angular_accel): def wait_for_stop(self, polling_interval_ms=10, disable_callbacks=True): disable_message_callbacks_prev = self.disable_message_callbacks - if disable_callbacks: self.disable_message_callbacks = True + if disable_callbacks: + self.disable_message_callbacks = True while True: res, msg = self.serial.send_command("M53\n") - if res != SerialInterface.ReplyStatus.OK: return res + if res != SerialInterface.ReplyStatus.OK: + return res elif msg.strip() == "1": return SerialInterface.ReplyStatus.OK @@ -380,10 +400,7 @@ def read_current_position(self): return None, None, None # Match values with NO space between axis letter and number - match = re.search( - r"X([-+]?\d*\.?\d+)\s*Y([-+]?\d*\.?\d+)\s*Z([-+]?\d*\.?\d+)", - response - ) + match = re.search(r"X([-+]?\d*\.?\d+)\s*Y([-+]?\d*\.?\d+)\s*Z([-+]?\d*\.?\d+)", response) if not match: raise ValueError(f"Invalid format: {response}") @@ -406,7 +423,7 @@ def set_servo_parameter(self, pos_kp=150, pos_ki=50000, vel_kp=0.2, vel_ki=100, return res def enable_motors(self, enable): - cmd = f"M17" if enable else "M18" + cmd = "M17" if enable else "M18" res, msg = self.serial.send_command(cmd, timeout=5) return res @@ -415,11 +432,11 @@ def set_pose(self, x, y, z): transformed = self.workspace_transform @ np.array([x, y, z, 1.0]) x_t, y_t, z_t = transformed[:3] / transformed[3] - cmd = f"G24 X{x_t:.6f} Y{y_t:.6f} Z{z_t:.6f}" # TODO: A, B ,C + cmd = f"G24 X{x_t:.6f} Y{y_t:.6f} Z{z_t:.6f}" # TODO: A, B ,C res, msg = self.serial.send_command(cmd) return res - def send_command(self, cmd: str, timeout_s: float=5): + def send_command(self, cmd: str, timeout_s: float = 5): res, msg = self.serial.send_command(cmd, timeout_s) return res, msg @@ -429,7 +446,7 @@ def _parse_table_data(data_string, cols): data = [[] for _ in range(cols)] for line in data_string.strip().splitlines(): - parts = line.strip().split(',') + parts = line.strip().split(",") if len(parts) != cols: continue # skip malformed lines numbers = map(float, parts) diff --git a/software/PythonAPI/calibration_plotter.py b/software/PythonAPI/open_micro_stage_api/calibration_plotter.py similarity index 58% rename from software/PythonAPI/calibration_plotter.py rename to software/PythonAPI/open_micro_stage_api/calibration_plotter.py index c0e3f26..ff9e737 100644 --- a/software/PythonAPI/calibration_plotter.py +++ b/software/PythonAPI/open_micro_stage_api/calibration_plotter.py @@ -1,40 +1,42 @@ -from open_micro_stage_api import OpenMicroStageInterface import matplotlib.pyplot as plt -plt.rcParams['figure.dpi'] = 200 + +from open_micro_stage import OpenMicroStageInterface + +plt.rcParams["figure.dpi"] = 200 + def plot_calibration_data(ax_encoder_counts, ax_field_angel, label, data): # Plot on the provided Axes object if ax_encoder_counts is not None: ax_encoder_counts.plot(data[0], data[2], label=label) - ax_encoder_counts.set_xlabel('Motor Angle [rad]') - ax_encoder_counts.set_ylabel('Encoder Counts Raw') - ax_encoder_counts.set_title('Encoder Count Plot') + ax_encoder_counts.set_xlabel("Motor Angle [rad]") + ax_encoder_counts.set_ylabel("Encoder Counts Raw") + ax_encoder_counts.set_title("Encoder Count Plot") ax_encoder_counts.legend() ax_encoder_counts.grid(True) # Plot on the provided Axes object if ax_field_angel is not None: ax_field_angel.plot(data[0], data[1], label=label) - ax_field_angel.set_xlabel('Motor Angle [rad]') - ax_field_angel.set_ylabel('Motor Field Angle [rad]') - ax_field_angel.set_title('Field Angle Plot') + ax_field_angel.set_xlabel("Motor Angle [rad]") + ax_field_angel.set_ylabel("Motor Field Angle [rad]") + ax_field_angel.set_title("Field Angle Plot") ax_field_angel.legend() ax_field_angel.grid(True) + def main(): # create interface and connect oms = OpenMicroStageInterface(show_communication=True, show_log_messages=True) - oms.connect('/dev/ttyACM0') + oms.connect("/dev/ttyACM0") # Create subplots - fig, ax = plt.subplots(1, 1, figsize=(10, 7), sharex='all') + fig, ax = plt.subplots(1, 1, figsize=(10, 7), sharex="all") for i in range(3): res, data = oms.calibrate_joint(i, save_result=False) - plot_calibration_data(ax, None, f'Actuator {i}', data) + plot_calibration_data(ax, None, f"Actuator {i}", data) # Adjust layout and show plt.tight_layout() plt.show() - -main() \ No newline at end of file diff --git a/software/PythonAPI/pyproject.toml b/software/PythonAPI/pyproject.toml new file mode 100644 index 0000000..b6e4c0b --- /dev/null +++ b/software/PythonAPI/pyproject.toml @@ -0,0 +1,62 @@ +[project] +name = "open_micro_stage_api" +version = "0.1.0" +description = "Python API for controlling the Open Micro Stage manipulator" +readme = "README.md" +requires-python = ">=3.8" +license = "MIT" +authors = [ + {name = "MicroManipulatorStepper Contributors"} +] +keywords = ["micromanipulator", "stepper", "stage", "control", "robotics"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", +] +dependencies = [ + "numpy", + "pyserial", + "colorama", +] + +[build-system] +requires = ["setuptools>=65.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project.optional-dependencies] +plotter = [ + "matplotlib", +] +dev = [ + "pytest>=7.0", + "pytest-cov", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/HonakerM/MicroManipulatorStepper" +Repository = "https://github.com/HonakerM/MicroManipulatorStepper.git" +Documentation = "https://github.com/HonakerM/MicroManipulatorStepper" +Issues = "https://github.com/HonakerM/MicroManipulatorStepper/issues" + +[tool.setuptools] +packages = ["open_micro_stage_api"] + +[tool.ruff] +line-length = 120 +target-version = "py38" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by formatter) +] \ No newline at end of file diff --git a/software/PythonAPI/requirements.txt b/software/PythonAPI/requirements.txt deleted file mode 100644 index 6e4d37a..0000000 --- a/software/PythonAPI/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy -pyserial -colorama diff --git a/software/PythonAPI/tests/__init__.py b/software/PythonAPI/tests/__init__.py new file mode 100644 index 0000000..cd71f58 --- /dev/null +++ b/software/PythonAPI/tests/__init__.py @@ -0,0 +1 @@ +"""Tests package for open_micro_stage.""" diff --git a/software/PythonAPI/tests/test_open_micro_stage.py b/software/PythonAPI/tests/test_open_micro_stage.py new file mode 100644 index 0000000..2ae0dac --- /dev/null +++ b/software/PythonAPI/tests/test_open_micro_stage.py @@ -0,0 +1,658 @@ +"""Tests for the OpenMicroStageInterface class.""" + +import unittest +from unittest.mock import patch + +import numpy as np + +from open_micro_stage_api.api import OpenMicroStageInterface, SerialInterface + +DESIRED_FIRMWARE_VERSION = "v1.0.1" + + +class MockSerialInterface: + """Generic mock implementation of SerialInterface for testing.""" + + ReplyStatus = SerialInterface.ReplyStatus + LogLevel = SerialInterface.LogLevel + + def __init__( + self, + port: str = "/dev/ttyACM0", + baud_rate: int = 115200, + command_msg_callback=None, + log_msg_callback=None, + unsolicited_msg_callback=None, + reconnect_timeout: int = 5, + ): + """Initialize mock serial interface.""" + self.port = port + self.baud_rate = baud_rate + self.reconnect_timeout = reconnect_timeout + self.command_msg_callback = command_msg_callback + self.log_message_callback = log_msg_callback + self.unsolicited_msg_callback = unsolicited_msg_callback + + # Command response queue + self.responses = [] + self.response_index = 0 + self.call_history = [] + + def set_response(self, status: SerialInterface.ReplyStatus, response: str): + """Set a single response for the next send_command call.""" + self.responses = [(status, response)] + self.response_index = 0 + + def set_responses(self, responses: list): + """Set multiple responses for sequential send_command calls. + + Args: + responses: List of (status, response) tuples + """ + self.responses = responses + self.response_index = 0 + + def send_command(self, cmd: str, timeout=2): + """Send a command and return a mocked response.""" + self.call_history.append((cmd, timeout)) + + if self.response_index >= len(self.responses): + # If we run out of responses, return the last one or OK + if self.responses: + status, response = self.responses[-1] + else: + status, response = SerialInterface.ReplyStatus.OK, "" + else: + status, response = self.responses[self.response_index] + self.response_index += 1 + + # Call the command callback if set + if self.command_msg_callback: + self.command_msg_callback(cmd, None, "") + + return status, response + + def close(self): + """Close the mock connection.""" + pass + + def reset(self): + """Reset mock state for a new test.""" + self.responses = [] + self.response_index = 0 + self.call_history = [] + + def get_last_command(self): + """Get the last command that was sent.""" + if self.call_history: + return self.call_history[-1][0] + return None + + def get_all_commands(self): + """Get all commands that were sent.""" + return [cmd for cmd, _ in self.call_history] + + def assert_command_called(self, cmd: str): + """Assert that a specific command was called.""" + if cmd not in self.get_all_commands(): + raise AssertionError(f"Command '{cmd}' was not called. Commands: {self.get_all_commands()}") + + def assert_command_called_with_args(self, cmd: str, timeout=None): + """Assert that a specific command was called with specific arguments.""" + for called_cmd, called_timeout in self.call_history: + if called_cmd == cmd and (timeout is None or called_timeout == timeout): + return + raise AssertionError( + f"Command '{cmd}' with timeout={timeout} was not called. Call history: {self.call_history}" + ) + + +class TestOpenMicroStageInterface(unittest.TestCase): + """End-to-end tests for OpenMicroStageInterface.""" + + def setUp(self): + """Set up test fixtures with mocked SerialInterface.""" + # Create a persistent mock instance + self.mock_serial_instance = MockSerialInterface() + + # Patch SerialInterface to return the same mock instance every time + self.patcher = patch("open_micro_stage.api.SerialInterface.__new__", return_value=self.mock_serial_instance) + self.patcher.start() + + # Create the interface + self.interface = OpenMicroStageInterface(show_communication=False, show_log_messages=False) + + # default connect for the majority of commands + self.mock_serial_instance.set_response( + SerialInterface.ReplyStatus.OK, + DESIRED_FIRMWARE_VERSION, + ) + self.interface.connect("/dev/ttyACM0") + self.calls_for_initailization = len(self.mock_serial_instance.call_history) + + def tearDown(self): + """Clean up patches.""" + self.patcher.stop() + + def test_initialization(self): + """Test that OpenMicroStageInterface initializes with correct defaults.""" + interface = OpenMicroStageInterface() + self.assertIsNone(interface.serial) + self.assertTrue(np.array_equal(interface.workspace_transform, np.eye(4))) + self.assertTrue(interface.show_communication) + self.assertTrue(interface.show_log_messages) + self.assertFalse(interface.disable_message_callbacks) + + def test_initialization_with_params(self): + """Test initialization with custom parameters.""" + interface = OpenMicroStageInterface(show_communication=False, show_log_messages=False) + self.assertFalse(interface.show_communication) + self.assertFalse(interface.show_log_messages) + + def test_connect_success(self): + """Test successful connection to device.""" + self.mock_serial_instance.set_response( + SerialInterface.ReplyStatus.OK, + DESIRED_FIRMWARE_VERSION, + ) + self.interface.connect("/dev/ttyACM0") + + self.assertIsNotNone(self.interface.serial) + self.assertEqual(self.interface.serial.port, "/dev/ttyACM0") + + def test_connect_with_custom_baud_rate(self): + """Test connection with custom baud rate.""" + self.mock_serial_instance.set_response( + SerialInterface.ReplyStatus.OK, + DESIRED_FIRMWARE_VERSION, + ) + + self.interface.connect("/dev/ttyACM0", baud_rate=115200) + + # Verify the mock was called with correct parameters + self.assertIsNotNone(self.interface.serial) + self.assertEqual(self.interface.serial.baud_rate, 115200) + + def test_connect_incompatible_firmware(self): + """Test connection fails with incompatible firmware version.""" + self.mock_serial_instance.set_response( + SerialInterface.ReplyStatus.OK, + "v0.9.0", + ) + + self.interface.connect("/dev/ttyACM0") + + # Serial should be set to None on incompatible version + self.assertIsNone(self.interface.serial) + + def test_disconnect(self): + """Test disconnection from device.""" + self.mock_serial_instance.set_response( + SerialInterface.ReplyStatus.OK, + DESIRED_FIRMWARE_VERSION, + ) + + self.interface.connect("/dev/ttyACM0") + self.interface.disconnect() + + self.assertIsNone(self.interface.serial) + + def test_disconnect_when_not_connected(self): + """Test disconnect gracefully handles when not connected.""" + # Should not raise exception + self.interface.disconnect() + self.assertIsNone(self.interface.serial) + + def test_set_and_get_workspace_transform(self): + """Test setting and getting workspace transform.""" + transform = np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]) + + self.interface.set_workspace_transform(transform) + result = self.interface.get_workspace_transform() + + self.assertTrue(np.array_equal(result, transform)) + + def test_read_firmware_version(self): + """Test reading firmware version.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "v1.2.3", + ) + + major, minor, patch = self.interface.read_firmware_version() + + self.assertEqual(major, 1) + self.assertEqual(minor, 2) + self.assertEqual(patch, 3) + self.interface.serial.assert_command_called("M58") + + def test_read_firmware_version_error(self): + """Test firmware version returns 0,0,0 on error.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.ERROR, + "", + ) + + major, minor, patch = self.interface.read_firmware_version() + + self.assertEqual((major, minor, patch), (0, 0, 0)) + + def test_home_all_axes(self): + """Test homing all axes.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.home() + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called("G28 A B C D E F\n") + + def test_home_specific_axes(self): + """Test homing specific axes.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.home(axis_list=[0, 2]) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called("G28 A C\n") + + def test_home_invalid_axis(self): + """Test homing with invalid axis index raises error.""" + with self.assertRaises(ValueError): + self.interface.home(axis_list=[10]) + + def test_calibrate_joint_no_save(self): + """Test calibrating a joint without saving results.""" + calibration_response = "0.5,1.0,100\n1.0,2.0,200\n1.5,3.0,300\n" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + calibration_response, + ) + + result, data = self.interface.calibrate_joint(0, save_result=False) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.assertEqual(len(data), 3) + self.assertEqual(data[0], [0.5, 1.0, 1.5]) # motor angles + self.assertEqual(data[1], [1.0, 2.0, 3.0]) # field angles + self.assertEqual(data[2], [100, 200, 300]) # encoder counts + self.interface.serial.assert_command_called("M56 J0 P") + + def test_calibrate_joint_with_save(self): + """Test calibrating a joint with saving results.""" + calibration_response = "0.5,1.0,100\n" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + calibration_response, + ) + + result, data = self.interface.calibrate_joint(1, save_result=True) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called("M56 J1 P S") + + def test_read_current_position(self): + """Test reading current position.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "X10.5 Y20.3 Z15.8", + ) + + x, y, z = self.interface.read_current_position() + + self.assertAlmostEqual(x, 10.5) + self.assertAlmostEqual(y, 20.3) + self.assertAlmostEqual(z, 15.8) + self.interface.serial.assert_command_called("M50") + + def test_read_current_position_error(self): + """Test reading current position returns None on error.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.ERROR, + "", + ) + + x, y, z = self.interface.read_current_position() + + self.assertIsNone(x) + self.assertIsNone(y) + self.assertIsNone(z) + + def test_read_current_position_invalid_format(self): + """Test reading current position with invalid format raises error.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "invalid format", + ) + + with self.assertRaises(ValueError): + self.interface.read_current_position() + + def test_move_to_immediate(self): + """Test moving to position with immediate execution.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.move_to(5.0, 10.0, 15.0, f=20.0, move_immediately=True) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("G0 X5.000000 Y10.000000 Z15.000000 F20.000", cmd) + self.assertIn("I", cmd) + + def test_move_to_with_workspace_transform(self): + """Test move_to applies workspace transform correctly.""" + # Set a simple translation transform + transform = np.array([[1, 0, 0, 2], [0, 1, 0, 3], [0, 0, 1, 4], [0, 0, 0, 1]]) + self.interface.set_workspace_transform(transform) + + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.move_to(0, 0, 0, f=10.0) + + # Expected transformed position is (2, 3, 4) + cmd = self.interface.serial.get_last_command() + self.assertIn("X2.000000", cmd) + self.assertIn("Y3.000000", cmd) + self.assertIn("Z4.000000", cmd) + + def test_move_to_blocking_busy_retry(self): + """Test move_to retries on BUSY when blocking is True.""" + self.interface.serial.set_responses( + [ + (SerialInterface.ReplyStatus.BUSY, ""), + (SerialInterface.ReplyStatus.OK, ""), + ] + ) + + result = self.interface.move_to(5.0, 10.0, 15.0, f=20.0, blocking=True, timeout=0.01) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.assertEqual(len(self.interface.serial.call_history), 2 + self.calls_for_initailization) + + def test_move_to_non_blocking_returns_busy(self): + """Test move_to returns BUSY immediately when blocking is False.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.BUSY, + "", + ) + + result = self.interface.move_to(5.0, 10.0, 15.0, f=20.0, blocking=False) + + self.assertEqual(result, SerialInterface.ReplyStatus.BUSY) + self.assertEqual(len(self.interface.serial.call_history), 1 + self.calls_for_initailization) + + def test_dwell(self): + """Test dwell command.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.dwell(time_s=2.5, blocking=True) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("G4 S2.500000", cmd) + + def test_set_max_acceleration(self): + """Test setting max acceleration.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.set_max_acceleration(linear_accel=100.0, angular_accel=50.0) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("M204 L100.000000 A50.000000", cmd) + + def test_set_max_acceleration_minimum_values(self): + """Test max acceleration enforces minimum values.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + self.interface.set_max_acceleration(linear_accel=0.001, angular_accel=0.001) + + cmd = self.interface.serial.get_last_command() + # Should be clamped to 0.01 + self.assertIn("M204 L0.010000 A0.010000", cmd) + + def test_wait_for_stop_ready(self): + """Test wait_for_stop returns when device is ready.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "1", + ) + + result = self.interface.wait_for_stop() + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called("M53\n") + + def test_wait_for_stop_polls_until_ready(self): + """Test wait_for_stop polls until device is ready.""" + self.interface.serial.set_responses( + [ + (SerialInterface.ReplyStatus.OK, "0"), + (SerialInterface.ReplyStatus.OK, "0"), + (SerialInterface.ReplyStatus.OK, "1"), + ] + ) + + result = self.interface.wait_for_stop() + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.assertEqual(len(self.interface.serial.call_history), 3 + self.calls_for_initailization) + + def test_wait_for_stop_error(self): + """Test wait_for_stop returns error status.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.ERROR, + "", + ) + + result = self.interface.wait_for_stop() + + self.assertEqual(result, SerialInterface.ReplyStatus.ERROR) + + def test_enable_motors(self): + """Test enabling motors.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.enable_motors(enable=True) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called_with_args("M17", timeout=5) + + def test_disable_motors(self): + """Test disabling motors.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.enable_motors(enable=False) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called_with_args("M18", timeout=5) + + def test_set_pose(self): + """Test setting pose.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.set_pose(x=5.0, y=10.0, z=15.0) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("G24 X5.000000 Y10.000000 Z15.000000", cmd) + + def test_send_custom_command(self): + """Test sending custom command.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "response data", + ) + + result, response = self.interface.send_command("M57", timeout_s=3.0) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.assertEqual(response, "response data") + + def test_read_device_state_info(self): + """Test reading device state info.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "state info", + ) + + result = self.interface.read_device_state_info() + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.interface.serial.assert_command_called("M57") + + def test_set_servo_parameter_defaults(self): + """Test setting servo parameters with defaults.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.set_servo_parameter() + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("M55", cmd) + self.assertIn("A150.000000", cmd) # pos_kp + self.assertIn("B50000.000000", cmd) # pos_ki + self.assertIn("C0.200000", cmd) # vel_kp + self.assertIn("D100.000000", cmd) # vel_ki + self.assertIn("F0.002500", cmd) # vel_filter_tc + + def test_set_servo_parameter_custom(self): + """Test setting servo parameters with custom values.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.set_servo_parameter( + pos_kp=200, pos_ki=60000, vel_kp=0.3, vel_ki=120, vel_filter_tc=0.003 + ) + + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + cmd = self.interface.serial.get_last_command() + self.assertIn("A200.000000", cmd) + self.assertIn("B60000.000000", cmd) + self.assertIn("C0.300000", cmd) + self.assertIn("D120.000000", cmd) + self.assertIn("F0.003000", cmd) + + def test_read_encoder_angles(self): + """Test reading encoder angles returns empty list.""" + self.interface.serial.set_response( + SerialInterface.ReplyStatus.OK, + "", + ) + + result = self.interface.read_encoder_angles() + + self.assertEqual(result, []) + self.interface.serial.assert_command_called("M51") + + def test_parse_table_data(self): + """Test parsing table data.""" + data_string = "1.0,2.0,3.0\n4.0,5.0,6.0\n7.0,8.0,9.0" + result = OpenMicroStageInterface._parse_table_data(data_string, 3) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0], [1.0, 4.0, 7.0]) + self.assertEqual(result[1], [2.0, 5.0, 8.0]) + self.assertEqual(result[2], [3.0, 6.0, 9.0]) + + def test_parse_table_data_with_malformed_lines(self): + """Test parsing table data skips malformed lines.""" + data_string = "1.0,2.0,3.0\ninvalid\n4.0,5.0,6.0" + result = OpenMicroStageInterface._parse_table_data(data_string, 3) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0], [1.0, 4.0]) + self.assertEqual(result[1], [2.0, 5.0]) + self.assertEqual(result[2], [3.0, 6.0]) + + def test_parse_table_data_single_row(self): + """Test parsing single row of data.""" + data_string = "10.5,20.3,15.8" + result = OpenMicroStageInterface._parse_table_data(data_string, 3) + + self.assertEqual(result[0], [10.5]) + self.assertEqual(result[1], [20.3]) + self.assertEqual(result[2], [15.8]) + + def test_workflow_connect_home_move_stop(self): + """Test end-to-end workflow: connect, home, move, wait for stop.""" + # Set responses for each command in sequence + self.mock_serial_instance.set_responses( + [ + (SerialInterface.ReplyStatus.OK, DESIRED_FIRMWARE_VERSION), # firmware version + (SerialInterface.ReplyStatus.OK, ""), # home + (SerialInterface.ReplyStatus.OK, ""), # move_to + (SerialInterface.ReplyStatus.OK, "1"), # wait_for_stop + ] + ) + + self.interface.connect("/dev/ttyACM0") + self.assertIsNotNone(self.interface.serial) + + home_result = self.interface.home() + self.assertEqual(home_result, SerialInterface.ReplyStatus.OK) + + move_result = self.interface.move_to(5.0, 10.0, 15.0, f=20.0) + self.assertEqual(move_result, SerialInterface.ReplyStatus.OK) + + stop_result = self.interface.wait_for_stop() + self.assertEqual(stop_result, SerialInterface.ReplyStatus.OK) + + def test_workflow_calibrate_and_move(self): + """Test end-to-end workflow: calibrate joint and then move.""" + calibration_data = "0.5,1.0,100\n1.0,2.0,200\n" + + self.mock_serial_instance.set_responses( + [ + (SerialInterface.ReplyStatus.OK, calibration_data), # calibrate + (SerialInterface.ReplyStatus.OK, "X5.0 Y10.0 Z15.0"), # read position + ] + ) + + result, data = self.interface.calibrate_joint(0, save_result=True) + self.assertEqual(result, SerialInterface.ReplyStatus.OK) + self.assertEqual(len(data), 3) + + x, y, z = self.interface.read_current_position() + self.assertAlmostEqual(x, 5.0) + self.assertAlmostEqual(y, 10.0) + self.assertAlmostEqual(z, 15.0) + + +# Manually run unittest +if __name__ == "__main__": + unittest.main()