Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openhtf/core/phase_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class PhaseOptions(object):
timeout will still apply when under the debugger.
phase_name_case: Case formatting options for phase name.
stop_on_measurement_fail: Whether to stop the test if any measurements fail.
prerequisites: List of phases that must be completed before this phase can
be run.
Example Usages: @PhaseOptions(timeout_s=1)
def PhaseFunc(test): pass @PhaseOptions(name='Phase({port})')
def PhaseFunc(test, port, other_info): pass
Expand All @@ -140,6 +142,7 @@ def PhaseFunc(test, port, other_info): pass
run_under_pdb = attr.ib(type=bool, default=False)
phase_name_case = attr.ib(type=PhaseNameCase, default=PhaseNameCase.KEEP)
stop_on_measurement_fail = attr.ib(type=bool, default=False)
prerequisites = attr.ib(type=Optional[List[Any]], default=None)

def format_strings(self, **kwargs: Any) -> 'PhaseOptions':
"""String substitution of name."""
Expand Down Expand Up @@ -173,6 +176,8 @@ def __call__(self, phase_func: PhaseT) -> 'PhaseDescriptor':
phase.options.stop_on_measurement_fail = self.stop_on_measurement_fail
if self.phase_name_case:
phase.options.phase_name_case = self.phase_name_case
if self.prerequisites is not None:
phase.options.prerequisites = self.prerequisites
return phase


Expand Down
49 changes: 32 additions & 17 deletions openhtf/core/phase_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import time
import traceback
import types
from typing import Any, Dict, Optional, Text, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Type, Union

