From 683217934fd962357909dd675a0e5c85419098fc Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Tue, 9 Jun 2026 13:24:38 +0000 Subject: [PATCH] Update undulator logic to use movable_logic --- .../insertion_device/apple2_undulator.py | 216 +++++++++++------- src/dodal/testing/fixtures/devices/apple2.py | 13 +- .../insertion_device/test_apple2_undulator.py | 28 +-- tests/devices/insertion_device/test_energy.py | 2 +- 4 files changed, 159 insertions(+), 100 deletions(-) diff --git a/src/dodal/devices/insertion_device/apple2_undulator.py b/src/dodal/devices/insertion_device/apple2_undulator.py index 1bc30b5f8b0..6a1daf8a912 100644 --- a/src/dodal/devices/insertion_device/apple2_undulator.py +++ b/src/dodal/devices/insertion_device/apple2_undulator.py @@ -7,19 +7,16 @@ import numpy as np from bluesky.protocols import Movable from ophyd_async.core import ( - DEFAULT_TIMEOUT, AsyncStatus, - Device, FlyMotorInfo, MovableLogic, Reference, SignalR, + SignalRW, SignalW, StandardReadable, StandardReadableFormat, - WatchableAsyncStatus, - WatcherUpdate, - observe_value, + derived_signal_rw, wait_for_value, ) from ophyd_async.epics.core import epics_signal_r, epics_signal_rw @@ -55,41 +52,46 @@ def extract_phase_val(self): return self.phase -async def estimate_motor_timeout( - setpoint: SignalR, curr_pos: SignalR, velocity: SignalR +async def estimate_motor_timeout_from_sigs( + current_pos: SignalR, new_pos: SignalR, velocity: SignalR ): vel = await velocity.get_value() - cur_pos = await curr_pos.get_value() - target_pos = float(await setpoint.get_value()) + cur_pos = await current_pos.get_value() + target_pos = float(await new_pos.get_value()) return abs((target_pos - cur_pos) * 2.0 / vel) + DEFAULT_MOTOR_MIN_TIMEOUT -class UndulatorBase(abc.ABC, Device, Generic[T]): +def estimate_motor_timeout( + current_pos: float, new_pos: float, velocity: float +) -> float: + return abs((new_pos - current_pos) * 2.0 / velocity) + DEFAULT_MOTOR_MIN_TIMEOUT + + +class UndulatorBase(abc.ABC, Generic[T]): """Abstract base class for Apple2 undulator devices. This class provides common functionality for undulator devices, including gate and status signal management, safety checks before motion, and abstract methods for setting demand positions and estimating move timeouts. """ - def __init__(self, name: str = ""): + def __init__(self): # Gate keeper open when move is requested, closed when move is completed self.gate: SignalR[UndulatorGateStatus] self.status: SignalR[EnabledDisabledUpper] - super().__init__(name=name) @abc.abstractmethod async def set_demand_positions(self, value: T) -> None: """Set the demand positions on the device without actually hitting move.""" @abc.abstractmethod - async def get_timeout(self) -> float: + async def get_timeout_for_apple2(self) -> float: """Get the timeout for the move based on an estimate of how long it will take.""" async def raise_if_cannot_move(self) -> None: if await self.status.get_value() is EnabledDisabledUpper.DISABLED: - raise RuntimeError(f"{self.name} is DISABLED and cannot move.") + raise RuntimeError(f"{self.status.name} is DISABLED and cannot move.") if await self.gate.get_value() is UndulatorGateStatus.OPEN: - raise RuntimeError(f"{self.name} is already in motion.") + raise RuntimeError(f"{self.status.name} is already in motion.") class SafeUndulatorMover(StandardReadable, UndulatorBase, Generic[T]): @@ -109,7 +111,7 @@ async def set(self, value: T) -> None: LOGGER.info(f"Setting {self.name} to {value}") await self.raise_if_cannot_move() await self.set_demand_positions(value) - timeout = await self.get_timeout() + timeout = await self.get_timeout_for_apple2() LOGGER.info(f"Moving {self.name} to {value} with timeout = {timeout}") await self.set_move.set(value=1, timeout=timeout) await wait_for_value(self.gate, UndulatorGateStatus.CLOSE, timeout=timeout) @@ -145,7 +147,65 @@ def movable_logic(self) -> MovableLogic: ) -class GapSafeMotorNoStop(UnstoppableMotor, UndulatorBase[float]): +@dataclass +class GapSafeMotorMoveLogic(UnstoppableMotorMoveLogic, UndulatorBase[float]): + gate: SignalR[UndulatorGateStatus] + status: SignalR[EnabledDisabledUpper] + set_move: SignalW[int] + + async def check_move(self, new_position: float) -> None: + await super().check_move(new_position) + await self.raise_if_cannot_move() + + async def move(self, new_position: float, timeout: float | None) -> None: + await self.set_demand_positions(new_position) + await self.set_move.set(1, timeout=timeout) + + await wait_for_value( + self.gate, + UndulatorGateStatus.CLOSE, + timeout=timeout, + ) + + async def calculate_timeout( + self, old_position: float, new_position: float + ) -> float: + vel = await self.velocity.get_value() + return estimate_motor_timeout(old_position, new_position, vel) + + async def get_timeout_for_apple2(self) -> float: + return await estimate_motor_timeout_from_sigs( + self.readback, self.setpoint, self.velocity + ) + + async def set_demand_positions(self, value: float) -> None: + await self.setpoint.set(value) + + +class UserSetpointWrapperUnstoppableMotor(UnstoppableMotor): + """Replace the motor setpoint with a derived signal user_setpoint. Used when the raw + underlying signal is a str rather than a float and the conversion is handled via + the derived signal so it works seemlessly like a normal motor using float. This + allows for plans and devices interacting with this device not needing to worry about + type checking or converting the values. + """ + + user_setpoint_str: SignalRW[str] + + def __init__(self, prefix: str, name: str = ""): + super().__init__(prefix, name) + self.user_setpoint = derived_signal_rw( + self._get_setpoint, self._set_setpoint, setpoint_str=self.user_setpoint_str + ) + + async def _set_setpoint(self, value: float) -> None: + await self.user_setpoint_str.set(str(value)) + + def _get_setpoint(self, setpoint_str: str) -> float: + return float(setpoint_str) + + +class GapSafeMotorNoStop(UserSetpointWrapperUnstoppableMotor): """Update gap safe motor that checks it's safe to move before moving.""" def __init__(self, set_move: SignalW[int], prefix: str, name: str = ""): @@ -155,42 +215,8 @@ def __init__(self, set_move: SignalW[int], prefix: str, name: str = ""): self.set_move = set_move super().__init__(prefix=prefix + "BLGAPMTR", name=name) - @WatchableAsyncStatus.wrap - async def set(self, new_position: float, timeout=DEFAULT_TIMEOUT): - ( - old_position, - units, - precision, - ) = await asyncio.gather( - self.user_setpoint.get_value(), - self.motor_egu.get_value(), - self.precision.get_value(), - ) - LOGGER.info(f"Setting {self.name} to {new_position}") - await self.raise_if_cannot_move() - await self.set_demand_positions(new_position) - timeout = await self.get_timeout() - LOGGER.info(f"Moving {self.name} to {new_position} with timeout = {timeout}") - - await self.set_move.set(value=1, timeout=timeout) - move_status = AsyncStatus( - wait_for_value(self.gate, UndulatorGateStatus.CLOSE, timeout=timeout) - ) - - async for current_position in observe_value( - self.user_readback, done_status=move_status - ): - yield WatcherUpdate( - current=current_position, - initial=old_position, - target=new_position, - name=self.name, - unit=units, - precision=precision, - ) - -class UndulatorGap(GapSafeMotorNoStop): +class UndulatorGap(GapSafeMotorNoStop, UndulatorBase): """Apple 2 undulator gap motor device. With PV corrections. Args: @@ -201,11 +227,13 @@ class UndulatorGap(GapSafeMotorNoStop): def __init__(self, prefix: str, name: str = ""): self.set_move = epics_signal_rw(int, prefix + "BLGSETP") # Nothing move until this is set to 1 and it will return to 0 when done. + + self.user_setpoint_str = epics_signal_rw(str, prefix + "BLGSET") super().__init__(self.set_move, prefix, name) self.max_velocity = epics_signal_r(float, prefix + "BLGSETVEL.HOPR") self.min_velocity = epics_signal_r(float, prefix + "BLGSETVEL.LOPR") - self.user_setpoint = epics_signal_rw(str, prefix + "BLGSET") + """ Clear the motor config_signal as we need new PV for velocity.""" self._describe_config_funcs = () self._read_config_funcs = () @@ -233,16 +261,31 @@ async def prepare(self, value: FlyMotorInfo) -> None: ) await super().prepare(value) - async def get_timeout(self) -> float: - return await estimate_motor_timeout( - self.user_setpoint, self.user_readback, self.velocity - ) + async def get_timeout_for_apple2(self) -> float: + return await self.movable_logic.get_timeout_for_apple2() async def set_demand_positions(self, value: float) -> None: - await self.user_setpoint.set(str(value)) + return await self.movable_logic.set_demand_positions(value) + @cached_property + def movable_logic(self) -> GapSafeMotorMoveLogic: + return GapSafeMotorMoveLogic( + readback=self.user_readback, + setpoint=self.user_setpoint, + low_limit_travel=self.low_limit_travel, + high_limit_travel=self.high_limit_travel, + motor_stop=None, # type: ignore + dial_low_limit_travel=self.dial_low_limit_travel, + dial_high_limit_travel=self.dial_high_limit_travel, + velocity=self.velocity, + acceleration_time=self.acceleration_time, + gate=self.gate, + status=self.status, + set_move=self.set_move, + ) -class UndulatorPhaseMotor(UnstoppableMotor): + +class UndulatorPhaseMotor(UserSetpointWrapperUnstoppableMotor): """Phase motor that will not stop. Args: @@ -252,9 +295,9 @@ class UndulatorPhaseMotor(UnstoppableMotor): def __init__(self, prefix: str, name: str = ""): motor_pv = f"{prefix}MTR" - super().__init__(prefix=motor_pv, name=name) - self.user_setpoint = epics_signal_rw(str, prefix + "SET") + self.user_setpoint_str = epics_signal_rw(str, prefix + "SET") self.user_setpoint_readback = epics_signal_r(float, prefix + "DMD") + super().__init__(prefix=motor_pv, name=name) Apple2PhaseValType = TypeVar("Apple2PhaseValType", bound=Apple2LockedPhasesVal) @@ -281,19 +324,19 @@ def __init__( async def set_demand_positions(self, value: Apple2PhaseValType) -> None: await asyncio.gather( - self.top_outer.user_setpoint.set(value=str(value.top_outer)), - self.btm_inner.user_setpoint.set(value=str(value.btm_inner)), + self.top_outer.user_setpoint.set(value.top_outer), + self.btm_inner.user_setpoint.set(value.btm_inner), ) - async def get_timeout(self) -> float: + async def get_timeout_for_apple2(self) -> float: """Get all motor speed, current positions and target positions to calculate required timeout. """ timeouts = await asyncio.gather( *[ - estimate_motor_timeout( - axis.user_setpoint_readback, + estimate_motor_timeout_from_sigs( axis.user_readback, + axis.user_setpoint_readback, axis.velocity, ) for axis in self.axes @@ -334,10 +377,10 @@ def __init__( async def set_demand_positions(self, value: Apple2PhasesVal) -> None: await asyncio.gather( - self.top_outer.user_setpoint.set(value=str(value.top_outer)), - self.top_inner.user_setpoint.set(value=str(value.top_inner)), - self.btm_inner.user_setpoint.set(value=str(value.btm_inner)), - self.btm_outer.user_setpoint.set(value=str(value.btm_outer)), + self.top_outer.user_setpoint.set(value.top_outer), + self.top_inner.user_setpoint.set(value.top_inner), + self.btm_inner.user_setpoint.set(value.btm_inner), + self.btm_outer.user_setpoint.set(value.btm_outer), ) @@ -362,17 +405,27 @@ def __init__( super().__init__(self.set_move, prefix, name) async def set_demand_positions(self, value: float) -> None: - await self.jaw_phase.user_setpoint.set(value=str(value)) - - async def get_timeout(self) -> float: + await self.jaw_phase.user_setpoint.set(value) + + # async def get_timeout_for_apple2(self) -> float: + # """Get motor speed, current position and target position to calculate required + # timeout. + # """ + # return await estimate_motor_timeout( + # self.jaw_phase.user_setpoint_readback, + # self.jaw_phase.user_readback, + # self.jaw_phase.velocity, + # ) + + async def get_timeout_for_apple2(self) -> float: """Get motor speed, current position and target position to calculate required timeout. """ - return await estimate_motor_timeout( - self.jaw_phase.user_setpoint_readback, - self.jaw_phase.user_readback, - self.jaw_phase.velocity, + readback, setpoint = await asyncio.gather( + self.jaw_phase.user_readback.get_value(), + self.jaw_phase.user_setpoint_readback.get_value(), ) + return await self.jaw_phase.movable_logic.calculate_timeout(readback, setpoint) PhaseAxesType = TypeVar("PhaseAxesType", bound=UndulatorLockedPhaseAxes) @@ -404,16 +457,19 @@ async def set(self, id_motor_values: Apple2Val) -> None: them all at the same time. """ # Only need to check gap as the phase motors share both status and gate with gap. - await self.gap().raise_if_cannot_move() + await self.gap().movable_logic.raise_if_cannot_move() await asyncio.gather( self.phase().set_demand_positions( value=id_motor_values.extract_phase_val() ), - self.gap().set_demand_positions(value=float(id_motor_values.gap)), + self.gap().set_demand_positions(id_motor_values.gap), ) timeout = np.max( - await asyncio.gather(self.gap().get_timeout(), self.phase().get_timeout()) + await asyncio.gather( + self.gap().get_timeout_for_apple2(), + self.phase().get_timeout_for_apple2(), + ) ) LOGGER.info( f"Moving {self.name} apple2 motors to {id_motor_values}, timeout = {timeout}" diff --git a/src/dodal/testing/fixtures/devices/apple2.py b/src/dodal/testing/fixtures/devices/apple2.py index a8d06cedc14..e79fb2ad612 100644 --- a/src/dodal/testing/fixtures/devices/apple2.py +++ b/src/dodal/testing/fixtures/devices/apple2.py @@ -38,12 +38,12 @@ def my_side_effect(file_path, reset_cached_result) -> str: @pytest.fixture async def mock_id_gap(prefix: str = "BLXX-EA-DET-007:") -> UndulatorGap: async with init_devices(mock=True): - mock_id_gap = UndulatorGap(prefix, "mock_id_gap") + mock_id_gap = UndulatorGap(prefix) assert mock_id_gap.name == "mock_id_gap" set_mock_value(mock_id_gap.gate, UndulatorGateStatus.CLOSE) set_mock_value(mock_id_gap.velocity, 1) set_mock_value(mock_id_gap.user_readback, 1) - set_mock_value(mock_id_gap.user_setpoint, "1") + set_mock_value(mock_id_gap.user_setpoint_str, "1") set_mock_value(mock_id_gap.status, EnabledDisabledUpper.ENABLED) return mock_id_gap @@ -109,8 +109,9 @@ async def mock_locked_apple2( mock_id_gap: UndulatorGap, mock_locked_phase_axes: UndulatorLockedPhaseAxes, ) -> Apple2[UndulatorLockedPhaseAxes]: - mock_locked_apple2 = Apple2[UndulatorLockedPhaseAxes]( - id_gap=mock_id_gap, - id_phase=mock_locked_phase_axes, - ) + with init_devices(mock=True): + mock_locked_apple2 = Apple2[UndulatorLockedPhaseAxes]( + id_gap=mock_id_gap, + id_phase=mock_locked_phase_axes, + ) return mock_locked_apple2 diff --git a/tests/devices/insertion_device/test_apple2_undulator.py b/tests/devices/insertion_device/test_apple2_undulator.py index 80820441057..7fae6242f6a 100644 --- a/tests/devices/insertion_device/test_apple2_undulator.py +++ b/tests/devices/insertion_device/test_apple2_undulator.py @@ -89,19 +89,21 @@ async def test_gap_cal_timout( ): set_mock_value(mock_id_gap.velocity, velocity) set_mock_value(mock_id_gap.user_readback, readback) - set_mock_value(mock_id_gap.user_setpoint, str(target)) + set_mock_value(mock_id_gap.user_setpoint_str, str(target)) - assert await mock_id_gap.get_timeout() == pytest.approx(expected_timeout, rel=0.1) + assert await mock_id_gap.get_timeout_for_apple2() == pytest.approx( + expected_timeout, rel=0.1 + ) async def test_given_gate_never_closes_then_setting_gaps_times_out( mock_id_gap: UndulatorGap, ): callback_on_mock_put( - mock_id_gap.user_setpoint, + mock_id_gap.user_setpoint_str, lambda *_, **__: set_mock_value(mock_id_gap.gate, UndulatorGateStatus.OPEN), ) - mock_id_gap.get_timeout = AsyncMock(return_value=0.002) + mock_id_gap.movable_logic.get_timeout_for_apple2 = AsyncMock(return_value=0.002) with pytest.raises(TimeoutError): await mock_id_gap.set(2) @@ -150,7 +152,7 @@ async def test_given_gate_never_closes_then_setting_phases_times_out( mock_phase_axes.top_outer.user_setpoint, lambda *_, **__: set_mock_value(mock_phase_axes.gate, UndulatorGateStatus.OPEN), ) - mock_phase_axes.get_timeout = AsyncMock(return_value=0.002) + mock_phase_axes.get_timeout_for_apple2 = AsyncMock(return_value=0.002) with pytest.raises(TimeoutError): await mock_phase_axes.set(set_value) @@ -190,7 +192,7 @@ async def test_gap_prepare_success(mock_id_gap: UndulatorGap): set_mock_value(mock_id_gap.acceleration_time, 0.5) fly_info = FlyMotorInfo(start_position=25, end_position=35, time_for_move=1) await mock_id_gap.prepare(fly_info) - get_mock_put(mock_id_gap.user_setpoint).assert_awaited_once_with( + get_mock_put(mock_id_gap.user_setpoint_str).assert_awaited_once_with( str(fly_info.ramp_up_start_pos(0.5)) ) @@ -248,7 +250,7 @@ async def test_phase_cal_timout( set_mock_value(mock_phase_axes.btm_inner.user_setpoint_readback, target[2]) set_mock_value(mock_phase_axes.btm_outer.user_setpoint_readback, target[3]) - assert await mock_phase_axes.get_timeout() == pytest.approx( + assert await mock_phase_axes.get_timeout_for_apple2() == pytest.approx( expected_timeout, rel=0.01 ) @@ -284,16 +286,16 @@ def set_complete_move(): callback_on_mock_put(mock_phase_axes.set_move, lambda *_, **__: set_complete_move()) run_engine(bps.abs_set(mock_phase_axes, set_value, wait=True)) get_mock_put(mock_phase_axes.set_move).assert_called_once_with(1) - get_mock_put(mock_phase_axes.top_inner.user_setpoint).assert_called_once_with( + get_mock_put(mock_phase_axes.top_inner.user_setpoint_str).assert_called_once_with( str(set_value.top_inner) ) - get_mock_put(mock_phase_axes.top_outer.user_setpoint).assert_called_once_with( + get_mock_put(mock_phase_axes.top_outer.user_setpoint_str).assert_called_once_with( str(set_value.top_outer) ) - get_mock_put(mock_phase_axes.btm_inner.user_setpoint).assert_called_once_with( + get_mock_put(mock_phase_axes.btm_inner.user_setpoint_str).assert_called_once_with( str(set_value.btm_inner) ) - get_mock_put(mock_phase_axes.btm_outer.user_setpoint).assert_called_once_with( + get_mock_put(mock_phase_axes.btm_outer.user_setpoint_str).assert_called_once_with( str(set_value.btm_outer) ) @@ -315,7 +317,7 @@ async def test_given_gate_never_closes_then_setting_jaw_phases_times_out( mock_jaw_phase.jaw_phase.user_setpoint, lambda *_, **__: set_mock_value(mock_jaw_phase.gate, UndulatorGateStatus.OPEN), ) - mock_jaw_phase.get_timeout = AsyncMock(return_value=0.002) + mock_jaw_phase.get_timeout_for_apple2 = AsyncMock(return_value=0.002) with pytest.raises(TimeoutError): await mock_jaw_phase.set(2) @@ -346,7 +348,7 @@ async def test_jaw_phase_cal_timout( set_mock_value(mock_jaw_phase.jaw_phase.user_readback, readback) set_mock_value(mock_jaw_phase.jaw_phase.user_setpoint_readback, target) - assert await mock_jaw_phase.get_timeout() == pytest.approx( + assert await mock_jaw_phase.get_timeout_for_apple2() == pytest.approx( expected_timeout, rel=0.01 ) diff --git a/tests/devices/insertion_device/test_energy.py b/tests/devices/insertion_device/test_energy.py index 1afd2421eef..fa9ef2bfc39 100644 --- a/tests/devices/insertion_device/test_energy.py +++ b/tests/devices/insertion_device/test_energy.py @@ -102,7 +102,7 @@ async def test_insertion_device_energy_prepare_success( ramp_up_start = start_gap - acceleration_time * velocity / 2.0 mock_id_energy.set.assert_awaited_once_with(energy=750) get_mock_put( - mock_id_controller.apple2().gap().user_setpoint + mock_id_controller.apple2().gap().user_setpoint_str ).assert_awaited_once_with(str(ramp_up_start)) assert await mock_id_controller.apple2().gap().velocity.get_value() == abs(velocity)