import attr
from openhtf import util
Expand Down Expand Up @@ -170,10 +170,12 @@ def __init__(self, phase_desc: phase_descriptor.PhaseDescriptor,
self._phase_desc = phase_desc
self._test_state = test_state
self._subtest_rec = subtest_rec
self._phase_state = test_state.running_phase_state
self._phase_execution_outcome = None # type: Optional[PhaseExecutionOutcome]

def _thread_proc(self) -> None:
"""Execute the encompassed phase and save the result."""
self._test_state.running_phase_state = self._phase_state
# Call the phase, save the return value, or default it to CONTINUE.
phase_return = self._phase_desc(self._test_state)
if phase_return is None:
Expand Down Expand Up @@ -239,6 +241,7 @@ def __init__(self, test_state: 'htf_test_state.TestState'):
# _execute_phase_once is setting up the next phase thread.
self._current_phase_thread_lock = threading.Lock()
self._current_phase_thread = None # type: Optional[PhaseExecutorThread]
self._active_phase_threads = set() # type: Set[PhaseExecutorThread]
self._stopping = threading.Event()

def _should_repeat(self, phase: phase_descriptor.PhaseDescriptor,
Expand Down Expand Up @@ -320,9 +323,12 @@ def _execute_phase_once(
phase_desc.name)
return PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), None


override_result = None
with self.test_state.running_phase_context(phase_desc) as phase_state:
if id(phase_desc) in getattr(self.test_state, '_concurrent_nodes', set()):
ctx_mgr = self.test_state.concurrent_running_phase_context
else:
ctx_mgr = self.test_state.running_phase_context
with ctx_mgr(phase_desc) as phase_state:
if subtest_rec:
self.logger.debug('Executing phase %s under subtest %s (from %s)',
phase_desc.name, phase_desc.func_location,
Expand All @@ -347,14 +353,18 @@ def _execute_phase_once(
run_with_profiling, subtest_rec)
phase_thread.start()
self._current_phase_thread = phase_thread
self._active_phase_threads.add(phase_thread)

phase_state.result = phase_thread.join_or_die()
if phase_state.result.is_repeat and is_last_repeat:
self.logger.error('Phase returned REPEAT, exceeding repeat_limit.')
phase_state.hit_repeat_limit = True
override_result = PhaseExecutionOutcome(
phase_descriptor.PhaseResult.STOP)
self._current_phase_thread = None
with self._current_phase_thread_lock:
self._active_phase_threads.discard(phase_thread)
if self._current_phase_thread == phase_thread:
self._current_phase_thread = None

# Refresh the result in case a validation for a partially set measurement
# or phase diagnoser raised an exception.
Expand Down Expand Up @@ -433,18 +443,23 @@ def stop(
"""
self._stopping.set()
with self._current_phase_thread_lock:
phase_thread = self._current_phase_thread
if not phase_thread:
return

if phase_thread.is_alive():
phase_thread.kill()

self.logger.debug('Waiting for cancelled phase to exit: %s', phase_thread)
timeout = timeouts.PolledTimeout.from_seconds(timeout_s)
while phase_thread.is_alive() and not timeout.has_expired():
time.sleep(0.1)
self.logger.debug('Cancelled phase %s exit',
"didn't" if phase_thread.is_alive() else 'did')
threads_to_kill = list(self._active_phase_threads)

timeout = timeouts.PolledTimeout.from_seconds(timeout_s)
for phase_thread in threads_to_kill:
if phase_thread.is_alive():
phase_thread.kill()

for phase_thread in threads_to_kill:
if phase_thread.is_alive():
self.logger.debug(
'Waiting for cancelled phase to exit: %s', phase_thread
)
while phase_thread.is_alive() and not timeout.has_expired():
time.sleep(0.1)
self.logger.debug(
'Cancelled phase %s exit',
"didn't" if phase_thread.is_alive() else 'did',
)
# Clear the currently running phase, whether it finished or timed out.
self.test_state.stop_running_phase()
229 changes: 229 additions & 0 deletions openhtf/core/phase_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright 2026 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Phase Graph support for OpenHTF.

PhaseGraph is a PhaseCollectionNode that manages its contained phases via
a topological sort based on their explicit prerequisites.
"""

from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Text, Tuple, Type

import attr
from openhtf import util
from openhtf.core import base_plugs
from openhtf.core import phase_collections
from openhtf.core import phase_descriptor


class CyclicDependencyError(Exception):
"""PhaseGraph phases have cyclic dependencies."""


class PhaseUnreachableError(Exception):
"""A prerequisite is not defined in the graph."""


class DuplicatePhaseNameError(Exception):
"""PhaseGraph phases have duplicate names."""


@attr.s(slots=True, frozen=True)
class PhaseEdge:
"""A dependent phase and its prerequisite phases.

Attributes:
dependent: The phase that depends on the prerequisites.
prerequisites: Phases that must run before the dependent phase.
"""

dependent = attr.ib(type=phase_descriptor.PhaseCallableOrNodeT)
prerequisites = attr.ib(type=Sequence[phase_descriptor.PhaseCallableOrNodeT])


@attr.s(slots=True, frozen=True, init=False)
class PhaseGraph(phase_collections.PhaseCollectionNode):
"""A phase collection whose execution order is defined by a DAG.

For each phase, the name must be unique within the PhaseGraph. The execution
order is determined by a topological sort of the phases based on their
prerequisites.

Attributes:
nodes: A tuple of PhaseDescriptor instances in topologically sorted order.
name: An optional name for this PhaseGraph.
"""

@classmethod
def from_edges(
cls,
edges: Sequence['PhaseEdge'],
name: Optional[Text] = None,
) -> 'PhaseGraph':
"""Constructs a PhaseGraph from explicit PhaseEdge objects."""
flat_unique_nodes = []
seen_ids = set()

def _add(node):
node_id = id(node)
if node_id not in seen_ids:
seen_ids.add(node_id)
flat_unique_nodes.append(node)

for edge in edges:
_add(edge.dependent)
for prereq in edge.prerequisites:
_add(prereq)

wrapped_nodes = []
wrapped_by_orig = {}
for n in flat_unique_nodes:
wrapped = phase_descriptor.PhaseDescriptor.wrap_or_copy(n)
wrapped_nodes.append(wrapped)
wrapped_by_orig[id(n)] = wrapped
wrapped_by_orig[wrapped.name] = wrapped

for edge in edges:
wrapped_dep = wrapped_by_orig[id(edge.dependent)]
prereq_names = []
for prereq in edge.prerequisites:
wrapped_prereq = wrapped_by_orig.get(id(prereq))
if wrapped_prereq:
prereq_names.append(wrapped_prereq.name)
wrapped_dep.options.prerequisites = prereq_names

return cls(*wrapped_nodes, name=name)

nodes = attr.ib(type=Tuple[phase_descriptor.PhaseDescriptor, ...])
name = attr.ib(type=Optional[Text], default=None)

def __init__(
self,
*args: phase_descriptor.PhaseCallableOrNodeT,
name: Optional[Text] = None,
nodes: Optional[Tuple[phase_descriptor.PhaseDescriptor, ...]] = None,
):
super(PhaseGraph, self).__init__()
object.__setattr__(self, 'name', name)

if nodes is not None:
args = args + tuple(nodes)

flattened = list(phase_collections._recursive_flatten(args))
# Verify elements are PhaseDescriptor instances for prerequisite matching
ph_desc_list = []
for n in flattened:
if isinstance(n, phase_descriptor.PhaseDescriptor):
ph_desc_list.append(n)
else:
# Wrap or copy standard callables / nodes
ph_desc_list.append(phase_descriptor.PhaseDescriptor.wrap_or_copy(n))

topologically_sorted = self._validate_and_toposort(ph_desc_list)
object.__setattr__(self, 'nodes', tuple(topologically_sorted))

def _validate_and_toposort(
self, nodes: List[phase_descriptor.PhaseDescriptor]
) -> List[phase_descriptor.PhaseDescriptor]:
"""Validates the DAG structure and returns topologically sorted nodes."""
# Verify unique phase names
node_by_name: Dict[str, phase_descriptor.PhaseDescriptor] = {}
for n in nodes:
if n.name in node_by_name:
raise DuplicatePhaseNameError(
f"Duplicate phase name '{n.name}' detected in PhaseGraph."
)
node_by_name[n.name] = n

# Match prerequisites to actual nodes
adjacency = {n.name: set() for n in nodes}
for n in nodes:
if n.options.prerequisites is not None:
for pr in n.options.prerequisites:
pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None)
if not pr_name or pr_name not in node_by_name:
raise PhaseUnreachableError(
f"Prerequisite '{pr_name}' for phase '{n.name}' not found in"
' PhaseGraph.'
)
adjacency[n.name].add(pr_name)

# Perform DFS-based topological sort with cycle detection.
visited = set()
temp_marked = set()
sorted_names = []

def _visit(node_name: str):
if node_name in temp_marked:
raise CyclicDependencyError(f"Cycle detected involving '{node_name}'")
if node_name in visited:
return
temp_marked.add(node_name)
for prereq_name in adjacency[node_name]:
_visit(prereq_name)
temp_marked.remove(node_name)
visited.add(node_name)
sorted_names.append(node_name)

for n in nodes:
if n.name not in visited:
_visit(n.name)

return [node_by_name[name] for name in sorted_names]

def _asdict(self) -> Dict[Text, Any]:
return {
'name': self.name,
'nodes': [n._asdict() for n in self.nodes],
}

def with_args(self, **kwargs: Any) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.with_args(**kwargs) for n in self.nodes),
name=util.format_string(self.name, kwargs),
)

def with_plugs(self, **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes),
name=util.format_string(self.name, subplugs),
)

def load_code_info(self) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.load_code_info() for n in self.nodes),
name=self.name,
)

def apply_to_all_phases(
self,
func: Callable[
[phase_descriptor.PhaseDescriptor], phase_descriptor.PhaseDescriptor
],
) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes),
name=self.name,
)

def filter_by_type(self, node_cls: Type[Any]) -> Iterator[Any]:
for node in self.nodes:
if isinstance(node, node_cls):
yield node
if isinstance(node, phase_collections.PhaseCollectionNode):
for sub_n in node.filter_by_type(node_cls):
yield sub_n
Loading
Loading