From 7f7a657f6ce4580d3678606e962e0b188712465b Mon Sep 17 00:00:00 2001 From: agsaru Date: Sun, 7 Jun 2026 17:38:55 +0000 Subject: [PATCH 1/4] used raw string --- metaflow/plugins/pypi/pip.py | 2 +- test/plugins/pip/test_pip_indices.py | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index 31b215bba00..67fcf206179 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -37,7 +37,7 @@ def __init__(self, error): self.package_spec = re.search( "ERROR: No matching distribution found for (.*)", self.error )[1] - self.package_name = re.match("\w*", self.package_spec)[0] + self.package_name = re.match(r"\w*", self.package_spec)[0] except Exception: pass diff --git a/test/plugins/pip/test_pip_indices.py b/test/plugins/pip/test_pip_indices.py index e5e8d09b6e7..5713a8153ca 100644 --- a/test/plugins/pip/test_pip_indices.py +++ b/test/plugins/pip/test_pip_indices.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock - from metaflow.plugins.pypi.pip import Pip @@ -9,7 +7,7 @@ def _make_pip(): return pip -def test_multiple_extra_index_urls_literal_newline(monkeypatch): +def test_multiple_extra_index_urls_literal_newline(mocker): """Regression test: pip config list separates multiple URLs with literal \\n.""" pip = _make_pip() config_output = ( @@ -17,14 +15,10 @@ def test_multiple_extra_index_urls_literal_newline(monkeypatch): r"global.extra-index-url='https://extra1.example.com/simple'\n'https://extra2.example.com/simple'" ) - # Use monkeypatch instead of unittest.mock.patch - mock_call = MagicMock(return_value=config_output) - monkeypatch.setattr(pip, "_call", mock_call) + mocker.patch.object(pip, "_call", return_value=config_output) - # Execute index, extras = pip.indices("dummy") - # Assert assert index == "https://pypi.org/simple" assert extras == [ "https://extra1.example.com/simple", From b90b17c733ea5a36274b6a6622e9beedefc1e46a Mon Sep 17 00:00:00 2001 From: agsaru Date: Mon, 8 Jun 2026 07:12:24 +0000 Subject: [PATCH 2/4] refactor(tests): standardize tests to follow pytest conventions --- test/cmd/develop/test_stub_generator.py | 827 ++++++++-------- test/cmd/diff/test_metaflow_diff.py | 276 +++--- test/data/s3/conftest.py | 39 +- test/data/s3/test_s3.py | 120 ++- test/data/s3/test_s3op.py | 62 +- .../conda/test_conda_environment_unit.py | 161 ++- test/plugins/conda/test_parsers.py | 231 ++--- test/unit/configs/test_config_naming.py | 58 +- test/unit/configs/test_config_plain.py | 87 +- test/unit/graph_inference/test_card_dag.py | 37 +- .../graph_inference/test_graph_inference.py | 262 +++-- test/unit/inheritance/test_inheritance.py | 579 +++++------ test/unit/localbatch/test_localbatch.py | 561 +++++------ .../mutators/test_add_decorator_returns.py | 50 +- test/unit/mutators/test_dual_inheritance.py | 64 +- .../mutators/test_flow_mutator_addition.py | 53 +- .../mutators/test_post_step_none_false.py | 28 +- .../mutators/test_remove_decorator_guard.py | 71 +- .../unit/mutators/test_string_step_mutator.py | 19 +- test/unit/spin/test_spin.py | 146 ++- test/unit/test_add_to_package.py | 789 ++++++++------- test/unit/test_argo_workflows_cli.py | 194 ++-- test/unit/test_artifact_serializer.py | 459 ++++----- test/unit/test_aws_util.py | 46 +- test/unit/test_card_creator.py | 81 +- test/unit/test_compute_resource_attributes.py | 160 +-- test/unit/test_config_value.py | 143 +-- test/unit/test_content_addressed_store.py | 55 +- test/unit/test_graph_endpoints_fallback.py | 49 +- test/unit/test_graph_structure.py | 120 +-- test/unit/test_kubernetes.py | 50 +- test/unit/test_local_metadata_provider.py | 61 +- test/unit/test_metaflow_version.py | 28 + test/unit/test_package_suffixes_mutator.py | 163 +-- test/unit/test_packaging_utils.py | 81 +- test/unit/test_pickle_serializer.py | 30 +- test/unit/test_pypi_parsers.py | 40 +- test/unit/test_remove_decorator.py | 31 +- test/unit/test_s3_empty_input.py | 244 ++--- test/unit/test_s3_storage.py | 37 +- test/unit/test_secrets_decorator.py | 84 +- test/unit/test_serializer_integration.py | 382 +++---- test/unit/test_serializer_lifecycle.py | 928 +++++++----------- test/unit/test_serializer_public_api.py | 92 +- test/unit/test_sourceless_dag_node.py | 25 +- test/unit/test_system_context.py | 302 +++--- test/unit/test_task_log_metadata_fetch.py | 26 +- test/unit/test_to_pod.py | 75 +- test/unit/test_tutorial_01_02_csv_parsing.py | 80 +- test/ux/conftest.py | 9 +- test/ux/core/conftest.py | 1 + test/ux/core/test_airflow_compilation.py | 163 ++- test/ux/core/test_argo_compilation.py | 82 +- test/ux/core/test_basic.py | 360 +++---- test/ux/core/test_compliance.py | 383 +++----- test/ux/core/test_config.py | 298 +++--- test/ux/core/test_dag.py | 303 +++--- test/ux/core/test_decorators.py | 110 ++- test/ux/core/test_lifecycle.py | 42 +- test/ux/core/test_resume.py | 194 ++-- test/ux/core/test_sfn_compilation.py | 60 +- test/ux/core/test_utils.py | 9 +- 62 files changed, 5382 insertions(+), 5218 deletions(-) diff --git a/test/cmd/develop/test_stub_generator.py b/test/cmd/develop/test_stub_generator.py index 561163e5152..f94ca2b8201 100644 --- a/test/cmd/develop/test_stub_generator.py +++ b/test/cmd/develop/test_stub_generator.py @@ -1,9 +1,7 @@ +import inspect import sys -import tempfile import typing -import inspect from typing import TypeVar, Optional -from unittest.mock import Mock, patch import pytest @@ -31,464 +29,429 @@ class ComplexGenericClass(typing.Generic[T, U]): pass -class TestStubGenerator: - """Test suite for StubGenerator functionality""" - - def setup_method(self): - """Set up test environment""" - self.temp_dir = tempfile.mkdtemp() - self.generator = StubGenerator(self.temp_dir, include_generated_for=False) - # Reset internal state - self.generator._reset() - self.generator._current_module_name = "test_module" - self.generator._current_name = None # Initialize to avoid AttributeError - - def test_get_element_name_basic_types(self): - """Test basic type handling""" - # Test builtin types - assert self.generator._get_element_name_with_module(int) == "int" - assert self.generator._get_element_name_with_module(str) == "str" - assert self.generator._get_element_name_with_module(type(None)) == "None" - - # Test TypeVar - type_var = TypeVar("TestTypeVar") - result = self.generator._get_element_name_with_module(type_var) - assert result == "TestTypeVar" - assert "TestTypeVar" in self.generator._typevars - - def test_get_element_name_class_objects(self): - """Test handling of class objects - the main issue we're fixing""" - # Mock the module to avoid import issues - mock_module = Mock() - mock_module.__name__ = "test.module" - - with patch("inspect.getmodule", return_value=mock_module): - result = self.generator._get_element_name_with_module(TestClass) - assert result == "test.module.TestClass" - assert "test.module" in self.generator._typing_imports - - def test_get_element_name_generic_alias_with_class_objects(self): - """Test the specific case that was failing - class objects in generic type arguments""" - # Create a generic alias with class objects as arguments - callable_type = typing.Callable[[TestClass, Optional[int]], TestClass] - - # Mock the module for TestClass - mock_module = Mock() - mock_module.__name__ = "test.module" - - with patch("inspect.getmodule", return_value=mock_module): - result = self.generator._get_element_name_with_module(callable_type) - - # Should not contain - assert " TestClass: - pass +@pytest.fixture +def stub_generator(tmp_path): + """Provides a fresh StubGenerator instance for each test.""" + generator = StubGenerator(str(tmp_path), include_generated_for=False) + generator._reset() + generator._current_module_name = "test_module" + generator._current_name = None # Initialize to avoid AttributeError + return generator + - mock_module = Mock() - mock_module.__name__ = "test.module" +@pytest.fixture +def mock_getmodule(mocker): + """Helper fixture to easily mock inspect.getmodule to return a specific module name.""" - with patch("inspect.getmodule", return_value=mock_module): - stub = self.generator._generate_function_stub("test_func", test_func) + def _mocker(module_name="test.module"): + mock_module = mocker.Mock() + mock_module.__name__ = module_name + return mocker.patch("inspect.getmodule", return_value=mock_module) - # Should not contain class objects - assert " Optional[TestClass]: - pass +@pytest.mark.parametrize( + "type_obj, expected", + [ + (int, "int"), + (str, "str"), + (type(None), "None"), + ], + ids=["int", "str", "NoneType"], +) +def test_get_element_name_resolves_basic_builtin_types( + stub_generator, type_obj, expected +): + """Test basic builtin type handling.""" + assert stub_generator._get_element_name_with_module(type_obj) == expected - mock_module = Mock() - mock_module.__name__ = "test.module" - with patch("inspect.getmodule", return_value=mock_module): - stub = self.generator._generate_class_stub( - "TestClassWithMethods", TestClassWithMethods - ) +def test_get_element_name_registers_and_resolves_typevars(stub_generator): + """Test TypeVar parsing and registration.""" + type_var = TypeVar("TestTypeVar") + result = stub_generator._get_element_name_with_module(type_var) - assert " TestClass: + pass - assert " Optional[TestClass]: + pass + + mock_getmodule("test.module") + stub = stub_generator._generate_class_stub( + "TestClassWithMethods", TestClassWithMethods + ) - def test_get_element_name_newtype(self): - """Test NewType handling""" - from typing import NewType + assert " initial_typing_imports - assert "test.isolated.module" in self.generator._typing_imports - - def test_get_element_name_nonetype_handling(self): - """Test that NoneType is properly converted to None in type annotations""" - # Test direct NoneType - result = self.generator._get_element_name_with_module(type(None)) - assert result == "None" - - # Test NoneType in generic type (like Callable[..., None]) - callable_type = typing.Callable[[TestClass], type(None)] - - mock_module = Mock() - mock_module.__name__ = "test.module" - - with patch("inspect.getmodule", return_value=mock_module): - result = self.generator._get_element_name_with_module(callable_type) - - # Should not contain NoneType, should contain None - assert "NoneType" not in result - assert "None" in result - assert "typing.Callable" in result - assert "test.module.TestClass" in result - - -# Integration test to verify class objects in generic types are properly handled -def test_class_objects_in_generic_types_no_leakage(): - """Regression test ensuring class objects don't leak as '' in type annotations""" - generator = StubGenerator("/tmp/test_stubs", include_generated_for=False) + ] + ], + ] + mock_getmodule("test.module") + + result = stub_generator._get_element_name_with_module(nested_type) + + assert " initial_typing_imports + assert "test.isolated.module" in stub_generator._typing_imports + + +def test_get_element_name_converts_nonetype(stub_generator, mock_getmodule): + """Test that NoneType is properly converted to None in generic type annotations.""" + callable_type = typing.Callable[[TestClass], type(None)] + mock_getmodule("test.module") + + result = stub_generator._get_element_name_with_module(callable_type) + + assert "NoneType" not in result + assert "None" in result + assert "typing.Callable" in result + assert "test.module.TestClass" in result + + +def test_class_objects_in_generic_types_do_not_leak_class_repr(tmp_path, mocker): + """Regression test ensuring class objects don't leak as '' in type annotations.""" + generator = StubGenerator(str(tmp_path), include_generated_for=False) generator._reset() generator._current_module_name = "test_module" @@ -497,36 +460,36 @@ def test_class_objects_in_generic_types_no_leakage(): FunctionParameters = type("FunctionParameters", (), {}) # Mock modules - mock_df_module = Mock() + mock_df_module = mocker.Mock() mock_df_module.__name__ = "metaflow_extensions.nflx.plugins.datatools.dataframe" - mock_fp_module = Mock() + mock_fp_module = mocker.Mock() mock_fp_module.__name__ = ( "metaflow_extensions.nflx.plugins.functions.core.function_parameters" ) - def mock_getmodule(obj): + def custom_getmodule(obj): if obj == MetaflowDataFrame: return mock_df_module elif obj == FunctionParameters: return mock_fp_module return None + mocker.patch("inspect.getmodule", side_effect=custom_getmodule) + # The problematic type annotation problematic_type = typing.Callable[ [MetaflowDataFrame, typing.Optional[FunctionParameters]], MetaflowDataFrame ] - with patch("inspect.getmodule", side_effect=mock_getmodule): - result = generator._get_element_name_with_module(problematic_type) - - # The key assertion - no class objects should appear - assert "=2.0" - assert result["packages"]["numpy"] == "1.21.0" - assert result["packages"]["pandas"] == "" - - def test_python_version(self): - content = "python==3.9\nrequests\n" - result = requirements_txt_parser(content) - assert result["python"] == "3.9" - assert "python" not in result["packages"] - assert result["packages"]["requests"] == "" - - def test_comments_and_blank_lines(self): - content = "# this is a comment\n\nrequests==2.0\n # another comment\nnumpy\n" - result = requirements_txt_parser(content) - assert len(result["packages"]) == 2 - - def test_inline_comments(self): - content = "requests==2.0 # HTTP library\n" - result = requirements_txt_parser(content) - assert result["packages"]["requests"] == "2.0" - - def test_extras(self): - content = "requests[security]==2.28.0\n" - result = requirements_txt_parser(content) - assert "requests[security]" in result["packages"] - - def test_direct_reference(self): - content = "mylib @ git+https://github.com/user/repo.git\n" - result = requirements_txt_parser(content) - assert any("mylib" in k for k in result["packages"]) - - def test_environment_markers_rejected(self): - content = 'requests==2.0; python_version>="3.6"\n' - with pytest.raises(ParserValueError, match="Environment markers"): - requirements_txt_parser(content) - - def test_invalid_requirement(self): - content = "not a valid requirement!!!\n" - with pytest.raises(ParserValueError, match="Not a valid PEP 508"): - requirements_txt_parser(content) - - def test_multiple_python_specs_rejected(self): - content = "python==3.9\npython==3.10\n" - with pytest.raises(ParserValueError, match="Multiple Python version"): - requirements_txt_parser(content) - - def test_empty_content(self): - result = requirements_txt_parser("") - assert result["packages"] == {} - assert result["python"] is None - - def test_rye_lockfile_skip(self): - """Rye lockfiles contain '-e file:.' which should be silently skipped.""" - content = "-e file:.\nrequests==2.0\n" - result = requirements_txt_parser(content) - assert result["packages"]["requests"] == "2.0" +def test_requirements_parser_simple_package(): + result = requirements_txt_parser("requests==2.28.0\n") + assert result["packages"] == {"requests": "2.28.0"} + assert result["python"] is None + + +def test_requirements_parser_multiple_packages(): + content = "requests>=2.0\nnumpy==1.21.0\npandas\n" + result = requirements_txt_parser(content) + assert result["packages"]["requests"] == ">=2.0" + assert result["packages"]["numpy"] == "1.21.0" + assert result["packages"]["pandas"] == "" + + +def test_requirements_parser_python_version(): + content = "python==3.9\nrequests\n" + result = requirements_txt_parser(content) + assert result["python"] == "3.9" + assert "python" not in result["packages"] + assert result["packages"]["requests"] == "" + + +def test_requirements_parser_comments_and_blank_lines(): + content = "# this is a comment\n\nrequests==2.0\n # another comment\nnumpy\n" + result = requirements_txt_parser(content) + assert len(result["packages"]) == 2 + + +def test_requirements_parser_inline_comments(): + content = "requests==2.0 # HTTP library\n" + result = requirements_txt_parser(content) + assert result["packages"]["requests"] == "2.0" + + +def test_requirements_parser_extras(): + content = "requests[security]==2.28.0\n" + result = requirements_txt_parser(content) + assert "requests[security]" in result["packages"] + + +def test_requirements_parser_direct_reference(): + content = "mylib @ git+https://github.com/user/repo.git\n" + result = requirements_txt_parser(content) + assert any("mylib" in k for k in result["packages"]) + + +@pytest.mark.parametrize( + "content, match", + [ + ('requests==2.0; python_version>="3.6"\n', "Environment markers"), + ("not a valid requirement!!!\n", "Not a valid PEP 508"), + ("python==3.9\npython==3.10\n", "Multiple Python version"), + ], + ids=["env_markers", "invalid_pep508", "multi_python"], +) +def test_requirements_parser_invalid_inputs(content, match): + with pytest.raises(ParserValueError, match=match): + requirements_txt_parser(content) + + +def test_requirements_parser_empty_content(): + result = requirements_txt_parser("") + assert result["packages"] == {} + assert result["python"] is None + + +def test_requirements_parser_rye_lockfile_skip(): + """Rye lockfiles contain '-e file:.' which should be silently skipped.""" + content = "-e file:.\nrequests==2.0\n" + result = requirements_txt_parser(content) + assert result["packages"]["requests"] == "2.0" # --------------------------------------------------------------------------- @@ -90,47 +95,53 @@ def test_rye_lockfile_skip(self): # --------------------------------------------------------------------------- -class TestCondaEnvironmentYmlParser: - def test_simple_deps(self): - content = "dependencies:\n - numpy=1.21.2\n - pandas=1.3.0\n" - result = conda_environment_yml_parser(content) - assert result["packages"]["numpy"] == "1.21.2" - assert result["packages"]["pandas"] == "1.3.0" - - def test_python_version(self): - content = "dependencies:\n - python=3.9\n - numpy\n" - result = conda_environment_yml_parser(content) - assert result["python"] == "3.9" - assert "python" not in result["packages"] - assert result["packages"]["numpy"] == "" - - def test_no_version(self): - content = "dependencies:\n - numpy\n" - result = conda_environment_yml_parser(content) - assert result["packages"]["numpy"] == "" - - def test_comments_skipped(self): - content = "# env file\ndependencies:\n # a comment\n - numpy=1.0\n" - result = conda_environment_yml_parser(content) - assert result["packages"]["numpy"] == "1.0" - - def test_subsection_rejected(self): - content = "dependencies:\n - pip:\n - requests\n" - with pytest.raises(ParserValueError, match="Unsupported subsection"): - conda_environment_yml_parser(content) - - def test_inline_comments(self): - content = "dependencies:\n - numpy=1.0 # math lib\n" - result = conda_environment_yml_parser(content) - assert result["packages"]["numpy"] == "1.0" - - def test_empty_deps(self): - content = "name: test\n" - result = conda_environment_yml_parser(content) - assert result["packages"] == {} - assert result["python"] is None - - def test_double_equals(self): - content = "dependencies:\n - numpy==1.21.2\n" - result = conda_environment_yml_parser(content) - assert result["packages"]["numpy"] == "1.21.2" +def test_conda_parser_simple_deps(): + content = "dependencies:\n - numpy=1.21.2\n - pandas=1.3.0\n" + result = conda_environment_yml_parser(content) + assert result["packages"]["numpy"] == "1.21.2" + assert result["packages"]["pandas"] == "1.3.0" + + +def test_conda_parser_python_version(): + content = "dependencies:\n - python=3.9\n - numpy\n" + result = conda_environment_yml_parser(content) + assert result["python"] == "3.9" + assert "python" not in result["packages"] + assert result["packages"]["numpy"] == "" + + +def test_conda_parser_no_version(): + content = "dependencies:\n - numpy\n" + result = conda_environment_yml_parser(content) + assert result["packages"]["numpy"] == "" + + +def test_conda_parser_comments_skipped(): + content = "# env file\ndependencies:\n # a comment\n - numpy=1.0\n" + result = conda_environment_yml_parser(content) + assert result["packages"]["numpy"] == "1.0" + + +def test_conda_parser_subsection_rejected(): + content = "dependencies:\n - pip:\n - requests\n" + with pytest.raises(ParserValueError, match="Unsupported subsection"): + conda_environment_yml_parser(content) + + +def test_conda_parser_inline_comments(): + content = "dependencies:\n - numpy=1.0 # math lib\n" + result = conda_environment_yml_parser(content) + assert result["packages"]["numpy"] == "1.0" + + +def test_conda_parser_empty_deps(): + content = "name: test\n" + result = conda_environment_yml_parser(content) + assert result["packages"] == {} + assert result["python"] is None + + +def test_conda_parser_double_equals(): + content = "dependencies:\n - numpy==1.21.2\n" + result = conda_environment_yml_parser(content) + assert result["packages"]["numpy"] == "1.21.2" diff --git a/test/unit/configs/test_config_naming.py b/test/unit/configs/test_config_naming.py index 0e9ec58fdc1..f82dc1dfc23 100644 --- a/test/unit/configs/test_config_naming.py +++ b/test/unit/configs/test_config_naming.py @@ -11,34 +11,30 @@ import pytest -class TestConfigNaming: - """Test Config parameter names with underscores and dashes.""" - - def test_flow_completes(self, config_naming_run): - """Test that the flow completes successfully.""" - assert config_naming_run.successful - assert config_naming_run.finished - - def test_config_with_underscore(self, config_naming_run): - """Test Config with underscore in name.""" - end_task = config_naming_run["end"].task - - assert end_task["underscore_test"].data == "underscore" - assert end_task["underscore_value"].data == 42 - assert end_task["underscore_dict"].data == {"test": "underscore", "value": 42} - - def test_config_with_dash(self, config_naming_run): - """Test Config with dash in name.""" - end_task = config_naming_run["end"].task - - assert end_task["dash_test"].data == "dash" - assert end_task["dash_value"].data == 99 - assert end_task["dash_dict"].data == {"test": "dash", "value": 99} - - def test_config_with_mixed_naming(self, config_naming_run): - """Test Config with both underscores and dashes in name.""" - end_task = config_naming_run["end"].task - - assert end_task["mixed_test"].data == "mixed" - assert end_task["mixed_value"].data == 123 - assert end_task["mixed_dict"].data == {"test": "mixed", "value": 123} +def test_flow_completes(config_naming_run): + """Test that the configuration parsing flow completes successfully.""" + assert config_naming_run.successful + assert config_naming_run.finished + + +@pytest.mark.parametrize( + "prefix, expected_text, expected_value", + [ + ("underscore", "underscore", 42), + ("dash", "dash", 99), + ("mixed", "mixed", 123), + ], + ids=["underscore_naming", "dash_naming", "mixed_naming"], +) +def test_config_parameter_naming_formats( + config_naming_run, prefix, expected_text, expected_value +): + """Test that configuration parameters parse correctly across different naming conventions.""" + end_task = config_naming_run["end"].task + + assert end_task[f"{prefix}_test"].data == expected_text + assert end_task[f"{prefix}_value"].data == expected_value + assert end_task[f"{prefix}_dict"].data == { + "test": expected_text, + "value": expected_value, + } diff --git a/test/unit/configs/test_config_plain.py b/test/unit/configs/test_config_plain.py index f9b95ba71b2..57c5edb547f 100644 --- a/test/unit/configs/test_config_plain.py +++ b/test/unit/configs/test_config_plain.py @@ -10,50 +10,43 @@ import pytest -class TestConfigPlain: - """Test Config with plain=True option.""" - - def test_flow_completes(self, config_plain_run): - """Test that the flow completes successfully.""" - assert config_plain_run.successful - assert config_plain_run.finished - - def test_plain_string_without_parser(self, config_plain_run): - """Test plain Config without parser returns raw string.""" - end_task = config_plain_run["end"].task - - # Verify it's a string - assert end_task["plain_str_type"].data == "str" - - # Verify the value is the raw string (not parsed JSON) - assert end_task["plain_str_value"].data == '{"raw": "string", "number": 123}' - - def test_plain_list_with_parser(self, config_plain_run): - """Test plain Config with parser returning list (non-dict).""" - end_task = config_plain_run["end"].task - - # Verify it's a list - assert end_task["plain_list_type"].data == "list" - - # Verify the list contents - assert end_task["plain_list_value"].data == [ - "apple", - "banana", - "cherry", - "date", - ] - assert end_task["plain_list_length"].data == 4 - assert end_task["plain_list_first"].data == "apple" - - def test_plain_tuple_with_parser(self, config_plain_run): - """Test plain Config with parser returning tuple (non-dict).""" - end_task = config_plain_run["end"].task - - # Verify it's a tuple type - assert end_task["plain_tuple_type"].data == "tuple" - - # Verify tuple contents - assert end_task["plain_tuple_value"].data == ("test_tuple", 42, True) - assert end_task["tuple_name"].data == "test_tuple" - assert end_task["tuple_count"].data == 42 - assert end_task["tuple_enabled"].data == True +def test_flow_completes(config_plain_run): + """Test that the configuration plain parsing flow completes successfully.""" + assert config_plain_run.successful + assert config_plain_run.finished + + +@pytest.mark.parametrize( + "key_prefix, expected_type, expected_value", + [ + ("plain_str", "str", '{"raw": "string", "number": 123}'), + ("plain_list", "list", ["apple", "banana", "cherry", "date"]), + ("plain_tuple", "tuple", ("test_tuple", 42, True)), + ], + ids=["raw_string", "parsed_list", "parsed_tuple"], +) +def test_plain_config_types_and_values( + config_plain_run, key_prefix, expected_type, expected_value +): + """Test that plain Config fields yield the expected type and exact raw or parsed values.""" + end_task = config_plain_run["end"].task + + assert end_task[f"{key_prefix}_type"].data == expected_type + assert end_task[f"{key_prefix}_value"].data == expected_value + + +def test_plain_list_properties(config_plain_run): + """Test detailed extraction and length properties of a plain parsed list.""" + end_task = config_plain_run["end"].task + + assert end_task["plain_list_length"].data == 4 + assert end_task["plain_list_first"].data == "apple" + + +def test_plain_tuple_properties(config_plain_run): + """Test detailed extraction and inner structures of a plain parsed tuple.""" + end_task = config_plain_run["end"].task + + assert end_task["tuple_name"].data == "test_tuple" + assert end_task["tuple_count"].data == 42 + assert end_task["tuple_enabled"].data is True diff --git a/test/unit/graph_inference/test_card_dag.py b/test/unit/graph_inference/test_card_dag.py index bdce054c24c..7c98fd19398 100644 --- a/test/unit/graph_inference/test_card_dag.py +++ b/test/unit/graph_inference/test_card_dag.py @@ -12,6 +12,7 @@ """ import json +import pytest from metaflow.plugins.cards.card_modules.basic import ( DefaultCardJSON, @@ -20,6 +21,7 @@ def _find_components_by_type(node, component_type): + """Recursively search for components of a specific type in a card JSON structure.""" if isinstance(node, dict): if node.get("type") == component_type: yield node @@ -31,12 +33,14 @@ def _find_components_by_type(node, component_type): # --------------------------------------------------------------------------- -# transform_flow_graph: shape-detection unit tests +# Fixtures # --------------------------------------------------------------------------- -def test_transform_flow_graph_supports_explicit_endpoints(): - graph = { +@pytest.fixture +def explicit_endpoints_graph(): + """Provides a fresh graph definition with custom explicit start/end steps.""" + return { "start_step": "begin", "end_step": "finish", "steps": { @@ -46,7 +50,23 @@ def test_transform_flow_graph_supports_explicit_endpoints(): }, } - transformed = transform_flow_graph(graph) + +@pytest.fixture +def legacy_graph(): + """Provides a fresh legacy graph definition relying on hardcoded keys.""" + return { + "start": {"type": "start", "next": ["end"], "doc": ""}, + "end": {"type": "end", "next": [], "doc": ""}, + } + + +# --------------------------------------------------------------------------- +# transform_flow_graph: shape-detection unit tests +# --------------------------------------------------------------------------- + + +def test_transform_flow_graph_supports_explicit_endpoints(explicit_endpoints_graph): + transformed = transform_flow_graph(explicit_endpoints_graph) assert transformed["start_step"] == "begin" assert transformed["end_step"] == "finish" @@ -56,13 +76,8 @@ def test_transform_flow_graph_supports_explicit_endpoints(): assert transformed["steps"]["finish"]["type"] == "end" -def test_transform_flow_graph_keeps_legacy_start_end_detection(): - graph = { - "start": {"type": "start", "next": ["end"], "doc": ""}, - "end": {"type": "end", "next": [], "doc": ""}, - } - - transformed = transform_flow_graph(graph) +def test_transform_flow_graph_keeps_legacy_start_end_detection(legacy_graph): + transformed = transform_flow_graph(legacy_graph) assert transformed["start_step"] == "start" assert transformed["end_step"] == "end" diff --git a/test/unit/graph_inference/test_graph_inference.py b/test/unit/graph_inference/test_graph_inference.py index f68d6833bfb..46dbe1ec7c1 100644 --- a/test/unit/graph_inference/test_graph_inference.py +++ b/test/unit/graph_inference/test_graph_inference.py @@ -9,29 +9,127 @@ - Single-step flows execute end-to-end """ +import pytest from metaflow.events import Trigger # --------------------------------------------------------------------------- -# Custom named flow (begin/middle/finish) +# Shared Shape Tests (Parametrized) # --------------------------------------------------------------------------- -def test_custom_named_flow_completes(custom_named_run): - assert custom_named_run.successful - assert custom_named_run.finished +@pytest.mark.parametrize( + "run_fixture", + [ + "custom_named_run", + "single_step_run", + "single_step_bare_run", + "custom_branch_run", + "single_step_with_config_run", + "single_step_with_stacked_decos_run", + "single_step_with_flow_mutator_run", + ], + ids=[ + "custom_named", + "single_step", + "single_step_bare", + "branch", + "config", + "stacked_decos", + "flow_mutator", + ], +) +def test_flow_completes_successfully(run_fixture, request): + """Verify that various flow configurations execute end-to-end.""" + run = request.getfixturevalue(run_fixture) + assert run.successful + assert run.finished + + +@pytest.mark.parametrize( + "run_fixture, expected_start, expected_end", + [ + ("custom_named_run", "begin", "finish"), + ("single_step_run", "only", "only"), + ("single_step_bare_run", "only", "only"), + ("custom_branch_run", "entry", "done"), + ], + ids=["custom_named", "single_step", "single_step_bare", "branch"], +) +def test_graph_info_endpoints(run_fixture, expected_start, expected_end, request): + """Verify _graph_info captures the correct explicit start/end steps.""" + run = request.getfixturevalue(run_fixture) + graph_info = run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == expected_start + assert graph_info["end_step"] == expected_end + + +@pytest.mark.parametrize( + "run_fixture, expected_start, expected_end", + [ + ("custom_named_run", "begin", "finish"), + ("single_step_run", "only", "only"), + ("single_step_bare_run", "only", "only"), + ], + ids=["custom_named", "single_step", "single_step_bare"], +) +def test_parameters_metadata_endpoints( + run_fixture, expected_start, expected_end, request +): + """Verify metadata reflects the start/end step parameters correctly.""" + run = request.getfixturevalue(run_fixture) + meta = run["_parameters"].task.metadata_dict + assert meta.get("start_step") == expected_start + assert meta.get("end_step") == expected_end + + +@pytest.mark.parametrize( + "run_fixture, expected_val", + [ + ("custom_named_run", 3), + ("single_step_run", 42), + ("single_step_bare_run", 42), + ], + ids=["custom_named", "single_step", "single_step_bare"], +) +def test_end_task_data_value(run_fixture, expected_val, request): + """Verify the primary test artifact is correctly set in the terminal task.""" + run = request.getfixturevalue(run_fixture) + assert run.end_task is not None + assert run.end_task["x"].data == expected_val + + +@pytest.mark.parametrize( + "run_fixture, expected_steps", + [ + ("custom_named_run", {"begin", "middle", "finish"}), + ("single_step_run", {"only"}), + ("single_step_bare_run", {"only"}), + ("custom_branch_run", {"entry", "a", "b", "merge", "done"}), + ], + ids=["custom_named", "single_step", "single_step_bare", "branch"], +) +def test_steps_present(run_fixture, expected_steps, request): + """Verify the client API surfaces all executed step IDs correctly.""" + run = request.getfixturevalue(run_fixture) + assert {step.id for step in run} == expected_steps + + +@pytest.mark.parametrize( + "run_fixture", + ["single_step_run", "single_step_bare_run"], + ids=["single_step", "single_step_bare"], +) +def test_single_step_parent_child_empty(run_fixture, request): + """Single-step flows should have no parent or child relationships.""" + run = request.getfixturevalue(run_fixture) + assert list(run["only"].parent_steps) == [] + assert list(run["only"].child_steps) == [] -def test_custom_named_graph_info_has_endpoints(custom_named_run): - graph_info = custom_named_run["_parameters"].task["_graph_info"].data - assert graph_info["start_step"] == "begin" - assert graph_info["end_step"] == "finish" - - -def test_custom_named_parameters_metadata_has_endpoints(custom_named_run): - meta = custom_named_run["_parameters"].task.metadata_dict - assert meta.get("start_step") == "begin" - assert meta.get("end_step") == "finish" +# --------------------------------------------------------------------------- +# Flow-Specific Edge Cases & Composition +# --------------------------------------------------------------------------- def test_custom_named_graph_endpoints_property(custom_named_run): @@ -40,17 +138,6 @@ def test_custom_named_graph_endpoints_property(custom_named_run): assert end == "finish" -def test_custom_named_end_task(custom_named_run): - end_task = custom_named_run.end_task - assert end_task is not None - assert end_task["x"].data == 3 - - -def test_custom_named_steps_present(custom_named_run): - step_names = {step.id for step in custom_named_run} - assert step_names == {"begin", "middle", "finish"} - - def test_custom_named_parent_steps(custom_named_run): assert list(custom_named_run["begin"].parent_steps) == [] assert [step.id for step in custom_named_run["middle"].parent_steps] == ["begin"] @@ -63,98 +150,7 @@ def test_custom_named_child_steps(custom_named_run): assert list(custom_named_run["finish"].child_steps) == [] -# --------------------------------------------------------------------------- -# Single-step flow (start == end) -# --------------------------------------------------------------------------- - - -def test_single_step_flow_completes(single_step_run): - assert single_step_run.successful - assert single_step_run.finished - - -def test_single_step_graph_info_start_equals_end(single_step_run): - graph_info = single_step_run["_parameters"].task["_graph_info"].data - assert graph_info["start_step"] == "only" - assert graph_info["end_step"] == "only" - assert graph_info["start_step"] == graph_info["end_step"] - - -def test_single_step_parameters_metadata(single_step_run): - meta = single_step_run["_parameters"].task.metadata_dict - assert meta.get("start_step") == "only" - assert meta.get("end_step") == "only" - - -def test_single_step_end_task(single_step_run): - end_task = single_step_run.end_task - assert end_task is not None - assert end_task["x"].data == 42 - - -def test_single_step_present(single_step_run): - assert {step.id for step in single_step_run} == {"only"} - - -def test_single_step_parent_child_empty(single_step_run): - assert list(single_step_run["only"].parent_steps) == [] - assert list(single_step_run["only"].child_steps) == [] - - -# --------------------------------------------------------------------------- -# Single-step flow with bare @step (implicit start == end) -# --------------------------------------------------------------------------- - - -def test_single_step_bare_flow_completes(single_step_bare_run): - assert single_step_bare_run.successful - assert single_step_bare_run.finished - - -def test_single_step_bare_graph_info_start_equals_end(single_step_bare_run): - graph_info = single_step_bare_run["_parameters"].task["_graph_info"].data - assert graph_info["start_step"] == "only" - assert graph_info["end_step"] == "only" - - -def test_single_step_bare_parameters_metadata(single_step_bare_run): - meta = single_step_bare_run["_parameters"].task.metadata_dict - assert meta.get("start_step") == "only" - assert meta.get("end_step") == "only" - - -def test_single_step_bare_end_task(single_step_bare_run): - end_task = single_step_bare_run.end_task - assert end_task is not None - assert end_task["x"].data == 42 - - -def test_single_step_bare_step_present(single_step_bare_run): - assert {step.id for step in single_step_bare_run} == {"only"} - - -def test_single_step_bare_parent_child_empty(single_step_bare_run): - assert list(single_step_bare_run["only"].parent_steps) == [] - assert list(single_step_bare_run["only"].child_steps) == [] - - -# --------------------------------------------------------------------------- -# Custom branch flow (entry/a/b/merge/done) -# --------------------------------------------------------------------------- - - -def test_branch_flow_completes(custom_branch_run): - assert custom_branch_run.successful - assert custom_branch_run.finished - - -def test_branch_graph_info_endpoints(custom_branch_run): - graph_info = custom_branch_run["_parameters"].task["_graph_info"].data - assert graph_info["start_step"] == "entry" - assert graph_info["end_step"] == "done" - - -def test_branch_end_task(custom_branch_run): +def test_branch_end_task_exists(custom_branch_run): assert custom_branch_run.end_task is not None @@ -163,16 +159,6 @@ def test_branch_merge_data(custom_branch_run): assert sorted(merge_task["vals"].data) == ["a", "b"] -def test_branch_steps_present(custom_branch_run): - assert {step.id for step in custom_branch_run} == { - "entry", - "a", - "b", - "merge", - "done", - } - - def test_branch_entry_has_two_children(custom_branch_run): children = [step.id for step in custom_branch_run["entry"].child_steps] assert sorted(children) == ["a", "b"] @@ -184,7 +170,7 @@ def test_branch_merge_has_two_parents(custom_branch_run): # --------------------------------------------------------------------------- -# Trigger integration +# Trigger Integration # --------------------------------------------------------------------------- @@ -198,28 +184,16 @@ def test_trigger_from_runs_uses_custom_terminal_step(custom_named_run): # --------------------------------------------------------------------------- -# Composition: single-step flows with Config, stacked decorators, FlowMutator +# Composition Specific Tests # --------------------------------------------------------------------------- -def test_single_step_with_config_completes(single_step_with_config_run): - """Config-bearing single-step flow runs to completion.""" - assert single_step_with_config_run.successful - assert single_step_with_config_run.finished - - def test_single_step_with_config_value_flows_to_artifact(single_step_with_config_run): """Config descriptor value is readable from the end task's artifact.""" end_task = single_step_with_config_run.end_task assert end_task["v"].data == 7 -def test_single_step_with_stacked_decos_completes(single_step_with_stacked_decos_run): - """Single-step flow with stacked @retry/@resources runs end-to-end.""" - assert single_step_with_stacked_decos_run.successful - assert single_step_with_stacked_decos_run.finished - - def test_single_step_with_stacked_decos_graph_info(single_step_with_stacked_decos_run): """_graph_info records all stacked decorators on the only step.""" graph_info = ( @@ -229,12 +203,6 @@ def test_single_step_with_stacked_decos_graph_info(single_step_with_stacked_deco assert {"retry", "resources"}.issubset(names) -def test_single_step_with_flow_mutator_completes(single_step_with_flow_mutator_run): - """FlowMutator-decorated single-step flow runs end-to-end.""" - assert single_step_with_flow_mutator_run.successful - assert single_step_with_flow_mutator_run.finished - - def test_single_step_with_flow_mutator_applied(single_step_with_flow_mutator_run): """FlowMutator.add_decorator landed @retry on the only step.""" graph_info = ( diff --git a/test/unit/inheritance/test_inheritance.py b/test/unit/inheritance/test_inheritance.py index d10956bb2dc..9dc60b926a6 100644 --- a/test/unit/inheritance/test_inheritance.py +++ b/test/unit/inheritance/test_inheritance.py @@ -13,323 +13,354 @@ import pytest -class TestComprehensiveLinear: - """Test comprehensive linear inheritance: FlowSpec -> BaseA -> BaseB -> BaseC -> Flow""" +# --------------------------------------------------------------------------- +# Linear Inheritance Tests (FlowSpec -> BaseA -> BaseB -> BaseC -> Flow) +# --------------------------------------------------------------------------- - def test_flow_completes(self, comprehensive_linear_run): - """Test that the flow completes successfully""" - assert comprehensive_linear_run.successful - assert comprehensive_linear_run.finished - def test_all_parameters_accessible(self, comprehensive_linear_run): - """Test that parameters from all levels are accessible""" - end_task = comprehensive_linear_run["end"].task +def test_linear_all_parameters_accessible(comprehensive_linear_run): + """Test that parameters from all levels are accessible in linear inheritance.""" + end_task = comprehensive_linear_run["end"].task - # From BaseA - assert end_task["result_alpha"].data == 10 - assert end_task["result_beta"].data == 5 + # From BaseA + assert end_task["result_alpha"].data == 10 + assert end_task["result_beta"].data == 5 - # From BaseC - assert end_task["result_gamma"].data == 2.5 + # From BaseC + assert end_task["result_gamma"].data == 2.5 - # From final class - assert end_task["result_delta"].data == "final" + # From final class + assert end_task["result_delta"].data == "final" - def test_all_configs_accessible(self, comprehensive_linear_run): - """Test that configs from all levels are accessible""" - end_task = comprehensive_linear_run["end"].task - # From BaseB - config_b = end_task["result_config_b"].data - assert config_b["multiplier"] == 3 - assert config_b["offset"] == 100 +def test_linear_all_configs_accessible(comprehensive_linear_run): + """Test that configs from all levels are accessible in linear inheritance.""" + end_task = comprehensive_linear_run["end"].task - # From BaseC - config_c = end_task["result_config_c"].data - assert config_c["mode"] == "production" - assert config_c["debug"] is False + # From BaseB + config_b = end_task["result_config_b"].data + assert config_b["multiplier"] == 3 + assert config_b["offset"] == 100 - def test_computation_with_configs(self, comprehensive_linear_run): - """Test computation using inherited parameters and configs""" - end_task = comprehensive_linear_run["end"].task + # From BaseC + config_c = end_task["result_config_c"].data + assert config_c["mode"] == "production" + assert config_c["debug"] is False - # start_value = alpha + beta = 10 + 5 = 15 - # processed_value = start_value * multiplier + offset = 15 * 3 + 100 = 145 - assert end_task["result_final"].data == 145 +def test_linear_computation_with_configs(comprehensive_linear_run): + """Test computation using inherited parameters and configs.""" + end_task = comprehensive_linear_run["end"].task -class TestMutatorWithBaseConfig: - """Test FlowMutator using config from base class""" + # start_value = alpha + beta = 10 + 5 = 15 + # processed_value = start_value * multiplier + offset = 15 * 3 + 100 = 145 + assert end_task["result_final"].data == 145 - def test_flow_completes(self, mutator_with_base_config_run): - """Test that flow completes successfully""" - assert mutator_with_base_config_run.successful - assert mutator_with_base_config_run.finished - def test_base_parameters_accessible(self, mutator_with_base_config_run): - """Test that base parameters are accessible""" - start_task = mutator_with_base_config_run["start"].task +# --------------------------------------------------------------------------- +# FlowMutator with Base Class Config Tests +# --------------------------------------------------------------------------- - assert start_task["result_base_param"].data == "base" - assert start_task["result_middle_param"].data == 100 - assert start_task["result_final_param"].data == 50 - def test_base_config_accessible(self, mutator_with_base_config_run): - """Test that config from base class is accessible""" - start_task = mutator_with_base_config_run["start"].task +def test_mutator_base_config_parameters_accessible(mutator_with_base_config_run): + """Test that base parameters are accessible when using a base config mutator.""" + start_task = mutator_with_base_config_run["start"].task - config = start_task["result_mutator_config"].data - assert config["param_to_inject"] == "dynamic_param" - assert config["default_value"] == 777 - assert config["inject_count"] == 42 + assert start_task["result_base_param"].data == "base" + assert start_task["result_middle_param"].data == 100 + assert start_task["result_final_param"].data == 50 - def test_mutator_injects_from_base_config(self, mutator_with_base_config_run): - """Test that mutator injects parameters based on base config""" - start_task = mutator_with_base_config_run["start"].task - # These parameters should be injected by the mutator based on mutator_config - assert start_task["result_dynamic_param"].data == 777 - assert start_task["result_injected_count"].data == 42 +def test_mutator_base_config_accessible(mutator_with_base_config_run): + """Test that config from base class is accessible.""" + start_task = mutator_with_base_config_run["start"].task - def test_computation_with_injected_params(self, mutator_with_base_config_run): - """Test computation using injected parameters""" - start_task = mutator_with_base_config_run["start"].task + config = start_task["result_mutator_config"].data + assert config["param_to_inject"] == "dynamic_param" + assert config["default_value"] == 777 + assert config["inject_count"] == 42 - # result_computation = middle_param + dynamic_param + injected_count - # = 100 + 777 + 42 = 919 - assert start_task["result_computation"].data == 919 +def test_mutator_injects_from_base_config(mutator_with_base_config_run): + """Test that mutator injects parameters based on base config.""" + start_task = mutator_with_base_config_run["start"].task -class TestMutatorWithDerivedConfig: - """Test FlowMutator in base class using config from derived class""" + # These parameters should be injected by the mutator based on mutator_config + assert start_task["result_dynamic_param"].data == 777 + assert start_task["result_injected_count"].data == 42 - def test_flow_completes(self, mutator_with_derived_config_run): - """Test that flow completes successfully""" - assert mutator_with_derived_config_run.successful - assert mutator_with_derived_config_run.finished - def test_all_parameters_accessible(self, mutator_with_derived_config_run): - """Test that all parameters from hierarchy are accessible""" - start_task = mutator_with_derived_config_run["start"].task +def test_mutator_base_config_computation_with_injected_params( + mutator_with_base_config_run, +): + """Test computation using parameters injected from base config.""" + start_task = mutator_with_base_config_run["start"].task - assert start_task["result_base_param"].data == "base_value" - assert start_task["result_middle_param"].data == 200 - assert start_task["result_final_param"].data == 999 + # result_computation = middle_param + dynamic_param + injected_count + # = 100 + 777 + 42 = 919 + assert start_task["result_computation"].data == 919 - def test_all_configs_accessible(self, mutator_with_derived_config_run): - """Test that all configs from hierarchy are accessible""" - start_task = mutator_with_derived_config_run["start"].task - middle_config = start_task["result_middle_config"].data - assert middle_config["env"] == "staging" +# --------------------------------------------------------------------------- +# FlowMutator with Derived Class Config Tests +# --------------------------------------------------------------------------- - runtime_config = start_task["result_runtime_config"].data - assert runtime_config["features"] == ["logging", "metrics"] - assert runtime_config["worker_count"] == 16 - def test_base_mutator_uses_derived_config(self, mutator_with_derived_config_run): - """Test that base class mutator injects parameters from derived config""" - start_task = mutator_with_derived_config_run["start"].task +def test_mutator_derived_config_all_parameters_accessible( + mutator_with_derived_config_run, +): + """Test that all parameters from hierarchy are accessible when using derived config.""" + start_task = mutator_with_derived_config_run["start"].task - # These parameters should be injected by base mutator using derived runtime_config - assert start_task["result_feature_logging"].data is True - assert start_task["result_feature_metrics"].data is True - assert start_task["result_worker_count"].data == 16 + assert start_task["result_base_param"].data == "base_value" + assert start_task["result_middle_param"].data == 200 + assert start_task["result_final_param"].data == 999 - def test_computation_with_forward_injected_params( - self, mutator_with_derived_config_run - ): - """Test computation using parameters injected from derived config""" - start_task = mutator_with_derived_config_run["start"].task - # result_computation = worker_count * enabled_features + final_param - # enabled_features = feature_logging (True=1) + feature_metrics (True=1) = 2 - # = 16 * 2 + 999 = 1031 - assert start_task["result_computation"].data == 1031 +def test_mutator_derived_config_all_configs_accessible(mutator_with_derived_config_run): + """Test that all configs from hierarchy are accessible.""" + start_task = mutator_with_derived_config_run["start"].task + middle_config = start_task["result_middle_config"].data + assert middle_config["env"] == "staging" -class TestComprehensiveDiamond: - """Test comprehensive diamond inheritance pattern""" + runtime_config = start_task["result_runtime_config"].data + assert runtime_config["features"] == ["logging", "metrics"] + assert runtime_config["worker_count"] == 16 - def test_flow_completes(self, comprehensive_diamond_run): - """Test that diamond inheritance flow completes""" - assert comprehensive_diamond_run.successful - assert comprehensive_diamond_run.finished - def test_parameters_from_all_branches(self, comprehensive_diamond_run): - """Test parameters from all branches are accessible""" - end_task = comprehensive_diamond_run["end"].task +def test_base_mutator_uses_derived_config(mutator_with_derived_config_run): + """Test that base class mutator injects parameters from derived config.""" + start_task = mutator_with_derived_config_run["start"].task - # From BaseA branch - assert end_task["result_param_a"].data == 100 + # These parameters should be injected by base mutator using derived runtime_config + assert start_task["result_feature_logging"].data is True + assert start_task["result_feature_metrics"].data is True + assert start_task["result_worker_count"].data == 16 - # From BaseB branch - assert end_task["result_param_b"].data == 50 - # From BaseC (merge point) - assert end_task["result_param_c"].data == 25 +def test_computation_with_forward_injected_params(mutator_with_derived_config_run): + """Test computation using parameters injected from derived config.""" + start_task = mutator_with_derived_config_run["start"].task - # From final class - assert end_task["result_final_param"].data == "complete" + # result_computation = worker_count * enabled_features + final_param + # enabled_features = feature_logging (True=1) + feature_metrics (True=1) = 2 + # = 16 * 2 + 999 = 1031 + assert start_task["result_computation"].data == 1031 - def test_configs_from_all_branches(self, comprehensive_diamond_run): - """Test configs from all branches are accessible""" - end_task = comprehensive_diamond_run["end"].task - # From BaseA branch - config_a = end_task["result_config_a"].data - assert config_a["branch"] == "A" - assert config_a["priority"] == 1 +# --------------------------------------------------------------------------- +# Comprehensive Diamond Inheritance Tests +# --------------------------------------------------------------------------- - # From BaseB branch - config_b = end_task["result_config_b"].data - assert config_b["branch"] == "B" - assert config_b["weight"] == 2.5 - # From BaseC (merge point) - config_c = end_task["result_config_c"].data - assert config_c["mode"] == "diamond" - assert config_c["enabled"] is True - - def test_mro_resolution(self, comprehensive_diamond_run): - """Test that MRO correctly resolves diamond pattern""" - # If flow completes and uses correct step from BaseA, MRO is working - assert comprehensive_diamond_run.successful - assert "start" in [step.id for step in comprehensive_diamond_run.steps()] - assert "process" in [step.id for step in comprehensive_diamond_run.steps()] +def test_diamond_parameters_from_all_branches(comprehensive_diamond_run): + """Test parameters from all branches of diamond inheritance are accessible.""" + end_task = comprehensive_diamond_run["end"].task - def test_computation_across_branches(self, comprehensive_diamond_run): - """Test computation using values from all branches""" - end_task = comprehensive_diamond_run["end"].task - - # value_a = param_a * priority = 100 * 1 = 100 - # processed = value_a + (param_b * weight) + param_c - # = 100 + (50 * 2.5) + 25 = 100 + 125 + 25 = 250 - assert end_task["result_final"].data == 250 - - -class TestComprehensiveMultiHierarchy: - """Test comprehensive multiple inheritance from independent hierarchies""" - - def test_flow_completes(self, comprehensive_multi_hierarchy_run): - """Test that multi-hierarchy flow completes""" - assert comprehensive_multi_hierarchy_run.successful - assert comprehensive_multi_hierarchy_run.finished - - def test_parameters_from_first_hierarchy(self, comprehensive_multi_hierarchy_run): - """Test parameters from first hierarchy are accessible""" - end_task = comprehensive_multi_hierarchy_run["end"].task - - assert end_task["result_param_a"].data == 10 - assert end_task["result_param_b"].data == 20 - - def test_parameters_from_second_hierarchy(self, comprehensive_multi_hierarchy_run): - """Test parameters from second hierarchy are accessible""" - end_task = comprehensive_multi_hierarchy_run["end"].task - - assert end_task["result_param_x"].data == 30 - assert end_task["result_param_y"].data == 40 - - def test_merge_point_parameters(self, comprehensive_multi_hierarchy_run): - """Test parameters from merge point are accessible""" - end_task = comprehensive_multi_hierarchy_run["end"].task - - assert end_task["result_param_c"].data == 5 - assert end_task["result_final_param"].data == "merged" - - def test_configs_from_both_hierarchies(self, comprehensive_multi_hierarchy_run): - """Test configs from both hierarchies are accessible""" - end_task = comprehensive_multi_hierarchy_run["end"].task - - # First hierarchy - config_a = end_task["result_config_a"].data - assert config_a["source"] == "hierarchy_a" - assert config_a["value"] == 100 - - # Second hierarchy - config_x = end_task["result_config_x"].data - assert config_x["source"] == "hierarchy_x" - assert config_x["multiplier"] == 2 - - config_y = end_task["result_config_y"].data - assert config_y["enabled"] is True - assert config_y["threshold"] == 50 - - # Merge point - config_c = end_task["result_config_c"].data - assert config_c["merge"] is True - assert config_c["offset"] == 200 - - def test_step_override_from_merge_point(self, comprehensive_multi_hierarchy_run): - """Test that BaseC's process step overrides BaseY's process step""" - # If the computation matches BaseC's logic (not BaseY's), override worked - end_task = comprehensive_multi_hierarchy_run["end"].task - - # hierarchy_a_result = param_a + param_b + config_a.value = 10 + 20 + 100 = 130 - # base_value = hierarchy_a_result * multiplier = 130 * 2 = 260 - # Since base_value (260) > threshold (50): - # processed_value = base_value + offset + param_c = 260 + 200 + 5 = 465 - assert end_task["result_final"].data == 465 - - def test_cross_hierarchy_computation(self, comprehensive_multi_hierarchy_run): - """Test computation using values from both hierarchies""" - end_task = comprehensive_multi_hierarchy_run["end"].task - - # Cross-hierarchy sum = param_a + param_b + param_x + param_y + param_c - # = 10 + 20 + 30 + 40 + 5 = 105 - assert end_task["result_cross_hierarchy"].data == 105 - - def test_mutator_from_first_hierarchy_executes( - self, comprehensive_multi_hierarchy_run - ): - end_task = comprehensive_multi_hierarchy_run["end"].task - assert end_task["logging_param_count"].data == 6 - assert end_task["logging_config_count"].data == 4 - - def test_decorated_step_from_first_hierarchy( - self, comprehensive_multi_hierarchy_run - ): - """Test that decorated step from first hierarchy works""" - end_task = comprehensive_multi_hierarchy_run["end"].task - assert end_task["source_from_var"].data == "hierarchy_x" - - -# Integration tests -class TestInheritanceIntegration: - """Integration tests across different inheritance patterns""" - - @pytest.mark.parametrize( - "fixture_name", - [ - "comprehensive_linear_run", - "mutator_with_base_config_run", - "mutator_with_derived_config_run", - "comprehensive_diamond_run", - "comprehensive_multi_hierarchy_run", - ], - ) - def test_all_flows_complete_successfully(self, fixture_name, request): - """Test that all inheritance pattern flows complete successfully""" - run = request.getfixturevalue(fixture_name) - assert run.successful, f"{fixture_name} did not complete successfully" - assert run.finished, f"{fixture_name} did not finish" - - @pytest.mark.parametrize( - "fixture_name,expected_steps", - [ - ("comprehensive_linear_run", ["start", "process", "end"]), - ("mutator_with_base_config_run", ["start", "end"]), - ("mutator_with_derived_config_run", ["start", "end"]), - ("comprehensive_diamond_run", ["start", "process", "end"]), - ("comprehensive_multi_hierarchy_run", ["start", "process", "end"]), - ], - ) - def test_expected_steps_present(self, fixture_name, expected_steps, request): - """Test that all expected steps are present in each flow""" - run = request.getfixturevalue(fixture_name) - step_names = [step.id for step in run.steps()] - - for expected_step in expected_steps: - assert ( - expected_step in step_names - ), f"Step {expected_step} not found in {fixture_name}" + # From BaseA branch + assert end_task["result_param_a"].data == 100 + + # From BaseB branch + assert end_task["result_param_b"].data == 50 + + # From BaseC (merge point) + assert end_task["result_param_c"].data == 25 + + # From final class + assert end_task["result_final_param"].data == "complete" + + +def test_diamond_configs_from_all_branches(comprehensive_diamond_run): + """Test configs from all branches of diamond inheritance are accessible.""" + end_task = comprehensive_diamond_run["end"].task + + # From BaseA branch + config_a = end_task["result_config_a"].data + assert config_a["branch"] == "A" + assert config_a["priority"] == 1 + + # From BaseB branch + config_b = end_task["result_config_b"].data + assert config_b["branch"] == "B" + assert config_b["weight"] == 2.5 + + # From BaseC (merge point) + config_c = end_task["result_config_c"].data + assert config_c["mode"] == "diamond" + assert config_c["enabled"] is True + + +def test_diamond_mro_resolution(comprehensive_diamond_run): + """Test that MRO correctly resolves diamond pattern.""" + # If flow completes and uses correct step from BaseA, MRO is working + assert comprehensive_diamond_run.successful + step_ids = [step.id for step in comprehensive_diamond_run.steps()] + assert "start" in step_ids + assert "process" in step_ids + + +def test_diamond_computation_across_branches(comprehensive_diamond_run): + """Test computation using values from all branches of a diamond structure.""" + end_task = comprehensive_diamond_run["end"].task + + # value_a = param_a * priority = 100 * 1 = 100 + # processed = value_a + (param_b * weight) + param_c + # = 100 + (50 * 2.5) + 25 = 100 + 125 + 25 = 250 + assert end_task["result_final"].data == 250 + + +# --------------------------------------------------------------------------- +# Comprehensive Multiple Inheritance Tests (Independent Hierarchies) +# --------------------------------------------------------------------------- + + +def test_multi_hierarchy_parameters_from_first_hierarchy( + comprehensive_multi_hierarchy_run, +): + """Test parameters from first independent hierarchy are accessible.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + + assert end_task["result_param_a"].data == 10 + assert end_task["result_param_b"].data == 20 + + +def test_multi_hierarchy_parameters_from_second_hierarchy( + comprehensive_multi_hierarchy_run, +): + """Test parameters from second independent hierarchy are accessible.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + + assert end_task["result_param_x"].data == 30 + assert end_task["result_param_y"].data == 40 + + +def test_multi_hierarchy_merge_point_parameters(comprehensive_multi_hierarchy_run): + """Test parameters from hierarchy merge point are accessible.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + + assert end_task["result_param_c"].data == 5 + assert end_task["result_final_param"].data == "merged" + + +def test_multi_hierarchy_configs_from_both_hierarchies( + comprehensive_multi_hierarchy_run, +): + """Test configs from both separate hierarchies are accessible.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + + # First hierarchy + config_a = end_task["result_config_a"].data + assert config_a["source"] == "hierarchy_a" + assert config_a["value"] == 100 + + # Second hierarchy + config_x = end_task["result_config_x"].data + assert config_x["source"] == "hierarchy_x" + assert config_x["multiplier"] == 2 + + config_y = end_task["result_config_y"].data + assert config_y["enabled"] is True + assert config_y["threshold"] == 50 + + # Merge point + config_c = end_task["result_config_c"].data + assert config_c["merge"] is True + assert config_c["offset"] == 200 + + +def test_multi_hierarchy_step_override_from_merge_point( + comprehensive_multi_hierarchy_run, +): + """Test that BaseC's process step overrides BaseY's process step.""" + # If the computation matches BaseC's logic (not BaseY's), override worked + end_task = comprehensive_multi_hierarchy_run["end"].task + + # hierarchy_a_result = param_a + param_b + config_a.value = 10 + 20 + 100 = 130 + # base_value = hierarchy_a_result * multiplier = 130 * 2 = 260 + # Since base_value (260) > threshold (50): + # processed_value = base_value + offset + param_c = 260 + 200 + 5 = 465 + assert end_task["result_final"].data == 465 + + +def test_multi_hierarchy_cross_hierarchy_computation(comprehensive_multi_hierarchy_run): + """Test computation using values combined from both separate hierarchies.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + + # Cross-hierarchy sum = param_a + param_b + param_x + param_y + param_c + # = 10 + 20 + 30 + 40 + 5 = 105 + assert end_task["result_cross_hierarchy"].data == 105 + + +def test_multi_hierarchy_mutator_from_first_hierarchy_executes( + comprehensive_multi_hierarchy_run, +): + """Verify mutator defined in first hierarchy applies context changes successfully.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + assert end_task["logging_param_count"].data == 6 + assert end_task["logging_config_count"].data == 4 + + +def test_multi_hierarchy_decorated_step_from_first_hierarchy( + comprehensive_multi_hierarchy_run, +): + """Test that decorated step from first hierarchy functions correctly.""" + end_task = comprehensive_multi_hierarchy_run["end"].task + assert end_task["source_from_var"].data == "hierarchy_x" + + +# --------------------------------------------------------------------------- +# Integration Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "fixture_name", + [ + "comprehensive_linear_run", + "mutator_with_base_config_run", + "mutator_with_derived_config_run", + "comprehensive_diamond_run", + "comprehensive_multi_hierarchy_run", + ], + ids=[ + "linear_inheritance", + "mutator_base_config", + "mutator_derived_config", + "diamond_inheritance", + "multi_hierarchy", + ], +) +def test_all_flows_complete_successfully(fixture_name, request): + """Test that all structural inheritance pattern flows run completely and successfully.""" + run = request.getfixturevalue(fixture_name) + assert run.successful, f"{fixture_name} did not complete successfully" + assert run.finished, f"{fixture_name} did not finish" + + +@pytest.mark.parametrize( + "fixture_name, expected_steps", + [ + ("comprehensive_linear_run", ["start", "process", "end"]), + ("mutator_with_base_config_run", ["start", "end"]), + ("mutator_with_derived_config_run", ["start", "end"]), + ("comprehensive_diamond_run", ["start", "process", "end"]), + ("comprehensive_multi_hierarchy_run", ["start", "process", "end"]), + ], + ids=[ + "linear_steps", + "mutator_base_steps", + "mutator_derived_steps", + "diamond_steps", + "multi_hierarchy_steps", + ], +) +def test_expected_steps_present(fixture_name, expected_steps, request): + """Test that all expected DAG step structural markers are safely compiled into each run topology.""" + run = request.getfixturevalue(fixture_name) + step_names = [step.id for step in run.steps()] + + for expected_step in expected_steps: + assert ( + expected_step in step_names + ), f"Step {expected_step} not found in {fixture_name}" diff --git a/test/unit/localbatch/test_localbatch.py b/test/unit/localbatch/test_localbatch.py index de726a06b87..f9778c2bc8a 100644 --- a/test/unit/localbatch/test_localbatch.py +++ b/test/unit/localbatch/test_localbatch.py @@ -2,13 +2,13 @@ Integration tests for localbatch — the local AWS Batch emulator. Test layout mirrors the spin tests: - - TestBatchAPI : validates every Batch REST endpoint using boto3. - No Docker required — jobs fail gracefully when unavailable. - - TestDockerExecution : validates actual container execution. - Skipped when Docker is not running. - - TestMetaflowE2E : runs a real Metaflow flow via the @batch decorator - and verifies artifacts, mirroring the spin test style. - Requires Docker. + - Batch REST API Tests : validates every Batch REST endpoint using boto3. + No Docker required — jobs fail gracefully when unavailable. + - Docker Execution Tests: validates actual container execution. + Skipped when Docker is not running. + - Metaflow E2E Tests : runs a real Metaflow flow via the @batch decorator + and verifies artifacts, mirroring the spin test style. + Requires Docker. Run only the API tests (no Docker needed): pytest test/unit/localbatch/ -m "not docker" @@ -18,7 +18,6 @@ """ import time - import pytest import requests @@ -26,7 +25,7 @@ # --------------------------------------------------------------------------- -# Helpers +# Helpers & Shared Fixtures # --------------------------------------------------------------------------- @@ -58,278 +57,292 @@ def _register(batch_client, name, command=None): return resp["jobDefinitionArn"] +@pytest.fixture(autouse=True) +def _require_docker_if_marked(request): + """Automatically skip any test marked with @pytest.mark.docker if daemon is unavailable.""" + if "docker" in request.keywords: + try: + import docker + + docker.from_env().ping() + except Exception: + pytest.skip("Docker not available") + + # --------------------------------------------------------------------------- -# Batch REST API tests (no Docker required) +# Batch REST API Tests (No Docker Required) # --------------------------------------------------------------------------- -class TestBatchAPI: - def test_health(self, localbatch_server): - resp = requests.get(f"{localbatch_server.base_url}/health") - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - def test_describe_queues_returns_default(self, batch_client): - resp = batch_client.describe_job_queues() - queues = resp["jobQueues"] - assert len(queues) >= 1 - names = [q["jobQueueName"] for q in queues] - assert "localbatch-default" in names - - def test_default_queue_is_healthy(self, batch_client): - resp = batch_client.describe_job_queues(jobQueues=["localbatch-default"]) - q = resp["jobQueues"][0] - assert q["state"] == "ENABLED" - assert q["status"] == "VALID" - assert len(q["computeEnvironmentOrder"]) >= 1 - - def test_describe_compute_environments(self, batch_client): - resp = batch_client.describe_compute_environments() - envs = resp["computeEnvironments"] - assert len(envs) >= 1 - assert envs[0]["status"] == "VALID" - - def test_register_job_definition(self, batch_client): +def test_health(localbatch_server): + resp = requests.get(f"{localbatch_server.base_url}/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + +def test_describe_queues_returns_default(batch_client): + resp = batch_client.describe_job_queues() + queues = resp["jobQueues"] + assert len(queues) >= 1 + names = [q["jobQueueName"] for q in queues] + assert "localbatch-default" in names + + +def test_default_queue_is_healthy(batch_client): + resp = batch_client.describe_job_queues(jobQueues=["localbatch-default"]) + q = resp["jobQueues"][0] + assert q["state"] == "ENABLED" + assert q["status"] == "VALID" + assert len(q["computeEnvironmentOrder"]) >= 1 + + +def test_describe_compute_environments(batch_client): + resp = batch_client.describe_compute_environments() + envs = resp["computeEnvironments"] + assert len(envs) >= 1 + assert envs[0]["status"] == "VALID" + + +def test_register_job_definition(batch_client): + resp = batch_client.register_job_definition( + jobDefinitionName="api-test-jobdef", + type="container", + containerProperties={ + "image": "alpine:latest", + "command": ["echo", "hello"], + "resourceRequirements": [ + {"type": "VCPU", "value": "1"}, + {"type": "MEMORY", "value": "256"}, + ], + }, + ) + assert resp["jobDefinitionName"] == "api-test-jobdef" + assert resp["revision"] == 1 + assert resp["jobDefinitionArn"].startswith("arn:aws:batch:") + + +def test_revision_increments_on_re_register(batch_client): + name = "revision-increment-jobdef" + for expected in (1, 2, 3): resp = batch_client.register_job_definition( - jobDefinitionName="api-test-jobdef", + jobDefinitionName=name, type="container", containerProperties={ "image": "alpine:latest", - "command": ["echo", "hello"], + "command": ["echo", str(expected)], "resourceRequirements": [ {"type": "VCPU", "value": "1"}, {"type": "MEMORY", "value": "256"}, ], }, ) - assert resp["jobDefinitionName"] == "api-test-jobdef" - assert resp["revision"] == 1 - assert resp["jobDefinitionArn"].startswith("arn:aws:batch:") - - def test_revision_increments_on_re_register(self, batch_client): - name = "revision-increment-jobdef" - for expected in (1, 2, 3): - resp = batch_client.register_job_definition( - jobDefinitionName=name, - type="container", - containerProperties={ - "image": "alpine:latest", - "command": ["echo", str(expected)], - "resourceRequirements": [ - {"type": "VCPU", "value": "1"}, - {"type": "MEMORY", "value": "256"}, - ], - }, - ) - assert resp["revision"] == expected - - def test_describe_job_definitions_by_name(self, batch_client): - _register(batch_client, "describe-by-name-job") - resp = batch_client.describe_job_definitions( - jobDefinitionName="describe-by-name-job", status="ACTIVE" - ) - defs = resp["jobDefinitions"] - assert len(defs) >= 1 - assert defs[0]["jobDefinitionName"] == "describe-by-name-job" - - def test_submit_job_returns_id_and_name(self, batch_client): - arn = _register(batch_client, "submit-test-job") - resp = batch_client.submit_job( - jobName="submit-test-run", - jobQueue="localbatch-default", - jobDefinition=arn, - ) - assert resp["jobId"] - assert resp["jobName"] == "submit-test-run" - - def test_job_reaches_terminal_state(self, batch_client): - """Without Docker the job fails; with Docker it succeeds — both are valid.""" - arn = _register(batch_client, "terminal-test-job") - resp = batch_client.submit_job( - jobName="terminal-run", - jobQueue="localbatch-default", - jobDefinition=arn, - ) - job = _poll_job(batch_client, resp["jobId"]) - assert job["status"] in ("SUCCEEDED", "FAILED") - assert job["jobName"] == "terminal-run" - assert job["jobQueue"] == "localbatch-default" + assert resp["revision"] == expected + - def test_describe_jobs_returns_correct_shape(self, batch_client): - arn = _register(batch_client, "describe-shape-job") - job_id = batch_client.submit_job( - jobName="describe-shape-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] +def test_describe_job_definitions_by_name(batch_client): + _register(batch_client, "describe-by-name-job") + resp = batch_client.describe_job_definitions( + jobDefinitionName="describe-by-name-job", status="ACTIVE" + ) + defs = resp["jobDefinitions"] + assert len(defs) >= 1 + assert defs[0]["jobDefinitionName"] == "describe-by-name-job" - _poll_job(batch_client, job_id) - resp = batch_client.describe_jobs(jobs=[job_id]) - job = resp["jobs"][0] - for field in ( - "jobId", - "jobName", - "jobQueue", - "jobDefinition", - "status", - "createdAt", - ): - assert field in job, f"Missing field: {field}" - - def test_terminate_job_transitions_to_failed(self, batch_client): - arn = _register(batch_client, "terminate-test-job", command=["sleep", "999"]) - job_id = batch_client.submit_job( - jobName="terminate-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - batch_client.terminate_job(jobId=job_id, reason="cancelled by test") - - job = _poll_job(batch_client, job_id) - assert job["status"] == "FAILED" - assert "cancelled by test" in job.get("statusReason", "") - - def test_list_jobs_returns_summary_list(self, batch_client): - resp = batch_client.list_jobs(jobQueue="localbatch-default", jobStatus="FAILED") - assert "jobSummaryList" in resp - for entry in resp["jobSummaryList"]: - assert "jobId" in entry - assert "status" in entry - - def test_ecs_metadata_endpoint(self, localbatch_server): - job_id = "deadbeef-1234-5678-abcd-000000000000" - resp = requests.get(f"{localbatch_server.base_url}/metadata/{job_id}/task") - assert resp.status_code == 200 - data = resp.json() - container = data["Containers"][0] - opts = container["LogOptions"] - assert opts["awslogs-group"] == "/localbatch/batch/job" - assert job_id in opts["awslogs-stream"] - assert container["LogDriver"] == "awslogs" +def test_submit_job_returns_id_and_name(batch_client): + arn = _register(batch_client, "submit-test-job") + resp = batch_client.submit_job( + jobName="submit-test-run", + jobQueue="localbatch-default", + jobDefinition=arn, + ) + assert resp["jobId"] + assert resp["jobName"] == "submit-test-run" + + +def test_job_reaches_terminal_state(batch_client): + """Without Docker the job fails; with Docker it succeeds — both are valid.""" + arn = _register(batch_client, "terminal-test-job") + resp = batch_client.submit_job( + jobName="terminal-run", + jobQueue="localbatch-default", + jobDefinition=arn, + ) + job = _poll_job(batch_client, resp["jobId"]) + assert job["status"] in ("SUCCEEDED", "FAILED") + assert job["jobName"] == "terminal-run" + assert job["jobQueue"] == "localbatch-default" + + +def test_describe_jobs_returns_correct_shape(batch_client): + arn = _register(batch_client, "describe-shape-job") + job_id = batch_client.submit_job( + jobName="describe-shape-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + _poll_job(batch_client, job_id) + + resp = batch_client.describe_jobs(jobs=[job_id]) + job = resp["jobs"][0] + for field in ( + "jobId", + "jobName", + "jobQueue", + "jobDefinition", + "status", + "createdAt", + ): + assert field in job, f"Missing field: {field}" + + +def test_terminate_job_transitions_to_failed(batch_client): + arn = _register(batch_client, "terminate-test-job", command=["sleep", "999"]) + job_id = batch_client.submit_job( + jobName="terminate-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + batch_client.terminate_job(jobId=job_id, reason="cancelled by test") + + job = _poll_job(batch_client, job_id) + assert job["status"] == "FAILED" + assert "cancelled by test" in job.get("statusReason", "") + + +def test_list_jobs_returns_summary_list(batch_client): + resp = batch_client.list_jobs(jobQueue="localbatch-default", jobStatus="FAILED") + assert "jobSummaryList" in resp + for entry in resp["jobSummaryList"]: + assert "jobId" in entry + assert "status" in entry + + +def test_ecs_metadata_endpoint(localbatch_server): + job_id = "deadbeef-1234-5678-abcd-000000000000" + resp = requests.get(f"{localbatch_server.base_url}/metadata/{job_id}/task") + assert resp.status_code == 200 + data = resp.json() + container = data["Containers"][0] + opts = container["LogOptions"] + assert opts["awslogs-group"] == "/localbatch/batch/job" + assert job_id in opts["awslogs-stream"] + assert container["LogDriver"] == "awslogs" # --------------------------------------------------------------------------- -# Docker execution tests +# Docker Execution Tests # --------------------------------------------------------------------------- @pytest.mark.docker -class TestDockerExecution: - @pytest.fixture(autouse=True) - def _require_docker(self): - try: - import docker +def test_successful_container(batch_client): + arn = _register( + batch_client, + "docker-success-job", + command=["sh", "-c", "echo success && exit 0"], + ) + job_id = batch_client.submit_job( + jobName="docker-success-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] - docker.from_env().ping() - except Exception: - pytest.skip("Docker not available") + job = _poll_job(batch_client, job_id, timeout=60) + assert job["status"] == "SUCCEEDED" - def test_successful_container(self, batch_client): - arn = _register( - batch_client, - "docker-success-job", - command=["sh", "-c", "echo success && exit 0"], - ) - job_id = batch_client.submit_job( - jobName="docker-success-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - job = _poll_job(batch_client, job_id, timeout=60) - assert job["status"] == "SUCCEEDED" - - def test_failed_container_reports_exit_code(self, batch_client): - arn = _register( - batch_client, "docker-fail-job", command=["sh", "-c", "exit 42"] - ) - job_id = batch_client.submit_job( - jobName="docker-fail-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - job = _poll_job(batch_client, job_id, timeout=60) - assert job["status"] == "FAILED" - assert job["container"]["exitCode"] == 42 - - def test_inject_env_is_visible_inside_container(self): - """ - Spin up a separate localbatch instance with inject_env set and confirm - the container can see the injected variable. - """ - import threading - - import boto3 - import uvicorn - from localbatch.runner import DockerRunner - from localbatch.server import create_app - from localbatch.store import Store - - port = 18766 - store = Store(queue_name="inject-queue") - runner = DockerRunner( - store, - host_addr="host.docker.internal", - port=port, - inject_env={"LOCALBATCH_CANARY": "canary-value"}, - ) - app = create_app(store, runner) - server = uvicorn.Server( - uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") - ) - t = threading.Thread(target=server.run, daemon=True) - t.start() - - for _ in range(100): - try: - requests.get(f"http://127.0.0.1:{port}/health", timeout=0.1) - break - except Exception: - time.sleep(0.05) - - client = boto3.client( - "batch", - endpoint_url=f"http://127.0.0.1:{port}", - region_name="us-east-1", - aws_access_key_id="test", - aws_secret_access_key="test", - ) - client.register_job_definition( - jobDefinitionName="inject-canary-job", - type="container", - containerProperties={ - "image": "alpine:latest", - "command": [ - "sh", - "-c", - 'test "$LOCALBATCH_CANARY" = "canary-value"', - ], - "resourceRequirements": [ - {"type": "VCPU", "value": "1"}, - {"type": "MEMORY", "value": "256"}, - ], - }, - ) - job_id = client.submit_job( - jobName="inject-canary-run", - jobQueue="inject-queue", - jobDefinition="inject-canary-job:1", - )["jobId"] - job = _poll_job(client, job_id, timeout=60) - assert ( - job["status"] == "SUCCEEDED" - ), "LOCALBATCH_CANARY was not visible or had wrong value inside container" +@pytest.mark.docker +def test_failed_container_reports_exit_code(batch_client): + arn = _register(batch_client, "docker-fail-job", command=["sh", "-c", "exit 42"]) + job_id = batch_client.submit_job( + jobName="docker-fail-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + job = _poll_job(batch_client, job_id, timeout=60) + assert job["status"] == "FAILED" + assert job["container"]["exitCode"] == 42 - server.should_exit = True - t.join(timeout=5) + +@pytest.mark.docker +def test_inject_env_is_visible_inside_container(): + """Spin up a separate localbatch instance with inject_env set and confirm + + the container can see the injected variable. + """ + import threading + import boto3 + import uvicorn + from localbatch.runner import DockerRunner + from localbatch.server import create_app + from localbatch.store import Store + + port = 18766 + store = Store(queue_name="inject-queue") + runner = DockerRunner( + store, + host_addr="host.docker.internal", + port=port, + inject_env={"LOCALBATCH_CANARY": "canary-value"}, + ) + app = create_app(store, runner) + server = uvicorn.Server( + uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + ) + t = threading.Thread(target=server.run, daemon=True) + t.start() + + for _ in range(100): + try: + requests.get(f"http://127.0.0.1:{port}/health", timeout=0.1) + break + except Exception: + time.sleep(0.05) + + client = boto3.client( + "batch", + endpoint_url=f"http://127.0.0.1:{port}", + region_name="us-east-1", + aws_access_key_id="test", + aws_secret_access_key="test", + ) + client.register_job_definition( + jobDefinitionName="inject-canary-job", + type="container", + containerProperties={ + "image": "alpine:latest", + "command": [ + "sh", + "-c", + 'test "$LOCALBATCH_CANARY" = "canary-value"', + ], + "resourceRequirements": [ + {"type": "VCPU", "value": "1"}, + {"type": "MEMORY", "value": "256"}, + ], + }, + ) + job_id = client.submit_job( + jobName="inject-canary-run", + jobQueue="inject-queue", + jobDefinition="inject-canary-job:1", + )["jobId"] + + job = _poll_job(client, job_id, timeout=60) + assert ( + job["status"] == "SUCCEEDED" + ), "LOCALBATCH_CANARY was not visible or had wrong value inside container" + + server.should_exit = True + t.join(timeout=5) # --------------------------------------------------------------------------- -# Metaflow end-to-end test (mirrors spin test style) +# Metaflow End-to-End Tests (Mirrors Spin Test Style) # --------------------------------------------------------------------------- @@ -341,37 +354,25 @@ def test_inject_env_is_visible_inside_container(self): @pytest.mark.docker -class TestMetaflowE2E: - """ - Runs a real Metaflow flow against localbatch and verifies that artifacts - produced inside the @batch step are correctly persisted and readable - from the client side — the same contract the spin tests enforce. +@_NEEDS_CORE_BATCH_PARAMS +def test_batch_step_artifacts_are_persisted(simple_batch_run): + """The @batch step writes message='hello from localbatch' and value=42. + + Both must be readable via the Metaflow client after the run finishes. """ + task = simple_batch_run["start"].task + assert task["message"].data == "hello from localbatch" + assert task["value"].data == 42 - @pytest.fixture(autouse=True) - def _require_docker(self): - try: - import docker - docker.from_env().ping() - except Exception: - pytest.skip("Docker not available") +@pytest.mark.docker +@_NEEDS_CORE_BATCH_PARAMS +def test_run_succeeds(simple_batch_run): + assert simple_batch_run.successful - @_NEEDS_CORE_BATCH_PARAMS - def test_batch_step_artifacts_are_persisted(self, simple_batch_run): - """ - The @batch step writes message='hello from localbatch' and value=42. - Both must be readable via the Metaflow client after the run finishes. - """ - task = simple_batch_run["start"].task - assert task["message"].data == "hello from localbatch" - assert task["value"].data == 42 - - @_NEEDS_CORE_BATCH_PARAMS - def test_run_succeeds(self, simple_batch_run): - assert simple_batch_run.successful - - @_NEEDS_CORE_BATCH_PARAMS - def test_all_steps_have_tasks(self, simple_batch_run): - step_names = {step.id for step in simple_batch_run.steps()} - assert {"start", "end"} <= step_names + +@pytest.mark.docker +@_NEEDS_CORE_BATCH_PARAMS +def test_all_steps_have_tasks(simple_batch_run): + step_names = {step.id for step in simple_batch_run.steps()} + assert {"start", "end"} <= step_names diff --git a/test/unit/mutators/test_add_decorator_returns.py b/test/unit/mutators/test_add_decorator_returns.py index c7143b4d1f5..92a57bc9b25 100644 --- a/test/unit/mutators/test_add_decorator_returns.py +++ b/test/unit/mutators/test_add_decorator_returns.py @@ -1,26 +1,40 @@ """Tests that MutableStep.add_decorator returns the decorator instance.""" +import pytest -class TestAddDecoratorReturns: - def test_flow_completes(self, add_decorator_return_run): - assert add_decorator_return_run.successful - def test_returned_decorator_is_not_none(self, add_decorator_return_run): - task = add_decorator_return_run["start"].task - assert task.data.returned_is_none is False +@pytest.fixture +def start_task_data(add_decorator_return_run): + """Shared setup: Extracts the data payload from the start task of the test run.""" + return add_decorator_return_run["start"].task.data - def test_returned_decorator_has_name(self, add_decorator_return_run): - task = add_decorator_return_run["start"].task - assert task.data.returned_has_name is True - def test_decorator_was_applied(self, add_decorator_return_run): - task = add_decorator_return_run["start"].task - assert task.data.added_var == "from_mutator" +def test_add_decorator_flow_completes_successfully(add_decorator_return_run): + """Test that the flow modified with add_decorator completes without errors.""" + assert add_decorator_return_run.successful - def test_duplicate_ignore_returns_none(self, add_decorator_return_run): - task = add_decorator_return_run["start"].task - assert task.data.duplicate_is_none is True - def test_duplicate_was_not_applied(self, add_decorator_return_run): - task = add_decorator_return_run["start"].task - assert task.data.should_not_exist is None +def test_add_decorator_returns_instance(start_task_data): + """Test that add_decorator returns an actual object, not None.""" + assert start_task_data.returned_is_none is False + + +def test_returned_decorator_instance_has_name_attribute(start_task_data): + """Test that the returned decorator instance has the expected properties.""" + assert start_task_data.returned_has_name is True + + +def test_added_decorator_executes_and_sets_data(start_task_data): + """Test that the dynamically added decorator was actually executed during the run.""" + # Verify that the variable injected by the dynamically added decorator exists + assert start_task_data.added_var == "from_mutator" + + +def test_adding_duplicate_decorator_with_ignore_returns_none(start_task_data): + """Test that attempting to add a duplicate decorator (when ignored) returns None.""" + assert start_task_data.duplicate_is_none is True + + +def test_ignored_duplicate_decorator_does_not_execute(start_task_data): + """Test that the ignored duplicate decorator does not apply its logic.""" + assert start_task_data.should_not_exist is None diff --git a/test/unit/mutators/test_dual_inheritance.py b/test/unit/mutators/test_dual_inheritance.py index ec7204fd778..0127a1cf487 100644 --- a/test/unit/mutators/test_dual_inheritance.py +++ b/test/unit/mutators/test_dual_inheritance.py @@ -1,27 +1,43 @@ """Tests for dual UserStepDecorator + StepMutator inheritance.""" +import pytest -class TestDualInheritance: - def test_flow_completes(self, dual_inherit_run): - assert dual_inherit_run.successful - - def test_pre_mutate_ran(self, dual_inherit_run): - """pre_mutate() should have added the environment variable.""" - task = dual_inherit_run["start"].task - assert task.data.pre_mutate_env_var == "pre_mutate_ran" - - def test_mutate_ran(self, dual_inherit_run): - """mutate() should have added the environment variable.""" - task = dual_inherit_run["start"].task - assert task.data.mutate_env_var == "hello" - - def test_pre_step_ran(self, dual_inherit_run): - """pre_step() should have set the artifact.""" - task = dual_inherit_run["start"].task - assert task.data.pre_step_ran is True - - def test_post_step_ran(self, dual_inherit_run): - """post_step() should have set the artifact on the start step, - visible in the end step via data propagation.""" - task = dual_inherit_run["end"].task - assert task.data.post_step_ran is True + +@pytest.fixture +def start_task_data(dual_inherit_run): + """Shared setup: Extracts the data payload from the start task of the test run.""" + return dual_inherit_run["start"].task.data + + +@pytest.fixture +def end_task_data(dual_inherit_run): + """Shared setup: Extracts the data payload from the end task of the test run.""" + return dual_inherit_run["end"].task.data + + +def test_dual_inheritance_flow_completes_successfully(dual_inherit_run): + """Test that a flow using a decorator with dual inheritance runs to completion.""" + assert dual_inherit_run.successful + + +def test_pre_mutate_hook_adds_environment_variable(start_task_data): + """Test that the pre_mutate() hook correctly injects its environment variable.""" + assert start_task_data.pre_mutate_env_var == "pre_mutate_ran" + + +def test_mutate_hook_adds_environment_variable(start_task_data): + """Test that the mutate() hook correctly injects its environment variable.""" + assert start_task_data.mutate_env_var == "hello" + + +def test_pre_step_hook_sets_artifact(start_task_data): + """Test that the pre_step() hook executes and successfully sets an artifact.""" + assert start_task_data.pre_step_ran is True + + +def test_post_step_hook_sets_artifact_visible_downstream(end_task_data): + """ + Test that post_step() sets an artifact on the start step, + and verifies it is visible in the end step via data propagation. + """ + assert end_task_data.post_step_ran is True diff --git a/test/unit/mutators/test_flow_mutator_addition.py b/test/unit/mutators/test_flow_mutator_addition.py index c7bafd7b042..0262e176af1 100644 --- a/test/unit/mutators/test_flow_mutator_addition.py +++ b/test/unit/mutators/test_flow_mutator_addition.py @@ -1,35 +1,40 @@ """Tests for dynamically adding FlowMutators via MutableFlow.add_decorator.""" +# --- Adding a FlowMutator by class reference --- -class TestDynamicFlowMutatorAddition: - """Adding a FlowMutator by class reference.""" - def test_flow_completes(self, dynamic_flow_mutator_run): - assert dynamic_flow_mutator_run.successful +def test_dynamic_flow_mutator_completes(dynamic_flow_mutator_run): + """Test that the flow completes successfully when adding a FlowMutator by class reference.""" + assert dynamic_flow_mutator_run.successful - def test_inner_pre_mutate_ran(self, dynamic_flow_mutator_run): - """InnerMutator.pre_mutate should have been called by the ongoing iteration.""" - task = dynamic_flow_mutator_run["start"].task - assert task.data.inner_pre == "inner_pre_mutate_ran" - def test_inner_mutate_ran(self, dynamic_flow_mutator_run): - """InnerMutator.mutate should also have been called.""" - task = dynamic_flow_mutator_run["start"].task - assert task.data.inner_mutate == "inner_mutate_ran" +def test_dynamic_flow_mutator_inner_pre_mutate_ran(dynamic_flow_mutator_run): + """Test InnerMutator.pre_mutate is called by the ongoing iteration.""" + task = dynamic_flow_mutator_run["start"].task + assert task.data.inner_pre == "inner_pre_mutate_ran" -class TestStringFlowMutatorAddition: - """Adding a FlowMutator by string name with arguments.""" +def test_dynamic_flow_mutator_inner_mutate_ran(dynamic_flow_mutator_run): + """Test InnerMutator.mutate is also called by the ongoing iteration.""" + task = dynamic_flow_mutator_run["start"].task + assert task.data.inner_mutate == "inner_mutate_ran" - def test_flow_completes(self, string_flow_mutator_run): - assert string_flow_mutator_run.successful - def test_string_mutator_pre_mutate_ran(self, string_flow_mutator_run): - """StringAddedMutator.pre_mutate should have run with the parsed arg.""" - task = string_flow_mutator_run["start"].task - assert task.data.string_tag == "from_string" +# --- Adding a FlowMutator by string name with arguments --- - def test_string_mutator_mutate_ran(self, string_flow_mutator_run): - """StringAddedMutator.mutate should also have been called.""" - task = string_flow_mutator_run["start"].task - assert task.data.string_mutate == "yes" + +def test_string_flow_mutator_completes(string_flow_mutator_run): + """Test that the flow completes successfully when adding a FlowMutator by string name.""" + assert string_flow_mutator_run.successful + + +def test_string_flow_mutator_pre_mutate_ran(string_flow_mutator_run): + """Test StringAddedMutator.pre_mutate runs with the parsed arg.""" + task = string_flow_mutator_run["start"].task + assert task.data.string_tag == "from_string" + + +def test_string_flow_mutator_mutate_ran(string_flow_mutator_run): + """Test StringAddedMutator.mutate is also called.""" + task = string_flow_mutator_run["start"].task + assert task.data.string_mutate == "yes" diff --git a/test/unit/mutators/test_post_step_none_false.py b/test/unit/mutators/test_post_step_none_false.py index e19eeb5f130..3f906734dea 100644 --- a/test/unit/mutators/test_post_step_none_false.py +++ b/test/unit/mutators/test_post_step_none_false.py @@ -1,18 +1,20 @@ """Regression test for post_step returning (None, False) being a no-op.""" -class TestPostStepNoneFalse: - def test_flow_completes(self, post_step_none_false_run): - """Run completes successfully rather than hitting RuntimeError at - task.py's `Invalid value passed to self.next` branch.""" - assert post_step_none_false_run.successful +def test_post_step_none_false_completes(post_step_none_false_run): + """Run completes successfully rather than hitting RuntimeError at + task.py's `Invalid value passed to self.next` branch.""" + assert post_step_none_false_run.successful - def test_pre_step_ran(self, post_step_none_false_run): - task = post_step_none_false_run["start"].task - assert task.data.pre_step_ran is True - def test_post_step_ran(self, post_step_none_false_run): - """post_step ran and its (None, False) return value was accepted as - a no-op (visible in the end step via data propagation).""" - task = post_step_none_false_run["end"].task - assert task.data.post_step_ran is True +def test_post_step_none_false_pre_step_ran(post_step_none_false_run): + """Verify that the pre_step was executed on the start task.""" + task = post_step_none_false_run["start"].task + assert task.data.pre_step_ran is True + + +def test_post_step_none_false_post_step_ran(post_step_none_false_run): + """post_step ran and its (None, False) return value was accepted as + a no-op (visible in the end step via data propagation).""" + task = post_step_none_false_run["end"].task + assert task.data.post_step_ran is True diff --git a/test/unit/mutators/test_remove_decorator_guard.py b/test/unit/mutators/test_remove_decorator_guard.py index 6cfcc2b2ec3..f888636e546 100644 --- a/test/unit/mutators/test_remove_decorator_guard.py +++ b/test/unit/mutators/test_remove_decorator_guard.py @@ -34,35 +34,42 @@ def end(self): pass -def _make_mutable_step(pre_mutate): - cls = FlowWithStepMutator - step_obj = getattr(cls, "start") - return MutableStep( - cls, - step_obj, - pre_mutate=pre_mutate, - statically_defined=True, - inserted_by=["test"], - ) - - -class TestRemoveDecoratorGuard: - def test_do_all_from_mutate_raises(self): - """Calling remove_decorator(name) without args/kwargs (do_all=True) - from a non-pre_mutate MutableStep must raise on a StepMutator.""" - ms = _make_mutable_step(pre_mutate=False) - with pytest.raises(MetaflowException, match="only allowed in the `pre_mutate`"): - ms.remove_decorator("dummy_step_mutator") - - def test_specific_match_from_mutate_raises(self): - """Same guard applies when an explicit deco_args/deco_kwargs match - is provided (this path already worked before the fix; keep covered).""" - ms = _make_mutable_step(pre_mutate=False) - with pytest.raises(MetaflowException, match="only allowed in the `pre_mutate`"): - ms.remove_decorator("dummy_step_mutator", deco_args=[], deco_kwargs={}) - - def test_do_all_from_pre_mutate_succeeds(self): - """From pre_mutate, removing the StepMutator via do_all should work.""" - ms = _make_mutable_step(pre_mutate=True) - removed = ms.remove_decorator("dummy_step_mutator") - assert removed is True +@pytest.fixture +def make_mutable_step(): + """Factory fixture to create MutableStep instances.""" + + def _make(pre_mutate): + cls = FlowWithStepMutator + step_obj = getattr(cls, "start") + return MutableStep( + cls, + step_obj, + pre_mutate=pre_mutate, + statically_defined=True, + inserted_by=["test"], + ) + + return _make + + +def test_remove_decorator_do_all_from_mutate_raises(make_mutable_step): + """Calling remove_decorator(name) without args/kwargs (do_all=True) + from a non-pre_mutate MutableStep must raise on a StepMutator.""" + ms = make_mutable_step(pre_mutate=False) + with pytest.raises(MetaflowException, match="only allowed in the `pre_mutate`"): + ms.remove_decorator("dummy_step_mutator") + + +def test_remove_decorator_specific_match_from_mutate_raises(make_mutable_step): + """Same guard applies when an explicit deco_args/deco_kwargs match + is provided (this path already worked before the fix; keep covered).""" + ms = make_mutable_step(pre_mutate=False) + with pytest.raises(MetaflowException, match="only allowed in the `pre_mutate`"): + ms.remove_decorator("dummy_step_mutator", deco_args=[], deco_kwargs={}) + + +def test_remove_decorator_do_all_from_pre_mutate_succeeds(make_mutable_step): + """From pre_mutate, removing the StepMutator via do_all should work.""" + ms = make_mutable_step(pre_mutate=True) + removed = ms.remove_decorator("dummy_step_mutator") + assert removed is True diff --git a/test/unit/mutators/test_string_step_mutator.py b/test/unit/mutators/test_string_step_mutator.py index ca205c59534..e5a65565f28 100644 --- a/test/unit/mutators/test_string_step_mutator.py +++ b/test/unit/mutators/test_string_step_mutator.py @@ -1,12 +1,13 @@ """Tests for string-based StepMutator addition via MutableStep.add_decorator.""" -class TestStringStepMutatorAddition: - def test_flow_completes(self, string_step_mutator_run): - assert string_step_mutator_run.successful - - def test_string_step_mutator_ran(self, string_step_mutator_run): - """StepMutator added by string should have its mutate() called, - adding the environment variable.""" - task = string_step_mutator_run["start"].task - assert task.data.step_mutator_val == "from_string" +def test_string_step_mutator_flow_completes(string_step_mutator_run): + """Test that the flow completes successfully when adding a StepMutator by string.""" + assert string_step_mutator_run.successful + + +def test_string_step_mutator_ran(string_step_mutator_run): + """StepMutator added by string should have its mutate() called, + adding the environment variable.""" + task = string_step_mutator_run["start"].task + assert task.data.step_mutator_val == "from_string" diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 5ab193bb7de..a41279d00de 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -1,8 +1,14 @@ +import os +import tempfile import pytest + from metaflow import Runner -import os from spin_test_helpers import assert_artifacts, run_step, FLOWS_DIR, ARTIFACTS_DIR +# --------------------------------------------------------------------------- +# Simple Flow Tests +# --------------------------------------------------------------------------- + @pytest.mark.parametrize( "flow_file,fixture_name", @@ -14,67 +20,81 @@ ], ids=["merge_artifacts", "simple_config", "simple_parameter", "complex_dag"], ) -def test_simple_flows(flow_file, fixture_name, request): - """Test simple flows that just need artifact validation.""" +def test_simple_flows_validate_artifacts(flow_file, fixture_name, request): + """Test that basic flows run steps correctly and validate their artifacts.""" run = request.getfixturevalue(fixture_name) - print(f"Running test for {flow_file}: {run}") + + # Act & Assert: Iterate through and run each step for step in run.steps(): - print("-" * 100) if fixture_name == "complex_dag_run": run_step(flow_file, run, step.id, environment="conda") else: run_step(flow_file, run, step.id) -def test_artifacts_module(complex_dag_run): - print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") +# --------------------------------------------------------------------------- +# Artifacts Module Tests +# --------------------------------------------------------------------------- + + +def test_artifacts_module_evaluates_correctly(complex_dag_run): + """Test that an external artifacts module correctly injects state into a spun step.""" + # Setup step_name = "step_a" task = complex_dag_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") artifacts_path = os.path.join(ARTIFACTS_DIR, "complex_dag_step_a.py") + # Act with Runner(flow_path, cwd=FLOWS_DIR, environment="conda").spin( task.pathspec, artifacts_module=artifacts_path, persist=True, ) as spin: - print("-" * 50) - print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + + # Assert spin_task = spin.task - print(f"my_output: {spin_task['my_output']}") assert spin_task["my_output"].data == [10, 11, 12, 3] -def test_artifacts_module_join_step( +def test_artifacts_module_injects_dynamic_data_in_join_step( complex_dag_run, complex_dag_step_d_artifacts, tmp_path ): - print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + """Test that dynamically generated artifacts are correctly loaded during a join step.""" + # Setup step_name = "step_d" task = complex_dag_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") - # Create a temporary artifacts file with dynamic data + # Setup: Create a temporary artifacts file with dynamic data temp_artifacts_file = tmp_path / "temp_complex_dag_step_d.py" temp_artifacts_file.write_text(f"ARTIFACTS = {repr(complex_dag_step_d_artifacts)}") + # Act with Runner(flow_path, cwd=FLOWS_DIR, environment="conda").spin( task.pathspec, artifacts_module=str(temp_artifacts_file), persist=True, ) as spin: - print("-" * 50) - print(f"Running test for step: step_d with task pathspec: {task.pathspec}") + + # Assert spin_task = spin.task assert spin_task["my_output"].data == [-1] -def test_timeout_decorator_enforcement(simple_config_run): - """Test that timeout decorator properly enforces timeout limits.""" +# --------------------------------------------------------------------------- +# Decorator & Config Tests +# --------------------------------------------------------------------------- + + +def test_timeout_decorator_enforces_time_limit(simple_config_run): + """Test that a configured timeout decorator stops execution and raises an exception.""" + # Setup step_name = "start" task = simple_config_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") - # With decorator enabled (should timeout and raise exception) + # Act & Assert with pytest.raises(Exception): with Runner( flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] @@ -85,13 +105,14 @@ def test_timeout_decorator_enforcement(simple_config_run): pass -def test_skip_decorators_bypass(simple_config_run): - """Test that skip_decorators successfully bypasses timeout decorator.""" +def test_skip_decorators_bypasses_timeout(simple_config_run): + """Test that using skip_decorators=True successfully ignores the timeout limit.""" + # Setup step_name = "start" task = simple_config_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") - # With skip_decorators=True (should succeed despite timeout) + # Act with Runner( flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] ).spin( @@ -99,16 +120,18 @@ def test_skip_decorators_bypass(simple_config_run): skip_decorators=True, persist=True, ) as spin: - print(f"Running test for step: {step_name} with skip_decorators=True") - # Should complete successfully even though sleep(5) > timeout(2) - spin_task = spin.task - assert spin_task.finished + + # Assert: Should complete successfully even though step length > timeout + assert spin.task.finished def test_spin_preserves_explicit_top_level_decospecs(spin_decospec_run): + """Test that spin respects top-level decorator specifications provided to Runner.""" + # Setup task = spin_decospec_run["start"].task flow_path = os.path.join(FLOWS_DIR, "spin_decospec_flow.py") + # Act & Assert with pytest.raises(Exception, match="timed out"): with Runner( flow_path, @@ -123,10 +146,13 @@ def test_spin_preserves_explicit_top_level_decospecs(spin_decospec_run): pass -def test_spin_step_does_not_apply_default_decospecs(spin_decospec_run): +def test_spin_step_ignores_default_decospecs(spin_decospec_run): + """Test that spin does NOT inadvertently apply METAFLOW_DEFAULT_DECOSPECS.""" + # Setup task = spin_decospec_run["start"].task flow_path = os.path.join(FLOWS_DIR, "spin_decospec_flow.py") + # Act with Runner( flow_path, cwd=FLOWS_DIR, @@ -137,44 +163,59 @@ def test_spin_step_does_not_apply_default_decospecs(spin_decospec_run): task.pathspec, persist=True, ) as spin: + + # Assert assert spin.task.finished assert spin.task["done"].data is True -def test_hidden_artifacts(simple_parameter_run): - """Test simple flows that just need artifact validation.""" +# --------------------------------------------------------------------------- +# Internal State & Integration Tests +# --------------------------------------------------------------------------- + + +def test_spin_persists_internal_hidden_artifacts(simple_parameter_run): + """Test that spinning a task retains internal Metaflow graph and state artifacts.""" + # Setup step_name = "start" task = simple_parameter_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") - print(f"Running test for hidden artifacts in {flow_path}: {simple_parameter_run}") + # Act with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec, persist=True) as spin: spin_task = spin.task + + # Assert assert "_graph_info" in spin_task assert "_foreach_stack" in spin_task -def test_card_flow(simple_card_run): - """Test a simple flow that has @card decorator.""" +def test_spin_generates_cards_correctly(simple_card_run): + """Test that spinning a flow with the @card decorator successfully outputs cards.""" + # Setup + from metaflow.cards import get_cards + step_name = "start" task = simple_card_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_card_flow.py") - print(f"Running test for cards in {flow_path}: {simple_card_run}") + # Act with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec, persist=True) as spin: - spin_task = spin.task - from metaflow.cards import get_cards + res = get_cards(spin.task, follow_resumed=False) - res = get_cards(spin_task, follow_resumed=False) - print(res) + # Assert + assert res is not None, "Cards should be generated and retrievable" + # Optional: assert len(res) > 0 if you expect a specific number of cards -def test_spin_with_parameters_raises_error(simple_parameter_run): - """Test that passing flow parameters to spin raises an error.""" +def test_spin_with_flow_parameters_raises_error(simple_parameter_run): + """Test that passing standard flow parameters to spin() raises an Unknown argument error.""" + # Setup step_name = "start" task = simple_parameter_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + # Act & Assert with pytest.raises(Exception, match="Unknown argument"): with Runner(flow_path, cwd=FLOWS_DIR).spin( task.pathspec, @@ -184,41 +225,42 @@ def test_spin_with_parameters_raises_error(simple_parameter_run): pass -# NOTE: This test has to be the last test because it modifies the metadata -# provider when calling inspect_spin -def test_inspect_spin_client_access(simple_parameter_run): - """Test accessing spin artifacts using inspect_spin client directly.""" +# --------------------------------------------------------------------------- +# WARNING: State-Modifying Test +# This test modifies the global metadata provider via `inspect_spin`. +# It is kept at the bottom of the file to prevent side-effects on other tests. +# --------------------------------------------------------------------------- + + +def test_inspect_spin_client_allows_artifact_access(simple_parameter_run): + """Test accessing spun artifacts directly using the inspect_spin client.""" + # Setup from metaflow import inspect_spin, Task - import tempfile step_name = "start" task = simple_parameter_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") with tempfile.TemporaryDirectory() as _: - # Run spin to generate artifacts + # Setup: Run spin to generate artifacts with Runner(flow_path, cwd=FLOWS_DIR).spin( task.pathspec, persist=True, ) as spin: spin_task = spin.task spin_pathspec = spin_task.pathspec + assert spin_task["a"] is not None assert spin_task["b"] is not None + assert spin_pathspec is not None - assert spin_pathspec is not None - - # Set metadata provider to spin + # Act: Set metadata provider to spin inspect_spin(FLOWS_DIR) client_task = Task(spin_pathspec, _namespace_check=False) - # Verify task is accessible + # Assert: Verify task and artifacts are accessible via client assert client_task is not None - - # Verify artifacts assert hasattr(client_task, "artifacts") - - # Verify artifact data assert client_task.artifacts.a.data == 10 assert client_task.artifacts.b.data == 20 assert client_task.artifacts.alpha.data == 0.05 diff --git a/test/unit/test_add_to_package.py b/test/unit/test_add_to_package.py index e6106bc1b32..b2e33159f83 100644 --- a/test/unit/test_add_to_package.py +++ b/test/unit/test_add_to_package.py @@ -6,11 +6,6 @@ import os import sys -import tempfile -from types import ModuleType -from unittest import mock -from unittest.mock import MagicMock, call - import pytest from metaflow.package import ( @@ -21,174 +16,274 @@ # --------------------------------------------------------------------------- -# Helpers +# Fixtures & Factories # --------------------------------------------------------------------------- -def _make_step(decorators=None, config_decorators=None): - step = MagicMock() - step.decorators = decorators or [] - step.config_decorators = config_decorators or [] - return step +@pytest.fixture +def make_step(mocker): + """Factory fixture to create a mocked flow step.""" + def _make(decorators=None, config_decorators=None): + step = mocker.MagicMock() + step.decorators = decorators or [] + step.config_decorators = config_decorators or [] + return step -def _make_flow(steps, flow_decorators=None, flow_mutators=None): - flow = MagicMock() - # The flow may be iterated multiple times (step decos + step mutators), - # so return a fresh iterator each time. - flow.__iter__ = lambda self: iter(steps) - flow._flow_decorators = flow_decorators or {} - flow._flow_mutators = flow_mutators or [] - return flow + return _make -def _make_environment(tuples=None): - env = MagicMock() - env.add_to_package.return_value = tuples or [] - return env +@pytest.fixture +def make_flow(mocker): + """Factory fixture to create a mocked flow with steps.""" + def _make(steps, flow_decorators=None, flow_mutators=None): + flow = mocker.MagicMock() + # The flow may be iterated multiple times, so return a fresh iterator + flow.__iter__ = lambda self: iter(steps) + flow._flow_decorators = flow_decorators or {} + flow._flow_mutators = flow_mutators or [] + return flow -def _make_mfcontent(): - return MagicMock() + return _make -def _make_deco(tuples): - """Create a mock decorator-like object with an add_to_package method.""" - deco = MagicMock() - deco.add_to_package.return_value = tuples - return deco +@pytest.fixture +def make_environment(mocker): + """Factory fixture to create a mocked environment.""" + def _make(tuples=None): + env = mocker.MagicMock() + env.add_to_package.return_value = tuples or [] + return env -def _build_pkg(flow, environment, mfcontent): - """Build a bare MetaflowPackage instance with minimal state.""" - pkg = object.__new__(MetaflowPackage) - pkg._flow = flow - pkg._environment = environment - pkg._mfcontent = mfcontent - pkg._user_content_from_addl = {} - return pkg + return _make -def _call_add_addl_files(flow, environment, mfcontent): - """Call _add_addl_files on a bare MetaflowPackage instance.""" - pkg = _build_pkg(flow, environment, mfcontent) - pkg._add_addl_files() - return mfcontent +@pytest.fixture +def make_deco(mocker): + """Factory fixture to create a mocked decorator with add_to_package.""" + + def _make(tuples): + deco = mocker.MagicMock() + deco.add_to_package.return_value = tuples + return deco + + return _make + + +@pytest.fixture +def mfcontent(mocker): + """Fixture to provide a fresh mfcontent mock per test.""" + return mocker.MagicMock() + + +@pytest.fixture +def build_pkg(mocker): + """Factory fixture to build a minimal MetaflowPackage instance.""" + + def _build( + flow=None, + environment=None, + mfcontent_mock=None, + user_content=None, + flow_dir=None, + ): + pkg = object.__new__(MetaflowPackage) + pkg._flow = flow + pkg._environment = environment + pkg._mfcontent = mfcontent_mock + pkg._user_content_from_addl = user_content or {} + + # State necessary for _user_code_tuples + pkg._user_code_filter = lambda _name: True + pkg._exclude_tl_dirs = [] + pkg._user_flow_dir = str(flow_dir) if flow_dir else None + return pkg + + return _build + + +@pytest.fixture +def setup_flow_dir(tmp_path): + """Creates a fake flow directory structure for walker testing.""" + + def _setup(*relpaths): + flow_file = tmp_path / "flow.py" + flow_file.write_text("# flow\n") + + for rel in relpaths: + p = tmp_path / rel + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("x\n") + + return flow_file, tmp_path + + return _setup # --------------------------------------------------------------------------- -# Tests +# Tests: _add_addl_files hooks & logic # --------------------------------------------------------------------------- -def test_flow_decorator_add_to_package(): - """Flow decorator's add_to_package files are added.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f: - deco = _make_deco([(f.name, "flow_deco_file.py", ContentType.CODE_CONTENT)]) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"my_deco": [deco]}, - ) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - mc.add_code_file.assert_called_once_with( - os.path.realpath(f.name), "flow_deco_file.py" - ) +def test_flow_decorator_files_are_added_to_code_content( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """Flow decorator's add_to_package files are correctly routed to add_code_file.""" + target_file = tmp_path / "flow_deco_file.py" + target_file.touch() + + deco = make_deco( + [(str(target_file), "flow_deco_file.py", ContentType.CODE_CONTENT)] + ) + flow = make_flow(steps=[make_step()], flow_decorators={"my_deco": [deco]}) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + mfcontent.add_code_file.assert_called_once_with( + os.path.realpath(target_file), "flow_deco_file.py" + ) -def test_flow_mutator_add_to_package_module(): +def test_flow_mutator_module_content_calls_add_module( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg +): """Flow mutator's add_to_package with MODULE_CONTENT calls add_module.""" import json - mutator = _make_deco([(json, None, ContentType.MODULE_CONTENT)]) - flow = _make_flow(steps=[_make_step()], flow_mutators=[mutator]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - mc.add_module.assert_called_once_with(json) + mutator = make_deco([(json, None, ContentType.MODULE_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_mutators=[mutator]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + mfcontent.add_module.assert_called_once_with(json) -def test_step_mutator_deduplicated_across_steps(): + +def test_step_mutator_deduplicates_same_instance_across_steps( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): """Same StepMutator instance on two steps: add_to_package called once.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f: - mutator = _make_deco([(f.name, "shared.py", ContentType.CODE_CONTENT)]) - step1 = _make_step(config_decorators=[mutator]) - step2 = _make_step(config_decorators=[mutator]) - flow = _make_flow(steps=[step1, step2]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - assert mutator.add_to_package.call_count == 1 - mc.add_code_file.assert_called_once() - - -def test_step_mutator_distinct_instances(): - """Two different StepMutator instances: both called.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f1, tempfile.NamedTemporaryFile( - suffix=".py" - ) as f2: - m1 = _make_deco([(f1.name, "file1.py", ContentType.CODE_CONTENT)]) - m2 = _make_deco([(f2.name, "file2.py", ContentType.CODE_CONTENT)]) - step = _make_step(config_decorators=[m1, m2]) - flow = _make_flow(steps=[step]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - assert m1.add_to_package.call_count == 1 - assert m2.add_to_package.call_count == 1 - assert mc.add_code_file.call_count == 2 - - -def test_legacy_two_tuple_defaults_to_code_content(): - """A 2-tuple (file_path, arcname) is treated as CODE_CONTENT.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f: - deco = _make_deco([(f.name, "legacy.py")]) - step = _make_step(decorators=[deco]) - flow = _make_flow(steps=[step]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - mc.add_code_file.assert_called_once_with(os.path.realpath(f.name), "legacy.py") - - -def test_non_unique_filename_raises(): - """Different file paths for the same arcname raises an exception.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f1, tempfile.NamedTemporaryFile( - suffix=".py" - ) as f2: - d1 = _make_deco([(f1.name, "same_name.py", ContentType.CODE_CONTENT)]) - d2 = _make_deco([(f2.name, "same_name.py", ContentType.CODE_CONTENT)]) - step = _make_step(decorators=[d1, d2]) - flow = _make_flow(steps=[step]) - with pytest.raises(NonUniqueFileNameToFilePathMappingException): - _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - - -def test_module_content_deduplicated(): - """Same module returned by two decorators: add_module called once.""" + target_file = tmp_path / "shared.py" + target_file.touch() + + mutator = make_deco([(str(target_file), "shared.py", ContentType.CODE_CONTENT)]) + step1 = make_step(config_decorators=[mutator]) + step2 = make_step(config_decorators=[mutator]) + flow = make_flow(steps=[step1, step2]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + assert mutator.add_to_package.call_count == 1 + mfcontent.add_code_file.assert_called_once() + + +def test_step_mutator_adds_multiple_distinct_instances( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """Two different StepMutator instances: both are called and added.""" + file1, file2 = tmp_path / "file1.py", tmp_path / "file2.py" + file1.touch() + file2.touch() + + m1 = make_deco([(str(file1), "file1.py", ContentType.CODE_CONTENT)]) + m2 = make_deco([(str(file2), "file2.py", ContentType.CODE_CONTENT)]) + step = make_step(config_decorators=[m1, m2]) + flow = make_flow(steps=[step]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + assert m1.add_to_package.call_count == 1 + assert m2.add_to_package.call_count == 1 + assert mfcontent.add_code_file.call_count == 2 + + +def test_legacy_two_tuple_defaults_to_code_content( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """A 2-tuple (file_path, arcname) is gracefully treated as CODE_CONTENT.""" + legacy_file = tmp_path / "legacy.py" + legacy_file.touch() + + deco = make_deco([(str(legacy_file), "legacy.py")]) + step = make_step(decorators=[deco]) + flow = make_flow(steps=[step]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + mfcontent.add_code_file.assert_called_once_with( + os.path.realpath(legacy_file), "legacy.py" + ) + + +def test_non_unique_filename_to_arcname_raises_exception( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """Different file paths targeting the same arcname raise a NonUniqueFileNameToFilePathMappingException.""" + file1, file2 = tmp_path / "f1.py", tmp_path / "f2.py" + file1.touch() + file2.touch() + + d1 = make_deco([(str(file1), "same_name.py", ContentType.CODE_CONTENT)]) + d2 = make_deco([(str(file2), "same_name.py", ContentType.CODE_CONTENT)]) + step = make_step(decorators=[d1, d2]) + flow = make_flow(steps=[step]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + + with pytest.raises(NonUniqueFileNameToFilePathMappingException): + pkg._add_addl_files() + + +def test_module_content_deduplicates_same_module( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg +): + """Same module returned by two decorators triggers add_module only once.""" import json - d1 = _make_deco([(json, None, ContentType.MODULE_CONTENT)]) - d2 = _make_deco([(json, None, ContentType.MODULE_CONTENT)]) - step = _make_step(decorators=[d1, d2]) - flow = _make_flow(steps=[step]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - mc.add_module.assert_called_once_with(json) + d1 = make_deco([(json, None, ContentType.MODULE_CONTENT)]) + d2 = make_deco([(json, None, ContentType.MODULE_CONTENT)]) + step = make_step(decorators=[d1, d2]) + flow = make_flow(steps=[step]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + mfcontent.add_module.assert_called_once_with(json) -def test_other_content_type(): - """OTHER_CONTENT files are passed to add_other_file.""" - with tempfile.NamedTemporaryFile(suffix=".yaml") as f: - deco = _make_deco([(f.name, "config.yaml", ContentType.OTHER_CONTENT)]) - step = _make_step(decorators=[deco]) - flow = _make_flow(steps=[step]) - mc = _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - mc.add_other_file.assert_called_once_with( - os.path.realpath(f.name), "config.yaml" - ) +def test_other_content_type_routes_to_add_other_file( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """OTHER_CONTENT files are correctly passed to add_other_file.""" + yaml_file = tmp_path / "config.yaml" + yaml_file.touch() + + deco = make_deco([(str(yaml_file), "config.yaml", ContentType.OTHER_CONTENT)]) + step = make_step(decorators=[deco]) + flow = make_flow(steps=[step]) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + mfcontent.add_other_file.assert_called_once_with( + os.path.realpath(yaml_file), "config.yaml" + ) -def test_ordering_flow_decorators_before_step_decorators(): - """Flow decorators are processed before step decorators. - We verify the flow decorator's add_to_package is called before the step - decorator's by checking call order on a shared mock recorder. - """ +def test_flow_decorators_execute_before_step_decorators( + mocker, make_step, make_flow, make_environment, mfcontent, build_pkg, tmp_path +): + """Flow decorators must be processed before step decorators.""" call_order = [] def make_recording_deco(label, tuples): - deco = MagicMock() + deco = mocker.MagicMock() def record(): call_order.append(label) @@ -197,249 +292,221 @@ def record(): deco.add_to_package = record return deco - with tempfile.NamedTemporaryFile(suffix=".py") as f1, tempfile.NamedTemporaryFile( - suffix=".py" - ) as f2: - flow_deco = make_recording_deco( - "flow_deco", [(f1.name, "flow_file.py", ContentType.CODE_CONTENT)] - ) - step_deco = make_recording_deco( - "step_deco", [(f2.name, "step_file.py", ContentType.CODE_CONTENT)] - ) - step = _make_step(decorators=[step_deco]) - flow = _make_flow( - steps=[step], - flow_decorators={"fd": [flow_deco]}, - ) - _call_add_addl_files(flow, _make_environment(), _make_mfcontent()) - assert call_order == ["flow_deco", "step_deco"] - - -def test_user_content_recorded(): - """USER_CONTENT tuples are recorded in _user_content_from_addl (not sent - to _mfcontent which handles code/other files only).""" - with tempfile.NamedTemporaryFile(suffix=".py") as f: - deco = _make_deco([(f.name, "extra.py", ContentType.USER_CONTENT)]) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"my_deco": [deco]}, - ) - mfcontent = _make_mfcontent() - pkg = _build_pkg(flow, _make_environment(), mfcontent) - pkg._add_addl_files() + f1, f2 = tmp_path / "f1.py", tmp_path / "f2.py" + f1.touch() + f2.touch() + + flow_deco = make_recording_deco( + "flow_deco", [(str(f1), "flow_file.py", ContentType.CODE_CONTENT)] + ) + step_deco = make_recording_deco( + "step_deco", [(str(f2), "step_file.py", ContentType.CODE_CONTENT)] + ) - # Not routed to _mfcontent — USER_CONTENT is packaged alongside user code. - mfcontent.add_code_file.assert_not_called() - mfcontent.add_other_file.assert_not_called() - mfcontent.add_module.assert_not_called() + step = make_step(decorators=[step_deco]) + flow = make_flow(steps=[step], flow_decorators={"fd": [flow_deco]}) - assert pkg._user_content_from_addl == {"extra.py": os.path.realpath(f.name)} + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + assert call_order == ["flow_deco", "step_deco"] -def test_user_content_duplicate_same_path_dedup(): - """Same USER_CONTENT arcname with the same path from two decorators: dedup.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f: - d1 = _make_deco([(f.name, "shared.py", ContentType.USER_CONTENT)]) - d2 = _make_deco([(f.name, "shared.py", ContentType.USER_CONTENT)]) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"fd": [d1, d2]}, - ) - pkg = _build_pkg(flow, _make_environment(), _make_mfcontent()) - pkg._add_addl_files() - assert pkg._user_content_from_addl == {"shared.py": os.path.realpath(f.name)} + +def test_user_content_is_recorded_but_not_added_to_mfcontent( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """USER_CONTENT tuples are tracked internally, not sent to _mfcontent.""" + extra_file = tmp_path / "extra.py" + extra_file.touch() + + deco = make_deco([(str(extra_file), "extra.py", ContentType.USER_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_decorators={"my_deco": [deco]}) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + # Verify USER_CONTENT is not routed directly to _mfcontent + mfcontent.add_code_file.assert_not_called() + mfcontent.add_other_file.assert_not_called() + mfcontent.add_module.assert_not_called() + + assert pkg._user_content_from_addl == {"extra.py": os.path.realpath(extra_file)} + + +def test_user_content_duplicates_with_same_path_are_deduplicated( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """Same USER_CONTENT arcname matching the same file path is safely deduplicated.""" + shared_file = tmp_path / "shared.py" + shared_file.touch() + + d1 = make_deco([(str(shared_file), "shared.py", ContentType.USER_CONTENT)]) + d2 = make_deco([(str(shared_file), "shared.py", ContentType.USER_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_decorators={"fd": [d1, d2]}) + + pkg = build_pkg(flow, make_environment(), mfcontent) + pkg._add_addl_files() + + assert pkg._user_content_from_addl == {"shared.py": os.path.realpath(shared_file)} -def test_user_content_duplicate_different_path_raises(): - """Same USER_CONTENT arcname with different paths raises the usual exception.""" - with tempfile.NamedTemporaryFile(suffix=".py") as f1, tempfile.NamedTemporaryFile( - suffix=".py" - ) as f2: - d1 = _make_deco([(f1.name, "shared.py", ContentType.USER_CONTENT)]) - d2 = _make_deco([(f2.name, "shared.py", ContentType.USER_CONTENT)]) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"fd": [d1, d2]}, - ) - with pytest.raises(NonUniqueFileNameToFilePathMappingException): - pkg = _build_pkg(flow, _make_environment(), _make_mfcontent()) - pkg._add_addl_files() +def test_user_content_duplicate_arcnames_with_different_paths_raises( + make_step, make_flow, make_environment, make_deco, mfcontent, build_pkg, tmp_path +): + """Same USER_CONTENT arcname with different backing paths raises an exception.""" + f1, f2 = tmp_path / "f1.py", tmp_path / "f2.py" + f1.touch() + f2.touch() + + d1 = make_deco([(str(f1), "shared.py", ContentType.USER_CONTENT)]) + d2 = make_deco([(str(f2), "shared.py", ContentType.USER_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_decorators={"fd": [d1, d2]}) + + pkg = build_pkg(flow, make_environment(), mfcontent) + + with pytest.raises(NonUniqueFileNameToFilePathMappingException): + pkg._add_addl_files() # --------------------------------------------------------------------------- -# _user_code_tuples — merge of USER_CONTENT with the flow-dir walker (DEF-010) +# Tests: _user_code_tuples merge logic # --------------------------------------------------------------------------- -def _build_pkg_for_user_tuples(tmpdir, user_content_from_addl=None): - """Build a minimal MetaflowPackage with just enough state for - _user_code_tuples(): a flow dir, filter, exclude list, and the - dict of USER_CONTENT files produced by add_to_package.""" - pkg = object.__new__(MetaflowPackage) - pkg._user_code_filter = lambda _name: True - pkg._exclude_tl_dirs = [] - pkg._user_content_from_addl = user_content_from_addl or {} - pkg._user_flow_dir = None - return pkg - - -def _fake_flow_dir(tmpdir, *relpaths): - """Create `flow.py` plus the given relative paths in tmpdir. Returns the - absolute flow.py path.""" - flow_file = os.path.join(tmpdir, "flow.py") - with open(flow_file, "w") as f: - f.write("# flow\n") - for rel in relpaths: - p = os.path.join(tmpdir, rel) - os.makedirs(os.path.dirname(p), exist_ok=True) - with open(p, "w") as f: - f.write("x\n") - return flow_file - - -def test_user_code_tuples_emits_addl_user_content_not_in_walker(): - """A USER_CONTENT file outside the walker's output gets emitted.""" - with tempfile.TemporaryDirectory() as flow_dir, tempfile.NamedTemporaryFile( - suffix=".cfg", delete=False - ) as external: - try: - _fake_flow_dir(flow_dir, "code.py") - pkg = _build_pkg_for_user_tuples( - flow_dir, - user_content_from_addl={"extra.cfg": external.name}, - ) - with mock.patch.object( - sys, "argv", [os.path.join(flow_dir, "flow.py")] - ), mock.patch("metaflow.R.use_r", return_value=False): - tuples = list(pkg._user_code_tuples()) - by_arc = {arc: path for path, arc in tuples} - # walker picked up code.py and flow.py from the flow dir - assert "code.py" in by_arc - assert "flow.py" in by_arc - # external USER_CONTENT was emitted as well - assert by_arc["extra.cfg"] == external.name - finally: - os.unlink(external.name) - - -def test_user_code_tuples_skips_addl_when_walker_already_has_it(): - """USER_CONTENT with same arcname as a walker-yielded file is dropped.""" - with tempfile.TemporaryDirectory() as flow_dir: - _fake_flow_dir(flow_dir, "code.py") - walker_path = os.path.join(flow_dir, "code.py") - # A *different* absolute path but same arcname. The walker wins. - with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as shadow: - try: - pkg = _build_pkg_for_user_tuples( - flow_dir, - user_content_from_addl={"code.py": shadow.name}, - ) - with mock.patch.object( - sys, "argv", [os.path.join(flow_dir, "flow.py")] - ), mock.patch("metaflow.R.use_r", return_value=False): - tuples = list(pkg._user_code_tuples()) - # code.py appears exactly once, with the walker's path - code_py = [t for t in tuples if t[1] == "code.py"] - assert len(code_py) == 1 - assert code_py[0][0] == walker_path - # shadow is never emitted - assert not any(t[0] == shadow.name for t in tuples) - finally: - os.unlink(shadow.name) - - -def test_user_code_tuples_respects_user_code_filter(): - """USER_CONTENT bypasses the suffix/user filter applied to the walker. - - The walker's filter excludes .yaml, but a USER_CONTENT tuple with a .yaml - arcname must still be emitted — this is a primary reason USER_CONTENT - exists. - """ - with tempfile.TemporaryDirectory() as flow_dir: - _fake_flow_dir(flow_dir, "conf.yaml") - yaml_path = os.path.join(flow_dir, "conf.yaml") - pkg = _build_pkg_for_user_tuples( - flow_dir, - user_content_from_addl={"conf.yaml": yaml_path}, - ) - # Restrict walker to .py only — .yaml should not come from the walker. - pkg._user_code_filter = lambda fname: fname.lower().endswith(".py") - with mock.patch.object( - sys, "argv", [os.path.join(flow_dir, "flow.py")] - ), mock.patch("metaflow.R.use_r", return_value=False): - tuples = list(pkg._user_code_tuples()) - by_arc = {arc: path for path, arc in tuples} - assert "conf.yaml" in by_arc - assert by_arc["conf.yaml"] == yaml_path +def test_user_code_tuples_emits_addl_user_content_not_in_walker( + mocker, build_pkg, setup_flow_dir, tmp_path +): + """A USER_CONTENT file outside the walker's normal output gets emitted properly.""" + external_file = tmp_path / "external" / "extra.cfg" + external_file.parent.mkdir() + external_file.touch() + + flow_file, flow_dir = setup_flow_dir("code.py") + + pkg = build_pkg(user_content={"extra.cfg": str(external_file)}, flow_dir=flow_dir) + + mocker.patch.object(sys, "argv", [str(flow_file)]) + mocker.patch("metaflow.R.use_r", return_value=False) + + tuples = list(pkg._user_code_tuples()) + by_arc = {arc: path for path, arc in tuples} + + assert "code.py" in by_arc + assert "flow.py" in by_arc + assert by_arc["extra.cfg"] == str(external_file) + + +def test_user_code_tuples_skips_addl_when_walker_already_has_it( + mocker, build_pkg, setup_flow_dir, tmp_path +): + """USER_CONTENT with the same arcname as a file yielded by the walker drops the duplicate.""" + shadow_file = tmp_path / "shadow" / "code.py" + shadow_file.parent.mkdir() + shadow_file.touch() + + flow_file, flow_dir = setup_flow_dir("code.py") + walker_path = str(flow_dir / "code.py") + + pkg = build_pkg(user_content={"code.py": str(shadow_file)}, flow_dir=flow_dir) + + mocker.patch.object(sys, "argv", [str(flow_file)]) + mocker.patch("metaflow.R.use_r", return_value=False) + + tuples = list(pkg._user_code_tuples()) + + code_py = [t for t in tuples if t[1] == "code.py"] + assert len(code_py) == 1 + assert code_py[0][0] == walker_path + assert not any(t[0] == str(shadow_file) for t in tuples) + + +def test_user_code_tuples_respects_user_code_filter(mocker, build_pkg, setup_flow_dir): + """USER_CONTENT bypasses the suffix/user filter applied to the standard walker.""" + flow_file, flow_dir = setup_flow_dir("conf.yaml") + yaml_path = str(flow_dir / "conf.yaml") + + pkg = build_pkg(user_content={"conf.yaml": yaml_path}, flow_dir=flow_dir) + # Restrict walker strictly to .py files + pkg._user_code_filter = lambda fname: fname.lower().endswith(".py") + + mocker.patch.object(sys, "argv", [str(flow_file)]) + mocker.patch("metaflow.R.use_r", return_value=False) + + tuples = list(pkg._user_code_tuples()) + by_arc = {arc: path for path, arc in tuples} + + assert "conf.yaml" in by_arc + assert by_arc["conf.yaml"] == yaml_path # --------------------------------------------------------------------------- -# Integration: _add_addl_files + _user_code_tuples together (DEF-011) +# Integration Tests # --------------------------------------------------------------------------- -def test_integration_add_addl_then_user_code_tuples_dedupes_by_arcname(): - """End-to-end: decorator emits USER_CONTENT for a file already in the - flow dir; the final user tuples contain it only once with the walker's - path. - """ - with tempfile.TemporaryDirectory() as flow_dir: - _fake_flow_dir(flow_dir, "code.py") - walker_path = os.path.join(flow_dir, "code.py") - - deco = _make_deco([(walker_path, "code.py", ContentType.USER_CONTENT)]) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"fd": [deco]}, - ) - pkg = _build_pkg(flow, _make_environment(), _make_mfcontent()) - pkg._user_code_filter = lambda _: True - pkg._exclude_tl_dirs = [] - pkg._user_flow_dir = None +def test_integration_add_addl_then_user_code_tuples_dedupes_by_arcname( + mocker, + make_step, + make_flow, + make_environment, + make_deco, + mfcontent, + build_pkg, + setup_flow_dir, +): + """End-to-end: Decorator emits USER_CONTENT for an existing flow file; deduplicated correctly.""" + flow_file, flow_dir = setup_flow_dir("code.py") + walker_path = str(flow_dir / "code.py") - # Phase 1: populate _user_content_from_addl from add_to_package hooks. - pkg._add_addl_files() - assert pkg._user_content_from_addl == {"code.py": os.path.realpath(walker_path)} - - # Phase 2: walk the flow dir and merge addl USER_CONTENT. - with mock.patch.object( - sys, "argv", [os.path.join(flow_dir, "flow.py")] - ), mock.patch("metaflow.R.use_r", return_value=False): - tuples = list(pkg._user_code_tuples()) - - # code.py is present exactly once, walker's copy wins (by arcname dedup). - code_py = [t for t in tuples if t[1] == "code.py"] - assert len(code_py) == 1 - - -def test_integration_add_addl_contributes_file_outside_flow_dir(): - """End-to-end: decorator emits USER_CONTENT for a file that is NOT in the - flow dir; it ends up in the user tuples via the merge path. - """ - with tempfile.TemporaryDirectory() as flow_dir, tempfile.NamedTemporaryFile( - suffix=".cfg", delete=False - ) as external: - try: - _fake_flow_dir(flow_dir) - - deco = _make_deco( - [(external.name, "external.cfg", ContentType.USER_CONTENT)] - ) - flow = _make_flow( - steps=[_make_step()], - flow_decorators={"fd": [deco]}, - ) - pkg = _build_pkg(flow, _make_environment(), _make_mfcontent()) - pkg._user_code_filter = lambda _: True - pkg._exclude_tl_dirs = [] - pkg._user_flow_dir = None - - pkg._add_addl_files() - with mock.patch.object( - sys, "argv", [os.path.join(flow_dir, "flow.py")] - ), mock.patch("metaflow.R.use_r", return_value=False): - tuples = list(pkg._user_code_tuples()) - by_arc = {arc: path for path, arc in tuples} - assert by_arc["external.cfg"] == os.path.realpath(external.name) - finally: - os.unlink(external.name) + deco = make_deco([(walker_path, "code.py", ContentType.USER_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_decorators={"fd": [deco]}) + + pkg = build_pkg(flow, make_environment(), mfcontent, flow_dir=flow_dir) + + # Phase 1: Populate _user_content_from_addl + pkg._add_addl_files() + assert pkg._user_content_from_addl == {"code.py": os.path.realpath(walker_path)} + + # Phase 2: Walk the flow dir and merge + mocker.patch.object(sys, "argv", [str(flow_file)]) + mocker.patch("metaflow.R.use_r", return_value=False) + + tuples = list(pkg._user_code_tuples()) + + # code.py is present exactly once, walker's copy wins + code_py = [t for t in tuples if t[1] == "code.py"] + assert len(code_py) == 1 + + +def test_integration_add_addl_contributes_file_outside_flow_dir( + mocker, + make_step, + make_flow, + make_environment, + make_deco, + mfcontent, + build_pkg, + setup_flow_dir, + tmp_path, +): + """End-to-end: Decorator emits USER_CONTENT for a file outside flow dir; merged successfully.""" + external_file = tmp_path / "external" / "external.cfg" + external_file.parent.mkdir() + external_file.touch() + + flow_file, flow_dir = setup_flow_dir() + + deco = make_deco([(str(external_file), "external.cfg", ContentType.USER_CONTENT)]) + flow = make_flow(steps=[make_step()], flow_decorators={"fd": [deco]}) + + pkg = build_pkg(flow, make_environment(), mfcontent, flow_dir=flow_dir) + + pkg._add_addl_files() + + mocker.patch.object(sys, "argv", [str(flow_file)]) + mocker.patch("metaflow.R.use_r", return_value=False) + + tuples = list(pkg._user_code_tuples()) + by_arc = {arc: path for path, arc in tuples} + + assert by_arc["external.cfg"] == os.path.realpath(external_file) diff --git a/test/unit/test_argo_workflows_cli.py b/test/unit/test_argo_workflows_cli.py index a9fa8379132..c91c782f379 100644 --- a/test/unit/test_argo_workflows_cli.py +++ b/test/unit/test_argo_workflows_cli.py @@ -8,37 +8,15 @@ ArgoWorkflowsDeployedFlow, ) - -@pytest.mark.parametrize( - "name, expected", - [ - ("a-valid-name", "a-valid-name"), - ("removing---@+_characters@_+", "removing---characters"), - ("numb3rs-4r3-0k-123", "numb3rs-4r3-0k-123"), - ("proj3ct.br4nch.flow_name", "proj3ct.br4nch.flowname"), - # should not break RFC 1123 subdomain requirements, - # though trailing characters do not need to be sanitized due to a hash being appended to them. - ( - "---1breaking1---.--2subdomain2--.-3rules3----", - "1breaking1.2subdomain2.3rules3----", - ), - ( - "1brea---king1.2sub---domain2.-3ru-les3--", - "1brea---king1.2sub---domain2.3ru-les3--", - ), - ("project.branch-cut-short-.flowname", "project.branch-cut-short.flowname"), - ("test...name", "test.name"), - ], -) -def test_sanitize_for_argo(name, expected): - sanitized = sanitize_for_argo(name) - assert sanitized == expected +# --------------------------------------------------------------------------- +# Shared Fixtures +# --------------------------------------------------------------------------- @pytest.fixture def make_argo_with_schedule(): - """ - Factory fixture: returns a callable that builds a minimal ArgoWorkflows-like + """Factory fixture: returns a callable that builds a minimal ArgoWorkflows-like + object whose _get_schedule() can be called without instantiating the full class (which requires a live graph, environment, datastore, etc.). @@ -67,6 +45,47 @@ def argo_without_schedule(): return instance +# --------------------------------------------------------------------------- +# Argo Sanitization and Scheduling Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, expected", + [ + ("a-valid-name", "a-valid-name"), + ("removing---@+_characters@_+", "removing---characters"), + ("numb3rs-4r3-0k-123", "numb3rs-4r3-0k-123"), + ("proj3ct.br4nch.flow_name", "proj3ct.br4nch.flowname"), + # should not break RFC 1123 subdomain requirements, + # though trailing characters do not need to be sanitized due to a hash being appended to them. + ( + "---1breaking1---.--2subdomain2--.-3rules3----", + "1breaking1.2subdomain2.3rules3----", + ), + ( + "1brea---king1.2sub---domain2.-3ru-les3--", + "1brea---king1.2sub---domain2.3ru-les3--", + ), + ("project.branch-cut-short-.flowname", "project.branch-cut-short.flowname"), + ("test...name", "test.name"), + ], + ids=[ + "valid-string", + "strip-special-chars", + "alphanumeric-with-dashes", + "strip-subdomain-underscores", + "rfc1123-subdomain-edge-dashes", + "rfc1123-subdomain-internal-dashes", + "strip-trailing-dash-before-dot", + "collapse-consecutive-dots", + ], +) +def test_sanitize_for_argo(name, expected): + """Verify string processing safely conforms to Argo resource naming limitations.""" + assert sanitize_for_argo(name) == expected + + def test_get_schedule_no_decorator_returns_none(argo_without_schedule): """No @schedule decorator → (None, None).""" assert argo_without_schedule._get_schedule() == (None, None) @@ -94,70 +113,95 @@ def test_get_schedule_no_decorator_returns_none(argo_without_schedule): def test_get_schedule( make_argo_with_schedule, schedule_value, timezone_value, expected ): + """Verify cron parsing extraction, payload truncation, and timezone options.""" argo = make_argo_with_schedule( schedule_value=schedule_value, timezone_value=timezone_value ) assert argo._get_schedule() == expected -def test_trigger_explanation_no_schedule_does_not_claim_cronworkflow( +@pytest.mark.parametrize( + "has_decorator, schedule_value, internal_schedule, flow_name, expected_substrings, unexpected_substrings", + [ + (False, None, None, None, [], ["CronWorkflow"]), + (True, None, None, None, [], ["CronWorkflow"]), + ( + True, + "0 0 * * ? *", + "0 0 * * ?", + "myflow", + ["CronWorkflow", "myflow"], + [], + ), + ], + ids=[ + "no-schedule-decorator", + "schedule-decorator-resolves-to-none", + "active-schedule-claims-cronworkflow", + ], +) +def test_trigger_explanation_behavior( + make_argo_with_schedule, argo_without_schedule, + has_decorator, + schedule_value, + internal_schedule, + flow_name, + expected_substrings, + unexpected_substrings, ): - """With no schedule, trigger_explanation() must not mention CronWorkflow.""" - argo_without_schedule._schedule = None - argo_without_schedule.triggers = [] - result = argo_without_schedule.trigger_explanation() - assert "CronWorkflow" not in result - + """Verify if trigger explanations correctly list or omit CronWorkflow rules based on state.""" + argo = ( + make_argo_with_schedule(schedule_value=schedule_value) + if has_decorator + else argo_without_schedule + ) -def test_trigger_explanation_schedule_none_does_not_claim_cronworkflow( - make_argo_with_schedule, -): - """ - When @schedule is present but resolved to None, trigger_explanation() - must not claim the workflow triggers via a CronWorkflow. - """ - argo = make_argo_with_schedule(schedule_value=None) - argo._schedule = None # mirrors what _get_schedule() would set + argo._schedule = internal_schedule argo.triggers = [] - result = argo.trigger_explanation() - assert "CronWorkflow" not in result + if flow_name: + argo.name = flow_name - -def test_trigger_explanation_active_schedule_claims_cronworkflow( - make_argo_with_schedule, -): - """When a real schedule is set, trigger_explanation() names the CronWorkflow.""" - argo = make_argo_with_schedule(schedule_value="0 0 * * ? *") - argo._schedule = "0 0 * * ?" - argo.name = "myflow" result = argo.trigger_explanation() - assert result is not None - assert "CronWorkflow" in result - assert "myflow" in result - - -def test_deployed_flow_workflow_template_returns_only_json_payload(): - workflow_template = {"kind": "WorkflowTemplate", "metadata": {"name": "myflow"}} - deployer = types.SimpleNamespace( - name="myflow", - flow_name="MyFlow", - metadata="local@user:test", - additional_info={"workflow_template": workflow_template}, - ) - deployed_flow = ArgoWorkflowsDeployedFlow(deployer) + for substring in expected_substrings: + assert substring in result + for substring in unexpected_substrings: + assert substring not in result - assert deployed_flow.workflow_template == workflow_template +# --------------------------------------------------------------------------- +# Deployed Flow Object Tests +# --------------------------------------------------------------------------- -def test_deployed_flow_workflow_template_returns_none_without_payload(): - deployer = types.SimpleNamespace( - name="myflow", - flow_name="MyFlow", - metadata="local@user:test", - ) +@pytest.mark.parametrize( + "additional_info, expected_template", + [ + ( + { + "workflow_template": { + "kind": "WorkflowTemplate", + "metadata": {"name": "myflow"}, + } + }, + {"kind": "WorkflowTemplate", "metadata": {"name": "myflow"}}, + ), + (None, None), + ], + ids=["with-json-payload", "without-payload"], +) +def test_deployed_flow_workflow_template_resolution(additional_info, expected_template): + """Verify workflow template extraction handles missing or present payloads cleanly.""" + fields = { + "name": "myflow", + "flow_name": "MyFlow", + "metadata": "local@user:test", + } + if additional_info is not None: + fields["additional_info"] = additional_info + + deployer = types.SimpleNamespace(**fields) deployed_flow = ArgoWorkflowsDeployedFlow(deployer) - assert deployed_flow.workflow_template is None + assert deployed_flow.workflow_template == expected_template diff --git a/test/unit/test_artifact_serializer.py b/test/unit/test_artifact_serializer.py index 80b2c15dcb0..6033c64ffca 100644 --- a/test/unit/test_artifact_serializer.py +++ b/test/unit/test_artifact_serializer.py @@ -8,17 +8,19 @@ SerializerStore, ) +# --------------------------------------------------------------------------- +# Registry Isolation Setup & Fixtures +# --------------------------------------------------------------------------- # Snapshot the registry before this module's classes are defined. Module-level -# test serializers (_HighPrioritySerializer, ...) self-register at class -# definition time; the module-scoped fixture below removes them at teardown so -# other test modules see an unpolluted registry. +# test serializers self-register at class definition time; the module-scoped +# fixture below removes them at teardown so other modules see an unpolluted registry. _PRE_IMPORT_SNAPSHOT = dict(SerializerStore._all_serializers) _PRE_IMPORT_ACTIVE_SNAPSHOT = set(SerializerStore._active_serializers) @pytest.fixture(scope="module", autouse=True) -def _restore_serializer_registry(): +def _restore_module_serializer_registry(): yield SerializerStore._all_serializers.clear() SerializerStore._all_serializers.update(_PRE_IMPORT_SNAPSHOT) @@ -27,8 +29,21 @@ def _restore_serializer_registry(): SerializerStore._ordered_cache = None +@pytest.fixture +def clean_store(): + """Fixture to cleanly revert mutations to the SerializerStore registry per test.""" + all_snapshot = dict(SerializerStore._all_serializers) + active_snapshot = set(SerializerStore._active_serializers) + yield + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(all_snapshot) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(active_snapshot) + SerializerStore._ordered_cache = None + + # --------------------------------------------------------------------------- -# Helpers — test serializer subclasses defined inside the test module +# Helpers — Shared Test Serializer Subclasses # --------------------------------------------------------------------------- @@ -105,9 +120,39 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError -# Dispatch is now driven by _active_serializers (post-Phase-6). The metaclass -# only populates _all_serializers; tests that assert against the ordered -# dispatch list must also mark their classes as active. +class _DualFormatSerializer(ArtifactSerializer): + """Toy serializer that implements both formats for str objects.""" + + TYPE = "test_dual_format" + PRIORITY = 40 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_dual_format" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return obj + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_dual_format", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return data + return data[0].decode("utf-8") + + +# Dispatch is driven by _active_serializers. The metaclass populates _all_serializers; +# tests asserting against the ordered dispatch list must also mark these classes active. SerializerStore._active_serializers.add(_HighPrioritySerializer) SerializerStore._active_serializers.add(_LowPrioritySerializer) SerializerStore._active_serializers.add(_SamePrioritySerializer) @@ -115,7 +160,7 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): # --------------------------------------------------------------------------- -# SerializerStore tests +# SerializerStore Tests # --------------------------------------------------------------------------- @@ -132,37 +177,30 @@ def test_base_class_not_registered(): assert None not in SerializerStore._all_serializers -def test_re_registration_overwrites(): +def test_re_registration_overwrites(clean_store): """A second class with the same TYPE overwrites the first (notebook-friendly).""" - original = SerializerStore._all_serializers["test_high"] - try: - class _ReplacementSerializer(ArtifactSerializer): - TYPE = "test_high" # same as _HighPrioritySerializer - PRIORITY = 1 + class _ReplacementSerializer(ArtifactSerializer): + TYPE = "test_high" # same as _HighPrioritySerializer + PRIORITY = 1 - @classmethod - def can_serialize(cls, obj): - return False + @classmethod + def can_serialize(cls, obj): + return False - @classmethod - def can_deserialize(cls, metadata): - return False + @classmethod + def can_deserialize(cls, metadata): + return False - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError - @classmethod - def deserialize( - cls, data, metadata=None, format=SerializationFormat.STORAGE - ): - raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError - assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer - finally: - SerializerStore._all_serializers["test_high"] = original - SerializerStore._ordered_cache = None + assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer def test_priority_ordering(): @@ -172,7 +210,7 @@ def test_priority_ordering(): assert priorities == sorted(priorities) -def test_priority_tie_last_wins(): +def test_priority_tie_last_wins(clean_store): """When PRIORITY is equal, last-registered wins the tie.""" class _TieFirst(ArtifactSerializer): @@ -215,24 +253,19 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - SerializerStore._active_serializers.add(_TieFirst) - SerializerStore._active_serializers.add(_TieSecond) - SerializerStore._ordered_cache = None - ordered = SerializerStore.get_ordered_serializers() - idx_first = ordered.index(_TieFirst) - idx_second = ordered.index(_TieSecond) - # _TieSecond was registered LAST, so it should appear BEFORE _TieFirst. - assert idx_second < idx_first, ( - "Expected last-registered (_TieSecond) to come first; got " - "_TieFirst at index %d, _TieSecond at index %d" % (idx_first, idx_second) - ) - finally: - SerializerStore._all_serializers.pop("test_tie_first", None) - SerializerStore._all_serializers.pop("test_tie_second", None) - SerializerStore._active_serializers.discard(_TieFirst) - SerializerStore._active_serializers.discard(_TieSecond) - SerializerStore._ordered_cache = None + SerializerStore._active_serializers.add(_TieFirst) + SerializerStore._active_serializers.add(_TieSecond) + SerializerStore._ordered_cache = None + + ordered = SerializerStore.get_ordered_serializers() + idx_first = ordered.index(_TieFirst) + idx_second = ordered.index(_TieSecond) + + # _TieSecond was registered LAST, so it should appear BEFORE _TieFirst. + assert idx_second < idx_first, ( + f"Expected last-registered (_TieSecond) to come first; got " + f"_TieFirst at index {idx_first}, _TieSecond at index {idx_second}" + ) def test_deterministic_ordering(): @@ -249,12 +282,39 @@ def test_high_priority_before_low(): assert types.index("test_high") < types.index("test_low") +def test_priority_tie_lexicographic_fallback(): + """When PRIORITY and registration index both tie (simulated), class_path lex-sort wins.""" + + # Within a single process, registration indices are unique. To test the tertiary + # key fallback, we manually simulate identical prefixes on the internal sort key. + class _AClass: + __module__ = "z.module" + __qualname__ = "AClass" + PRIORITY = 100 + + class _BClass: + __module__ = "a.module" + __qualname__ = "BClass" + PRIORITY = 100 + + keys = [ + (_AClass.PRIORITY, 0, f"{_AClass.__module__}.{_AClass.__qualname__}"), + (_BClass.PRIORITY, 0, f"{_BClass.__module__}.{_BClass.__qualname__}"), + ] + sorted_keys = sorted(keys) + + # "a.module.BClass" < "z.module.AClass" lexicographically + assert sorted_keys[0][2] == "a.module.BClass" + assert sorted_keys[1][2] == "z.module.AClass" + + # --------------------------------------------------------------------------- -# SerializationMetadata tests +# SerializationMetadata Tests # --------------------------------------------------------------------------- def test_metadata_fields(): + """Verify attributes are assigned and mapped properly on metadata container.""" meta = SerializationMetadata( obj_type="dict", size=1024, @@ -268,105 +328,78 @@ def test_metadata_fields(): def test_metadata_is_namedtuple(): + """Verify that SerializationMetadata preserves namedtuple traits.""" meta = SerializationMetadata("str", 10, "utf-8", {}) assert isinstance(meta, tuple) assert len(meta) == 4 # --------------------------------------------------------------------------- -# SerializedBlob tests +# SerializedBlob Tests # --------------------------------------------------------------------------- -def test_blob_bytes_auto_detect(): - """bytes value auto-detects as not a reference.""" - blob = SerializedBlob(b"hello") - assert blob.is_reference is False - assert blob.needs_save is True - - -def test_blob_str_auto_detect(): - """str value auto-detects as a reference.""" - blob = SerializedBlob("sha1_key_abc123") - assert blob.is_reference is True - assert blob.needs_save is False - - -def test_blob_explicit_is_reference_override(): - """Explicit is_reference overrides auto-detection.""" - # bytes but marked as reference (edge case) - blob = SerializedBlob(b"data", is_reference=True) - assert blob.is_reference is True - assert blob.needs_save is False - - # str but marked as not a reference (edge case) - blob = SerializedBlob("inline_data", is_reference=False) - assert blob.is_reference is False - assert blob.needs_save is True - - -def test_blob_value_preserved(): - data = b"\x00\x01\x02\x03" +@pytest.mark.parametrize( + "value, kwargs, expected_is_reference, expected_needs_save", + [ + (b"hello", {}, False, True), + ("sha1_key_abc123", {}, True, False), + (b"data", {"is_reference": True}, True, False), + ("inline_data", {"is_reference": False}, False, True), + ], + ids=[ + "bytes-auto-detect-payload", + "str-auto-detect-reference", + "bytes-explicit-reference-override", + "str-explicit-payload-override", + ], +) +def test_blob_reference_detection( + value, kwargs, expected_is_reference, expected_needs_save +): + """Verify SerializedBlob reference tracking and explicit override behaviors.""" + blob = SerializedBlob(value, **kwargs) + assert blob.is_reference is expected_is_reference + assert blob.needs_save is expected_needs_save + + +@pytest.mark.parametrize( + "data", + [b"\x00\x01\x02\x03", "abc123def456"], + ids=["bytes-payload", "str-reference"], +) +def test_blob_value_preserved(data): + """Verify values given to the SerializedBlob initialization are kept identical.""" blob = SerializedBlob(data) assert blob.value is data - key = "abc123def456" - blob = SerializedBlob(key) - assert blob.value is key - -def test_blob_rejects_invalid_types(): +@pytest.mark.parametrize( + "bad_value", + [123, 3.14, None, [], {}], + ids=["int", "float", "none", "list", "dict"], +) +def test_blob_rejects_invalid_types(bad_value): """SerializedBlob must be str or bytes — reject everything else.""" - for bad_value in [123, 3.14, None, [], {}]: - with pytest.raises(TypeError, match="must be str or bytes"): - SerializedBlob(bad_value) + with pytest.raises(TypeError, match="must be str or bytes"): + SerializedBlob(bad_value) # --------------------------------------------------------------------------- -# Wire vs storage format dispatch +# Wire vs Storage Format Dispatch # --------------------------------------------------------------------------- -class _DualFormatSerializer(ArtifactSerializer): - """Toy serializer that implements both formats for str objects.""" - - TYPE = "test_dual_format" - PRIORITY = 40 - - @classmethod - def can_serialize(cls, obj): - return isinstance(obj, str) - - @classmethod - def can_deserialize(cls, metadata): - return metadata.encoding == "test_dual_format" - - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - if format == SerializationFormat.WIRE: - return obj - blob = obj.encode("utf-8") - return ( - [SerializedBlob(blob)], - SerializationMetadata("str", len(blob), "test_dual_format", {}), - ) - - @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - if format == SerializationFormat.WIRE: - return data - return data[0].decode("utf-8") - - def test_format_enum_values(): + """Verify the string mapping properties of the SerializationFormat enum.""" assert SerializationFormat.STORAGE.value == "storage" assert SerializationFormat.WIRE.value == "wire" - # str-backed Enum, so direct string comparison still works. assert SerializationFormat.STORAGE == "storage" assert SerializationFormat.WIRE == "wire" def test_dual_format_storage_roundtrip(): + """Verify the basic storage format workflow serialization loop.""" blobs, meta = _DualFormatSerializer.serialize("hello") assert meta.encoding == "test_dual_format" assert ( @@ -376,6 +409,7 @@ def test_dual_format_storage_roundtrip(): def test_dual_format_wire_roundtrip(): + """Verify the wire format workflow serialization loop.""" wire = _DualFormatSerializer.serialize("hello", format=SerializationFormat.WIRE) assert isinstance(wire, str) assert ( @@ -385,6 +419,7 @@ def test_dual_format_wire_roundtrip(): def test_pickle_serializer_rejects_wire(): + """Verify standard PickleSerializer errors explicitly when utilizing the wire format.""" from metaflow.plugins.datastores.serializers.pickle_serializer import ( PickleSerializer, ) @@ -395,36 +430,12 @@ def test_pickle_serializer_rejects_wire(): PickleSerializer.deserialize("42", format=SerializationFormat.WIRE) -def test_priority_tie_lexicographic_fallback(): - """When PRIORITY and registration index both tie (simulated), class_path lex-sort wins.""" - - # Within a single process, registration indices are always unique, so - # to actually exercise the tertiary key we construct two classes with - # identical (PRIORITY, registration_index) by manipulating the combined - # list passed to the sort logic. We do this by calling the internal - # sort key directly. - class _AClass: - __module__ = "z.module" - __qualname__ = "AClass" - PRIORITY = 100 - - class _BClass: - __module__ = "a.module" - __qualname__ = "BClass" - PRIORITY = 100 - - # Same registration index (simulated): the (priority, -idx) prefix ties. - keys = [ - (_AClass.PRIORITY, 0, "%s.%s" % (_AClass.__module__, _AClass.__qualname__)), - (_BClass.PRIORITY, 0, "%s.%s" % (_BClass.__module__, _BClass.__qualname__)), - ] - sorted_keys = sorted(keys) - # "a.module.BClass" < "z.module.AClass" lexicographically - assert sorted_keys[0][2] == "a.module.BClass" - assert sorted_keys[1][2] == "z.module.AClass" +# --------------------------------------------------------------------------- +# Setup and Lazy Import Infrastructure Tests +# --------------------------------------------------------------------------- -def test_setup_imports_default_is_noop(): +def test_setup_imports_default_is_noop(clean_store): """Default setup_imports should be callable and do nothing.""" class _NoOverride(ArtifactSerializer): @@ -446,21 +457,23 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - result = _NoOverride.setup_imports() - assert result is None - result = _NoOverride.setup_imports(context="anything") - assert result is None - finally: - SerializerStore._all_serializers.pop("test_no_override", None) - SerializerStore._ordered_cache = None + assert _NoOverride.setup_imports() is None + assert _NoOverride.setup_imports(context="anything") is None -def test_lazy_import_happy_path(): - """lazy_import imports the module, stashes on cls at the leaf alias, and returns it.""" +@pytest.mark.parametrize( + "module_name, alias, target_attr", + [ + ("json", None, "json"), + ("json", "j", "j"), + ], + ids=["default-leaf-alias", "custom-alias"], +) +def test_lazy_import_success(clean_store, module_name, alias, target_attr): + """lazy_import imports the module, stashes on cls at the given alias, and returns it.""" - class _LazyOk(ArtifactSerializer): - TYPE = "test_lazy_ok" + class _LazyTarget(ArtifactSerializer): + TYPE = "test_lazy_target" @classmethod def can_serialize(cls, obj): @@ -478,54 +491,30 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - mod = _LazyOk.lazy_import("json") - import json as _json - - assert mod is _json - assert _LazyOk.json is _json - finally: - SerializerStore._all_serializers.pop("test_lazy_ok", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyOk, "json"): - delattr(_LazyOk, "json") + import json as _json + kwargs = {"alias": alias} if alias else {} + mod = _LazyTarget.lazy_import(module_name, **kwargs) -def test_lazy_import_custom_alias(): - """alias= overrides the default leaf-name stash key.""" + assert mod is _json + assert getattr(_LazyTarget, target_attr) is _json - class _LazyAlias(ArtifactSerializer): - TYPE = "test_lazy_alias" - @classmethod - def can_serialize(cls, obj): - return False - - @classmethod - def can_deserialize(cls, metadata): - return False - - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError - - @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError - - try: - _LazyAlias.lazy_import("json", alias="j") - import json as _json - - assert _LazyAlias.j is _json - finally: - SerializerStore._all_serializers.pop("test_lazy_alias", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyAlias, "j"): - delattr(_LazyAlias, "j") - - -def test_lazy_import_rejects_reserved_names(): +@pytest.mark.parametrize( + "bad_alias", + [ + "TYPE", + "PRIORITY", + "serialize", + "deserialize", + "can_serialize", + "can_deserialize", + "setup_imports", + "lazy_import", + "_secret", + ], +) +def test_lazy_import_rejects_reserved_names(clean_store, bad_alias): """Attempting to shadow TYPE / PRIORITY / dispatch methods raises.""" class _LazyReserved(ArtifactSerializer): @@ -547,26 +536,11 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - for bad in [ - "TYPE", - "PRIORITY", - "serialize", - "deserialize", - "can_serialize", - "can_deserialize", - "setup_imports", - "lazy_import", - "_secret", - ]: - with pytest.raises(ValueError, match="reserved or invalid"): - _LazyReserved.lazy_import("json", alias=bad) - finally: - SerializerStore._all_serializers.pop("test_lazy_reserved", None) - SerializerStore._ordered_cache = None - - -def test_lazy_import_rejects_double_assignment(): + with pytest.raises(ValueError, match="reserved or invalid"): + _LazyReserved.lazy_import("json", alias=bad_alias) + + +def test_lazy_import_rejects_double_assignment(clean_store): """Calling lazy_import twice with the same alias on the same cls raises.""" class _LazyDup(ArtifactSerializer): @@ -588,20 +562,12 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - _LazyDup.lazy_import("json") - with pytest.raises(ValueError, match="already set"): - _LazyDup.lazy_import("sys", alias="json") - finally: - SerializerStore._all_serializers.pop("test_lazy_dup", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyDup, "json"): - delattr(_LazyDup, "json") - if hasattr(_LazyDup, "_lazy_imported_names"): - delattr(_LazyDup, "_lazy_imported_names") + _LazyDup.lazy_import("json") + with pytest.raises(ValueError, match="already set"): + _LazyDup.lazy_import("sys", alias="json") -def test_setup_imports_accepts_both_signatures(): +def test_setup_imports_accepts_both_signatures(clean_store): """Bootstrap calls setup_imports correctly whether author writes (cls) or (cls, context=None).""" class _OneArg(ArtifactSerializer): @@ -654,13 +620,8 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): from metaflow.datastore.artifacts.serializer import _call_setup_imports - try: - _call_setup_imports(_OneArg, context=None) - assert _OneArg.called is True + _call_setup_imports(_OneArg, context=None) + assert _OneArg.called is True - _call_setup_imports(_TwoArg, context="some-ctx") - assert _TwoArg.called_with == "some-ctx" - finally: - SerializerStore._all_serializers.pop("test_setup_one_arg", None) - SerializerStore._all_serializers.pop("test_setup_two_arg", None) - SerializerStore._ordered_cache = None + _call_setup_imports(_TwoArg, context="some-ctx") + assert _TwoArg.called_with == "some-ctx" diff --git a/test/unit/test_aws_util.py b/test/unit/test_aws_util.py index a212996c4eb..f28ee1247d7 100644 --- a/test/unit/test_aws_util.py +++ b/test/unit/test_aws_util.py @@ -1,3 +1,4 @@ +import contextlib import pytest from metaflow.plugins.aws.aws_utils import validate_aws_tag @@ -8,32 +9,29 @@ [ ("test", "value", False), ("test-with@chars+ - = ._/", "value@with.chars-+ - = ._/", False), - ( - "a" * 128, - "ok", - False, - ), # <=128 char key should work. - ("a" * 129, "ok", True), # >128 char key should fail. - ( - "ok", - "a" * 256, - False, - ), # <=256 char value should work. - ("ok", "a" * 257, True), # >256 char value should fail. + ("a" * 128, "ok", False), # <=128 char key should work + ("a" * 129, "ok", True), # >128 char key should fail + ("ok", "a" * 256, False), # <=256 char value should work + ("ok", "a" * 257, True), # >256 char value should fail ("aWs:not-allowed", "ok", True), # 'aws:' prefix should not be allowed as key ("ok", "AWS:not-allowed", True), # 'aws:' prefix should not be allowed as value - ( - "ok-aws:", - "middleaWs:not-allowed", - False, - ), # 'aws:' itself is not a restricted pattern + ("ok-aws:", "middleaWs:not-allowed", False), # 'aws:' substring itself is fine + ], + ids=[ + "simple-valid", + "allowed-special-chars", + "key-max-length-128", + "key-length-exceeded-129", + "value-max-length-256", + "value-length-exceeded-257", + "aws-prefix-key-rejected", + "aws-prefix-value-rejected", + "aws-substring-allowed", ], ) -def test_validate_aws_tag(key, value, should_raise): - did_raise = False - try: - validate_aws_tag(key, value) - except Exception as e: - did_raise = True +def test_aws_tag_validation_rules(key, value, should_raise): + """Verify AWS tag validation enforces character sets, length limits, and prefix restrictions.""" + expectation = pytest.raises(Exception) if should_raise else contextlib.nullcontext() - assert did_raise == should_raise + with expectation: + validate_aws_tag(key, value) diff --git a/test/unit/test_card_creator.py b/test/unit/test_card_creator.py index 4043fa5298e..669d6192f28 100644 --- a/test/unit/test_card_creator.py +++ b/test/unit/test_card_creator.py @@ -1,4 +1,8 @@ -"""Regression tests for async card process timeout handling.""" +"""Regression tests for async card process timeout handling. + +These tests ensure that asynchronous card creation processes are reliably managed, +properly timed out, and remain resilient to system clock variations (like NTP syncs). +""" import subprocess @@ -6,12 +10,25 @@ from metaflow.plugins.cards.card_creator import CardCreator, CardProcessManager +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + CARD_UUID = "card-uuid" ASYNC_TIMEOUT = 60 +# Module paths for cleaner mocking +MODULE_TIME = "metaflow.plugins.cards.card_creator.time" +MODULE_SUBPROCESS = "metaflow.plugins.cards.card_creator.subprocess" + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def clean_card_process_registry(): + """Ensure a pristine process registry before and after every test.""" CardProcessManager.async_card_processes.clear() yield CardProcessManager.async_card_processes.clear() @@ -19,6 +36,7 @@ def clean_card_process_registry(): @pytest.fixture def running_process(mocker): + """Provide a mock subprocess that appears to be actively running.""" process = mocker.Mock() process.poll.return_value = None return process @@ -26,35 +44,42 @@ def running_process(mocker): @pytest.fixture def card_creator(): + """Provide a minimally configured CardCreator instance.""" return CardCreator( top_level_options=[], should_save_metadata_lambda=lambda _: (False, {}) ) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + def test_register_card_process_uses_monotonic_timestamp(mocker, running_process): - mocker.patch("metaflow.plugins.cards.card_creator.time.time", return_value=1.0) - mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=42.0 - ) + """Verify that process registration relies on monotonic time, not wall-clock time.""" + mocker.patch(f"{MODULE_TIME}.time", return_value=1.0) + mocker.patch(f"{MODULE_TIME}.monotonic", return_value=42.0) CardProcessManager._register_card_process(CARD_UUID, running_process) - _, started = CardProcessManager._get_card_process(CARD_UUID) + assert started == 42.0 def test_wait_for_async_process_ignores_backward_wall_clock_jump( mocker, card_creator, running_process ): - mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=100.0 - ) + """Ensure a backward jump in the system clock (e.g., NTP sync) does not bypass the timeout logic.""" + mocker.patch(f"{MODULE_TIME}.monotonic", return_value=100.0) CardProcessManager._register_card_process(CARD_UUID, running_process) + running_process.poll.side_effect = [None, 0] - mocker.patch("metaflow.plugins.cards.card_creator.time.time", return_value=-86400.0) mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=700.0 - ) + f"{MODULE_TIME}.time", return_value=-86400.0 + ) # Extreme backward wall-clock jump + mocker.patch( + f"{MODULE_TIME}.monotonic", return_value=700.0 + ) # Monotonic time correctly progresses card_creator._wait_for_async_processes_to_finish( CARD_UUID, async_timeout=ASYNC_TIMEOUT @@ -67,14 +92,14 @@ def test_wait_for_async_process_ignores_backward_wall_clock_jump( def test_wait_for_async_process_leaves_process_within_timeout( mocker, card_creator, running_process ): - mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=100.0 - ) + """Ensure processes are allowed to continue running if the timeout threshold has not been breached.""" + mocker.patch(f"{MODULE_TIME}.monotonic", return_value=100.0) CardProcessManager._register_card_process(CARD_UUID, running_process) + running_process.poll.side_effect = [None, 0] mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=105.0 - ) + f"{MODULE_TIME}.monotonic", return_value=105.0 + ) # Only 5 seconds elapsed card_creator._wait_for_async_processes_to_finish( CARD_UUID, async_timeout=ASYNC_TIMEOUT @@ -85,26 +110,30 @@ def test_wait_for_async_process_leaves_process_within_timeout( def test_async_run_replaces_timed_out_process(mocker, card_creator, running_process): + """Verify that launching a new async command correctly kills and replaces an existing timed-out process.""" replacement_process = mocker.Mock() popen = mocker.patch( - "metaflow.plugins.cards.card_creator.subprocess.Popen", + f"{MODULE_SUBPROCESS}.Popen", return_value=replacement_process, ) - mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=100.0 - ) + + # Register the initial process + mocker.patch(f"{MODULE_TIME}.monotonic", return_value=100.0) CardProcessManager._register_card_process(CARD_UUID, running_process) - mocker.patch("metaflow.plugins.cards.card_creator.time.time", return_value=-86400.0) - mocker.patch( - "metaflow.plugins.cards.card_creator.time.monotonic", return_value=700.0 - ) + # Simulate time passing beyond the timeout threshold + mocker.patch(f"{MODULE_TIME}.time", return_value=-86400.0) + mocker.patch(f"{MODULE_TIME}.monotonic", return_value=700.0) + + # Attempt to run a new command with the same UUID output, failed = card_creator._run_command( ["python", "card.py"], CARD_UUID, {"KEY": "value"}, wait=False ) + # Assertions assert output == b"" assert failed is False + running_process.kill.assert_called_once_with() popen.assert_called_once_with( ["python", "card.py"], @@ -112,5 +141,7 @@ def test_async_run_replaces_timed_out_process(mocker, card_creator, running_proc stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, ) + + # Verify the registry holds the new process process, _ = CardProcessManager._get_card_process(CARD_UUID) assert process is replacement_process diff --git a/test/unit/test_compute_resource_attributes.py b/test/unit/test_compute_resource_attributes.py index adb21c521b5..6d052b4f6dd 100644 --- a/test/unit/test_compute_resource_attributes.py +++ b/test/unit/test_compute_resource_attributes.py @@ -1,77 +1,95 @@ from collections import namedtuple -from metaflow.plugins.aws.aws_utils import compute_resource_attributes - - -MockDeco = namedtuple("MockDeco", ["name", "attributes"]) - - -def test_compute_resource_attributes(): - - # use default if nothing is set - assert compute_resource_attributes([], MockDeco("batch", {}), {"cpu": "1"}) == { - "cpu": "1" - } - - # @batch overrides default and you can use ints as attributes - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": 1}), {"cpu": "2"} - ) == {"cpu": "1"} - # Same but value set as str not int - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": "1"}), {"cpu": "2"} - ) == {"cpu": "1"} +import pytest - # same but use default memory - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": "1"}), {"cpu": "2", "memory": "100"} - ) == {"cpu": "1", "memory": "100"} - - # same but cpu set via @resources - assert compute_resource_attributes( - [], MockDeco("resources", {"cpu": "1"}), {"cpu": "2", "memory": "100"} - ) == {"cpu": "1", "memory": "100"} - - # take largest of @resources and @batch if both are present - assert compute_resource_attributes( - [MockDeco("resources", {"cpu": "2"})], - MockDeco("batch", {"cpu": 1}), - {"cpu": "3"}, - ) == {"cpu": "2.0"} - - # take largest of @resources and @batch if both are present - assert compute_resource_attributes( - [MockDeco("resources", {"cpu": 0.83})], - MockDeco("batch", {"cpu": "0.5"}), - {"cpu": "1"}, - ) == {"cpu": "0.83"} - - -def test_compute_resource_attributes_string(): - """Test string-valued resource attributes""" - - # if default is None and the value is not set in @batch, the value is not included in computed attributes in the end - assert compute_resource_attributes( - [], MockDeco("batch", {}), {"cpu": "1", "instance_type": None} - ) == {"cpu": "1"} +from metaflow.plugins.aws.aws_utils import compute_resource_attributes - # use string value from deco if set (default is None) - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": "p3.xlarge"}), - {"cpu": "1", "instance_type": None}, - ) == {"cpu": "1", "instance_type": "p3.xlarge"} +MockDeco = namedtuple("MockDeco", ["name", "attributes"]) - # use string value from deco if set (default is not None) - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": "p3.xlarge"}), - {"cpu": "1", "instance_type": "p4.xlarge"}, - ) == {"cpu": "1", "instance_type": "p3.xlarge"} - # use string value from defaults if @batch has it set to None - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": None}), - {"cpu": "1", "instance_type": "p4.xlarge"}, - ) == {"cpu": "1", "instance_type": "p4.xlarge"} +@pytest.mark.parametrize( + "decorators, primary_deco, defaults, expected", + [ + # --- Numeric attribute resolution --- + # use default if nothing is set + ([], MockDeco("batch", {}), {"cpu": "1"}, {"cpu": "1"}), + # @batch overrides default and you can use ints as attributes + ([], MockDeco("batch", {"cpu": 1}), {"cpu": "2"}, {"cpu": "1"}), + # Same but value set as str not int + ([], MockDeco("batch", {"cpu": "1"}), {"cpu": "2"}, {"cpu": "1"}), + # same but use default memory + ( + [], + MockDeco("batch", {"cpu": "1"}), + {"cpu": "2", "memory": "100"}, + {"cpu": "1", "memory": "100"}, + ), + # same but cpu set via @resources + ( + [], + MockDeco("resources", {"cpu": "1"}), + {"cpu": "2", "memory": "100"}, + {"cpu": "1", "memory": "100"}, + ), + # --- Max/Largest resource resolution across decorators --- + # take largest of @resources and @batch if both are present + ( + [MockDeco("resources", {"cpu": "2"})], + MockDeco("batch", {"cpu": 1}), + {"cpu": "3"}, + {"cpu": "2.0"}, + ), + # take largest of @resources and @batch if both are present (floats) + ( + [MockDeco("resources", {"cpu": 0.83})], + MockDeco("batch", {"cpu": "0.5"}), + {"cpu": "1"}, + {"cpu": "0.83"}, + ), + # --- String attribute resolution --- + # if default is None and the value is not set in @batch, it is omitted + ( + [], + MockDeco("batch", {}), + {"cpu": "1", "instance_type": None}, + {"cpu": "1"}, + ), + # use string value from deco if set (default is None) + ( + [], + MockDeco("batch", {"instance_type": "p3.xlarge"}), + {"cpu": "1", "instance_type": None}, + {"cpu": "1", "instance_type": "p3.xlarge"}, + ), + # use string value from deco if set (default is not None) + ( + [], + MockDeco("batch", {"instance_type": "p3.xlarge"}), + {"cpu": "1", "instance_type": "p4.xlarge"}, + {"cpu": "1", "instance_type": "p3.xlarge"}, + ), + # use string value from defaults if @batch has it set to None + ( + [], + MockDeco("batch", {"instance_type": None}), + {"cpu": "1", "instance_type": "p4.xlarge"}, + {"cpu": "1", "instance_type": "p4.xlarge"}, + ), + ], + ids=[ + "fallback-to-default", + "batch-int-overrides-default", + "batch-str-overrides-default", + "merge-batch-with-default-memory", + "merge-resources-with-default-memory", + "resolve-max-between-resources-and-batch-int", + "resolve-max-between-resources-and-batch-float", + "omit-none-default-if-missing-in-batch", + "batch-str-overrides-none-default", + "batch-str-overrides-str-default", + "none-in-batch-falls-back-to-default-str", + ], +) +def test_compute_resource_attributes(decorators, primary_deco, defaults, expected): + """Verify that resource attributes are correctly merged and resolved from decorators and defaults.""" + assert compute_resource_attributes(decorators, primary_deco, defaults) == expected diff --git a/test/unit/test_config_value.py b/test/unit/test_config_value.py index d1201c1a9f6..482c2e278ab 100644 --- a/test/unit/test_config_value.py +++ b/test/unit/test_config_value.py @@ -4,92 +4,123 @@ from metaflow.user_configs.config_parameters import ConfigValue +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- -def test_isinstance(): - orig_dict = {"a": 1, "b": 2} - c_value = ConfigValue(orig_dict) - assert isinstance(c_value, dict) - - -def test_todict(): - orig_dict = {"a": 1, "b": 2} - c_value = ConfigValue(orig_dict) - assert c_value.to_dict() == orig_dict - orig_dict = {"a": 1, "b": [1, 2, 3], "c": {"d": 4}, "e": {"f": [{"g": 5}]}} - c_value = ConfigValue(orig_dict) - assert c_value.to_dict() == orig_dict +@pytest.fixture +def simple_dict(): + """Provides a basic, flat dictionary.""" + return {"a": 1, "b": 2} -def test_container_has_config_value(): - orig_dict = { +@pytest.fixture +def nested_dict(): + """Provides a complex, deeply nested dictionary containing lists and tuples.""" + return { "a": 1, "b": [1, 2, 3], "c": {"d": 4}, "e": {"f": [{"g": 5}]}, "h": ({"i": 6},), } - c_value = ConfigValue(orig_dict) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_config_value_is_dict_instance(simple_dict): + """Ensure ConfigValue correctly registers as a subclass/instance of dict.""" + c_value = ConfigValue(simple_dict) + assert isinstance(c_value, dict) + + +def test_todict_returns_original_data(simple_dict, nested_dict): + """Verify that to_dict() reconstructs the original standard dictionary exactly.""" + assert ConfigValue(simple_dict).to_dict() == simple_dict + assert ConfigValue(nested_dict).to_dict() == nested_dict + + +def test_container_preserves_config_value_type_internally(nested_dict): + """Verify that nested dictionaries within lists/tuples are correctly wrapped as ConfigValues.""" + c_value = ConfigValue(nested_dict) + + # Dot-notation access works deeply assert c_value.e.f[0].g == 5 + + # Nested dicts become ConfigValues assert isinstance(c_value.c, ConfigValue) assert isinstance(c_value.e, ConfigValue) + + # Lists remain lists, but their dict elements become ConfigValues assert isinstance(c_value.e.f, list) assert isinstance(c_value.e.f[0], ConfigValue) + + # Tuples remain tuples, but their dict elements become ConfigValues assert isinstance(c_value.h, tuple) assert isinstance(c_value.h[0], ConfigValue) -def test_non_modifiable(): - orig_dict = {"a": 1, "b": 2, "c": 3} - c_value = ConfigValue(orig_dict) - with pytest.raises(TypeError): - c_value["d"] = 4 - with pytest.raises(TypeError): - c_value.popitem() - with pytest.raises(TypeError): - c_value.pop("a", 5) - with pytest.raises(TypeError): - c_value.clear() - with pytest.raises(TypeError): - c_value.update({"e": 6}) - with pytest.raises(TypeError): - c_value.setdefault("f", 7) +@pytest.mark.parametrize( + "operation", + [ + lambda c: c.__setitem__("d", 4), + lambda c: c.popitem(), + lambda c: c.pop("a", 5), + lambda c: c.clear(), + lambda c: c.update({"e": 6}), + lambda c: c.setdefault("f", 7), + lambda c: c.__delitem__("b"), + ], + ids=[ + "setitem", + "popitem", + "pop", + "clear", + "update", + "setdefault", + "delitem", + ], +) +def test_config_value_is_non_modifiable(simple_dict, operation): + """Ensure that all standard dict mutation methods raise a TypeError.""" + # Expand simple_dict slightly so pop/del have valid targets if needed + extended_dict = {**simple_dict, "c": 3} + c_value = ConfigValue(extended_dict) + with pytest.raises(TypeError): - del c_value["b"] + operation(c_value) - assert c_value.to_dict() == orig_dict + # Ensure the underlying structure was not secretly mutated + assert c_value.to_dict() == extended_dict -def test_json_dumpable(): - orig_dict = { - "a": 1, - "b": [1, 2, 3], - "c": {"d": 4}, - "e": {"f": [{"g": 5}]}, - "h": ({"i": 6},), - } - c_value = ConfigValue(orig_dict) - assert json.loads(json.dumps(c_value)) == json.loads(json.dumps(orig_dict)) +def test_json_dumpable(nested_dict): + """Ensure the custom ConfigValue dictionary behaves natively with the json module.""" + c_value = ConfigValue(nested_dict) + # Compare the serialized outputs + assert json.loads(json.dumps(c_value)) == json.loads(json.dumps(nested_dict)) + + +def test_dict_like_iteration_and_access(nested_dict): + """Verify standard dictionary iteration, membership, and length behaviors work.""" + c_value = ConfigValue(nested_dict) -def test_dict_like_behavior(): - orig_dict = { - "a": 1, - "b": [1, 2, 3], - "c": {"d": 4}, - "e": {"f": [{"g": 5}]}, - "h": ({"i": 6},), - } - c_value = ConfigValue(orig_dict) assert "a" in c_value assert "d" not in c_value assert len(c_value) == 5 - assert c_value.keys() == orig_dict.keys() + + assert c_value.keys() == nested_dict.keys() + for k, v in c_value.items(): - assert v == orig_dict[k] + assert v == nested_dict[k] for k in c_value.keys(): - assert k in orig_dict + assert k in nested_dict for v in c_value.values(): - assert v in orig_dict.values() + assert v in nested_dict.values() diff --git a/test/unit/test_content_addressed_store.py b/test/unit/test_content_addressed_store.py index 2042f8e4216..097a5971eed 100644 --- a/test/unit/test_content_addressed_store.py +++ b/test/unit/test_content_addressed_store.py @@ -1,17 +1,25 @@ -from contextlib import contextmanager +import contextlib +from pathlib import Path import pytest from metaflow.datastore.content_addressed_store import ContentAddressedStore from metaflow.datastore.exceptions import DataException +# --------------------------------------------------------------------------- +# Mocks & Helpers +# --------------------------------------------------------------------------- -@contextmanager + +@contextlib.contextmanager def _loaded_bytes(entries): + """Context manager to simulate loading bytes iteratively.""" yield iter(entries) -class _FakeStorageImpl(object): +class _FakeStorageImpl: + """A minimal fake storage implementation to support CAS loading tests.""" + TYPE = "fake" def __init__(self, entries): @@ -27,26 +35,33 @@ def path_split(path): @staticmethod def full_uri(path): - return "fake://" + path + return f"fake://{path}" def load_bytes(self, paths): - expected_paths = [entry[0] for entry in self._entries] - assert set(expected_paths).issubset( - set(paths) - ), "expected paths %s not all in %s" % (expected_paths, paths) + expected_paths = {entry[0] for entry in self._entries} + assert expected_paths.issubset( + paths + ), f"expected paths {expected_paths} not all in {paths}" return _loaded_bytes(self._entries) def _make_store(entries): + """Helper to initialize a ContentAddressedStore with fake storage.""" return ContentAddressedStore("prefix", _FakeStorageImpl(entries)) -def _write_blob_file(tmp_path, name="blob.bin", data=b"not-a-valid-gzip-stream"): +def _write_blob_file(tmp_path: Path, name="blob.bin", data=b"not-a-valid-gzip-stream"): + """Helper to write a temporary binary blob file.""" blob_file = tmp_path / name blob_file.write_bytes(data) return str(blob_file) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "meta, unpack_error, expected_substrings", [ @@ -62,15 +77,20 @@ def _write_blob_file(tmp_path, name="blob.bin", data=b"not-a-valid-gzip-stream") ["Could not unpack artifact", "boom"], ), ], - ids=["missing_version", "unknown_version", "unpack_failure"], + ids=["missing-version", "unknown-version", "unpack-failure"], ) def test_load_blobs_error_message_uses_current_path_key( tmp_path, monkeypatch, meta, unpack_error, expected_substrings ): + """ + Verify that load_blobs error messages accurately reference the *current* + failing path, preventing regression of a bug where a stale outer-loop + variable caused misleading exception messages. + """ stale_key = "aaaaaaaaaa" current_key = "bbbbbbbbbb" - stale_path = "prefix/aa/%s" % stale_key - current_path = "prefix/bb/%s" % current_key + stale_path = f"prefix/aa/{stale_key}" + current_path = f"prefix/bb/{current_key}" file_path = _write_blob_file(tmp_path) store = _make_store([(current_path, file_path, meta)]) @@ -89,7 +109,14 @@ def _raise_unpack_error(_fileobj): list(store.load_blobs([current_key, stale_key])) message = str(exc.value) - assert current_path in message - assert stale_path not in message + + # Assertions + assert ( + current_path in message + ), f"Expected active path {current_path} in error message." + assert ( + stale_path not in message + ), f"Stale path {stale_path} leaked into the error message." + for expected in expected_substrings: assert expected in message diff --git a/test/unit/test_graph_endpoints_fallback.py b/test/unit/test_graph_endpoints_fallback.py index c067839d79d..67e0181c4bc 100644 --- a/test/unit/test_graph_endpoints_fallback.py +++ b/test/unit/test_graph_endpoints_fallback.py @@ -15,48 +15,65 @@ from metaflow.client.core import Run from metaflow.exception import MetaflowNotFound +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +FALLBACK_ENDPOINTS = ("start", "end") + +# --------------------------------------------------------------------------- +# Fixtures & Helpers +# --------------------------------------------------------------------------- + @pytest.fixture def run(): - """Bare Run instance, Run.__init__ skipped to avoid metadata service I/O.""" + """Provide a bare Run instance, skipping Run.__init__ to avoid metadata service I/O.""" return Run.__new__(Run) -def _params_step(mocker, metadata): - """Stand-in for run["_parameters"] with a controlled metadata_dict.""" +def _mock_params_step(mocker, metadata_dict): + """Build a mock task stand-in representing run["_parameters"] with custom metadata.""" params = mocker.MagicMock() - params.task.metadata_dict = metadata + params.task.metadata_dict = metadata_dict return params +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + def test_missing_metadata_falls_back_to_literals(run, mocker): - """Empty metadata returns ('start', 'end').""" - mocker.patch.object(Run, "__getitem__", return_value=_params_step(mocker, {})) - assert run._graph_endpoints == ("start", "end") + """Verify that empty or missing metadata immediately returns the legacy fallback endpoints.""" + mocker.patch.object(Run, "__getitem__", return_value=_mock_params_step(mocker, {})) + + assert run._graph_endpoints == FALLBACK_ENDPOINTS def test_metaflow_not_found_caches_fallback(run, mocker): - """MetaflowNotFound (old run, no _parameters) caches the fallback.""" + """Verify that MetaflowNotFound (e.g., an old run lacking parameters) permanently caches the fallback.""" mocker.patch.object(Run, "__getitem__", side_effect=MetaflowNotFound("_parameters")) - assert run._graph_endpoints == ("start", "end") - assert run._cached_endpoints == ("start", "end") + + assert run._graph_endpoints == FALLBACK_ENDPOINTS + assert run._cached_endpoints == FALLBACK_ENDPOINTS -def test_transient_error_not_cached(run, mocker): - """A transient exception returns the fallback but does NOT cache it.""" +def test_transient_error_does_not_cache_fallback(run, mocker): + """Verify that transient exceptions fallback safely but do NOT poison the cache for future retries.""" mocker.patch.object( Run, "__getitem__", side_effect=[ RuntimeError("transient (e.g., metadata service down)"), - _params_step(mocker, {"start_step": "begin", "end_step": "finish"}), + _mock_params_step(mocker, {"start_step": "begin", "end_step": "finish"}), ], ) - # First call: transient error, fallback returned, not cached. - assert run._graph_endpoints == ("start", "end") + # First attempt: encounters a transient error, falls back, and completely avoids caching + assert run._graph_endpoints == FALLBACK_ENDPOINTS assert not hasattr(run, "_cached_endpoints") - # Second call: succeeds, caches. + # Second attempt: network/service recovers, successfully resolves custom steps, and caches the result assert run._graph_endpoints == ("begin", "finish") assert run._cached_endpoints == ("begin", "finish") diff --git a/test/unit/test_graph_structure.py b/test/unit/test_graph_structure.py index b2e48a37c4e..79ba9d5fe2a 100644 --- a/test/unit/test_graph_structure.py +++ b/test/unit/test_graph_structure.py @@ -1,5 +1,4 @@ -""" -Tests for structural inference of start/end steps in FlowGraph. +"""Tests for structural inference of start/end steps in FlowGraph. Verifies that: - Start step is determined by zero in-edges, end step by zero out-edges @@ -12,12 +11,13 @@ """ import pytest -from metaflow import Config, FlowMutator, FlowSpec, step, Parameter, retry, resources + +from metaflow import Config, FlowMutator, FlowSpec, Parameter, resources, retry, step from metaflow.flowspec import FlowStateItems -from metaflow.lint import linter, LintWarn +from metaflow.lint import LintWarn, linter # --------------------------------------------------------------------------- -# Flow definitions for testing +# Flow Definitions for Testing # --------------------------------------------------------------------------- @@ -122,12 +122,12 @@ def terminus(self): # --------------------------------------------------------------------------- -# Flow classes: single-step flows composed with configs, decorators, mutators +# Composed Flow Variations (Configs, Decorators, Mutators) # --------------------------------------------------------------------------- class _SingleStepWithConfig(FlowSpec): - """Single-step flow with a Config descriptor.""" + """Single-step flow utilizing a Config descriptor descriptor.""" cfg = Config("cfg", default_value={"x": 7}) @@ -137,7 +137,7 @@ def only(self): class _SingleStepWithStackedDecos(FlowSpec): - """Single-step flow with multiple step decorators stacked.""" + """Single-step flow with multiple stacked step decorators.""" @retry(times=3) @resources(cpu=2, memory=1024) @@ -147,7 +147,7 @@ def only(self): class _AddRetryMutator(FlowMutator): - """Adds @retry to every step. Used to verify mutators reach a single-step flow.""" + """Appends @retry to every step to verify mutators target single-step topologies.""" def pre_mutate(self, mutable_flow): for _, s in mutable_flow.steps: @@ -156,7 +156,7 @@ def pre_mutate(self, mutable_flow): @_AddRetryMutator class _SingleStepWithFlowMutator(FlowSpec): - """Single-step flow with a FlowMutator applied at the class level.""" + """Single-step flow with a class-level FlowMutator applied.""" @step(start=True, end=True) def only(self): @@ -179,42 +179,40 @@ def only(self): ], ids=[ "standard", - "custom_linear", - "single_step", + "custom-linear", + "single-step", "branch", "foreach", - "split_start", + "split-start", ], ) def flow_with_endpoints(request): - """Yields (flow_class, expected_start, expected_end) for each topology.""" + """Yield pairs of (flow_class, expected_start, expected_end) across topologies.""" return request.param # --------------------------------------------------------------------------- -# Tests: Structural inference +# Tests: Structural Inference & Node Properties # --------------------------------------------------------------------------- def test_start_end_inference(flow_with_endpoints): + """Verify standard and custom-named start/end step resolution properties.""" flow_cls, expected_start, expected_end = flow_with_endpoints graph = flow_cls._graph assert graph.start_step == expected_start assert graph.end_step == expected_end -# --------------------------------------------------------------------------- -# Tests: Node types -# --------------------------------------------------------------------------- - - def test_standard_flow_types(): + """Verify standard start/end node type assignments.""" graph = StandardFlow._graph assert graph["start"].type == "start" assert graph["end"].type == "end" def test_custom_linear_types(): + """Verify linear flow node type mapping for custom-named steps.""" graph = CustomNamedLinearFlow._graph assert graph["begin"].type == "start" assert graph["middle"].type == "linear" @@ -222,13 +220,13 @@ def test_custom_linear_types(): def test_single_step_type_is_end(): - """Single-step flow: type is 'end' since it's terminal.""" + """Verify single-step flows default type to 'end' because they are terminal.""" graph = SingleStepFlow._graph assert graph["only"].type == "end" def test_branch_entry_is_split(): - """Entry step that splits should keep 'split' type, not be overridden to 'start'.""" + """Ensure an entry step that acts as a split preserves its 'split' type identity.""" graph = CustomNamedBranchFlow._graph assert graph["entry"].type == "split" assert graph["merge"].type == "join" @@ -236,12 +234,13 @@ def test_branch_entry_is_split(): def test_split_start_keeps_split_type(): - """Start step that is also a split must keep 'split' type for lint balance.""" + """Ensure start steps that split preserve the 'split' type to prevent linter balance errors.""" graph = SplitStartFlow._graph assert graph["origin"].type == "split" def test_foreach_entry_keeps_foreach_type(): + """Ensure structural routing correctly flags foreach entry and join steps.""" graph = CustomNamedForeachFlow._graph assert graph["init"].type == "foreach" assert graph["collect"].type == "join" @@ -249,6 +248,7 @@ def test_foreach_entry_keeps_foreach_type(): def test_custom_flow_in_funcs_out_funcs(): + """Verify in_funcs and out_funcs structural list accuracy.""" graph = CustomNamedLinearFlow._graph assert graph["begin"].in_funcs == [] assert graph["begin"].out_funcs == ["middle"] @@ -257,29 +257,33 @@ def test_custom_flow_in_funcs_out_funcs(): # --------------------------------------------------------------------------- -# Tests: output_steps / graph_structure +# Tests: Graph Output Structures & Serialization # --------------------------------------------------------------------------- def test_standard_graph_structure(): + """Verify default graph output maps cleanly to dictionary schemas.""" steps_info, graph_structure = StandardFlow._graph.output_steps() assert graph_structure == ["start", "end"] assert set(steps_info.keys()) == {"start", "end"} def test_custom_linear_graph_structure(): + """Verify custom-named linear flow serializes sequence steps in order.""" steps_info, graph_structure = CustomNamedLinearFlow._graph.output_steps() assert graph_structure == ["begin", "middle", "finish"] assert set(steps_info.keys()) == {"begin", "middle", "finish"} def test_single_step_graph_structure(): + """Verify singular flow graphs produce single-element output arrays.""" steps_info, graph_structure = SingleStepFlow._graph.output_steps() assert graph_structure == ["only"] assert set(steps_info.keys()) == {"only"} def test_branch_graph_structure(): + """Verify branching structures encapsulate interior paths inside endpoints.""" steps_info, graph_structure = CustomNamedBranchFlow._graph.output_steps() assert graph_structure[0] == "entry" assert graph_structure[-1] == "done" @@ -288,7 +292,7 @@ def test_branch_graph_structure(): def test_steps_info_types_match(): - """Steps info type should match the node_to_type mapping.""" + """Ensure step metadata mapping values match underlying node_to_type structures.""" steps_info, _ = CustomNamedLinearFlow._graph.output_steps() assert steps_info["begin"]["type"] == "start" assert steps_info["middle"]["type"] == "linear" @@ -296,13 +300,14 @@ def test_steps_info_types_match(): def test_split_start_type_in_steps_info(): - """When start step is a split, steps_info should show split-static.""" + """Ensure step information outputs map combined 'split-static' labels appropriately.""" steps_info, _ = SplitStartFlow._graph.output_steps() assert steps_info["origin"]["type"] == "split-static" assert steps_info["terminus"]["type"] == "end" def test_steps_info_has_next(): + """Verify 'next' step arrays accurately depict downstream paths.""" steps_info, _ = CustomNamedLinearFlow._graph.output_steps() assert steps_info["begin"]["next"] == ["middle"] assert steps_info["middle"]["next"] == ["finish"] @@ -310,7 +315,7 @@ def test_steps_info_has_next(): # --------------------------------------------------------------------------- -# Tests: Topological sort +# Tests: Topological Sorting # --------------------------------------------------------------------------- @@ -327,14 +332,14 @@ def test_single_step_sort(): def test_branch_sort_order(): - """Start must come first, end must come last.""" + """Verify sort topologies systematically anchor start first and end last.""" graph = CustomNamedBranchFlow._graph assert graph.sorted_nodes[0] == "entry" assert graph.sorted_nodes[-1] == "done" # --------------------------------------------------------------------------- -# Tests: Lint validation +# Tests: Linting and Sanity Checks # --------------------------------------------------------------------------- @@ -350,19 +355,20 @@ def test_branch_sort_order(): ], ids=[ "standard", - "custom_linear", - "single_step", + "custom-linear", + "single-step", "branch", "foreach", - "split_start", + "split-start", ], ) def test_flow_passes_lint(flow_cls): + """Verify that all standard test flow configurations cleanly pass validation checks.""" linter.run_checks(flow_cls._graph) # --------------------------------------------------------------------------- -# Tests: node_info metadata +# Tests: Node Info Meta-attributes # --------------------------------------------------------------------------- @@ -423,19 +429,19 @@ def test_node_info_empty_dict(): def test_node_info_absent_step_in_output_steps(): - """Steps without node_info should have None or {} in output_steps.""" + """Steps lacking explicit node_info attributes return empty structures in serialized paths.""" steps_info, _ = _NodeInfoFlow._graph.output_steps() end_info = steps_info["end"]["node_info"] assert end_info is None or end_info == {} # --------------------------------------------------------------------------- -# Tests: Annotation mechanics +# Tests: Step Annotation Behaviors # --------------------------------------------------------------------------- def test_plain_step_has_no_annotations(): - """Plain @step sets is_start_step=False and is_end_step=False.""" + """Verify standard decorators set default marker positions to False.""" graph = StandardFlow._graph assert graph["start"].is_start_step is False assert graph["start"].is_end_step is False @@ -444,7 +450,7 @@ def test_plain_step_has_no_annotations(): def test_annotated_step_flags(): - """@step(start=True) and @step(end=True) set the flags on the node.""" + """Verify start=True and end=True configure internal targeting flags properly.""" graph = CustomNamedLinearFlow._graph assert graph["begin"].is_start_step is True assert graph["begin"].is_end_step is False @@ -455,7 +461,7 @@ def test_annotated_step_flags(): def test_annotated_single_step(): - """@step(start=True, end=True) single-step flow works.""" + """Verify single-step flows marking both parameters concurrently register successfully.""" graph = SingleStepFlow._graph assert graph["only"].is_start_step is True assert graph["only"].is_end_step is True @@ -464,7 +470,7 @@ def test_annotated_single_step(): def test_source_backed_single_step_with_next_still_fails_lint(): - """Source-backed @step(start=True, end=True) with self.next() still fails lint.""" + """Ensure loops targeting single-step configurations throw lint validation errors.""" class BadSingleStepFlow(FlowSpec): @step(start=True, end=True) @@ -476,7 +482,7 @@ def only(self): def test_mixed_annotated_start_named_end(): - """Annotated start + name-based end fallback.""" + """Verify explicit start definitions combine with automated fallback terminal tracking.""" class MixedFlow(FlowSpec): @step(start=True) @@ -495,7 +501,7 @@ def end(self): def test_backward_compat_name_based(): - """Flow with just 'start'/'end' names still works (no annotations).""" + """Ensure standard name-string fallback mechanics maintain legacy compliance constraints.""" graph = StandardFlow._graph assert graph.start_step == "start" assert graph.end_step == "end" @@ -504,12 +510,12 @@ def test_backward_compat_name_based(): # --------------------------------------------------------------------------- -# Tests: composition with configs, stacked decorators, and flow mutators +# Tests: Advanced Object & Lifecycle Composition # --------------------------------------------------------------------------- def test_single_step_with_config_descriptor_registered(): - """Config descriptor is registered on a single-step flow.""" + """Verify config tracking metrics link successfully to unified pipelines.""" graph = _SingleStepWithConfig._graph assert graph.start_step == "only" == graph.end_step names = {name for name, _ in _SingleStepWithConfig._get_parameters()} @@ -517,21 +523,14 @@ def test_single_step_with_config_descriptor_registered(): def test_single_step_with_multiple_step_decorators(): - """Multiple step decorators stack correctly on a single-step flow.""" + """Verify multi-tier runtime decorator combinations align correctly on simple targets.""" graph = _SingleStepWithStackedDecos._graph deco_names = {deco.name for deco in graph["only"].decorators} assert {"retry", "resources"}.issubset(deco_names) def test_single_step_with_flow_mutator_registered(): - """FlowMutator is registered on a single-step flow at class-definition time. - - pre_mutate only fires when the flow is processed via the CLI layer, so - the decorator it adds won't appear on the graph in a unit test. What we - can verify here is that the mutator syntax is accepted by a single-step - FlowSpec and that it's registered as a flow mutator. End-to-end execution - is covered by the matching integration test. - """ + """Verify class-level mutators link successfully to simple pipeline tracking arrays.""" flow_cls = _SingleStepWithFlowMutator._flow_cls graph = flow_cls._graph assert graph.start_step == "only" == graph.end_step @@ -540,7 +539,7 @@ def test_single_step_with_flow_mutator_registered(): # --------------------------------------------------------------------------- -# Negative-path tests: malformed annotation patterns caught by lint +# Negative-Path Test Factories & Edge Cases # --------------------------------------------------------------------------- @@ -605,7 +604,7 @@ def compute(self): def _lint_warnings(flow_cls): - """Run all lint checks and collect LintWarn exceptions.""" + """Run validation checks and capture thrown warnings.""" graph = flow_cls._graph warnings = [] for rule in linter._checks: @@ -624,9 +623,10 @@ def _lint_warnings(flow_cls): (_make_no_start, "start_step"), (_make_no_end, "end_step"), ], - ids=["multiple_start", "multiple_end", "no_start", "no_end"], + ids=["multiple-start", "multiple-end", "no-start", "no-end"], ) def test_malformed_flow_sets_none(flow_factory, expected_none_field): + """Verify that ambiguous or malformed pipeline layouts clear structural pointers to None.""" graph = flow_factory()._graph assert getattr(graph, expected_none_field) is None @@ -639,12 +639,12 @@ def test_malformed_flow_sets_none(flow_factory, expected_none_field): (_make_no_start, "start"), (_make_no_end, "end"), ], - ids=["multiple_start_lint", "multiple_end_lint", "no_start_lint", "no_end_lint"], + ids=["multiple-start-lint", "multiple-end-lint", "no-start-lint", "no-end-lint"], ) def test_malformed_flow_caught_by_lint(flow_factory, match_pattern): + """Verify lint execution blocks invalid start/end variations with clear error contexts.""" _, warnings = _lint_warnings(flow_factory()) combined = " ".join(warnings).lower() - assert match_pattern in combined, "Expected lint warning about '%s', got: %s" % ( - match_pattern, - warnings, - ) + assert ( + match_pattern in combined + ), f"Expected lint warning about '{match_pattern}', got: {warnings}" diff --git a/test/unit/test_kubernetes.py b/test/unit/test_kubernetes.py index 6ede190d428..2eb8d1448e5 100644 --- a/test/unit/test_kubernetes.py +++ b/test/unit/test_kubernetes.py @@ -7,6 +7,11 @@ ) +# --------------------------------------------------------------------------- +# validate_kube_labels +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "labels", [ @@ -25,7 +30,7 @@ "1234567890" "1234567890" "123" - ) + ) # 63 characters (max valid length) }, { "label": ( @@ -39,8 +44,18 @@ ) }, ], + ids=[ + "none", + "single_label", + "multiple_labels", + "none_value", + "single_char", + "empty_string", + "max_length_63_chars", + "max_length_with_allowed_special_chars", + ], ) -def test_kubernetes_decorator_validate_kube_labels(labels): +def test_validate_kube_labels_accepts_valid_inputs(labels): assert validate_kube_labels(labels) @@ -59,18 +74,31 @@ def test_kubernetes_decorator_validate_kube_labels(labels): "1234567890" "1234567890" "1234" - ) + ) # 64 characters (exceeds max length) }, {"label": "(){}??"}, {"valid": "test", "invalid": "bißchen"}, ], + ids=[ + "ends_with_hyphen", + "starts_with_dot", + "invalid_chars_parentheses", + "exceeds_max_length_64_chars", + "only_invalid_chars", + "invalid_unicode_chars", + ], ) -def test_kubernetes_decorator_validate_kube_labels_fail(labels): +def test_validate_kube_labels_rejects_invalid_inputs(labels): """Fail if label contains invalid characters or is too long""" with pytest.raises(KubernetesException): validate_kube_labels(labels) +# --------------------------------------------------------------------------- +# parse_kube_keyvalue_list +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "items,requires_both,expected", [ @@ -79,8 +107,14 @@ def test_kubernetes_decorator_validate_kube_labels_fail(labels): (["key"], False, {"key": None}), (["key=value", "key2=value2"], True, {"key": "value", "key2": "value2"}), ], + ids=[ + "single_kv_requires_both", + "single_kv_optional_both", + "key_only_optional_both", + "multiple_kv_requires_both", + ], ) -def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): +def test_parse_kube_keyvalue_list_success(items, requires_both, expected): ret = parse_kube_keyvalue_list(items, requires_both) assert ret == expected @@ -91,7 +125,11 @@ def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): (["key=value", "key=value2"], True), (["key"], True), ], + ids=[ + "duplicate_keys_not_allowed", + "missing_value_when_requires_both", + ], ) -def test_kubernetes_parse_keyvalue_list(items, requires_both): +def test_parse_kube_keyvalue_list_raises_exception(items, requires_both): with pytest.raises(KubernetesException): parse_kube_keyvalue_list(items, requires_both) diff --git a/test/unit/test_local_metadata_provider.py b/test/unit/test_local_metadata_provider.py index be080c5f115..fb4de8c9000 100644 --- a/test/unit/test_local_metadata_provider.py +++ b/test/unit/test_local_metadata_provider.py @@ -1,31 +1,36 @@ +import pytest + from metaflow.plugins.metadata_providers.local import LocalMetadataProvider -def test_deduce_run_id_from_meta_dir(): - test_cases = [ - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/start/1/_meta", - "sub_type": "task", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/start/_meta", - "sub_type": "step", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/_meta", - "sub_type": "run", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/_meta", - "sub_type": "flow", - "expected_run_id": None, - }, - ] - for case in test_cases: - actual_run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( - case["meta_path"], case["sub_type"] - ) - assert case["expected_run_id"] == actual_run_id +@pytest.mark.parametrize( + "meta_path, sub_type, expected_run_id", + [ + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/start/1/_meta", + "task", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/start/_meta", + "step", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/_meta", + "run", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/_meta", + "flow", + None, + ), + ], + ids=["task_level", "step_level", "run_level", "flow_level"], +) +def test_deduce_run_id_from_meta_dir(meta_path, sub_type, expected_run_id): + actual_run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( + meta_path, sub_type + ) + assert actual_run_id == expected_run_id diff --git a/test/unit/test_metaflow_version.py b/test/unit/test_metaflow_version.py index f19048742fb..cf67656f9b9 100644 --- a/test/unit/test_metaflow_version.py +++ b/test/unit/test_metaflow_version.py @@ -31,6 +31,19 @@ "v9.2.97-rc.15.post100-git3a13f86-dirty", ), ], + ids=[ + "plain_tag_clean_private", + "plain_tag_clean_public", + "plain_tag_ahead_private", + "plain_tag_ahead_public", + "plain_tag_ahead_dirty_private", + "plain_tag_clean_dirty_private", + "dashed_tag_clean_private", + "dashed_tag_ahead_private", + "dashed_tag_ahead_public", + "dashed_tag_ahead_dirty_private", + "multi_dashed_tag_ahead_dirty_private", + ], ) def test_format_git_describe_parses_known_shapes(git_str, public, expected): assert format_git_describe(git_str, public=public) == expected @@ -48,6 +61,12 @@ def test_format_git_describe_parses_known_shapes(git_str, public, expected): # but guarded symmetrically with the clean branch. "a-b-dirty", ], + ids=[ + "none", + "single_token", + "two_tokens", + "two_tokens_dirty", + ], ) def test_format_git_describe_returns_none_for_unparseable(git_str): assert format_git_describe(git_str) is None @@ -69,6 +88,15 @@ def test_format_git_describe_returns_none_for_unparseable(git_str): # PEP 440 local-version (+…) identifier is stripped alongside. ("v1.0-rc.1.post12-gitabcdef0+ext(foo)", "v1.0-rc.1.post12"), ], + ids=[ + "plain_tag_only", + "plain_tag_git_suffix", + "plain_tag_git_dirty", + "dashed_tag_only", + "dashed_tag_git_suffix", + "dashed_tag_git_dirty", + "pep440_local_version_stripped", + ], ) def test_make_public_version_strips_only_private_suffixes(version_string, expected): assert make_public_version(version_string) == expected diff --git a/test/unit/test_package_suffixes_mutator.py b/test/unit/test_package_suffixes_mutator.py index cd0085df619..96190efdffa 100644 --- a/test/unit/test_package_suffixes_mutator.py +++ b/test/unit/test_package_suffixes_mutator.py @@ -1,76 +1,103 @@ """Tests for the ``package_suffixes`` example FlowMutator.""" import os -import tempfile -from unittest import mock +import pytest from metaflow.packaging_sys import ContentType from metaflow.plugins.package_suffixes_mutator import package_suffixes -def _make_flow(tmpdir): - """Write a minimal flow.py and a few sibling files; return the flow file.""" - flow_file = os.path.join(tmpdir, "flow.py") - with open(flow_file, "w") as f: - f.write("# dummy flow\n") - with open(os.path.join(tmpdir, "config.yaml"), "w") as f: - f.write("a: 1\n") - with open(os.path.join(tmpdir, "data.json"), "w") as f: - f.write("{}\n") - os.makedirs(os.path.join(tmpdir, "sub")) - with open(os.path.join(tmpdir, "sub", "nested.yaml"), "w") as f: - f.write("b: 2\n") - with open(os.path.join(tmpdir, "ignored.txt"), "w") as f: - f.write("ignored\n") - return flow_file - - -def _build_mutator(suffixes, flow_file): - m = package_suffixes.__new__(package_suffixes) - # We only need _flow_cls for inspect.getfile(); patch that directly instead - # of constructing a real class. - m._flow_cls = mock.Mock() - m.init(suffixes) - return m - - -def test_init_list_form(): - m = package_suffixes.__new__(package_suffixes) - m.init([".yaml", "json"]) - assert m._suffixes == (".yaml", ".json") - - -def test_init_string_form(): - m = package_suffixes.__new__(package_suffixes) - m.init(".yaml,json, .txt") - assert m._suffixes == (".yaml", ".json", ".txt") - - -def test_add_to_package_yields_matching_files(): - with tempfile.TemporaryDirectory() as tmp: - flow_file = _make_flow(tmp) - m = _build_mutator([".yaml", ".json"], flow_file) - with mock.patch("inspect.getfile", return_value=flow_file): - results = list(m.add_to_package()) - - # All tuples are USER_CONTENT - assert all(t[2] == ContentType.USER_CONTENT for t in results) - - arcnames = {t[1] for t in results} - # walk() yields arcnames relative to the flow directory (no flow dir - # basename prefix), matching the convention used by _user_code_tuples. - assert "config.yaml" in arcnames - assert "data.json" in arcnames - assert os.path.join("sub", "nested.yaml") in arcnames - # Non-matching files are not included. - assert "ignored.txt" not in arcnames - # flow.py is a .py file — not part of the configured extra suffixes. - assert "flow.py" not in arcnames - - -def test_add_to_package_empty_suffixes_yields_nothing(): - with tempfile.TemporaryDirectory() as tmp: - flow_file = _make_flow(tmp) - m = _build_mutator([], flow_file) - with mock.patch("inspect.getfile", return_value=flow_file): - assert list(m.add_to_package()) == [] +# --------------------------------------------------------------------------- +# Fixtures & Factories +# --------------------------------------------------------------------------- + + +@pytest.fixture +def setup_flow_dir(tmp_path): + """Fixture to write a minimal flow.py and a few sibling files. + Returns the tuple (flow_file_path, flow_directory_path). + """ + flow_file = tmp_path / "flow.py" + flow_file.write_text("# dummy flow\n") + + (tmp_path / "config.yaml").write_text("a: 1\n") + (tmp_path / "data.json").write_text("{}\n") + + sub_dir = tmp_path / "sub" + sub_dir.mkdir() + (sub_dir / "nested.yaml").write_text("b: 2\n") + + (tmp_path / "ignored.txt").write_text("ignored\n") + + return flow_file, tmp_path + + +@pytest.fixture +def make_mutator(mocker): + """Factory fixture to initialize the package_suffixes mutator with mock dependencies.""" + + def _build(suffixes): + m = package_suffixes.__new__(package_suffixes) + # We only need _flow_cls for inspect.getfile(); mock that directly instead + # of constructing a real class. + m._flow_cls = mocker.Mock() + m.init(suffixes) + return m + + return _build + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "input_suffixes, expected", + [ + ([".yaml", "json"], (".yaml", ".json")), + (".yaml,json, .txt", (".yaml", ".json", ".txt")), + ], + ids=["list_format", "string_format"], +) +def test_init_normalizes_suffixes(make_mutator, input_suffixes, expected): + """Test that init correctly parses and normalizes list and string suffix inputs.""" + m = make_mutator(input_suffixes) + assert m._suffixes == expected + + +def test_add_to_package_yields_matching_files(mocker, setup_flow_dir, make_mutator): + """Test that add_to_package recursively finds files matching the configured suffixes.""" + flow_file, flow_dir = setup_flow_dir + m = make_mutator([".yaml", ".json"]) + + mocker.patch("inspect.getfile", return_value=str(flow_file)) + results = list(m.add_to_package()) + + # All tuples are USER_CONTENT + assert all(t[2] == ContentType.USER_CONTENT for t in results) + + arcnames = {t[1] for t in results} + + # walk() yields arcnames relative to the flow directory (no flow dir + # basename prefix), matching the convention used by _user_code_tuples. + assert "config.yaml" in arcnames + assert "data.json" in arcnames + assert os.path.join("sub", "nested.yaml") in arcnames + + # Non-matching files are not included. + assert "ignored.txt" not in arcnames + # flow.py is a .py file — not part of the configured extra suffixes. + assert "flow.py" not in arcnames + + +def test_add_to_package_yields_nothing_when_suffixes_empty( + mocker, setup_flow_dir, make_mutator +): + """Test that an empty suffix list results in no files being yielded.""" + flow_file, flow_dir = setup_flow_dir + m = make_mutator([]) + + mocker.patch("inspect.getfile", return_value=str(flow_file)) + + assert list(m.add_to_package()) == [] diff --git a/test/unit/test_packaging_utils.py b/test/unit/test_packaging_utils.py index f407f74ced6..c0802161c9d 100644 --- a/test/unit/test_packaging_utils.py +++ b/test/unit/test_packaging_utils.py @@ -1,38 +1,57 @@ -import os -import tempfile +import pytest from metaflow.packaging_sys.utils import walk -def _make_file(path): - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w") as f: - f.write("") - - -def test_walk_includes_files_when_hidden_dir_is_ancestor_of_root(): - """Regression: hidden ancestor dirs must not exclude user files.""" - with tempfile.TemporaryDirectory() as base: - root = os.path.join(base, ".hidden_parent", "project", "flows") - os.makedirs(root) - _make_file(os.path.join(root, "hello_flow.py")) - - results = {rel for _, rel in walk(root, exclude_hidden=True)} +@pytest.mark.parametrize( + "root_parts, files_to_create, expected_included, expected_excluded", + [ + ( + # Root is under a hidden ancestor. + # Regression: hidden ancestor dirs must not exclude user files. + [".hidden_parent", "project", "flows"], + ["hello_flow.py"], + ["hello_flow.py"], + [], + ), + ( + # Root contains both visible files and hidden directories. + # Hidden directories *under* root should be excluded. + [".hidden_parent", "project"], + ["visible.py", ".secret/hidden.py"], + ["visible.py"], + ["hidden.py", ".secret"], + ), + ], + ids=[ + "hidden_ancestor_allows_visible_children", + "hidden_descendant_is_excluded", + ], +) +def test_walk_handles_hidden_directories( + tmp_path, root_parts, files_to_create, expected_included, expected_excluded +): + # Setup: Create dynamic root structure + root_dir = tmp_path.joinpath(*root_parts) + root_dir.mkdir(parents=True, exist_ok=True) + + # Setup: Create files within the root + for file_path in files_to_create: + full_path = root_dir / file_path + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.write_text("") + + # Act + # walk yields (absolute_path, relative_path) tuples + results = {rel for _, rel in walk(str(root_dir), exclude_hidden=True)} + + # Assert + for expected in expected_included: assert any( - "hello_flow.py" in r for r in results - ), f"Expected hello_flow.py in walk results, got: {results}" - - -def test_walk_excludes_hidden_dirs_under_root(): - """Hidden directories *under* root should still be excluded.""" - with tempfile.TemporaryDirectory() as base: - root = os.path.join(base, ".hidden_parent", "project") - os.makedirs(root) - _make_file(os.path.join(root, "visible.py")) - _make_file(os.path.join(root, ".secret", "hidden.py")) + expected in r for r in results + ), f"Expected {expected} in walk results, got: {results}" - results = {rel for _, rel in walk(root, exclude_hidden=True)} - assert any("visible.py" in r for r in results) + for excluded in expected_excluded: assert not any( - "hidden.py" in r for r in results - ), f"hidden.py should be excluded, got: {results}" + excluded in r for r in results + ), f"Expected {excluded} to be excluded, got: {results}" diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py index d08206a88dd..7bbd5650e18 100644 --- a/test/unit/test_pickle_serializer.py +++ b/test/unit/test_pickle_serializer.py @@ -1,5 +1,4 @@ import pickle - import pytest from metaflow.datastore.artifacts.serializer import ( @@ -27,21 +26,18 @@ def test_registered_in_store(): assert SerializerStore._all_serializers["pickle"] is PickleSerializer -def test_last_in_ordering(): +def test_last_in_ordering(monkeypatch): """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" - # Dispatch is driven by _active_serializers (post-Phase-6). Ensure Pickle - # is active for this test regardless of whether bootstrap() has already - # run in the current process. - was_active = PickleSerializer in SerializerStore._active_serializers - SerializerStore._active_serializers.add(PickleSerializer) - SerializerStore._ordered_cache = None - try: - ordered = SerializerStore.get_ordered_serializers() - assert ordered[-1] is PickleSerializer - finally: - if not was_active: - SerializerStore._active_serializers.discard(PickleSerializer) - SerializerStore._ordered_cache = None + # Use monkeypatch to safely append PickleSerializer to active state and clear cache. + # This ensures automatic cleanup after the test runs. + updated_serializers = SerializerStore._active_serializers.copy() + updated_serializers.add(PickleSerializer) + + monkeypatch.setattr(SerializerStore, "_active_serializers", updated_serializers) + monkeypatch.setattr(SerializerStore, "_ordered_cache", None) + + ordered = SerializerStore.get_ordered_serializers() + assert ordered[-1] is PickleSerializer # --------------------------------------------------------------------------- @@ -90,6 +86,7 @@ def test_can_serialize_any_object(obj): @pytest.mark.parametrize( "encoding", ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"], + ids=["pickle-v2", "pickle-v4", "gzip-pickle-v2", "gzip-pickle-v4"], ) def test_can_deserialize_valid_encodings(encoding): meta = SerializationMetadata("object", 100, encoding, {}) @@ -99,6 +96,7 @@ def test_can_deserialize_valid_encodings(encoding): @pytest.mark.parametrize( "encoding", ["json", "iotype:text", "msgpack", "unknown", ""], + ids=["json", "iotype-text", "msgpack", "unknown", "empty-string"], ) def test_cannot_deserialize_unknown_encodings(encoding): meta = SerializationMetadata("object", 100, encoding, {}) @@ -181,6 +179,8 @@ def test_round_trip(obj): class _CustomObj: + """Helper class to verify custom object serialization properties.""" + def __init__(self, x): self.x = x diff --git a/test/unit/test_pypi_parsers.py b/test/unit/test_pypi_parsers.py index 6867bc9b14d..56758f600ab 100644 --- a/test/unit/test_pypi_parsers.py +++ b/test/unit/test_pypi_parsers.py @@ -1,3 +1,5 @@ +import pytest + from metaflow.plugins.pypi.parsers import ( requirements_txt_parser, conda_environment_yml_parser, @@ -5,6 +7,10 @@ ParserValueError, ) +# --------------------------------------------------------------------------- +# Test Data Constants +# --------------------------------------------------------------------------- + VALID_REQ = """ dummypkg==1.1.1 anotherpkg==0.0.1 @@ -77,7 +83,12 @@ """ -def test_yml_parser(): +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_conda_environment_yml_parser_extracts_packages_and_python(): result = conda_environment_yml_parser(VALID_YML) assert result["python"] == "3.10.*" @@ -85,30 +96,25 @@ def test_yml_parser(): assert result["packages"]["anotherpkg"] == "0.0.1" -def test_requirements_parser(): - # success case - result = requirements_txt_parser(VALID_REQ) +@pytest.mark.parametrize( + "content", + [VALID_REQ, VALID_RYE_LOCK], + ids=["standard_requirements", "rye_lockfile"], +) +def test_requirements_txt_parser_extracts_packages(content): + result = requirements_txt_parser(content) - assert result["python"] == None + assert result["python"] is None assert result["packages"]["dummypkg"] == "1.1.1" assert result["packages"]["anotherpkg"] == "0.0.1" - # Rye lockfile success case - result = requirements_txt_parser(VALID_RYE_LOCK) - - assert result["python"] == None - assert result["packages"]["dummypkg"] == "1.1.1" - assert result["packages"]["anotherpkg"] == "0.0.1" - # failures - try: +def test_requirements_txt_parser_rejects_invalid_flags(): + with pytest.raises(ParserValueError): requirements_txt_parser(INVALID_REQ) - raise Exception("parsing invalid content did not raise an expected exception.") - except ParserValueError: - pass # expected to raise -def test_toml_parser(): +def test_pyproject_toml_parser_extracts_packages_and_python(): result = pyproject_toml_parser(VALID_TOML) assert result["python"] == ">=3.8" diff --git a/test/unit/test_remove_decorator.py b/test/unit/test_remove_decorator.py index 02a39fe4667..9b53c6e08ad 100644 --- a/test/unit/test_remove_decorator.py +++ b/test/unit/test_remove_decorator.py @@ -1,3 +1,4 @@ +import pytest from types import SimpleNamespace from metaflow.user_decorators.mutable_step import MutableStep @@ -12,7 +13,9 @@ def get_args_kwargs(self): return [], {} -def test_remove_decorator(monkeypatch): +@pytest.fixture +def mock_step(): + """Provides a MutableStep with pre-configured decorators.""" step = MutableStep.__new__(MutableStep) step._pre_mutate = True step._inserted_by = "test" @@ -25,14 +28,28 @@ def test_remove_decorator(monkeypatch): wrappers=[], config_decorators=[], ) + return step - monkeypatch.setattr( - UserStepDecoratorBase, - "get_decorator_by_name", - staticmethod(lambda _: object), + +def test_remove_decorator_success(mock_step, mocker): + # Mocking the lookup to identify our DummyStepDecorator + mocker.patch.object( + UserStepDecoratorBase, "get_decorator_by_name", return_value=DummyStepDecorator ) - removed = step.remove_decorator("retry") + removed = mock_step.remove_decorator("retry") - assert step._my_step.decorators == [] + assert mock_step._my_step.decorators == [] assert removed is True + + +def test_remove_decorator_not_found(mock_step, mocker): + # Mocking a scenario where the decorator is not found + mocker.patch.object( + UserStepDecoratorBase, "get_decorator_by_name", return_value=None + ) + + removed = mock_step.remove_decorator("nonexistent") + + assert len(mock_step._my_step.decorators) == 2 + assert removed is False diff --git a/test/unit/test_s3_empty_input.py b/test/unit/test_s3_empty_input.py index 14b89722071..8344a185948 100644 --- a/test/unit/test_s3_empty_input.py +++ b/test/unit/test_s3_empty_input.py @@ -14,16 +14,17 @@ Fix: return early from both methods when the input list is empty. """ -import os -import tempfile -from unittest.mock import MagicMock, patch - import pytest from metaflow.plugins.datatools.s3.s3 import S3 +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + -def _make_s3(tmp_path): +@pytest.fixture +def s3_instance(tmp_path): """Create a minimal S3 instance without a real S3 connection.""" s3 = object.__new__(S3) s3._tmproot = str(tmp_path) @@ -36,111 +37,128 @@ def _make_s3(tmp_path): return s3 -class TestPutManyFilesEmptyInput: - def test_returns_empty_list(self, tmp_path): - """_put_many_files with no items returns [] without calling s3op.""" - s3 = _make_s3(tmp_path) - result = s3._put_many_files(iter([]), overwrite=True) - assert result == [] - - def test_does_not_call_s3op(self, tmp_path): - """s3op subprocess must not be spawned when there is nothing to upload.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - s3._put_many_files(iter([]), overwrite=True) - mock_op.assert_not_called() - - def test_error_handler_would_crash_without_guard(self): - """ - Regression: document the latent IndexError in the error handler. - - Before the guard was added, _put_many_files called s3op even with - empty url_info. s3op exits 0 but writes "Uploading 0 files.." to - stderr. The `if stderr:` block then executed: - url_info[0][2]["key"] # IndexError — url_info is [] - This test proves that the error handler is unsafe with an empty list, - which is why the early-return guard is necessary. - """ - url_info = [] # what packing_iter() produces on a CAS cache hit - stderr = b"Uploading 0 files..\nUploaded 0 files." # s3op info output - with pytest.raises(IndexError): - # Reproduce the exact expression from the old error handler: - _ = url_info[0][2]["key"] - - def test_non_empty_input_still_calls_s3op(self, tmp_path): - """Non-empty input should still go through s3op (not short-circuit).""" - s3 = _make_s3(tmp_path) - - with tempfile.NamedTemporaryFile( - dir=str(tmp_path), delete=False, mode="wb" - ) as tf: - tf.write(b"data") - local = tf.name - - try: - with patch.object( - s3, - "_s3op_with_retries", - return_value=([b"s3://bucket/key " + tf.name.encode() + b" "], b"", 0), - ) as mock_op: - - def _gen(): - yield local, "s3://bucket/key", {"key": "mykey"} - - s3._put_many_files(_gen(), overwrite=True) - mock_op.assert_called_once() - finally: - os.unlink(local) - - -class TestReadManyFilesEmptyInput: - def test_returns_empty_generator(self, tmp_path): - """_read_many_files with no prefixes yields nothing.""" - s3 = _make_s3(tmp_path) - result = list(s3._read_many_files("get", iter([]))) - assert result == [] - - def test_does_not_call_s3op(self, tmp_path): - """s3op subprocess must not be spawned when there is nothing to read.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - list(s3._read_many_files("get", iter([]))) - mock_op.assert_not_called() - - def test_error_handler_would_crash_without_guard(self): - """ - Regression: document the latent IndexError in the error handler. - - Before the guard was added, _read_many_files called s3op even with - empty prefixes_and_ranges. s3op exits 0 but writes - "Info_downloading 0 files.." to stderr. The `if stderr:` block then - executed: - prefixes_and_ranges[0] # IndexError — list is [] - This test proves that the error handler is unsafe with an empty list, - which is why the early-return guard is necessary. - """ - prefixes_and_ranges = [] # empty input materialised as a list - stderr = b"Info_downloading 0 files..\nInfo_downloaded 0 files." - with pytest.raises(IndexError): - # Reproduce the exact expression from the old error handler: - _ = prefixes_and_ranges[0] - - def test_empty_input_for_all_ops(self, tmp_path): - """All s3op modes (get, list, info, put) are safe with empty input.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - for op in ("get", "list", "info", "put"): - result = list(s3._read_many_files(op, iter([]))) - assert result == [], f"op={op} should yield nothing for empty input" - mock_op.assert_not_called() - - def test_non_empty_input_still_calls_s3op(self, tmp_path): - """Non-empty input should still go through s3op (not short-circuit).""" - s3 = _make_s3(tmp_path) - with patch.object( - s3, - "_s3op_with_retries", - return_value=([b"s3://b/k /tmp/f 100"], b"", 0), - ) as mock_op: - list(s3._read_many_files("get", iter([("s3://b/k", None)]))) - mock_op.assert_called_once() +# --------------------------------------------------------------------------- +# Test Functions: Put Many Files +# --------------------------------------------------------------------------- + + +def test_put_many_files_empty_input_returns_empty_list(s3_instance): + """_put_many_files with no items returns [] without calling s3op.""" + result = s3_instance._put_many_files(iter([]), overwrite=True) + assert result == [] + + +def test_put_many_files_empty_input_does_not_call_s3op(mocker, s3_instance): + """s3op subprocess must not be spawned when there is nothing to upload.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + s3_instance._put_many_files(iter([]), overwrite=True) + + mock_op.assert_not_called() + + +def test_put_many_files_error_handler_would_crash_without_guard(): + """ + Regression: document the latent IndexError in the error handler. + + Before the guard was added, _put_many_files called s3op even with + empty url_info. s3op exits 0 but writes "Uploading 0 files.." to + stderr. The `if stderr:` block then executed: + url_info[0][2]["key"] # IndexError — url_info is [] + This test proves that the error handler is unsafe with an empty list, + which is why the early-return guard is necessary. + """ + url_info = [] # what packing_iter() produces on a CAS cache hit + stderr = b"Uploading 0 files..\nUploaded 0 files." # s3op info output + + with pytest.raises(IndexError): + # Reproduce the exact expression from the old error handler: + _ = url_info[0][2]["key"] + + +def test_put_many_files_non_empty_input_calls_s3op(mocker, s3_instance, tmp_path): + """Non-empty input should still go through s3op (not short-circuit).""" + # Use tmp_path instead of NamedTemporaryFile to avoid manual cleanup + local_file = tmp_path / "test_data.txt" + local_file.write_bytes(b"data") + + mock_op = mocker.patch.object( + s3_instance, + "_s3op_with_retries", + return_value=([b"s3://bucket/key " + str(local_file).encode() + b" "], b"", 0), + ) + + def _gen(): + yield str(local_file), "s3://bucket/key", {"key": "mykey"} + + s3_instance._put_many_files(_gen(), overwrite=True) + + mock_op.assert_called_once() + + +# --------------------------------------------------------------------------- +# Test Functions: Read Many Files +# --------------------------------------------------------------------------- + + +def test_read_many_files_empty_input_returns_empty_generator(s3_instance): + """_read_many_files with no prefixes yields nothing.""" + result = list(s3_instance._read_many_files("get", iter([]))) + assert result == [] + + +def test_read_many_files_empty_input_does_not_call_s3op(mocker, s3_instance): + """s3op subprocess must not be spawned when there is nothing to read.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + list(s3_instance._read_many_files("get", iter([]))) + + mock_op.assert_not_called() + + +def test_read_many_files_error_handler_would_crash_without_guard(): + """ + Regression: document the latent IndexError in the error handler. + + Before the guard was added, _read_many_files called s3op even with + empty prefixes_and_ranges. s3op exits 0 but writes + "Info_downloading 0 files.." to stderr. The `if stderr:` block then + executed: + prefixes_and_ranges[0] # IndexError — list is [] + This test proves that the error handler is unsafe with an empty list, + which is why the early-return guard is necessary. + """ + prefixes_and_ranges = [] # empty input materialised as a list + stderr = b"Info_downloading 0 files..\nInfo_downloaded 0 files." + + with pytest.raises(IndexError): + # Reproduce the exact expression from the old error handler: + _ = prefixes_and_ranges[0] + + +@pytest.mark.parametrize( + "s3_op", + ["get", "list", "info", "put"], + ids=["op_get", "op_list", "op_info", "op_put"], +) +def test_read_many_files_empty_input_safe_for_all_ops(mocker, s3_instance, s3_op): + """All s3op modes (get, list, info, put) are safe with empty input.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + result = list(s3_instance._read_many_files(s3_op, iter([]))) + + assert result == [], f"op={s3_op} should yield nothing for empty input" + mock_op.assert_not_called() + + +def test_read_many_files_non_empty_input_calls_s3op(mocker, s3_instance): + """Non-empty input should still go through s3op (not short-circuit).""" + mock_op = mocker.patch.object( + s3_instance, + "_s3op_with_retries", + return_value=([b"s3://b/k /tmp/f 100"], b"", 0), + ) + + list(s3_instance._read_many_files("get", iter([("s3://b/k", None)]))) + + mock_op.assert_called_once() diff --git a/test/unit/test_s3_storage.py b/test/unit/test_s3_storage.py index 6f762a9d87b..c88da4ab2aa 100644 --- a/test/unit/test_s3_storage.py +++ b/test/unit/test_s3_storage.py @@ -1,3 +1,5 @@ +"""Tests for S3 datastore byte saving and metadata mapping.""" + from io import BytesIO import pytest @@ -6,7 +8,12 @@ from metaflow.plugins.datastores.s3_storage import S3Storage -def _make_storage(): +# --- Fixtures --- + + +@pytest.fixture +def s3_storage(): + """Fixture providing a minimal, uninitialized S3Storage instance.""" storage = object.__new__(S3Storage) storage.datastore_root = "s3://unit-test-root" storage.s3_client = object() @@ -33,12 +40,21 @@ def patched_s3(mocker): return s3 -def test_save_bytes_put_many_preserves_metadata_slot(patched_s3, test_items): - storage = _make_storage() - storage.save_bytes(iter(test_items), overwrite=True, len_hint=11) +# --- Tests --- + + +def test_save_bytes_put_many_preserves_metadata_slot( + s3_storage, patched_s3, test_items +): + """ + When len_hint > 10, save_bytes optimizes by delegating to put_many. + This test ensures the metadata payload survives the batch translation. + """ + s3_storage.save_bytes(iter(test_items), overwrite=True, len_hint=11) put_objs, overwrite = patched_s3.put_many.call_args[0] put_objs = list(put_objs) + assert overwrite is True assert put_objs[0].encryption is None assert put_objs[0].metadata == {"k": "v"} @@ -46,13 +62,20 @@ def test_save_bytes_put_many_preserves_metadata_slot(patched_s3, test_items): assert put_objs[1].metadata is None -def test_save_bytes_sequential_preserves_metadata(patched_s3, test_items): - storage = _make_storage() - storage.save_bytes(iter(test_items), overwrite=False, len_hint=2) +def test_save_bytes_sequential_preserves_metadata(s3_storage, patched_s3, test_items): + """ + When len_hint <= 10, save_bytes falls back to sequential put() calls. + This test ensures the metadata kwargs are passed correctly per item. + """ + s3_storage.save_bytes(iter(test_items), overwrite=False, len_hint=2) put_calls = patched_s3.put.call_args_list assert len(put_calls) == 2 + + # Check first item ("a") assert put_calls[0][0][0] == "a" assert put_calls[0][1]["metadata"] == {"k": "v"} + + # Check second item ("b") assert put_calls[1][0][0] == "b" assert put_calls[1][1]["metadata"] is None diff --git a/test/unit/test_secrets_decorator.py b/test/unit/test_secrets_decorator.py index 3b19644e4c4..a689c1c944a 100644 --- a/test/unit/test_secrets_decorator.py +++ b/test/unit/test_secrets_decorator.py @@ -1,21 +1,23 @@ +"""Tests for parsing, validating, and resolving SecretSpecs and environment variables.""" + import os -import time import pytest -from metaflow.exception import MetaflowException import metaflow.metaflow_config +from metaflow.exception import MetaflowException from metaflow.plugins.secrets.secrets_decorator import ( SecretSpec, + get_secrets_backend_provider, + validate_env_vars, validate_env_vars_across_secrets, validate_env_vars_vs_existing_env, - validate_env_vars, - get_secrets_backend_provider, ) @pytest.fixture def default_secrets_backend(mocker): + """Fixture to establish a consistent default backend type for tests.""" mocker.patch( "metaflow.metaflow_config.DEFAULT_SECRETS_BACKEND_TYPE", "some-default-backend-type", @@ -23,74 +25,94 @@ def default_secrets_backend(mocker): def test_missing_default_secrets_backend_type(mocker): + """Test that missing a default backend type raises an exception when parsing implicit strings.""" mocker.patch("metaflow.metaflow_config.DEFAULT_SECRETS_BACKEND_TYPE", None) assert metaflow.metaflow_config.DEFAULT_SECRETS_BACKEND_TYPE is None + with pytest.raises(MetaflowException): SecretSpec.secret_spec_from_str("secret_id", None) def test_secret_spec_from_str_explicit_type(default_secrets_backend): - assert SecretSpec.secret_spec_from_str("explicit-type.the_id", None).to_json() == { + """Test parsing a dot-separated string into a specific backend type and id.""" + spec = SecretSpec.secret_spec_from_str("explicit-type.the_id", None) + expected = { "options": {}, "secret_id": "the_id", "secrets_backend_type": "explicit-type", "role": None, } + assert spec.to_json() == expected def test_secret_spec_from_str_implicit_type(default_secrets_backend): - assert SecretSpec.secret_spec_from_str("the_id", None).to_json() == { + """Test parsing a raw id string falls back to the default backend type.""" + spec = SecretSpec.secret_spec_from_str("the_id", None) + expected = { "options": {}, "secret_id": "the_id", "secrets_backend_type": "some-default-backend-type", "role": None, } + assert spec.to_json() == expected def test_secret_spec_from_dict_explicit_type_no_options(default_secrets_backend): - assert SecretSpec.secret_spec_from_dict( + """Test parsing a dictionary specifying a backend type.""" + spec = SecretSpec.secret_spec_from_dict( {"type": "explicit-type", "id": "the_id"}, None - ).to_json() == { + ) + expected = { "options": {}, "secret_id": "the_id", "secrets_backend_type": "explicit-type", "role": None, } + assert spec.to_json() == expected def test_secret_spec_from_dict_implicit_type_with_options(default_secrets_backend): - assert SecretSpec.secret_spec_from_dict( + """Test parsing a dictionary inherits default type and retains options.""" + spec = SecretSpec.secret_spec_from_dict( {"id": "the_id", "options": {"a": "b"}}, None - ).to_json() == { + ) + expected = { "options": {"a": "b"}, "secret_id": "the_id", "secrets_backend_type": "some-default-backend-type", "role": None, } + assert spec.to_json() == expected def test_role_resolution_source_level_wins(default_secrets_backend): - assert SecretSpec.secret_spec_from_dict( + """Test that a role specified in the source dictionary overrides the decorator-level role.""" + spec = SecretSpec.secret_spec_from_dict( {"id": "the_id", "role": "source-level-role"}, "decorator-level-role", - ).to_json() == { + ) + expected = { "secret_id": "the_id", "secrets_backend_type": "some-default-backend-type", "role": "source-level-role", "options": {}, } + assert spec.to_json() == expected def test_role_resolution_falls_back_to_decorator_level(default_secrets_backend): - assert SecretSpec.secret_spec_from_dict( + """Test that omitting the role in the source dict falls back to the decorator-level role.""" + spec = SecretSpec.secret_spec_from_dict( {"id": "the_id"}, role="decorator-level-role", - ).to_json() == { + ) + expected = { "secret_id": "the_id", "secrets_backend_type": "some-default-backend-type", "role": "decorator-level-role", "options": {}, } + assert spec.to_json() == expected @pytest.mark.parametrize( @@ -104,16 +126,19 @@ def test_role_resolution_falls_back_to_decorator_level(default_secrets_backend): ids=["bad_type", "bad_id", "bad_options", "bad_role"], ) def test_secret_spec_from_dict_rejects_invalid(spec_dict, default_secrets_backend): + """Test that malformed dictionaries raise exceptions during parsing.""" with pytest.raises(MetaflowException): SecretSpec.secret_spec_from_dict(spec_dict, None) def test_secrets_provider_resolution_unknown_backend(): + """Test that resolving an unregistered backend type raises an exception.""" with pytest.raises(MetaflowException): - get_secrets_backend_provider(str(time.time())) + get_secrets_backend_provider("non_existent_backend_type_123") def test_validate_env_vars_across_secrets_rejects_overlap(): + """Test that multiple secret specs returning the same env var key raise a collision exception.""" all_secrets_env_vars = [ (SecretSpec.secret_spec_from_str("t.1", None), {"A": "a", "B": "b"}), (SecretSpec.secret_spec_from_str("t.2", None), {"B": "b", "C": "c"}), @@ -122,12 +147,15 @@ def test_validate_env_vars_across_secrets_rejects_overlap(): validate_env_vars_across_secrets(all_secrets_env_vars) -def test_validate_env_vars_vs_existing_env_rejects_collision(): - existing_os_env_k, existing_os_env_v = next(iter(os.environ.items())) +def test_validate_env_vars_vs_existing_env_rejects_collision(monkeypatch): + """Test that env vars provided by a secret cannot overwrite existing OS environment variables.""" + # Setup a deterministic existing environment variable + monkeypatch.setenv("EXISTING_MOCK_ENV_VAR", "some_value") + all_secrets_env_vars = [ ( SecretSpec.secret_spec_from_str("t.1", None), - {"A": "a", existing_os_env_k: existing_os_env_v}, + {"A": "a", "EXISTING_MOCK_ENV_VAR": "secret_value"}, ), ] with pytest.raises(MetaflowException): @@ -135,6 +163,7 @@ def test_validate_env_vars_vs_existing_env_rejects_collision(): def test_validate_env_vars_accepts_typical_keys(): + """Test that standard, well-formed bash-compatible environment variable keys pass validation.""" validate_env_vars( { "TYPICAL_KEY_1": "TYPICAL_VALUE_1", @@ -143,14 +172,20 @@ def test_validate_env_vars_accepts_typical_keys(): ) -@pytest.mark.parametrize("bad_key", [1, tuple(), b"old_school"]) +@pytest.mark.parametrize( + "bad_key", [1, tuple(), b"old_school"], ids=["int_key", "tuple_key", "bytes_key"] +) def test_validate_env_vars_rejects_mistyped_keys(bad_key): + """Test that non-string keys raise validation errors.""" with pytest.raises(MetaflowException): validate_env_vars({bad_key: "v"}) -@pytest.mark.parametrize("bad_value", [1, {}, b"old_school"]) +@pytest.mark.parametrize( + "bad_value", [1, {}, b"old_school"], ids=["int_value", "dict_value", "bytes_value"] +) def test_validate_env_vars_rejects_mistyped_values(bad_value): + """Test that non-string values raise validation errors.""" with pytest.raises(MetaflowException): validate_env_vars({"K": bad_value}) @@ -165,7 +200,16 @@ def test_validate_env_vars_rejects_mistyped_values(bad_value): "door-", "METAFLOW_SOMETHING_OR_OTHER", ], + ids=[ + "starts_with_number", + "contains_space", + "contains_symbol", + "contains_unicode", + "ends_with_dash", + "reserved_metaflow_prefix", + ], ) def test_validate_env_vars_rejects_weird_keys(weird_key): + """Test that invalid bash identifier formats or reserved prefixes are rejected.""" with pytest.raises(MetaflowException): validate_env_vars({weird_key: "v"}) diff --git a/test/unit/test_serializer_integration.py b/test/unit/test_serializer_integration.py index c8e09d46c00..d3885ffd0d5 100644 --- a/test/unit/test_serializer_integration.py +++ b/test/unit/test_serializer_integration.py @@ -9,12 +9,14 @@ """ import json -import os -import shutil -import tempfile +import threading import pytest +from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, +) from metaflow.datastore.artifacts.serializer import ( ArtifactSerializer, SerializationFormat, @@ -22,21 +24,51 @@ SerializedBlob, SerializerStore, ) +from metaflow.datastore.exceptions import ( + DataException, + UnpicklableArtifactException, +) +from metaflow.datastore.flow_datastore import FlowDataStore +from metaflow.exception import MetaflowException +from metaflow.plugins.datastores.local_storage import LocalStorage from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + # --------------------------------------------------------------------------- -# Test PickleSerializer round-trip through save/load artifacts +# Fixtures # --------------------------------------------------------------------------- +@pytest.fixture +def isolated_store(): + """ + Fixture to isolate SerializerStore global state per test. + Prevents tests from poisoning the registry if they fail mid-execution. + """ + saved_all = dict(SerializerStore._all_serializers) + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + saved_pending = dict(SerializerStore._pending_by_module) + saved_cache = SerializerStore._ordered_cache + + yield + + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(saved_all) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + SerializerStore._pending_by_module.clear() + SerializerStore._pending_by_module.update(saved_pending) + SerializerStore._ordered_cache = saved_cache + + @pytest.fixture def task_datastore(tmp_path): """Create a minimal TaskDataStore wired to a local storage backend.""" - from metaflow.datastore.flow_datastore import FlowDataStore - from metaflow.plugins.datastores.local_storage import LocalStorage - - storage_root = str(tmp_path / "datastore") - os.makedirs(storage_root, exist_ok=True) + storage_root = tmp_path / "datastore" + storage_root.mkdir() flow_ds = FlowDataStore( flow_name="TestFlow", @@ -45,7 +77,7 @@ def task_datastore(tmp_path): event_logger=None, monitor=None, storage_impl=LocalStorage, - ds_root=storage_root, + ds_root=str(storage_root), ) task_ds = flow_ds.get_task_datastore( @@ -62,30 +94,35 @@ def task_datastore(tmp_path): return task_ds -def test_save_load_pickle_round_trip(task_datastore): - """Standard Python objects go through PickleSerializer and round-trip.""" - artifacts = [ +# --------------------------------------------------------------------------- +# Test PickleSerializer round-trip through save/load artifacts +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, value", + [ ("my_dict", {"key": "value", "nested": [1, 2, 3]}), ("my_int", 42), ("my_str", "hello world"), ("my_none", None), - ] - task_datastore.save_artifacts(iter(artifacts)) + ], + ids=["dict", "int", "str", "none"], +) +def test_save_load_pickle_round_trip(task_datastore, name, value): + """Standard Python objects go through PickleSerializer and round-trip.""" + task_datastore.save_artifacts(iter([(name, value)])) # Verify metadata - for name, _ in artifacts: - info = task_datastore._info[name] - assert "encoding" in info - assert info["encoding"] == "pickle-v4" - assert info["size"] > 0 - assert "type" in info + info = task_datastore._info[name] + assert "encoding" in info + assert info["encoding"] == "pickle-v4" + assert info["size"] > 0 + assert "type" in info # Load and verify - loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) - assert loaded["my_dict"] == {"key": "value", "nested": [1, 2, 3]} - assert loaded["my_int"] == 42 - assert loaded["my_str"] == "hello world" - assert loaded["my_none"] is None + loaded = dict(task_datastore.load_artifacts([name])) + assert loaded[name] == value def test_distinct_objects_on_load(task_datastore): @@ -107,14 +144,9 @@ def test_metadata_auto_populates_source_for_pickle(task_datastore): assert info.get("serializer_info", {}).get("source") == "metaflow" -def test_author_source_is_not_overridden(task_datastore): +def test_author_source_is_not_overridden(task_datastore, isolated_store): """A serializer that sets its own ``source`` in serializer_info should not have it overridden by the auto-injected bootstrap source.""" - from metaflow.datastore.artifacts import SerializationFormat, SerializerStore - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) class _ExplicitSourceSerializer(ArtifactSerializer): TYPE = "test_explicit_source" @@ -159,15 +191,9 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): task_datastore._serializers = [_ExplicitSourceSerializer, PickleSerializer] - try: - task_datastore.save_artifacts(iter([("hello", "world")])) - info = task_datastore._info["hello"] - assert info["serializer_info"]["source"] == "i-picked-this-myself" - finally: - SerializerStore._records.pop("test_explicit_source", None) - SerializerStore._active_serializers.discard(_ExplicitSourceSerializer) - SerializerStore._all_serializers.pop("test_explicit_source", None) - SerializerStore._ordered_cache = None + task_datastore.save_artifacts(iter([("hello", "world")])) + info = task_datastore._info["hello"] + assert info["serializer_info"]["source"] == "i-picked-this-myself" # --------------------------------------------------------------------------- @@ -175,10 +201,9 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): # --------------------------------------------------------------------------- -def test_custom_serializer_takes_priority(task_datastore): +def test_custom_serializer_takes_priority(task_datastore, isolated_store): """A custom serializer with lower PRIORITY claims matching objects over pickle.""" - # Define and register a custom serializer inside the test class _JsonStringSerializer(ArtifactSerializer): TYPE = "test_json_str" PRIORITY = 50 @@ -208,29 +233,23 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format="storage"): return json.loads(data[0].decode("utf-8")) - # Explicitly set serializers: custom first, then pickle fallback. - # Don't use get_ordered_serializers() to avoid pollution from other test files. task_datastore._serializers = [_JsonStringSerializer, PickleSerializer] - try: - task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) + task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) - # "msg" should use our custom serializer (str → _JsonStringSerializer) - msg_info = task_datastore._info["msg"] - assert msg_info["encoding"] == "test_json_str" - assert msg_info["serializer_info"] == {"format": "json-utf8"} + # "msg" should use our custom serializer (str → _JsonStringSerializer) + msg_info = task_datastore._info["msg"] + assert msg_info["encoding"] == "test_json_str" + assert msg_info["serializer_info"] == {"format": "json-utf8"} - # "num" should fall through to PickleSerializer (int → not claimed by custom) - num_info = task_datastore._info["num"] - assert num_info["encoding"] == "pickle-v4" + # "num" should fall through to PickleSerializer (int → not claimed by custom) + num_info = task_datastore._info["num"] + assert num_info["encoding"] == "pickle-v4" - # Both round-trip correctly - loaded = dict(task_datastore.load_artifacts(["msg", "num"])) - assert loaded["msg"] == "hello" - assert loaded["num"] == 42 - finally: - SerializerStore._all_serializers.pop("test_json_str", None) - SerializerStore._ordered_cache = None + # Both round-trip correctly + loaded = dict(task_datastore.load_artifacts(["msg", "num"])) + assert loaded["msg"] == "hello" + assert loaded["num"] == 42 # --------------------------------------------------------------------------- @@ -240,7 +259,6 @@ def deserialize(cls, data, metadata=None, format="storage"): def test_backward_compat_old_metadata(task_datastore): """Artifacts saved with old metadata format (no serializer_info) still load.""" - # Save normally first task_datastore.save_artifacts(iter([("old_artifact", {"a": 1})])) # Simulate old metadata format: no serializer_info, old encoding @@ -248,27 +266,22 @@ def test_backward_compat_old_metadata(task_datastore): "size": 100, "type": "", "encoding": "gzip+pickle-v4", - # no "serializer_info" key } - # Should still load via PickleSerializer (can_deserialize handles gzip+pickle-v4) + # Should still load via PickleSerializer loaded = dict(task_datastore.load_artifacts(["old_artifact"])) assert loaded["old_artifact"] == {"a": 1} def test_backward_compat_no_encoding(task_datastore): """Very old artifacts without encoding field default to gzip+pickle-v2.""" - # Save an artifact task_datastore.save_artifacts(iter([("ancient", 99)])) - # Simulate very old metadata: no encoding, no serializer_info task_datastore._info["ancient"] = { "size": 10, "type": "", - # no "encoding" key — defaults to gzip+pickle-v2 } - # Should still load loaded = dict(task_datastore.load_artifacts(["ancient"])) assert loaded["ancient"] == 99 @@ -289,7 +302,6 @@ def test_missing_info_with_object_uses_pickle_defaults(task_datastore): del task_datastore._info["present"] loaded = dict(task_datastore.load_artifacts(["present"])) - assert loaded["present"] == {"value": 1} @@ -311,13 +323,12 @@ def test_info_without_object_raises_key_error(task_datastore): # --------------------------------------------------------------------------- -def test_post_init_registration_reaches_existing_datastore(task_datastore): +def test_post_init_registration_reaches_existing_datastore( + task_datastore, isolated_store +): """A serializer registered AFTER the datastore was constructed must still be visible. Without the dynamic ``_serializers`` property, lazy imports - (e.g. ``import torch`` after ``TaskDataStore.__init__``) would be silently - ignored for that instance. - """ - # Drop the test override so the property falls back to the live registry. + would be silently ignored for that instance.""" task_datastore._serializers = None class _PostInitSerializer(ArtifactSerializer): @@ -340,15 +351,10 @@ def serialize(cls, obj, format="storage"): def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError - # Dispatch reads from _active_serializers now (post-Phase-6). SerializerStore._active_serializers.add(_PostInitSerializer) SerializerStore._ordered_cache = None - try: - assert _PostInitSerializer in task_datastore._serializers - finally: - SerializerStore._all_serializers.pop("test_post_init_registration", None) - SerializerStore._active_serializers.discard(_PostInitSerializer) - SerializerStore._ordered_cache = None + + assert _PostInitSerializer in task_datastore._serializers # --------------------------------------------------------------------------- @@ -356,13 +362,11 @@ def deserialize(cls, data, metadata=None, format="storage"): # --------------------------------------------------------------------------- -def test_info_not_populated_when_serializer_returns_no_blobs(task_datastore): - """ - Regression for the "_info[name] poisoned on validation failure" bug: if a - serializer returns an empty blob list, ``save_artifacts`` must raise - without leaving partial metadata in ``_info``. - """ - from metaflow.datastore.exceptions import DataException +def test_info_not_populated_when_serializer_returns_no_blobs( + task_datastore, isolated_store +): + """If a serializer returns an empty blob list, ``save_artifacts`` must raise + without leaving partial metadata in ``_info``.""" class _EmptyBlobSerializer(ArtifactSerializer): TYPE = "test_empty_blob" @@ -378,28 +382,23 @@ def can_deserialize(cls, metadata): @classmethod def serialize(cls, obj, format="storage"): - return ( - [], - SerializationMetadata("x", 0, "test_empty_blob", {}), - ) + return ([], SerializationMetadata("x", 0, "test_empty_blob", {})) @classmethod def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError task_datastore._serializers = [_EmptyBlobSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="returned no blobs"): - task_datastore.save_artifacts(iter([("bad", object())])) - assert "bad" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_empty_blob", None) - SerializerStore._ordered_cache = None + + with pytest.raises(DataException, match="returned no blobs"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info -def test_info_not_populated_when_serializer_returns_multi_blob(task_datastore): +def test_info_not_populated_when_serializer_returns_multi_blob( + task_datastore, isolated_store +): """Same guarantee as above for the multi-blob rejection path.""" - from metaflow.datastore.exceptions import DataException class _MultiBlobSerializer(ArtifactSerializer): TYPE = "test_multi_blob" @@ -425,31 +424,20 @@ def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError task_datastore._serializers = [_MultiBlobSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="single-blob serializers"): - task_datastore.save_artifacts(iter([("bad", object())])) - assert "bad" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_multi_blob", None) - SerializerStore._ordered_cache = None + + with pytest.raises(DataException, match="single-blob serializers"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info # --------------------------------------------------------------------------- -# Exception flow: PickleSerializer owns its own UnpicklableArtifactException; -# extension MetaflowExceptions pass through; everything else gets wrapped. +# Exception flow mapping # --------------------------------------------------------------------------- def test_pickle_serializer_raises_unpicklable_with_artifact_name(task_datastore): - """PickleSerializer raises ``UnpicklableArtifactException`` from inside - ``serialize()`` (no name); ``save_artifacts`` re-raises it with the - artifact name attached so users see the original "named X" message.""" - import threading - - from metaflow.datastore.exceptions import UnpicklableArtifactException - - # ``threading.Lock`` raises ``TypeError`` from ``pickle.dumps``, which is - # the path that turns into ``UnpicklableArtifactException``. + """PickleSerializer raises ``UnpicklableArtifactException``; + ``save_artifacts`` re-raises it with the artifact name attached.""" unpicklable = threading.Lock() with pytest.raises(UnpicklableArtifactException, match='named "bad_one"'): @@ -457,11 +445,9 @@ def test_pickle_serializer_raises_unpicklable_with_artifact_name(task_datastore) assert "bad_one" not in task_datastore._info -def test_extension_metaflow_exception_passes_through(task_datastore): +def test_extension_metaflow_exception_passes_through(task_datastore, isolated_store): """An extension serializer raising a ``MetaflowException`` subclass must - propagate as-is — wrapping it in ``DataException`` would obscure the - original headline/message that is already user-facing.""" - from metaflow.exception import MetaflowException + propagate as-is.""" class _ExtensionError(MetaflowException): headline = "Extension validation failed" @@ -487,22 +473,17 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError task_datastore._serializers = [_RaisingSerializer, PickleSerializer] - try: - with pytest.raises(_ExtensionError, match="schema mismatch on field 'foo'"): - task_datastore.save_artifacts(iter([("x", object())])) - assert "x" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_passthrough_ser", None) - SerializerStore._ordered_cache = None + + with pytest.raises(_ExtensionError, match="schema mismatch on field 'foo'"): + task_datastore.save_artifacts(iter([("x", object())])) + assert "x" not in task_datastore._info -def test_extension_type_error_is_not_mislabeled_unpicklable(task_datastore): +def test_extension_type_error_is_not_mislabeled_unpicklable( + task_datastore, isolated_store +): """A non-pickle serializer raising ``TypeError`` must NOT be reported as - ``UnpicklableArtifactException`` — that wrapper is pickle-specific now.""" - from metaflow.datastore.exceptions import ( - DataException, - UnpicklableArtifactException, - ) + ``UnpicklableArtifactException``.""" class _TypeErrorSerializer(ArtifactSerializer): TYPE = "test_typeerror_ser" @@ -525,31 +506,23 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError task_datastore._serializers = [_TypeErrorSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="_TypeErrorSerializer") as exc_info: - task_datastore.save_artifacts(iter([("x", object())])) - # Critically: NOT UnpicklableArtifactException — that name would lie - # to users about which serializer rejected the object. - assert not isinstance(exc_info.value, UnpicklableArtifactException) - assert "x" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_typeerror_ser", None) - SerializerStore._ordered_cache = None - - -def test_can_serialize_exception_falls_through_to_pickle(task_datastore): + + with pytest.raises(DataException, match="_TypeErrorSerializer") as exc_info: + task_datastore.save_artifacts(iter([("x", object())])) + + assert not isinstance(exc_info.value, UnpicklableArtifactException) + assert "x" not in task_datastore._info + + +def test_can_serialize_exception_falls_through_to_pickle( + task_datastore, isolated_store +): """A buggy custom serializer's can_serialize exception must NOT crash - save_artifacts. The buggy serializer is skipped; pickle fallback handles - the artifact; dispatch_error_count is incremented.""" - from metaflow.datastore.artifacts import SerializationFormat - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) + save_artifacts. Pickle fallback handles it; dispatch_error_count is incremented.""" class _BuggyCanSerialize(ArtifactSerializer): TYPE = "test_buggy_cs" - PRIORITY = 1 # tried first + PRIORITY = 1 @classmethod def can_serialize(cls, obj): @@ -567,7 +540,6 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - # Seed a diagnostic record so dispatch_error_count has somewhere to go. rec = SerializerRecord( name="test_buggy_cs", class_path="test.inline.BuggyCanSerialize", @@ -580,27 +552,15 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): task_datastore._serializers = [_BuggyCanSerialize, PickleSerializer] - try: - # Must NOT raise. - task_datastore.save_artifacts(iter([("x", 42)])) - assert task_datastore._info["x"]["encoding"] == "pickle-v4" - assert rec.dispatch_error_count == 1 - assert rec.last_error is not None - assert "RuntimeError" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_buggy_cs", None) - SerializerStore._active_serializers.discard(_BuggyCanSerialize) - SerializerStore._records.pop("test_buggy_cs", None) - SerializerStore._ordered_cache = None - - -def test_can_deserialize_exception_falls_through(task_datastore): + # Must NOT raise. + task_datastore.save_artifacts(iter([("x", 42)])) + assert task_datastore._info["x"]["encoding"] == "pickle-v4" + assert rec.dispatch_error_count == 1 + assert "RuntimeError" in rec.last_error + + +def test_can_deserialize_exception_falls_through(task_datastore, isolated_store): """Same guarantee for can_deserialize during load_artifacts.""" - from metaflow.datastore.artifacts import SerializationFormat - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) class _BuggyCanDeserialize(ArtifactSerializer): TYPE = "test_buggy_cd" @@ -632,35 +592,22 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): SerializerStore._records["test_buggy_cd"] = rec SerializerStore._active_serializers.add(_BuggyCanDeserialize) - # First save an artifact normally via pickle so load has something to load. + # First save an artifact normally via pickle task_datastore._serializers = [PickleSerializer] task_datastore.save_artifacts(iter([("y", "hello")])) - # Now install the buggy serializer and try to load — buggy can_deserialize - # should be skipped and pickle should take over. + # Now install the buggy serializer and try to load task_datastore._serializers = [_BuggyCanDeserialize, PickleSerializer] - try: - loaded = dict(task_datastore.load_artifacts(["y"])) - assert loaded["y"] == "hello" - assert rec.dispatch_error_count == 1 - assert rec.last_error is not None - assert "RuntimeError" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_buggy_cd", None) - SerializerStore._active_serializers.discard(_BuggyCanDeserialize) - SerializerStore._records.pop("test_buggy_cd", None) - SerializerStore._ordered_cache = None + loaded = dict(task_datastore.load_artifacts(["y"])) + assert loaded["y"] == "hello" + assert rec.dispatch_error_count == 1 + assert "RuntimeError" in rec.last_error -def test_subclass_lazy_import_stashes_on_child_not_parent(): +def test_subclass_lazy_import_stashes_on_child_not_parent(isolated_store): """lazy_import on a subclass should set attrs on the subclass, not the parent. Parent and children should each have their own _lazy_imported_names set.""" - from metaflow.datastore.artifacts import ( - ArtifactSerializer, - SerializationFormat, - SerializerStore, - ) class _ParentSer(ArtifactSerializer): TYPE = "test_inherit_parent" @@ -692,29 +639,16 @@ class _ChildSer(_ParentSer): def setup_imports(cls, context=None): cls.lazy_import("sys") - try: - _ParentSer.setup_imports() - _ChildSer.setup_imports() - import json as _json - import sys as _sys - - # Parent has json; child has sys - assert _ParentSer.json is _json - assert _ChildSer.sys is _sys - - # Each class should have its OWN _lazy_imported_names set - # (not a shared inherited one) - parent_names = _ParentSer.__dict__.get("_lazy_imported_names", set()) - child_names = _ChildSer.__dict__.get("_lazy_imported_names", set()) - assert parent_names == {"json"} - assert child_names == {"sys"} - finally: - for t in ("test_inherit_parent", "test_inherit_child"): - SerializerStore._all_serializers.pop(t, None) - SerializerStore._ordered_cache = None - for c, attr in ((_ParentSer, "json"), (_ChildSer, "sys")): - if attr in c.__dict__: - delattr(c, attr) - for c in (_ParentSer, _ChildSer): - if "_lazy_imported_names" in c.__dict__: - delattr(c, "_lazy_imported_names") + _ParentSer.setup_imports() + _ChildSer.setup_imports() + + import json as _json + import sys as _sys + + assert _ParentSer.json is _json + assert _ChildSer.sys is _sys + + parent_names = _ParentSer.__dict__.get("_lazy_imported_names", set()) + child_names = _ChildSer.__dict__.get("_lazy_imported_names", set()) + assert parent_names == {"json"} + assert child_names == {"sys"} diff --git a/test/unit/test_serializer_lifecycle.py b/test/unit/test_serializer_lifecycle.py index d201f06587a..4e7fcef4e0b 100644 --- a/test/unit/test_serializer_lifecycle.py +++ b/test/unit/test_serializer_lifecycle.py @@ -9,6 +9,68 @@ SerializerRecord, SerializerState, ) +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializerStore, + SerializationFormat, +) +from metaflow import metaflow_config + + +# --- Fixtures --- + + +@pytest.fixture +def isolated_store(): + """ + Fixture to isolate SerializerStore global state per test. + Provides a clean slate and restores original state afterward. + """ + # Snapshot original state + saved_all = dict(SerializerStore._all_serializers) + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + saved_pending = dict(SerializerStore._pending_by_module) + saved_cache = SerializerStore._ordered_cache + + # Clear state for test isolation + SerializerStore._all_serializers.clear() + SerializerStore._active_serializers.clear() + SerializerStore._records.clear() + SerializerStore._pending_by_module.clear() + SerializerStore._ordered_cache = None + + yield + + # Restore original state + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(saved_all) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + SerializerStore._pending_by_module.clear() + SerializerStore._pending_by_module.update(saved_pending) + SerializerStore._ordered_cache = saved_cache + + +@pytest.fixture +def make_serializer_module(monkeypatch): + """ + Factory fixture to dynamically create a module from source code + and register it. Automatically cleans up sys.modules via monkeypatch. + """ + + def _make(module_name, source_code): + mod = types.ModuleType(module_name) + exec(source_code, mod.__dict__) + monkeypatch.setitem(sys.modules, module_name, mod) + return mod + + return _make + + +# --- Tests --- def test_serializer_record_default_fields(): @@ -45,14 +107,7 @@ def test_serializer_record_as_dict(): assert d["import_trigger"] == "eager" -from metaflow.datastore.artifacts.serializer import ( - ArtifactSerializer, - SerializerStore, - SerializationFormat, -) - - -def test_store_separates_all_vs_active(): +def test_store_separates_all_vs_active(isolated_store): """_all_serializers is the known-classes index; _active_serializers is dispatch pool.""" class _Known(ArtifactSerializer): @@ -74,23 +129,18 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - # Metaclass registers on class body execution - assert _Known in SerializerStore._all_serializers.values() - # But without bootstrap, it is NOT in the dispatch pool - assert _Known not in SerializerStore._active_serializers - # _records is an empty dict initially (for entries from DESC tuples) - assert isinstance(SerializerStore._records, dict) - finally: - SerializerStore._all_serializers.pop("test_known_not_active", None) - SerializerStore._active_serializers.discard(_Known) - SerializerStore._ordered_cache = None + # Metaclass registers on class body execution + assert _Known in SerializerStore._all_serializers.values() + # But without bootstrap, it is NOT in the dispatch pool + assert _Known not in SerializerStore._active_serializers + # _records is an empty dict initially (for entries from DESC tuples) + assert isinstance(SerializerStore._records, dict) -def test_bootstrap_activates_dependency_free_serializer(): +def test_bootstrap_activates_dependency_free_serializer( + isolated_store, make_serializer_module +): """bootstrap_entries with an in-process serializer moves it to ACTIVE.""" - - mod = types.ModuleType("_test_bootstrap_mod") source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat @@ -103,40 +153,26 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError """ - exec(source, mod.__dict__) - sys.modules["_test_bootstrap_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_bootstrap_probe", "_test_bootstrap_mod._BootProbe"), - ] - ) - rec = SerializerStore._records["test_bootstrap_probe"] - assert rec.state == SerializerState.ACTIVE - assert rec.priority == 60 - assert rec.type == "test_bootstrap_probe" - assert rec.import_trigger == "eager" - assert mod._BootProbe in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_bootstrap_probe", None) - SerializerStore._active_serializers.discard(mod._BootProbe) - SerializerStore._records.pop("test_bootstrap_probe", None) - SerializerStore._ordered_cache = None - del sys.modules["_test_bootstrap_mod"] - - -def test_bootstrap_rejects_name_type_mismatch(): - """Tuple first element MUST equal class.TYPE.""" + mod = make_serializer_module("_test_bootstrap_mod", source) + SerializerStore.bootstrap_entries( + [("test_bootstrap_probe", "_test_bootstrap_mod._BootProbe")] + ) + + rec = SerializerStore._records["test_bootstrap_probe"] + assert rec.state == SerializerState.ACTIVE + assert rec.priority == 60 + assert rec.type == "test_bootstrap_probe" + assert rec.import_trigger == "eager" + assert mod._BootProbe in SerializerStore._active_serializers - mod = types.ModuleType("_test_mismatch_mod") - exec( - """ + +def test_bootstrap_rejects_name_type_mismatch(isolated_store, make_serializer_module): + """Tuple first element MUST equal class.TYPE.""" + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Mismatch(ArtifactSerializer): @@ -146,161 +182,113 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_mismatch_mod", source) + SerializerStore.bootstrap_entries( + [("declared_name", "_test_mismatch_mod._Mismatch")] ) - sys.modules["_test_mismatch_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("declared_name", "_test_mismatch_mod._Mismatch"), - ] - ) - rec = SerializerStore._records["declared_name"] - assert rec.state == SerializerState.BROKEN - assert "tuple name" in rec.last_error - assert "actual_type" in rec.last_error - finally: - SerializerStore._all_serializers.pop("actual_type", None) - SerializerStore._records.pop("declared_name", None) - del sys.modules["_test_mismatch_mod"] - - -def test_bootstrap_missing_module_parks_entry(): - """ModuleNotFoundError during import_module moves entry to PENDING_ON_IMPORTS.""" + rec = SerializerStore._records["declared_name"] + assert rec.state == SerializerState.BROKEN + assert "tuple name" in rec.last_error + assert "actual_type" in rec.last_error + + +def test_bootstrap_missing_module_parks_entry(isolated_store): + """ModuleNotFoundError during import_module moves entry to PENDING_ON_IMPORTS.""" SerializerStore.bootstrap_entries( - [ - ("test_absent", "_never_created_module._Absent"), - ] + [("test_absent", "_never_created_module._Absent")] ) - try: - rec = SerializerStore._records["test_absent"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "_never_created_module" in rec.awaiting_modules - # _pending_by_module should track it - assert "test_absent" in SerializerStore._pending_by_module.get( - "_never_created_module", [] - ) - finally: - SerializerStore._records.pop("test_absent", None) - SerializerStore._pending_by_module.pop("_never_created_module", None) - - -def test_bootstrap_missing_class_in_module_broken(): + + rec = SerializerStore._records["test_absent"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "_never_created_module" in rec.awaiting_modules + assert "test_absent" in SerializerStore._pending_by_module.get( + "_never_created_module", [] + ) + + +def test_bootstrap_missing_class_in_module_broken( + isolated_store, make_serializer_module +): """getattr failure after successful module import moves to BROKEN.""" - mod = types.ModuleType("_test_no_class_mod") - # Intentionally empty — no class inside - sys.modules["_test_no_class_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_no_class", "_test_no_class_mod._Missing"), - ] - ) - rec = SerializerStore._records["test_no_class"] - assert rec.state == SerializerState.BROKEN - assert "class" in rec.last_error.lower() - assert "_Missing" in rec.last_error - finally: - SerializerStore._records.pop("test_no_class", None) - del sys.modules["_test_no_class_mod"] - - -def test_bootstrap_setup_imports_missing_dep_parks_entry(): + # Intentionally empty module — no class inside + make_serializer_module("_test_no_class_mod", "") + SerializerStore.bootstrap_entries( + [("test_no_class", "_test_no_class_mod._Missing")] + ) + + rec = SerializerStore._records["test_no_class"] + assert rec.state == SerializerState.BROKEN + assert "class" in rec.last_error.lower() + assert "_Missing" in rec.last_error + + +def test_bootstrap_setup_imports_missing_dep_parks_entry( + isolated_store, make_serializer_module +): """ImportError inside setup_imports parks on the missing module name.""" - mod = types.ModuleType("_test_setup_missing_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _WantsMissing(ArtifactSerializer): TYPE = "test_setup_wants_missing" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("absent_at_setup_time_xyz") + def setup_imports(cls, context=None): cls.lazy_import("absent_at_setup_time_xyz") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_setup_missing_mod", source) + SerializerStore.bootstrap_entries( + [("test_setup_wants_missing", "_test_setup_missing_mod._WantsMissing")] ) - sys.modules["_test_setup_missing_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_setup_wants_missing", "_test_setup_missing_mod._WantsMissing"), - ] - ) - rec = SerializerStore._records["test_setup_wants_missing"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "absent_at_setup_time_xyz" in rec.awaiting_modules - finally: - SerializerStore._all_serializers.pop("test_setup_wants_missing", None) - SerializerStore._records.pop("test_setup_wants_missing", None) - SerializerStore._pending_by_module.pop("absent_at_setup_time_xyz", None) - del sys.modules["_test_setup_missing_mod"] - - -def test_bootstrap_setup_imports_other_exception_broken(): + + rec = SerializerStore._records["test_setup_wants_missing"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "absent_at_setup_time_xyz" in rec.awaiting_modules + + +def test_bootstrap_setup_imports_other_exception_broken( + isolated_store, make_serializer_module +): """Non-ImportError from setup_imports moves entry to BROKEN.""" - mod = types.ModuleType("_test_setup_boom_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Boom(ArtifactSerializer): TYPE = "test_boom" @classmethod - def setup_imports(cls, context=None): - raise RuntimeError("explicit boom from test") + def setup_imports(cls, context=None): raise RuntimeError("explicit boom from test") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, - ) - sys.modules["_test_setup_boom_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_boom", "_test_setup_boom_mod._Boom"), - ] - ) - rec = SerializerStore._records["test_boom"] - assert rec.state == SerializerState.BROKEN - assert "RuntimeError" in rec.last_error - assert "explicit boom" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_boom", None) - SerializerStore._records.pop("test_boom", None) - del sys.modules["_test_setup_boom_mod"] - - -def test_bootstrap_disabled_toggle(): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_setup_boom_mod", source) + SerializerStore.bootstrap_entries([("test_boom", "_test_setup_boom_mod._Boom")]) + + rec = SerializerStore._records["test_boom"] + assert rec.state == SerializerState.BROKEN + assert "RuntimeError" in rec.last_error + assert "explicit boom" in rec.last_error + + +def test_bootstrap_disabled_toggle(isolated_store, make_serializer_module): """Entries whose name is in disabled_names land in DISABLED state.""" - mod = types.ModuleType("_test_disable_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _DisableMe(ArtifactSerializer): @@ -310,280 +298,185 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + mod = make_serializer_module("_test_disable_mod", source) + SerializerStore.bootstrap_entries( + [("test_disable", "_test_disable_mod._DisableMe")], + disabled_names={"test_disable"}, ) - sys.modules["_test_disable_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [("test_disable", "_test_disable_mod._DisableMe")], - disabled_names={"test_disable"}, - ) - rec = SerializerStore._records["test_disable"] - assert rec.state == SerializerState.DISABLED - # Class should NOT be in active pool - assert mod._DisableMe not in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_disable", None) - SerializerStore._active_serializers.discard(mod._DisableMe) - SerializerStore._records.pop("test_disable", None) - del sys.modules["_test_disable_mod"] - - -def test_retry_activates_pending_record_on_module_import(): + + rec = SerializerStore._records["test_disable"] + assert rec.state == SerializerState.DISABLED + assert mod._DisableMe not in SerializerStore._active_serializers + + +def test_retry_activates_pending_record_on_module_import( + isolated_store, make_serializer_module, monkeypatch +): """When a pending record's awaited module imports, the record retries to ACTIVE.""" - ser_mod = types.ModuleType("_test_retry_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Pending(ArtifactSerializer): TYPE = "test_retry_pending" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("retry_dep_mod_name") + def setup_imports(cls, context=None): cls.lazy_import("retry_dep_mod_name") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_retry_ser_mod", source) + monkeypatch.delitem(sys.modules, "retry_dep_mod_name", raising=False) + + SerializerStore.bootstrap_entries( + [("test_retry_pending", "_test_retry_ser_mod._Pending")] + ) + + rec = SerializerStore._records["test_retry_pending"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "retry_dep_mod_name" in rec.awaiting_modules + assert ( + "test_retry_pending" in SerializerStore._pending_by_module["retry_dep_mod_name"] + ) + + # Simulate the dep becoming available, then fire the retry hook. + dep_mod = types.ModuleType("retry_dep_mod_name") + monkeypatch.setitem(sys.modules, "retry_dep_mod_name", dep_mod) + SerializerStore._on_module_imported("retry_dep_mod_name", dep_mod) + + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._Pending in SerializerStore._active_serializers + assert "test_retry_pending" not in SerializerStore._pending_by_module.get( + "retry_dep_mod_name", [] ) - sys.modules["_test_retry_ser_mod"] = ser_mod - sys.modules.pop("retry_dep_mod_name", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_retry_pending", "_test_retry_ser_mod._Pending"), - ] - ) - rec = SerializerStore._records["test_retry_pending"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "retry_dep_mod_name" in rec.awaiting_modules - assert ( - "test_retry_pending" - in SerializerStore._pending_by_module["retry_dep_mod_name"] - ) - - # Simulate the dep becoming available, then fire the retry hook. - dep_mod = types.ModuleType("retry_dep_mod_name") - sys.modules["retry_dep_mod_name"] = dep_mod - SerializerStore._on_module_imported("retry_dep_mod_name", dep_mod) - - assert rec.state == SerializerState.ACTIVE - assert rec.import_trigger == "hook" - assert ser_mod._Pending in SerializerStore._active_serializers - # _pending_by_module should no longer list this record under that module - assert "test_retry_pending" not in SerializerStore._pending_by_module.get( - "retry_dep_mod_name", [] - ) - finally: - SerializerStore._all_serializers.pop("test_retry_pending", None) - SerializerStore._active_serializers.discard(ser_mod._Pending) - SerializerStore._records.pop("test_retry_pending", None) - SerializerStore._pending_by_module.pop("retry_dep_mod_name", None) - SerializerStore._ordered_cache = None - sys.modules.pop("_test_retry_ser_mod", None) - sys.modules.pop("retry_dep_mod_name", None) - - -def test_retry_hits_loop_guard_after_repeated_failure(): + + +def test_retry_hits_loop_guard_after_repeated_failure( + isolated_store, make_serializer_module, monkeypatch +): """Calling _on_module_imported when setup_imports still fails on the same module name should transition to BROKEN via the loop guard.""" - ser_mod = types.ModuleType("_test_loop_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Loopy(ArtifactSerializer): TYPE = "test_loopy" @classmethod - def setup_imports(cls, context=None): - # Always raises on the same module name even if retried. - cls.lazy_import("never_resolves_mod") + def setup_imports(cls, context=None): cls.lazy_import("never_resolves_mod") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, - ) - sys.modules["_test_loop_ser_mod"] = ser_mod - sys.modules.pop("never_resolves_mod", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_loopy", "_test_loop_ser_mod._Loopy"), - ] - ) - rec = SerializerStore._records["test_loopy"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - - # Fake the dep appearing but NOT actually installing it — the retry - # will re-run setup_imports, which will ImportError again on the same name. - dep_mod = types.ModuleType("never_resolves_mod") - # We deliberately DO NOT put dep_mod in sys.modules, so lazy_import - # inside setup_imports will raise ModuleNotFoundError again. - SerializerStore._on_module_imported("never_resolves_mod", dep_mod) - - assert rec.state == SerializerState.BROKEN - assert "repeated" in rec.last_error.lower() - finally: - SerializerStore._all_serializers.pop("test_loopy", None) - SerializerStore._records.pop("test_loopy", None) - SerializerStore._pending_by_module.pop("never_resolves_mod", None) - sys.modules.pop("_test_loop_ser_mod", None) - sys.modules.pop("never_resolves_mod", None) - - -def test_retry_fires_via_real_import_hook(tmp_path, monkeypatch): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_loop_ser_mod", source) + monkeypatch.delitem(sys.modules, "never_resolves_mod", raising=False) + + SerializerStore.bootstrap_entries([("test_loopy", "_test_loop_ser_mod._Loopy")]) + + rec = SerializerStore._records["test_loopy"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Fake the dep appearing but DO NOT actually put it in sys.modules, + # so lazy_import will raise ModuleNotFoundError again. + dep_mod = types.ModuleType("never_resolves_mod") + SerializerStore._on_module_imported("never_resolves_mod", dep_mod) + + assert rec.state == SerializerState.BROKEN + assert "repeated" in rec.last_error.lower() + + +def test_retry_fires_via_real_import_hook( + tmp_path, monkeypatch, isolated_store, make_serializer_module +): """End-to-end: park a serializer on a missing module, install the hook, actually import the module, verify the serializer activates.""" + # Setup test package directory to act as an importable module pkg_dir = tmp_path / "fixture_retry_dep" pkg_dir.mkdir() (pkg_dir / "__init__.py").write_text("VALUE = 42\n") - # NOTE: we prepend syspath AFTER bootstrap_entries below so the first - # lazy_import() fails and the record parks on the missing module. - ser_mod = types.ModuleType("_test_e2e_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _E2E(ArtifactSerializer): TYPE = "test_e2e_retry" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("fixture_retry_dep") + def setup_imports(cls, context=None): cls.lazy_import("fixture_retry_dep") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, - ) - sys.modules["_test_e2e_ser_mod"] = ser_mod - # Make sure the dep module isn't pre-imported - sys.modules.pop("fixture_retry_dep", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_e2e_retry", "_test_e2e_ser_mod._E2E"), - ] - ) - rec = SerializerStore._records["test_e2e_retry"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - - # Now make the dep discoverable on sys.path and actually import it. - monkeypatch.syspath_prepend(str(tmp_path)) - import fixture_retry_dep # noqa: F401 - - # After real import, the hook chain should have fired. - assert rec.state == SerializerState.ACTIVE - assert rec.import_trigger == "hook" - assert ser_mod._E2E in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_e2e_retry", None) - SerializerStore._active_serializers.discard(ser_mod._E2E) - SerializerStore._records.pop("test_e2e_retry", None) - SerializerStore._pending_by_module.pop("fixture_retry_dep", None) - SerializerStore._ordered_cache = None - sys.modules.pop("_test_e2e_ser_mod", None) - sys.modules.pop("fixture_retry_dep", None) - # Clean up interceptor state - from metaflow.datastore.artifacts.lazy_registry import _interceptor - - _interceptor._watched.discard("fixture_retry_dep") - _interceptor._processed.discard("fixture_retry_dep") - - -def test_bootstrap_with_no_extensions_still_runs_core(): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_e2e_ser_mod", source) + monkeypatch.delitem(sys.modules, "fixture_retry_dep", raising=False) + + SerializerStore.bootstrap_entries([("test_e2e_retry", "_test_e2e_ser_mod._E2E")]) + + rec = SerializerStore._records["test_e2e_retry"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Make the dep discoverable on sys.path and actually import it + monkeypatch.syspath_prepend(str(tmp_path)) + import fixture_retry_dep # noqa: F401 + + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._E2E in SerializerStore._active_serializers + + # Cleanup interceptor state to prevent leaking to other tests + from metaflow.datastore.artifacts.lazy_registry import _interceptor + + _interceptor._watched.discard("fixture_retry_dep") + _interceptor._processed.discard("fixture_retry_dep") + + +def test_bootstrap_with_no_extensions_still_runs_core(isolated_store): """bootstrap() reads core ARTIFACT_SERIALIZERS_DESC from metaflow.plugins and activates PickleSerializer.""" - from metaflow.plugins.datastores.serializers.pickle_serializer import ( - PickleSerializer, + SerializerStore.bootstrap() + + # PickleSerializer should be in active pool (core entry) + pickle_active = any( + r.type == "pickle" and r.state == SerializerState.ACTIVE + for r in SerializerStore._records.values() ) + assert pickle_active - # Snapshot and clear state - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() - try: - SerializerStore.bootstrap() - # PickleSerializer should be in active pool (core entry) - pickle_active = any( - r.type == "pickle" and r.state == SerializerState.ACTIVE - for r in SerializerStore._records.values() - ) - assert pickle_active - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_bootstrap_stamps_core_source_on_record(): +def test_bootstrap_stamps_core_source_on_record(isolated_store): """Core serializers bootstrap with source='metaflow' on their records.""" - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() + SerializerStore.bootstrap() - try: - SerializerStore.bootstrap() - pickle_rec = next( - (r for r in SerializerStore._records.values() if r.type == "pickle"), - None, - ) - assert pickle_rec is not None - assert pickle_rec.source == "metaflow" - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_bootstrap_entries_accepts_source_override(): - """bootstrap_entries accepts ``source`` and attaches it to each record.""" - import sys as _sys - import types as _types + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.type == "pickle"), + None, + ) + assert pickle_rec is not None + assert pickle_rec.source == "metaflow" - saved_records = dict(SerializerStore._records) - saved_active = set(SerializerStore._active_serializers) - ser_mod = _types.ModuleType("_test_source_ser_mod") - exec( - """ +def test_bootstrap_entries_accepts_source_override( + isolated_store, make_serializer_module +): + """bootstrap_entries accepts ``source`` and attaches it to each record.""" + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _SourceProbe(ArtifactSerializer): @@ -593,72 +486,40 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_source_ser_mod", source) + + SerializerStore.bootstrap_entries( + [("test_source_probe", "_test_source_ser_mod._SourceProbe")], + source="fake-extension", ) - _sys.modules["_test_source_ser_mod"] = ser_mod - - try: - SerializerStore.bootstrap_entries( - [("test_source_probe", "_test_source_ser_mod._SourceProbe")], - source="fake-extension", - ) - rec = SerializerStore._records["test_source_probe"] - assert rec.source == "fake-extension" - assert SerializerStore.get_source_for(ser_mod._SourceProbe) == "fake-extension" - finally: - SerializerStore._all_serializers.pop("test_source_probe", None) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._ordered_cache = None - _sys.modules.pop("_test_source_ser_mod", None) - - -def test_bootstrap_applies_disabled_toggle(monkeypatch): - """bootstrap() respects -name toggles in ENABLED_ARTIFACT_SERIALIZER config.""" - from metaflow import metaflow_config - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() + rec = SerializerStore._records["test_source_probe"] + assert rec.source == "fake-extension" + assert SerializerStore.get_source_for(ser_mod._SourceProbe) == "fake-extension" + +def test_bootstrap_applies_disabled_toggle(isolated_store, monkeypatch): + """bootstrap() respects -name toggles in ENABLED_ARTIFACT_SERIALIZER config.""" monkeypatch.setattr( - metaflow_config, - "ENABLED_ARTIFACT_SERIALIZER", - ["-pickle"], - raising=False, + metaflow_config, "ENABLED_ARTIFACT_SERIALIZER", ["-pickle"], raising=False ) - try: - SerializerStore.bootstrap() - pickle_rec = next( - (r for r in SerializerStore._records.values() if r.name == "pickle"), - None, - ) - assert pickle_rec is not None - assert pickle_rec.state == SerializerState.DISABLED - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_list_serializer_status_returns_dicts(): - """list_serializer_status returns one dict per _records entry, with the - documented shape.""" - from metaflow.datastore.artifacts import list_serializer_status - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, + SerializerStore.bootstrap() + + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.name == "pickle"), + None, ) + assert pickle_rec is not None + assert pickle_rec.state == SerializerState.DISABLED + + +def test_list_serializer_status_returns_dicts(isolated_store): + """list_serializer_status returns one dict per _records entry, with the documented shape.""" + from metaflow.datastore.artifacts import list_serializer_status # Seed a fake record rec = SerializerRecord( @@ -671,115 +532,80 @@ def test_list_serializer_status_returns_dicts(): ) SerializerStore._records["fake_test_serializer"] = rec - try: - status = list_serializer_status() - assert isinstance(status, list) - match = next((s for s in status if s["name"] == "fake_test_serializer"), None) - assert match is not None - assert match["state"] == "active" - assert match["priority"] == 42 - assert match["type"] == "fake_test_serializer" - assert match["import_trigger"] == "eager" - assert match["class_path"] == "inline.FakeSer" - for key in ( - "name", - "class_path", - "state", - "awaiting_modules", - "last_error", - "priority", - "type", - "import_trigger", - "dispatch_error_count", - ): - assert key in match, "missing key '%s' in status dict" % key - finally: - SerializerStore._records.pop("fake_test_serializer", None) - - -def test_reset_for_tests_clears_registry_state(): + status = list_serializer_status() + assert isinstance(status, list) + + match = next((s for s in status if s["name"] == "fake_test_serializer"), None) + assert match is not None + assert match["state"] == "active" + assert match["priority"] == 42 + assert match["type"] == "fake_test_serializer" + assert match["import_trigger"] == "eager" + assert match["class_path"] == "inline.FakeSer" + + expected_keys = [ + "name", + "class_path", + "state", + "awaiting_modules", + "last_error", + "priority", + "type", + "import_trigger", + "dispatch_error_count", + ] + for key in expected_keys: + assert key in match, f"missing key '{key}' in status dict" + + +def test_reset_for_tests_clears_registry_state(isolated_store, make_serializer_module): """SerializerStore._reset_for_tests clears _records, _active_serializers, _pending_by_module, _ordered_cache, and per-class _lazy_imported_names.""" - import sys as _sys - import types as _types from metaflow.datastore.artifacts.lazy_registry import _interceptor - # Seed state by bootstrapping an inline serializer - ser_mod = _types.ModuleType("_test_reset_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _ResetProbe(ArtifactSerializer): TYPE = "test_reset_probe" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("json") + def setup_imports(cls, context=None): cls.lazy_import("json") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_reset_ser_mod", source) + + # Seed state by bootstrapping inline serializers + SerializerStore.bootstrap_entries( + [("test_reset_probe", "_test_reset_ser_mod._ResetProbe")] + ) + SerializerStore.bootstrap_entries( + [("test_reset_pending", "_never_exists_mod._Absent")] ) - _sys.modules["_test_reset_ser_mod"] = ser_mod - - try: - SerializerStore.bootstrap_entries( - [ - ("test_reset_probe", "_test_reset_ser_mod._ResetProbe"), - ] - ) - # Also seed a pending record to exercise _pending_by_module - SerializerStore.bootstrap_entries( - [ - ("test_reset_pending", "_never_exists_mod._Absent"), - ] - ) - - # Snapshot pre-reset state - assert "test_reset_probe" in SerializerStore._records - assert ser_mod._ResetProbe in SerializerStore._active_serializers - assert "_never_exists_mod" in SerializerStore._pending_by_module - # ResetProbe should have _lazy_imported_names populated - assert "json" in ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) - - # Pre-reset, the interceptor should be watching _never_exists_mod. - assert "_never_exists_mod" in _interceptor._watched - - # Call reset - SerializerStore._reset_for_tests() - - # Post-reset: all registry state empty - assert SerializerStore._records == {} - assert len(SerializerStore._active_serializers) == 0 - assert SerializerStore._pending_by_module == {} - assert SerializerStore._ordered_cache is None - - # The probe class should no longer have stashed attrs - assert "json" not in ser_mod._ResetProbe.__dict__ - assert ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) == set() - - # Interceptor watches should also be cleared - assert "_never_exists_mod" not in _interceptor._watched - finally: - # In case reset didn't clean up (e.g., test failed mid-way) - SerializerStore._all_serializers.pop("test_reset_probe", None) - SerializerStore._records.pop("test_reset_probe", None) - SerializerStore._records.pop("test_reset_pending", None) - SerializerStore._active_serializers.discard(ser_mod._ResetProbe) - SerializerStore._pending_by_module.clear() - SerializerStore._ordered_cache = None - _sys.modules.pop("_test_reset_ser_mod", None) - for attr in ("json",): - if attr in ser_mod._ResetProbe.__dict__: - delattr(ser_mod._ResetProbe, attr) - # Re-bootstrap so subsequent tests see the normal active pool - # (e.g. PickleSerializer in _active_serializers + _records). - SerializerStore.bootstrap() + + # Snapshot pre-reset state assertion + assert "test_reset_probe" in SerializerStore._records + assert ser_mod._ResetProbe in SerializerStore._active_serializers + assert "_never_exists_mod" in SerializerStore._pending_by_module + assert "json" in getattr(ser_mod._ResetProbe, "_lazy_imported_names", set()) + assert "_never_exists_mod" in _interceptor._watched + + # Act + SerializerStore._reset_for_tests() + + # Assert post-reset state is empty + assert SerializerStore._records == {} + assert len(SerializerStore._active_serializers) == 0 + assert SerializerStore._pending_by_module == {} + assert SerializerStore._ordered_cache is None + + # The probe class should no longer have stashed attrs + assert "json" not in ser_mod._ResetProbe.__dict__ + assert getattr(ser_mod._ResetProbe, "_lazy_imported_names", set()) == set() + assert "_never_exists_mod" not in _interceptor._watched diff --git a/test/unit/test_serializer_public_api.py b/test/unit/test_serializer_public_api.py index 50a3e7e8634..5d8dc5f1e19 100644 --- a/test/unit/test_serializer_public_api.py +++ b/test/unit/test_serializer_public_api.py @@ -1,55 +1,55 @@ """Smoke tests guarding the public surface of metaflow.datastore.artifacts.""" - -def test_register_serializer_for_type_not_public(): - """Imperative per-type registration is not a public API.""" - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "register_serializer_for_type") - - -def test_serializer_config_not_public(): - """SerializerConfig is not a public export.""" - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "SerializerConfig") - - -def test_register_serializer_config_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "register_serializer_config") - - -def test_iter_registered_configs_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "iter_registered_configs") - - -def test_load_serializer_class_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "load_serializer_class") - - -def test_plugins_has_no_artifact_serializers_global(): - """metaflow.plugins does not expose a resolved ARTIFACT_SERIALIZERS global. - Dispatch reads directly from SerializerStore.get_ordered_serializers().""" - import metaflow.plugins as mplugins - +import pytest + +import metaflow.datastore.artifacts as mda +import metaflow.plugins as mplugins +from metaflow.datastore.artifacts import list_serializer_status + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "internal_attribute", + [ + "register_serializer_for_type", + "SerializerConfig", + "register_serializer_config", + "iter_registered_configs", + "load_serializer_class", + ], + ids=[ + "register_serializer_for_type", + "serializer_config", + "register_serializer_config", + "iter_registered_configs", + "load_serializer_class", + ], +) +def test_datastore_artifacts_hides_internal_api(internal_attribute): + """Ensure internal serialization methods and configs are not exposed publicly.""" + assert not hasattr(mda, internal_attribute) + + +def test_plugins_hides_artifact_serializers_global(): + """ + metaflow.plugins does not expose a resolved ARTIFACT_SERIALIZERS global. + Dispatch reads directly from SerializerStore.get_ordered_serializers(). + """ assert not hasattr( mplugins, "ARTIFACT_SERIALIZERS" ), "Expected ARTIFACT_SERIALIZERS to be absent; still present" -def test_pickle_serializer_is_active_after_import(): - """After import metaflow, PickleSerializer should be in ACTIVE state.""" - from metaflow.datastore.artifacts import list_serializer_status - +def test_pickle_serializer_defaults_to_active_state(): + """After importing metaflow, PickleSerializer should be in the ACTIVE state.""" status = list_serializer_status() pickle_rec = next((r for r in status if r.get("type") == "pickle"), None) - assert pickle_rec is not None, "PickleSerializer record missing; status=%r" % status - assert pickle_rec["state"] == "active", ( - "Expected pickle active; got %r" % pickle_rec - ) + + assert pickle_rec is not None, f"PickleSerializer record missing; status={status!r}" + assert ( + pickle_rec["state"] == "active" + ), f"Expected pickle active; got {pickle_rec!r}" diff --git a/test/unit/test_sourceless_dag_node.py b/test/unit/test_sourceless_dag_node.py index dea6c318d78..aef36bcb9c9 100644 --- a/test/unit/test_sourceless_dag_node.py +++ b/test/unit/test_sourceless_dag_node.py @@ -21,16 +21,25 @@ produce, without depending on the extension package. """ +import pytest + from metaflow import FlowSpec, step from metaflow.lint import linter +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + -def _make_dynamic_single_step_flow(): +@pytest.fixture +def dynamic_single_step_flow_class(): + """Fixture to dynamically generate a FlowSpec class without an inspectable source file.""" namespace = {} exec( compile("def only(self):\n self.x = 42\n", "", "exec"), namespace ) only = step(start=True, end=True)(namespace["only"]) + return type( "DynamicSingleStepFlow", (FlowSpec,), @@ -38,10 +47,17 @@ def _make_dynamic_single_step_flow(): ) -def test_dynamic_single_step_without_inspectable_source(): - """Dynamically-generated @step(start=True, end=True) works without source.""" - graph = _make_dynamic_single_step_flow()._graph +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_sourceless_single_step_generates_valid_graph(dynamic_single_step_flow_class): + """Test that a dynamically-generated @step(start=True, end=True) parses and lints correctly without source.""" + # Act + graph = dynamic_single_step_flow_class._graph + # Assert: Graph properties are correctly synthesized assert graph.start_step == "only" assert graph.end_step == "only" assert graph["only"].type == "end" @@ -49,4 +65,5 @@ def test_dynamic_single_step_without_inspectable_source(): assert graph["only"].func_lineno == 1 assert graph["only"].source_file == "" + # Assert: The synthesized graph passes standard lint checks without crashing linter.run_checks(graph) diff --git a/test/unit/test_system_context.py b/test/unit/test_system_context.py index 5ac50be7ae1..f1265827504 100644 --- a/test/unit/test_system_context.py +++ b/test/unit/test_system_context.py @@ -10,122 +10,117 @@ from metaflow.decorators import Decorator, StepDecorator, FlowDecorator # --------------------------------------------------------------------------- -# Fixtures to ensure singleton is reset between tests +# Fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def reset_singleton(): + """Ensure the system context singleton is reset between tests.""" yield system_context._reset() -# --------------------------------------------------------------------------- -# ExecutionPhase -# --------------------------------------------------------------------------- - - -def test_execution_phase_enum_values(): - assert ExecutionPhase.LAUNCH.value == "launch" - assert ExecutionPhase.TRAMPOLINE.value == "trampoline" - assert ExecutionPhase.TASK.value == "task" - - -# --------------------------------------------------------------------------- -# _phase_from_cli_args -# --------------------------------------------------------------------------- - - -def test_phase_from_cli_args_none(): - assert _phase_from_cli_args(None) == ExecutionPhase.LAUNCH - - -def test_phase_from_cli_args_empty(): - assert _phase_from_cli_args([]) == ExecutionPhase.LAUNCH - - -def test_phase_from_cli_args_run_is_launch(): - assert _phase_from_cli_args(["run"]) == ExecutionPhase.LAUNCH - - -def test_phase_from_cli_args_resume_is_launch(): - assert _phase_from_cli_args(["resume"]) == ExecutionPhase.LAUNCH - - -def test_phase_from_cli_args_step_is_task(): - assert _phase_from_cli_args(["step", "mystep"]) == ExecutionPhase.TASK - - -def test_phase_from_cli_args_init_is_task(): - assert _phase_from_cli_args(["init"]) == ExecutionPhase.TASK - - -def test_phase_from_cli_args_spin_step_is_task(): - assert _phase_from_cli_args(["spin-step"]) == ExecutionPhase.TASK - - -def test_phase_from_cli_args_batch_is_trampoline(mocker): +@pytest.fixture +def mock_trampoline_plugins(mocker): + """Mock the trampoline plugin names to return a deterministic set.""" mocker.patch( "metaflow.plugins.get_trampoline_cli_names", return_value=frozenset({"batch", "kubernetes"}), ) - assert _phase_from_cli_args(["batch", "step", "train"]) == ExecutionPhase.TRAMPOLINE -def test_phase_from_cli_args_kubernetes_is_trampoline(mocker): - mocker.patch( - "metaflow.plugins.get_trampoline_cli_names", - return_value=frozenset({"batch", "kubernetes"}), - ) - assert ( - _phase_from_cli_args(["kubernetes", "step", "train"]) - == ExecutionPhase.TRAMPOLINE - ) - - -def test_phase_from_cli_args_deployment_is_launch(): - assert _phase_from_cli_args(["argo-workflows", "create"]) == ExecutionPhase.LAUNCH - assert _phase_from_cli_args(["step-functions", "create"]) == ExecutionPhase.LAUNCH +# --------------------------------------------------------------------------- +# ExecutionPhase & CLI Arg Resolution +# --------------------------------------------------------------------------- -def test_phase_from_cli_args_unknown_is_launch(): - assert _phase_from_cli_args(["show"]) == ExecutionPhase.LAUNCH - assert _phase_from_cli_args(["status"]) == ExecutionPhase.LAUNCH +@pytest.mark.parametrize( + "phase, expected_string", + [ + (ExecutionPhase.LAUNCH, "launch"), + (ExecutionPhase.TRAMPOLINE, "trampoline"), + (ExecutionPhase.TASK, "task"), + ], + ids=["launch_phase", "trampoline_phase", "task_phase"], +) +def test_execution_phase_enum_values_match_expected_strings(phase, expected_string): + """Test that the ExecutionPhase enum values evaluate to the correct string literals.""" + assert phase.value == expected_string + + +@pytest.mark.parametrize( + "cli_args, expected_phase", + [ + (None, ExecutionPhase.LAUNCH), + ([], ExecutionPhase.LAUNCH), + (["run"], ExecutionPhase.LAUNCH), + (["resume"], ExecutionPhase.LAUNCH), + (["argo-workflows", "create"], ExecutionPhase.LAUNCH), + (["step-functions", "create"], ExecutionPhase.LAUNCH), + (["show"], ExecutionPhase.LAUNCH), + (["status"], ExecutionPhase.LAUNCH), + (["step", "mystep"], ExecutionPhase.TASK), + (["init"], ExecutionPhase.TASK), + (["spin-step"], ExecutionPhase.TASK), + (["batch", "step", "train"], ExecutionPhase.TRAMPOLINE), + (["kubernetes", "step", "train"], ExecutionPhase.TRAMPOLINE), + ], + ids=[ + "none", + "empty", + "run", + "resume", + "argo_create", + "step_functions_create", + "show", + "status", + "step", + "init", + "spin_step", + "batch_plugin", + "k8s_plugin", + ], +) +def test_phase_from_cli_args_resolves_correct_execution_phase( + mock_trampoline_plugins, cli_args, expected_phase +): + """Test that the CLI argument parser correctly maps commands to ExecutionPhases.""" + assert _phase_from_cli_args(cli_args) == expected_phase # --------------------------------------------------------------------------- -# SystemContext: phase queries +# SystemContext: Phase Queries # --------------------------------------------------------------------------- -def test_system_context_launch_phase_queries(): - system_context._update(phase=ExecutionPhase.LAUNCH) - assert system_context.is_launch - assert not system_context.is_trampoline - assert not system_context.is_task - assert system_context.phase == ExecutionPhase.LAUNCH - - -def test_system_context_trampoline_phase_queries(): - system_context._update(phase=ExecutionPhase.TRAMPOLINE) - assert not system_context.is_launch - assert system_context.is_trampoline - assert not system_context.is_task - +@pytest.mark.parametrize( + "target_phase, expect_launch, expect_trampoline, expect_task", + [ + (ExecutionPhase.LAUNCH, True, False, False), + (ExecutionPhase.TRAMPOLINE, False, True, False), + (ExecutionPhase.TASK, False, False, True), + ], + ids=["launch_active", "trampoline_active", "task_active"], +) +def test_system_context_boolean_flags_reflect_current_phase( + target_phase, expect_launch, expect_trampoline, expect_task +): + """Test that the boolean helper properties correctly reflect the underlying phase.""" + system_context._update(phase=target_phase) -def test_system_context_task_phase_queries(): - system_context._update(phase=ExecutionPhase.TASK) - assert not system_context.is_launch - assert not system_context.is_trampoline - assert system_context.is_task + assert system_context.phase == target_phase + assert system_context.is_launch is expect_launch + assert system_context.is_trampoline is expect_trampoline + assert system_context.is_task is expect_task # --------------------------------------------------------------------------- -# SystemContext: progressive update +# SystemContext: Progressive Update # --------------------------------------------------------------------------- -def test_system_context_initial_values_are_none(): +def test_system_context_initializes_with_none_values(): assert system_context.flow is None assert system_context.graph is None assert system_context.environment is None @@ -133,7 +128,7 @@ def test_system_context_initial_values_are_none(): assert system_context.task_id is None -def test_system_context_progressive_update(): +def test_system_context_supports_progressive_updates(): # Flow-level info arrives first system_context._update(flow="my_flow", graph="my_graph") assert system_context.flow == "my_flow" @@ -150,66 +145,52 @@ def test_system_context_progressive_update(): assert system_context.retry_count == 2 -def test_system_context_update_overwrites(): +def test_system_context_updates_overwrite_existing_values(): system_context._update(run_id="run-1") assert system_context.run_id == "run-1" + system_context._update(run_id="run-2") assert system_context.run_id == "run-2" -def test_system_context_update_invalid_key_raises(): +def test_system_context_update_raises_attribute_error_on_invalid_keys(): with pytest.raises(AttributeError, match="no_such_field"): system_context._update(no_such_field="value") -def test_system_context_reset(): +def test_system_context_reset_clears_all_attributes(): system_context._update(phase=ExecutionPhase.TASK, flow="f", run_id="r") system_context._reset() + assert system_context.phase is None assert system_context.flow is None assert system_context.run_id is None -# --------------------------------------------------------------------------- -# SystemContext: input_paths -# --------------------------------------------------------------------------- - - -def test_system_context_input_paths_initial_none(): +def test_system_context_input_paths_lifecycle(): assert system_context.input_paths is None - -def test_system_context_input_paths_update(): system_context._update(input_paths=["run/step/1", "run/step/2"]) assert system_context.input_paths == ["run/step/1", "run/step/2"] # --------------------------------------------------------------------------- -# Decorator base class: system_ctx property +# Decorator Base Classes # --------------------------------------------------------------------------- -def test_decorator_system_ctx_property(): - d = Decorator() - assert d.system_ctx is system_context - - -def test_step_decorator_system_ctx_property(): - d = StepDecorator() - assert d.system_ctx is system_context - - -def test_flow_decorator_system_ctx_property(): - d = FlowDecorator() +@pytest.mark.parametrize( + "decorator_cls", + [Decorator, StepDecorator, FlowDecorator], + ids=["base_decorator", "step_decorator", "flow_decorator"], +) +def test_decorator_classes_expose_system_context_singleton(decorator_cls): + """Test that all decorator base classes correctly expose the context singleton.""" + d = decorator_cls() assert d.system_ctx is system_context -# --------------------------------------------------------------------------- -# _ctx variant defaults -# --------------------------------------------------------------------------- - - -def test_step_decorator_ctx_variants_are_none(): +def test_step_decorator_ctx_variants_default_to_none(): d = StepDecorator() assert d.step_init_ctx is None assert d.runtime_init_ctx is None @@ -222,17 +203,17 @@ def test_step_decorator_ctx_variants_are_none(): assert d.task_finished_ctx is None -def test_flow_decorator_ctx_variant_is_none(): +def test_flow_decorator_ctx_variant_defaults_to_none(): d = FlowDecorator() assert d.flow_init_ctx is None # --------------------------------------------------------------------------- -# _ctx variant overrides +# _ctx Variant Overrides # --------------------------------------------------------------------------- -def test_step_init_ctx_override(): +def test_step_init_ctx_override_is_called_successfully(): class MyDeco(StepDecorator): name = "test_deco" called = False @@ -242,11 +223,12 @@ def step_init_ctx(self, step_name): d = MyDeco() assert d.step_init_ctx is not None + d.step_init_ctx("train") assert MyDeco.called -def test_task_step_completed_ctx_override_success(): +def test_task_step_completed_ctx_handles_exceptions_correctly(): class MyDeco(StepDecorator): name = "test_deco" last_exception = "NOT_CALLED" @@ -266,7 +248,7 @@ def task_step_completed_ctx(self, step_name, exception=None): assert MyDeco.last_exception is err -def test_task_step_completed_ctx_handles_exception(): +def test_task_step_completed_ctx_returns_true_when_handling_exception(): class CatchDeco(StepDecorator): name = "catch" @@ -276,11 +258,11 @@ def task_step_completed_ctx(self, step_name, exception=None): return None d = CatchDeco() - assert bool(d.task_step_completed_ctx("train", exception=ValueError("x"))) is True + assert d.task_step_completed_ctx("train", exception=ValueError("x")) is True assert not d.task_step_completed_ctx("train") -def test_task_decorate_ctx_override(): +def test_task_decorate_ctx_successfully_wraps_functions(): class WrapDeco(StepDecorator): name = "wrap" @@ -293,11 +275,12 @@ def wrapper(*args, **kwargs): d = WrapDeco() original = lambda: 42 wrapped = d.task_decorate_ctx("train", original) + assert wrapped is not original assert wrapped() == 42 -def test_flow_init_ctx_override(): +def test_flow_init_ctx_receives_options_dictionary(): class MyFlowDeco(FlowDecorator): name = "test_flow_deco" received_options = None @@ -307,10 +290,11 @@ def flow_init_ctx(self, options): d = MyFlowDeco() d.flow_init_ctx({"name": "test"}) + assert MyFlowDeco.received_options == {"name": "test"} -def test_legacy_hook_still_works(): +def test_legacy_hooks_trigger_when_ctx_variants_are_missing(): """Decorators that don't define _ctx variants still use legacy hooks.""" class LegacyDeco(StepDecorator): @@ -330,47 +314,50 @@ def step_init( LegacyDeco.called_with = step_name d = LegacyDeco() + assert d.step_init_ctx is None d.step_init("f", "g", "train", [], "env", "ds", "log") + assert LegacyDeco.called_with == "train" # --------------------------------------------------------------------------- -# SystemContext: shared state (inter-decorator communication) +# SystemContext: Shared State (Inter-Decorator Communication) # --------------------------------------------------------------------------- -def test_shared_state_publish_and_get(): +def test_shared_state_publish_and_retrieve_values(): system_context.publish("train", "timeout", "seconds", 300) assert system_context.get_published("train", "timeout", "seconds") == 300 -def test_shared_state_get_missing_namespace(): - assert system_context.get_published("train", "nonexistent", "key") is None - - -def test_shared_state_get_missing_key(): +@pytest.mark.parametrize( + "namespace, key, fallback", + [ + ("train", "nonexistent", "key"), + ("train", "timeout", "nonexistent"), + ], + ids=["missing_namespace", "missing_key"], +) +def test_shared_state_returns_none_for_missing_keys(namespace, key, fallback): system_context.publish("train", "timeout", "seconds", 300) - assert system_context.get_published("train", "timeout", "nonexistent") is None + assert system_context.get_published(namespace, key, fallback) is None -def test_shared_state_get_default(): - assert system_context.get_published("train", "timeout", "seconds", 60) == 60 +def test_shared_state_get_published_respects_default_fallback_values(): + assert system_context.get_published("train", "timeout", "seconds", default=60) == 60 -def test_shared_state_has_published_namespace(): +def test_shared_state_has_published_returns_booleans(): assert not system_context.has_published("train", "timeout") system_context.publish("train", "timeout", "seconds", 300) - assert system_context.has_published("train", "timeout") - -def test_shared_state_has_published_key(): - system_context.publish("train", "timeout", "seconds", 300) + assert system_context.has_published("train", "timeout") assert system_context.has_published("train", "timeout", "seconds") assert not system_context.has_published("train", "timeout", "minutes") -def test_shared_state_get_all_published(): +def test_shared_state_get_all_published_returns_full_dictionary(): system_context.publish("train", "resources", "cpu", "4") system_context.publish("train", "resources", "memory", "8192") system_context.publish("train", "resources", "gpu", "1") @@ -379,57 +366,46 @@ def test_shared_state_get_all_published(): assert all_resources == {"cpu": "4", "memory": "8192", "gpu": "1"} -def test_shared_state_get_all_published_missing(): +def test_shared_state_get_all_published_returns_empty_dict_on_missing_namespace(): assert system_context.get_all_published("train", "nonexistent") == {} -def test_shared_state_overwrite_published(): +def test_shared_state_publishing_overwrites_existing_keys(): system_context.publish("train", "resources", "cpu", "4") system_context.publish("train", "resources", "cpu", "8") assert system_context.get_published("train", "resources", "cpu") == "8" -def test_shared_state_multiple_namespaces(): +def test_shared_state_isolates_data_across_namespaces_and_steps(): system_context.publish("train", "resources", "cpu", "4") system_context.publish("train", "timeout", "seconds", 300) - system_context.publish("train", "batch", "image", "my-image:latest") - - assert system_context.get_published("train", "resources", "cpu") == "4" - assert system_context.get_published("train", "timeout", "seconds") == 300 - assert system_context.get_published("train", "batch", "image") == "my-image:latest" - - -# --------------------------------------------------------------------------- -# SystemContext: step isolation for shared state -# --------------------------------------------------------------------------- - - -def test_step_isolation_shared_state(): - system_context.publish("train", "resources", "cpu", "4") system_context.publish("predict", "resources", "cpu", "16") assert system_context.get_published("train", "resources", "cpu") == "4" + assert system_context.get_published("train", "timeout", "seconds") == 300 assert system_context.get_published("predict", "resources", "cpu") == "16" # --------------------------------------------------------------------------- -# SystemContext: step decorator registration +# SystemContext: Step Decorator Registration # --------------------------------------------------------------------------- -def test_registration_register_and_get_step_decorators(): +def test_decorator_registration_stores_and_retrieves_lists(): decos = ["deco1", "deco2"] system_context.register_step_decorators("train", decos) assert system_context.get_step_decorators("train") == decos -def test_registration_get_step_decorators_missing_step(): +def test_decorator_registration_returns_empty_list_for_missing_steps(): assert system_context.get_step_decorators("nonexistent") == [] -def test_registration_reset_clears_shared_and_decorators(): +def test_reset_clears_shared_state_and_decorator_registrations(): system_context.publish("train", "timeout", "seconds", 300) system_context.register_step_decorators("train", ["d"]) + system_context._reset() + assert system_context.get_step_decorators("train") == [] assert system_context.get_published("train", "timeout", "seconds") is None diff --git a/test/unit/test_task_log_metadata_fetch.py b/test/unit/test_task_log_metadata_fetch.py index cb1c3de2d06..f97bceca015 100644 --- a/test/unit/test_task_log_metadata_fetch.py +++ b/test/unit/test_task_log_metadata_fetch.py @@ -5,17 +5,23 @@ from metaflow.client.core import Task PATH_COMPONENTS = ("TestFlow", "123", "start", "1") + LOG_METADATA = { "ds-type": "local", "ds-root": "/tmp/logs", "attempt": "2", } + SIZE_METADATA = { "ds-type": "local", "ds-root": "/tmp/logs", "attempt": "3", } +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + @pytest.fixture def minimal_task(mocker): @@ -37,7 +43,14 @@ def filecache_cls(mocker): return mocker.patch("metaflow.client.core.FileCache") -@pytest.mark.parametrize("stream", ["stdout", "stderr"]) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "stream", ["stdout", "stderr"], ids=["stdout_stream", "stderr_stream"] +) def test_loglines_uses_supplied_metadata_without_refetching( minimal_task, metadata_dict_mock, filecache_cls, mocker, stream ): @@ -53,7 +66,9 @@ def test_loglines_uses_supplied_metadata_without_refetching( merge_logs.assert_called_once_with([]) -@pytest.mark.parametrize("stream", ["stdout", "stderr"]) +@pytest.mark.parametrize( + "stream", ["stdout", "stderr"], ids=["stdout_stream", "stderr_stream"] +) def test_log_size_uses_supplied_metadata_without_refetching( minimal_task, metadata_dict_mock, filecache_cls, stream ): @@ -68,12 +83,17 @@ def test_log_size_uses_supplied_metadata_without_refetching( @pytest.mark.parametrize( - ("explicit_attempt", "meta_dict", "expected"), + "explicit_attempt, meta_dict, expected", [ (5, {"attempt": "0"}, 5), (None, {"attempt": "2"}, 2), (None, {}, 0), ], + ids=[ + "explicit_overrides_metadata", + "metadata_overrides_default", + "defaults_to_zero", + ], ) def test_resolve_log_attempt_prefers_explicit_attempt_then_metadata( minimal_task, explicit_attempt, meta_dict, expected diff --git a/test/unit/test_to_pod.py b/test/unit/test_to_pod.py index 3968dc1fbbe..70d1f5c5439 100644 --- a/test/unit/test_to_pod.py +++ b/test/unit/test_to_pod.py @@ -4,8 +4,13 @@ (used by DAGNode.node_info serialization for extensions like FunctionSpec). """ +import pytest from metaflow.util import to_pod +# --------------------------------------------------------------------------- +# Dummy Callables for Testing +# --------------------------------------------------------------------------- + def _top_level_fn(): pass @@ -20,45 +25,77 @@ def instance_method(self): pass -def test_to_pod_primitives(): - assert to_pod("abc") == "abc" - assert to_pod(42) == 42 - assert to_pod(3.14) == 3.14 - - -def test_to_pod_list_set_tuple(): - assert to_pod([1, 2, 3]) == [1, 2, 3] - assert sorted(to_pod({1, 2, 3})) == [1, 2, 3] - assert to_pod((1, 2, 3)) == [1, 2, 3] - - -def test_to_pod_dict(): +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "input_val, expected", + [ + ("abc", "abc"), + (42, 42), + (3.14, 3.14), + ], + ids=["string", "integer", "float"], +) +def test_to_pod_converts_primitives_unchanged(input_val, expected): + """Test that basic primitives pass through POD conversion unchanged.""" + assert to_pod(input_val) == expected + + +@pytest.mark.parametrize( + "input_val, expected", + [ + ([1, 2, 3], [1, 2, 3]), + ((1, 2, 3), [1, 2, 3]), + ({1, 2, 3}, [1, 2, 3]), + ], + ids=["list", "tuple", "set"], +) +def test_to_pod_converts_iterable_collections_to_lists(input_val, expected): + """Test that lists, tuples, and sets are converted to standard lists.""" + result = to_pod(input_val) + + # Sets are unordered, so we must sort the result before comparing + if isinstance(input_val, set): + assert sorted(result) == expected + else: + assert result == expected + + +def test_to_pod_converts_dicts_unchanged(): + """Test that simple dictionaries pass through unchanged.""" assert to_pod({"a": 1, "b": 2}) == {"a": 1, "b": 2} -def test_to_pod_nested(): +def test_to_pod_handles_nested_structures(): + """Test that to_pod recursively converts nested collections.""" value = {"outer": [{"inner": (1, 2)}], "other": {"k": "v"}} - assert to_pod(value) == { + expected = { "outer": [{"inner": [1, 2]}], "other": {"k": "v"}, } + assert to_pod(value) == expected -def test_to_pod_callable_uses_qualname(): +def test_to_pod_converts_callable_to_qualname(): """Callables serialize to their __qualname__ for _graph_info persistence.""" result = to_pod(_top_level_fn) assert result == "_top_level_fn" -def test_to_pod_callable_in_dict(): +def test_to_pod_converts_callables_inside_dictionaries(): """Simulates DAGNode.node_info with function references (FunctionSpec use case).""" result = to_pod({"init_func": _top_level_fn, "call_func": _Wrapper.static_method}) + assert result["init_func"] == "_top_level_fn" assert result["call_func"] == "_Wrapper.static_method" -def test_to_pod_lambda_uses_qualname(): - """Lambdas have __qualname__ like 'test_to_pod_lambda_uses_qualname..'.""" +def test_to_pod_converts_lambda_to_qualname_string(): + """Lambdas have __qualname__ like 'test_to_pod_converts_lambda_to_qualname_string..'.""" fn = lambda x: x # noqa: E731 result = to_pod(fn) + assert "" in result diff --git a/test/unit/test_tutorial_01_02_csv_parsing.py b/test/unit/test_tutorial_01_02_csv_parsing.py index 37659106b88..5c1dbfc1335 100644 --- a/test/unit/test_tutorial_01_02_csv_parsing.py +++ b/test/unit/test_tutorial_01_02_csv_parsing.py @@ -1,8 +1,10 @@ import csv +import pytest +# Module-level constant for immutable test data SAMPLE_CSV = ( "movie_title,title_year,genres,gross\n" - '"Monsters,\n Inc.",2001,"Animation|\nComedy",289907418\n' + '"Monsters, Inc.",2001,Animation|Comedy,289907418\n' '"I, Robot",2004,Action|Sci-Fi,144795350\n' ) @@ -10,10 +12,11 @@ def parse_csv(data, cols): """ Parse CSV into dataframe - """ result = {c: [] for c in cols} int_cols = ("title_year", "gross") + + # Note: If testing quoted newlines, replace data.splitlines() with io.StringIO(data) for row in csv.DictReader(data.splitlines()): for c in cols: val = int(row[c]) if c in int_cols else row[c] @@ -21,43 +24,44 @@ def parse_csv(data, cols): return result -def test_playlist_csv_parsing(): - """ - Validate Test Cases For Tutorial 01 - - """ - df = parse_csv(SAMPLE_CSV, ["movie_title", "genres"]) - - # Title with commas is parsed as a single field - assert {"Monsters, Inc.", "I, Robot"} <= set(df["movie_title"]) - - # All values are correctly aligned to their respective columns - assert {"Animation|Comedy", "Action|Sci-Fi"} <= set(df["genres"]) - - # No rows are dropped or duplicated - assert all(len(col) == 2 for col in df.values()) - - # Dataframe keeps exactly 2 columns: movie_title and genres - assert len(df) == 2 - - -def test_stats_csv_parsing(): +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "columns, expected_data", + [ + ( + ["movie_title", "genres"], + { + "movie_title": ["Monsters, Inc.", "I, Robot"], + "genres": ["Animation|Comedy", "Action|Sci-Fi"], + }, + ), + ( + ["movie_title", "title_year", "genres", "gross"], + { + "movie_title": ["Monsters, Inc.", "I, Robot"], + "title_year": [2001, 2004], + "genres": ["Animation|Comedy", "Action|Sci-Fi"], + "gross": [289907418, 144795350], + }, + ), + ], + ids=["subset_of_columns", "all_columns"], +) +def test_parse_csv_extracts_requested_columns(columns, expected_data): """ - Validate Test Cases For Tutorial 02 - + Test that parse_csv correctly extracts, filters, and types the specified columns. """ - df = parse_csv(SAMPLE_CSV, ["movie_title", "title_year", "genres", "gross"]) - - # Title with commas is parsed as a single field - assert {"Monsters, Inc.", "I, Robot"} <= set(df["movie_title"]) - - # All values are correctly aligned to their respective columns - assert {2001, 2004} <= set(df["title_year"]) - assert {"Animation|Comedy", "Action|Sci-Fi"} <= set(df["genres"]) - assert {289907418, 144795350} <= set(df["gross"]) + # Act + df = parse_csv(SAMPLE_CSV, columns) - # No rows are dropped or duplicated - assert all(len(col) == 2 for col in df.values()) + # Assert: Dataframe keeps exactly the requested number of columns + assert len(df) == len(columns) - # Dataframe keeps exactly 4 columns: movie_title, title_year, genres, gross - assert len(df) == 4 + # Assert: All values are correctly aligned, typed, and no rows are dropped + for col in columns: + assert len(df[col]) == 2 + assert df[col] == expected_data[col] diff --git a/test/ux/conftest.py b/test/ux/conftest.py index e1021196fe7..a36f761f391 100644 --- a/test/ux/conftest.py +++ b/test/ux/conftest.py @@ -8,11 +8,11 @@ import os import uuid -import pytest from dataclasses import dataclass from enum import Enum from typing import List, Optional +import pytest from omegaconf import OmegaConf @@ -41,6 +41,7 @@ def _load_config(rootdir=None): if rootdir: candidates.append(os.path.join(str(rootdir), "ux_test_config.yaml")) candidates.append(os.path.join(os.path.dirname(__file__), "ux_test_config.yaml")) + for path in candidates: if os.path.exists(path): return OmegaConf.load(path) @@ -57,10 +58,12 @@ def _enabled_backends(cfg): backend_name = compute.get("backend") or None image = compute.get("image") or None decospec = None + if backend_name and image: - decospec = "%s:image=%s" % (backend_name, image) + decospec = f"{backend_name}:image={image}" elif backend_name: decospec = backend_name + return [ { "name": "default", @@ -258,7 +261,7 @@ def pytest_generate_tests(metafunc): params.append(pytest.param(mode, marks=marks)) else: params.append(pytest.param(b, marks=marks)) - ids.append("%s-%s" % (b["name"], mode)) + ids.append(f"{b['name']}-{mode}") if needs_exec and needs_backend: metafunc.parametrize(["exec_mode", "backend"], params, ids=ids) diff --git a/test/ux/core/conftest.py b/test/ux/core/conftest.py index db1bc908567..8a3e0205fba 100644 --- a/test/ux/core/conftest.py +++ b/test/ux/core/conftest.py @@ -41,6 +41,7 @@ def _set_devstack_env(): os.environ.setdefault("AWS_ENDPOINT_URL_BATCH", "http://localhost:8000") os.environ.setdefault("AWS_ENDPOINT_URL_SFN", "http://localhost:8082") os.environ.setdefault("AWS_ENDPOINT_URL_DYNAMODB", "http://localhost:8765") + # EventBridge stub: handles the schedule() call from the SFN deployer. # The stub returns ResourceNotFoundException for DisableRule (ignored by # EventBridgeClient._disable) so that deploying unscheduled flows works. diff --git a/test/ux/core/test_airflow_compilation.py b/test/ux/core/test_airflow_compilation.py index 788be89e853..745c46ac4e8 100644 --- a/test/ux/core/test_airflow_compilation.py +++ b/test/ux/core/test_airflow_compilation.py @@ -9,74 +9,18 @@ """ import ast -import json +import os import subprocess import sys -import tempfile -import os + import pytest pytestmark = [pytest.mark.airflow_compilation] -def _get_compile_env(): - """Get environment variables for compilation-only tests.""" - env = os.environ.copy() - env["METAFLOW_DEFAULT_METADATA"] = "local" - return env - - -def _compile_flow_to_dag(flow_path, **extra_tl_args): - """Compile a flow to an Airflow DAG Python file.""" - from .test_utils import _resolve_flow_path - - full_path = _resolve_flow_path(flow_path) - - with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode="w") as f: - dag_file_path = f.name - - cmd = [sys.executable, full_path, "--no-pylint"] - for k, v in extra_tl_args.items(): - if v is not None: - cmd.extend([f"--{k.replace('_', '-')}", str(v)]) - cmd.extend(["airflow", "create", dag_file_path]) - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=30, - env=_get_compile_env(), - ) - if result.returncode != 0: - # Clean up on failure - try: - os.unlink(dag_file_path) - except OSError: - pass - stderr = result.stderr or "" - stdout = result.stdout or "" - if "No such command" in stderr or "No such command" in stdout: - pytest.skip( - "airflow CLI not available (extension may override plugins)" - ) - if "ConnectionRefusedError" in stderr or "ConnectionError" in stderr: - pytest.skip("Airflow backend not configured (connection refused)") - if "is not supported" in stderr: - pytest.skip(f"Feature not supported by Airflow: {stderr.strip()}") - pytest.fail(f"Compilation failed:\nstderr: {stderr}\nstdout: {stdout}") - - with open(dag_file_path, "r") as f: - dag_source = f.read() - - return dag_source, dag_file_path - except Exception: - try: - os.unlink(dag_file_path) - except OSError: - pass - raise +# --------------------------------------------------------------------------- +# Core Validation Logic +# --------------------------------------------------------------------------- def _validate_dag_source(dag_source, dag_file_path=None): @@ -132,59 +76,83 @@ def _validate_dag_source(dag_source, dag_file_path=None): # --------------------------------------------------------------------------- -# Tests +# Fixtures # --------------------------------------------------------------------------- @pytest.fixture -def compile_and_validate(): - """Compile a flow to an Airflow DAG, validate it, and clean up the tempfile.""" +def compile_and_validate(tmp_path, monkeypatch): + """ + Factory fixture to compile a flow to an Airflow DAG and validate it. + Automatically handles environment variables and temporary file cleanup. + """ + # Ensure compilation-only environment variables are set safely + monkeypatch.setenv("METAFLOW_DEFAULT_METADATA", "local") def _impl(flow_path, **extra_tl_args): - dag_source, dag_file_path = _compile_flow_to_dag(flow_path, **extra_tl_args) - try: - result = _validate_dag_source(dag_source, dag_file_path) - assert ( - result["result"] == "OK" - ), f"Validation failed: {result.get('diagnostics')}" - return dag_source - finally: - try: - os.unlink(dag_file_path) - except OSError: - pass - - return _impl - + from .test_utils import _resolve_flow_path -def test_linear_flow(compile_and_validate): - """Simple start->step->end flow compiles to valid Airflow DAG.""" - compile_and_validate("basic/helloworld.py") + full_path = _resolve_flow_path(flow_path) + dag_file_path = tmp_path / "compiled_dag.py" + cmd = [sys.executable, full_path, "--no-pylint"] + for k, v in extra_tl_args.items(): + if v is not None: + cmd.extend([f"--{k.replace('_', '-')}", str(v)]) + cmd.extend(["airflow", "create", str(dag_file_path)]) -def test_branch_flow(compile_and_validate): - """Parallel branch flow compiles to valid Airflow DAG.""" - compile_and_validate("dag/branch_flow.py") + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + stderr = result.stderr or "" + stdout = result.stdout or "" + if "No such command" in stderr or "No such command" in stdout: + pytest.skip( + "airflow CLI not available (extension may override plugins)" + ) + if "ConnectionRefusedError" in stderr or "ConnectionError" in stderr: + pytest.skip("Airflow backend not configured (connection refused)") + if "is not supported" in stderr: + pytest.skip(f"Feature not supported by Airflow: {stderr.strip()}") + pytest.fail(f"Compilation failed:\nstderr: {stderr}\nstdout: {stdout}") -def test_foreach_flow(compile_and_validate): - """Foreach flow compiles to valid Airflow DAG.""" - compile_and_validate("dag/foreach_flow.py") + # Read the generated DAG and validate it + dag_source = dag_file_path.read_text() + validation = _validate_dag_source(dag_source, str(dag_file_path)) + assert ( + validation["result"] == "OK" + ), f"Validation failed: {validation.get('diagnostics')}" + return dag_source -def test_retry_flow(compile_and_validate): - """Flow with @retry compiles to valid Airflow DAG.""" - compile_and_validate("basic/retry_flow.py") + return _impl -def test_resources_flow(compile_and_validate): - """Flow with @resources compiles to valid Airflow DAG.""" - compile_and_validate("basic/resources_flow.py") +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- -def test_schedule_flow(compile_and_validate): - """Flow with @schedule compiles to valid Airflow DAG.""" - compile_and_validate("lifecycle/schedule_flow.py") +@pytest.mark.parametrize( + "flow_path", + [ + "basic/helloworld.py", + "dag/branch_flow.py", + "dag/foreach_flow.py", + "basic/retry_flow.py", + "basic/resources_flow.py", + "lifecycle/schedule_flow.py", + ], + ids=["linear", "branch", "foreach", "retry", "resources", "schedule"], +) +def test_airflow_dag_compilation(compile_and_validate, flow_path): + """Core Metaflow flow patterns compile to structurally valid Airflow DAGs.""" + compile_and_validate(flow_path) # --------------------------------------------------------------------------- @@ -195,6 +163,7 @@ def test_schedule_flow(compile_and_validate): def test_tags_are_list_not_tuple(compile_and_validate): """DAG tags must be a list, not a tuple (Airflow rejects tuples).""" dag_source = compile_and_validate("basic/helloworld.py") + # Check that tags assignment uses list syntax tree = ast.parse(dag_source) for node in ast.walk(tree): diff --git a/test/ux/core/test_argo_compilation.py b/test/ux/core/test_argo_compilation.py index 4b33bce22c9..32c5cfcbe7e 100644 --- a/test/ux/core/test_argo_compilation.py +++ b/test/ux/core/test_argo_compilation.py @@ -1,8 +1,16 @@ import pytest +from metaflow import Deployer + +from .test_utils import _resolve_flow_path, prepare_runner_deployer_args pytestmark = [pytest.mark.argo_compilation, pytest.mark.scheduler_only] +# --------------------------------------------------------------------------- +# Helpers and Assertion Callbacks +# --------------------------------------------------------------------------- + + def _find_duplicate_task_names(workflow_template): duplicates = {} for template in workflow_template.get("spec", {}).get("templates", []): @@ -18,65 +26,63 @@ def _find_duplicate_task_names(workflow_template): return duplicates -def test_argo_only_json_exposes_workflow_template( - exec_mode, decospecs, tag, scheduler_config -): - if exec_mode != "deployer": - pytest.skip("Argo compilation tests require deployer mode") - if scheduler_config.scheduler_type != "argo-workflows": - pytest.skip("Argo compilation tests require the argo-workflows scheduler") +def _assert_only_json_structure(workflow_template, deployed_flow_name): + """Verify the foundational structure of the generated Argo WorkflowTemplate.""" + assert workflow_template is not None + assert workflow_template["kind"] == "WorkflowTemplate" + assert workflow_template["metadata"]["name"] == deployed_flow_name + assert workflow_template["spec"]["templates"] - from metaflow import Deployer - from .test_utils import _resolve_flow_path, prepare_runner_deployer_args +def _assert_deduplicated_task_names(workflow_template, deployed_flow_name): + """Verify that complex DAG topologies do not produce duplicate task names in Argo.""" + assert workflow_template is not None + assert _find_duplicate_task_names(workflow_template) == {} - deployed_flow = ( - Deployer( - flow_file=_resolve_flow_path("basic/helloworld.py"), - show_output=False, - **prepare_runner_deployer_args({"decospecs": decospecs}), - ) - .argo_workflows() - .create( - only_json=True, - tags=tag + ["test_argo_only_json_exposes_workflow_template"], - **(scheduler_config.deploy_args or {}), - ) - ) - workflow_template = deployed_flow.workflow_template - assert workflow_template is not None - assert workflow_template["kind"] == "WorkflowTemplate" - assert workflow_template["metadata"]["name"] == deployed_flow.name - assert workflow_template["spec"]["templates"] +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- -def test_foreach_split_switch_join_task_names_are_deduplicated( - exec_mode, decospecs, tag, scheduler_config +@pytest.mark.parametrize( + "flow_name, test_suffix, assertion_fn", + [ + pytest.param( + "basic/helloworld.py", + "only_json_exposes_workflow_template", + _assert_only_json_structure, + id="only_json_structure", + ), + pytest.param( + "dag/foreach_split_switch_dedup_flow.py", + "foreach_split_switch_dedup", + _assert_deduplicated_task_names, + id="task_name_deduplication", + ), + ], +) +def test_argo_compilation_behaviors( + exec_mode, decospecs, tag, scheduler_config, flow_name, test_suffix, assertion_fn ): + """Parametrized test covering Argo JSON compilation outputs and structural integrity.""" if exec_mode != "deployer": pytest.skip("Argo compilation tests require deployer mode") if scheduler_config.scheduler_type != "argo-workflows": pytest.skip("Argo compilation tests require the argo-workflows scheduler") - from metaflow import Deployer - - from .test_utils import _resolve_flow_path, prepare_runner_deployer_args - deployed_flow = ( Deployer( - flow_file=_resolve_flow_path("dag/foreach_split_switch_dedup_flow.py"), + flow_file=_resolve_flow_path(flow_name), show_output=False, **prepare_runner_deployer_args({"decospecs": decospecs}), ) .argo_workflows() .create( only_json=True, - tags=tag + ["test_argo_foreach_split_switch_dedup"], + tags=tag + [f"test_argo_{test_suffix}"], **(scheduler_config.deploy_args or {}), ) ) - workflow_template = deployed_flow.workflow_template - assert workflow_template is not None - assert _find_duplicate_task_names(workflow_template) == {} + assertion_fn(deployed_flow.workflow_template, deployed_flow.name) diff --git a/test/ux/core/test_basic.py b/test/ux/core/test_basic.py index 1eef3933964..4ec7118c42f 100644 --- a/test/ux/core/test_basic.py +++ b/test/ux/core/test_basic.py @@ -1,33 +1,127 @@ +import time import uuid import pytest -pytestmark = pytest.mark.basic from .test_utils import ( - execute_test_flow, deploy_flow_to_scheduler, - wait_for_deployed_run, + execute_test_flow, verify_run_provenance, + wait_for_deployed_run, ) +pytestmark = pytest.mark.basic + +# --------------------------------------------------------------------------- +# Assertion Callbacks for Basic Flows +# --------------------------------------------------------------------------- + + +def _assert_hello_world(run): + assert run.successful, "Run was not successful" + assert ( + run["hello"].task.data.message == "Metaflow says: Hi!" + ), "Hello world message didn't match" + + +def _assert_retry(run): + assert run.successful, "Run was not successful" + assert run["flaky"].task.data.attempts == 1, "Expected success on retry attempt 1" + + +def _assert_resources(run): + assert run.successful, "Run was not successful" + assert run["join"].task.data.labels == [ + "medium", + "small", + ], "Resource branch labels didn't match" + + +def _assert_catch(run): + assert run.successful, "Run was not successful" + assert ( + run["failing"].task.data.error is not None + ), "@catch did not store the exception" + + +def _assert_timeout(run): + assert run.successful, "Run was not successful" + assert run["work"].task.data.done is True, "Timeout step did not complete" + + +def _assert_resources_cpu(run): + assert run.successful, "Run was not successful" + assert ( + run["end"].task.data.message == "Metaflow says: Hi Resources CPU!" + ), "Message didn't match" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "flow_name, test_name, assertion_fn, extra_marks", + [ + pytest.param( + "basic/helloworld.py", + "hello_world", + _assert_hello_world, + [], + id="hello_world", + ), + pytest.param("basic/retry_flow.py", "retry", _assert_retry, [], id="retry"), + pytest.param( + "basic/resources_flow.py", + "resources", + _assert_resources, + [], + id="resources", + ), + pytest.param("basic/catch_flow.py", "catch", _assert_catch, [], id="catch"), + pytest.param( + "basic/timeout_flow.py", "timeout", _assert_timeout, [], id="timeout" + ), + pytest.param( + "basic/resources_cpu_flow.py", + "resources_cpu", + _assert_resources_cpu, + [pytest.mark.scheduler_only], + id="resources_cpu", + ), + ], +) +def test_basic_flow_behaviors( + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + request, + flow_name, + test_name, + assertion_fn, + extra_marks, +): + """Parametrized test for standard flow features.""" + for mark in extra_marks: + request.node.add_marker(mark) -def test_hello_world(exec_mode, decospecs, compute_env, tag, scheduler_config): run = execute_test_flow( - flow_name="basic/helloworld.py", + flow_name=flow_name, exec_mode=exec_mode, decospecs=decospecs, tag=tag, scheduler_config=scheduler_config, - test_name="hello_world", + test_name=test_name, tl_args_extra={"env": compute_env}, ) - assert run.successful, "Run was not successful" - assert ( - run["hello"].task.data.message == "Metaflow says: Hi!" - ), "Hello world message didn't match" + assertion_fn(run) def test_hello_project(exec_mode, decospecs, compute_env, tag, scheduler_config): + """Verify branch propagation.""" branch = str(uuid.uuid4())[:8] run = execute_test_flow( flow_name="basic/helloproject.py", @@ -41,7 +135,7 @@ def test_hello_project(exec_mode, decospecs, compute_env, tag, scheduler_config) assert run.successful, "Run was not successful" rbranch = run["end"].task.data.branch - assert "test." + branch == rbranch, "Branch name does not match expected" + assert f"test.{branch}" == rbranch, "Branch name does not match expected" @pytest.mark.scheduler_only @@ -49,12 +143,13 @@ def test_from_deployment(exec_mode, decospecs, compute_env, tag, scheduler_confi """Verify DeployedFlow.from_deployment() works for all schedulers.""" from metaflow.runner.deployer import DeployedFlow - test_unique_tag = "test_from_deployment_%s" % exec_mode + test_unique_tag = f"test_from_deployment_{exec_mode}" combined_tags = tag + [test_unique_tag] scheduler_type = scheduler_config.scheduler_type if scheduler_type is None: pytest.skip("No scheduler configured — deployer tests require a scheduler_type") + # Normalize to the impl key used by DeployedFlow.from_deployment(impl=...) impl = scheduler_type.replace("-", "_") @@ -89,75 +184,6 @@ def test_from_deployment(exec_mode, decospecs, compute_env, tag, scheduler_confi assert run3["start"].task.data.message == "Metaflow says: Hi!" -def test_retry(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @retry retries a failing step and succeeds on the second attempt.""" - run = execute_test_flow( - flow_name="basic/retry_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="retry", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - assert run["flaky"].task.data.attempts == 1, "Expected success on retry attempt 1" - - -def test_resources(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @resources decorator does not break execution across backends.""" - run = execute_test_flow( - flow_name="basic/resources_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="resources", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - assert run["join"].task.data.labels == [ - "medium", - "small", - ], "Resource branch labels didn't match" - - -def test_catch(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @catch stores the exception and allows the flow to continue.""" - run = execute_test_flow( - flow_name="basic/catch_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="catch", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - assert ( - run["failing"].task.data.error is not None - ), "@catch did not store the exception" - - -def test_timeout(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @timeout decorator does not break normal execution.""" - run = execute_test_flow( - flow_name="basic/timeout_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="timeout", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - assert run["work"].task.data.done is True, "Timeout step did not complete" - - @pytest.mark.conda def test_hello_conda(exec_mode, decospecs, compute_env, tag, scheduler_config): run = execute_test_flow( @@ -182,38 +208,12 @@ def test_hello_conda(exec_mode, decospecs, compute_env, tag, scheduler_config): ), "itsdangerous version incorrect" -@pytest.mark.scheduler_only -def test_resources_cpu(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @resources(cpu=N, memory=N) deploys and runs on each scheduler backend.""" - run = execute_test_flow( - flow_name="basic/resources_cpu_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="resources_cpu", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - assert ( - run["end"].task.data.message == "Metaflow says: Hi Resources CPU!" - ), "Message didn't match" - - @pytest.mark.scheduler_only @pytest.mark.deployer def test_fail_flow_reports_failed_status( exec_mode, decospecs, compute_env, tag, scheduler_config ): - """Verify schedulers report FAILED (not RUNNING/PENDING) when a step raises. - - Catches A03-1: _check_sysroot_completion returns RUNNING forever for flows - that crash before reaching the end step, because end/ dir never appears. - """ - import time - from .test_utils import deploy_flow_to_scheduler - + """Verify schedulers report FAILED (not RUNNING/PENDING) when a step raises.""" scheduler_type = scheduler_config.scheduler_type if scheduler_type is None: pytest.skip("No scheduler configured — requires a scheduler_type") @@ -227,9 +227,9 @@ def test_fail_flow_reports_failed_status( ) triggered = deployed_flow.trigger() - deadline = time.time() + 300 final_status = None + while time.time() < deadline: s = triggered.status # Normalize to uppercase — Argo returns "Failed"/"Succeeded", SFN "FAILED"/"SUCCEEDED" @@ -238,10 +238,9 @@ def test_fail_flow_reports_failed_status( break time.sleep(5) - assert final_status == "FAILED", ( - "A flow that raises RuntimeError mid-step should report FAILED, got %r" - % final_status - ) + assert ( + final_status == "FAILED" + ), f"A flow that raises RuntimeError mid-step should report FAILED, got {final_status}" @pytest.mark.scheduler_only @@ -249,14 +248,7 @@ def test_fail_flow_reports_failed_status( def test_split_in_branch_deployer( exec_mode, decospecs, compute_env, tag, scheduler_config ): - """Verify a split nested inside a branch compiles and executes correctly. - - Catches A02-2: _find_join_step's while loop follows only out_funcs[0], - causing it to return the inner join instead of the outer join. Without the - fix, outer_join and end are silently dropped from the compiled flow. - """ - from .test_utils import deploy_flow_to_scheduler, wait_for_deployed_run - + """Verify a split nested inside a branch compiles and executes correctly.""" scheduler_type = scheduler_config.scheduler_type if scheduler_type is None: pytest.skip("No scheduler configured — requires a scheduler_type") @@ -281,86 +273,98 @@ def test_split_in_branch_deployer( ], "inner_join should receive results from inner_x and inner_y" -def test_custom_step_names(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify a linear flow with @step(start=True)/@step(end=True) annotations.""" - run = execute_test_flow( - flow_name="basic/hello_custom_steps.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="custom_step_names", - tl_args_extra={"env": compute_env}, - ) +# --------------------------------------------------------------------------- +# Custom Endpoint Verification +# --------------------------------------------------------------------------- - assert run.successful, "Run was not successful" - step_names = {step.id for step in run} - assert step_names == {"begin", "process", "finish"}, ( - "Expected custom step names, got %s" % step_names - ) + +def _assert_custom_steps(run): assert ( run["finish"].task.data.result == "Hello from custom start step -> processed -> done" ), "Data did not flow through custom-named steps" - # Verify graph endpoint metadata is persisted and readable via client API. - # This exercises the init -> persist_constants -> register_metadata chain - # which runs for all backends (local Runner AND scheduler deployer). - start, end = run._graph_endpoints - assert start == "begin", "Expected start_step=begin, got %s" % start - assert end == "finish", "Expected end_step=finish, got %s" % end - assert run.end_task is not None, "end_task should resolve for custom terminal step" - -def test_single_step_flow(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify a single-step flow with @step(start=True, end=True).""" - run = execute_test_flow( - flow_name="basic/single_step_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="single_step", - tl_args_extra={"env": compute_env}, - ) - - assert run.successful, "Run was not successful" - step_names = {step.id for step in run} - assert step_names == {"only"}, "Expected single step 'only', got %s" % step_names +def _assert_single_step(run): assert run["only"].task.data.result == 42, "Single step data incorrect" - start, end = run._graph_endpoints - assert start == "only", "Expected start_step=only, got %s" % start - assert end == "only", "Expected end_step=only, got %s" % end - assert run.end_task is not None + +def _assert_custom_branch(run): + assert sorted(run["done"].task.data.result) == [ + "left", + "right", + ], "Branch data did not merge correctly" -def test_custom_branch_flow(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify a branching flow with custom start/end step annotations.""" +@pytest.mark.parametrize( + "flow_name, test_name, expected_steps, expected_start, expected_end, assertion_fn", + [ + pytest.param( + "basic/hello_custom_steps.py", + "custom_step_names", + {"begin", "process", "finish"}, + "begin", + "finish", + _assert_custom_steps, + id="custom_step_names", + ), + pytest.param( + "basic/single_step_flow.py", + "single_step", + {"only"}, + "only", + "only", + _assert_single_step, + id="single_step", + ), + pytest.param( + "basic/custom_branch_flow.py", + "custom_branch", + {"entry", "left", "right", "merge", "done"}, + "entry", + "done", + _assert_custom_branch, + id="custom_branch", + ), + ], +) +def test_custom_endpoints_behaviors( + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + flow_name, + test_name, + expected_steps, + expected_start, + expected_end, + assertion_fn, +): + """Verify various flow structures with @step(start=True)/@step(end=True) annotations.""" run = execute_test_flow( - flow_name="basic/custom_branch_flow.py", + flow_name=flow_name, exec_mode=exec_mode, decospecs=decospecs, tag=tag, scheduler_config=scheduler_config, - test_name="custom_branch", + test_name=test_name, tl_args_extra={"env": compute_env}, ) assert run.successful, "Run was not successful" step_names = {step.id for step in run} - assert step_names == {"entry", "left", "right", "merge", "done"}, ( - "Expected custom branch step names, got %s" % step_names - ) - assert sorted(run["done"].task.data.result) == [ - "left", - "right", - ], "Branch data did not merge correctly" + assert ( + step_names == expected_steps + ), f"Expected custom step names {expected_steps}, got {step_names}" + + assertion_fn(run) + # Verify graph endpoint metadata is persisted and readable via client API. start, end = run._graph_endpoints - assert start == "entry", "Expected start_step=entry, got %s" % start - assert end == "done", "Expected end_step=done, got %s" % end - assert run.end_task is not None + assert start == expected_start, f"Expected start_step={expected_start}, got {start}" + assert end == expected_end, f"Expected end_step={expected_end}, got {end}" + assert run.end_task is not None, "end_task should resolve for custom terminal step" @pytest.mark.scheduler_only diff --git a/test/ux/core/test_compliance.py b/test/ux/core/test_compliance.py index 5f2057ce4ff..f1192526cb2 100644 --- a/test/ux/core/test_compliance.py +++ b/test/ux/core/test_compliance.py @@ -15,6 +15,7 @@ import uuid import pytest +# Apply markers to all tests in this module pytestmark = [pytest.mark.compliance, pytest.mark.scheduler_only] from .test_utils import ( @@ -24,6 +25,61 @@ ) +def _deploy_and_run_compliance( + flow_name, + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + test_suffix, + tl_args_extra=None, + run_kwargs=None, + allow_failure=False, + catch_unsupported=False, +): + """Internal helper to remove deployment boilerplate from compliance tests.""" + if exec_mode != "deployer": + pytest.skip("compliance test requires deployer mode") + + combined_tags = tag + [f"test_compliance_{test_suffix}_{exec_mode}"] + + env_vars = (compute_env or {}).copy() + tl_args = {"decospecs": decospecs} + + if tl_args_extra: + if "env" in tl_args_extra: + env_vars.update(tl_args_extra.pop("env")) + tl_args.update(tl_args_extra) + + tl_args["env"] = env_vars + + try: + from metaflow.exception import MetaflowException + except ImportError: + MetaflowException = Exception + + try: + deployed_flow = deploy_flow_to_scheduler( + flow_name=flow_name, + tl_args=tl_args, + scheduler_args={"cluster": scheduler_config.cluster}, + deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, + scheduler_type=scheduler_config.scheduler_type, + ) + except (MetaflowException, Exception) as e: + msg = str(e).lower() + if catch_unsupported and ("not supported" in msg or "not yet supported" in msg): + pytest.skip( + f"{scheduler_config.scheduler_type} does not support this feature: {e}" + ) + raise + + if allow_failure: + return wait_for_deployed_run_allow_failure(deployed_flow, run_kwargs=run_kwargs) + return wait_for_deployed_run(deployed_flow, run_kwargs=run_kwargs) + + # --------------------------------------------------------------------------- # test_run_params_multiple_values # @@ -36,40 +92,24 @@ # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_run_params_multiple_values( exec_mode, decospecs, compute_env, tag, scheduler_config ): """Deployer trigger must accept a list for run_params, not a tuple.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - trigger_param = str(uuid.uuid4())[:8] - test_unique_tag = f"test_compliance_run_params_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": { - "METAFLOW_CLICK_API_PROCESS_CONFIG": "1", - **compute_env, - }, - "decospecs": decospecs, - } - deployed_flow = deploy_flow_to_scheduler( + run = _deploy_and_run_compliance( flow_name="config/mutable_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="run_params", + tl_args_extra={"env": {"METAFLOW_CLICK_API_PROCESS_CONFIG": "1"}}, + run_kwargs={"trigger_param": trigger_param, "param2": "48"}, ) - # Pass two run_params as a list. If the orchestrator passes a tuple here, - # the trigger() implementation raises TypeError before the run starts. - run_kwargs = {"trigger_param": trigger_param, "param2": "48"} - run = wait_for_deployed_run(deployed_flow, run_kwargs=run_kwargs) - assert ( run.successful ), "Run was not successful (check that run_params is a list, not a tuple)" @@ -92,38 +132,26 @@ def test_run_params_multiple_values( # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_branch_propagated_to_steps( exec_mode, decospecs, compute_env, tag, scheduler_config ): """--branch must be forwarded to each step subprocess, not just the start command.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - branch = str(uuid.uuid4())[:8] - test_unique_tag = f"test_compliance_branch_{exec_mode}" - combined_tags = tag + [test_unique_tag] - tl_args = { - "env": compute_env, - "decospecs": decospecs, - "branch": branch, - } - - deployed_flow = deploy_flow_to_scheduler( + run = _deploy_and_run_compliance( flow_name="basic/helloproject.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="branch", + tl_args_extra={"branch": branch}, ) - run = wait_for_deployed_run(deployed_flow) - assert run.successful, "Run was not successful" rbranch = run["end"].task.data.branch - expected = "test." + branch + expected = f"test.{branch}" assert rbranch == expected, ( f"Branch name mismatch: got {rbranch!r}, expected {expected!r}. " "This usually means --branch was not forwarded to step subprocesses." @@ -144,33 +172,20 @@ def test_branch_propagated_to_steps( # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_retry_count_from_scheduler( exec_mode, decospecs, compute_env, tag, scheduler_config ): """Retry attempt number must come from the scheduler, not hardcoded to 0.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - - test_unique_tag = f"test_compliance_retry_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": compute_env, - "decospecs": decospecs, - } - - deployed_flow = deploy_flow_to_scheduler( + run = _deploy_and_run_compliance( flow_name="basic/retry_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="retry", ) - run = wait_for_deployed_run(deployed_flow) - assert run.successful, ( "Run was not successful — if @retry fails, the scheduler may be " "passing retry_count=0 instead of deriving it from the native attempt number." @@ -194,51 +209,36 @@ def test_retry_count_from_scheduler( # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_config_value_propagated( exec_mode, decospecs, compute_env, tag, scheduler_config ): """METAFLOW_FLOW_CONFIG_VALUE must be injected so @config/@project work in tasks.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - trigger_param = str(uuid.uuid4())[:8] - test_unique_tag = f"test_compliance_config_{exec_mode}" - combined_tags = tag + [test_unique_tag] # Override the config so project_name differs from the default. config_value = [ ("cfg_default_value", {"a": {"project_name": "compliance_project", "b": "99"}}) ] - tl_args = { - "env": { - "METAFLOW_CLICK_API_PROCESS_CONFIG": "1", - **compute_env, - }, - "package_suffixes": ".py,.json", - "config_value": config_value, - "decospecs": decospecs, - } - - deployed_flow = deploy_flow_to_scheduler( + run = _deploy_and_run_compliance( flow_name="config/config_simple.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, - ) - - run = wait_for_deployed_run( - deployed_flow, run_kwargs={"trigger_param": trigger_param} + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="config", + tl_args_extra={ + "env": {"METAFLOW_CLICK_API_PROCESS_CONFIG": "1"}, + "package_suffixes": ".py,.json", + "config_value": config_value, + }, + run_kwargs={"trigger_param": trigger_param}, ) assert run.successful, "Run was not successful" # The project tag is set by @project using the config-derived project_name. - # If METAFLOW_FLOW_CONFIG_VALUE was not injected, the project tag will use - # the default project_name ("config_project") instead of "compliance_project". expected_project_tag = "project:compliance_project" assert expected_project_tag in run.tags, ( f"Expected tag {expected_project_tag!r} not found in {sorted(run.tags)}. " @@ -266,45 +266,20 @@ def test_config_value_propagated( # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_nested_foreach_or_skip( exec_mode, decospecs, compute_env, tag, scheduler_config ): """Nested foreach must either work correctly or be rejected at deploy time with 'not supported'.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - - from metaflow.exception import MetaflowException - - test_unique_tag = f"test_compliance_nested_foreach_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": compute_env, - "decospecs": decospecs, - } - - # Let the orchestrator tell us whether it supports nested foreach. - # If .create() raises with "not supported", skip — the orchestrator - # correctly rejects the unsupported graph. No hardcoded dict needed. - try: - deployed_flow = deploy_flow_to_scheduler( - flow_name="dag/nested_foreach_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, - ) - except (MetaflowException, Exception) as e: - msg = str(e).lower() - if "not supported" in msg or "not yet supported" in msg: - pytest.skip( - f"{scheduler_config.scheduler_type} does not support nested foreach: {e}" - ) - raise # unexpected error — let the test fail normally - - run = wait_for_deployed_run(deployed_flow) + run = _deploy_and_run_compliance( + flow_name="dag/nested_foreach_flow.py", + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="nested_foreach", + catch_unsupported=True, + ) assert run.successful, "Nested foreach run was not successful" all_results = run["outer_join"].task.data.all_results @@ -315,101 +290,58 @@ def test_nested_foreach_or_skip( # --------------------------------------------------------------------------- -# test_timeout_enforcement +# test_timeout_enforcement_behaviors # # WHY: The existing test_timeout only verifies that @timeout doesn't break -# normal execution (step sleeps 1s with a 10-minute timeout — always passes). -# If the timeout= kwarg on subprocess.run() is completely broken, that test -# still passes. This test deploys a flow where a step sleeps well beyond its -# @timeout(seconds=5) and verifies the run actually fails. +# normal execution (step sleeps 1s with a 10-minute timeout). We must verify +# that the orchestrator actually enforces the limit and kills the step if +# exceeded. Additionally, _get_timeout_seconds previously ignored the 'minutes' +# attribute, so we test both variants here. # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only @pytest.mark.skip( reason="@timeout enforcement on remote backends (argo/sfn/airflow) is not " "reliable — the run may hang instead of failing. Needs backend-specific " "timeout mechanisms (e.g. activeDeadlineSeconds for k8s). See #XXXX." ) -def test_timeout_enforcement(exec_mode, decospecs, compute_env, tag, scheduler_config): - """A step that exceeds its @timeout must be killed — the run must fail.""" - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - - test_unique_tag = f"test_compliance_timeout_enforce_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": compute_env, - "decospecs": decospecs, - } - - deployed_flow = deploy_flow_to_scheduler( - flow_name="basic/timeout_enforce_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, - ) - - run = wait_for_deployed_run_allow_failure(deployed_flow) - - assert not run.successful, ( - "Run should have failed because the 'slow' step exceeds its " - "@timeout(seconds=5), but it succeeded. Timeout enforcement may be broken." - ) - - -# --------------------------------------------------------------------------- -# test_timeout_minutes_enforced -# -# WHY: _get_timeout_seconds only read 'seconds' attribute, silently ignoring -# 'minutes'. @timeout(minutes=1) produced no timeout at all — the step ran -# indefinitely. This test verifies that minute-based timeouts are actually -# enforced (D-TIMEOUT-1). -# --------------------------------------------------------------------------- - - -@pytest.mark.compliance -@pytest.mark.scheduler_only -@pytest.mark.skip( - reason="@timeout enforcement on remote backends (argo/sfn/airflow) is not " - "reliable — the run may hang instead of failing. Needs backend-specific " - "timeout mechanisms (e.g. activeDeadlineSeconds for k8s). See #XXXX." +@pytest.mark.parametrize( + "flow_name, expected_failure_msg", + [ + pytest.param( + "basic/timeout_enforce_flow.py", + "Run should have failed because the 'slow' step exceeds its @timeout(seconds=5), but it succeeded. Timeout enforcement may be broken.", + id="seconds", + ), + pytest.param( + "basic/timeout_minutes_flow.py", + "@timeout(minutes=1) was NOT enforced — the step ran for 2+ minutes without being killed. Check that _get_timeout_seconds correctly computes minutes*60+seconds.", + id="minutes", + ), + ], ) -def test_timeout_minutes_enforced( - exec_mode, decospecs, compute_env, tag, scheduler_config +def test_timeout_enforcement_behaviors( + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + flow_name, + expected_failure_msg, ): - """WHY: _get_timeout_seconds only read 'seconds' attribute, silently ignoring 'minutes'. - @timeout(minutes=1) produced no timeout at all. - This test verifies that minute-based timeouts are actually enforced (D-TIMEOUT-1). - """ - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - - test_unique_tag = f"test_compliance_timeout_minutes_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": compute_env, - "decospecs": decospecs, - } - - deployed_flow = deploy_flow_to_scheduler( - flow_name="basic/timeout_minutes_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, + """A step that exceeds its @timeout (whether set by seconds or minutes) must be killed.""" + run = _deploy_and_run_compliance( + flow_name=flow_name, + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix=f"timeout_{flow_name.split('/')[-1].split('_')[1]}", + allow_failure=True, ) - run = wait_for_deployed_run_allow_failure(deployed_flow) - - assert not run.successful, ( - "@timeout(minutes=1) was NOT enforced — the step ran for 2+ minutes without being killed. " - "Check that _get_timeout_seconds correctly computes minutes*60+seconds." - ) + assert not run.successful, expected_failure_msg # --------------------------------------------------------------------------- @@ -421,38 +353,25 @@ def test_timeout_minutes_enforced( # --------------------------------------------------------------------------- -@pytest.mark.compliance -@pytest.mark.scheduler_only def test_run_param_not_dropped( exec_mode, decospecs, compute_env, tag, scheduler_config ): """WHY: Parameters were silently dropped when trigger variables dict had None values or when JSON serialization lost the value. Verify parameter values arrive correctly. """ - if exec_mode != "deployer": - pytest.skip("compliance test requires deployer mode") - - test_unique_tag = f"test_compliance_run_param_not_dropped_{exec_mode}" - combined_tags = tag + [test_unique_tag] - - tl_args = { - "env": compute_env, - "decospecs": decospecs, - } - - deployed_flow = deploy_flow_to_scheduler( + run = _deploy_and_run_compliance( flow_name="basic/reserved_param_flow.py", - tl_args=tl_args, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=scheduler_config.scheduler_type, + exec_mode=exec_mode, + decospecs=decospecs, + compute_env=compute_env, + tag=tag, + scheduler_config=scheduler_config, + test_suffix="run_param_not_dropped", + run_kwargs={"retry_count": 42}, ) - run = wait_for_deployed_run(deployed_flow, run_kwargs={"retry_count": 42}) - assert run.successful, "Run was not successful" assert run["start"].task.data.stored_retry_count == 42, ( - "Expected retry_count=42, got %r. " + f"Expected retry_count=42, got {run['start'].task.data.stored_retry_count}. " "Parameter may have been dropped or not passed correctly." - % run["start"].task.data.stored_retry_count ) diff --git a/test/ux/core/test_config.py b/test/ux/core/test_config.py index dd7aac72c4a..5ce43b932d6 100644 --- a/test/ux/core/test_config.py +++ b/test/ux/core/test_config.py @@ -49,81 +49,94 @@ def _run_config_flow( ) -def test_config_simple_default( - exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name +@pytest.mark.parametrize( + "test_id, flow_name, tl_args_extra, expected_config, check_corner_cases", + [ + pytest.param( + "default", + "config/config_simple.py", + {"package_suffixes": ".py,.json"}, + {"a": {"b": "41", "project_name": "config_project"}}, + False, + id="default", + ), + pytest.param( + "config_value", + "config/config_simple.py", + { + "package_suffixes": ".py,.json", + "config_value": [ + ( + "cfg_default_value", + {"a": {"project_name": "config_project_2", "b": "56"}}, + ) + ], + }, + {"a": {"project_name": "config_project_2", "b": "56"}}, + False, + id="config_value", + ), + pytest.param( + "corner_cases", + "config/config_corner_cases.py", + {"package_suffixes": ".json"}, + {"a": {"b": "41", "project_name": "config_project"}}, + True, + id="corner_cases", + ), + ], +) +def test_config_simple_behaviors( + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + backend_name, + test_id, + flow_name, + tl_args_extra, + expected_config, + check_corner_cases, ): - """Config test with default values.""" + """Parametrized test covering default configs, config overrides, and corner cases.""" trigger_param = str(uuid.uuid4())[:8] run = _run_config_flow( - flow_name="config/config_simple.py", + flow_name=flow_name, exec_mode=exec_mode, decospecs=decospecs, compute_env=compute_env, tag=tag, scheduler_config=scheduler_config, - test_name=f"config_simple_default_{backend_name}", - tl_args_extra={"package_suffixes": ".py,.json"}, + test_name=f"config_simple_{test_id}_{backend_name}", + tl_args_extra=tl_args_extra, run_params={"trigger_param": trigger_param}, ) - default_config = {"a": {"b": "41", "project_name": "config_project"}} - assert run.successful, "Run was not successful" - expected_project_tag = f"project:{default_config['a']['project_name']}" + expected_project_tag = f"project:{expected_config['a']['project_name']}" assert expected_project_tag in run.tags, "Project name is incorrect" end_task = run["end"].task assert end_task.data.trigger_param == trigger_param assert end_task.data.config_val == 5, "config_val incorrect" assert ( - end_task.data.config_val_2 == default_config["a"]["b"] + end_task.data.config_val_2 == expected_config["a"]["b"] ), "config_val_2 incorrect" assert end_task.data.config_from_env == "5", "config_from_env incorrect" assert ( - end_task.data.config_from_env_2 == default_config["a"]["b"] + end_task.data.config_from_env_2 == expected_config["a"]["b"] ), "config_from_env_2 incorrect" + if check_corner_cases: + assert end_task.data.var1 == "1", "var1 incorrect" + assert end_task.data.var2 == "2", "var2 incorrect" -def test_config_simple_config_value( - exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name -): - """Config test using config_value override.""" - trigger_param = str(uuid.uuid4())[:8] - config_value = [ - ("cfg_default_value", {"a": {"project_name": "config_project_2", "b": "56"}}) - ] - run = _run_config_flow( - flow_name="config/config_simple.py", - exec_mode=exec_mode, - decospecs=decospecs, - compute_env=compute_env, - tag=tag, - scheduler_config=scheduler_config, - test_name=f"config_simple_config_value_{backend_name}", - tl_args_extra={"package_suffixes": ".py,.json", "config_value": config_value}, - run_params={"trigger_param": trigger_param}, - ) - - config = config_value[0][1] - - assert run.successful, "Run was not successful" - expected_project_tag = f"project:{config['a']['project_name']}" - assert expected_project_tag in run.tags, "Project name is incorrect" - end_task = run["end"].task - assert end_task.data.trigger_param == trigger_param - assert end_task.data.config_val == 5, "config_val incorrect" - assert end_task.data.config_val_2 == config["a"]["b"], "config_val_2 incorrect" - assert end_task.data.config_from_env == "5", "config_from_env incorrect" - assert ( - end_task.data.config_from_env_2 == config["a"]["b"] - ), "config_from_env_2 incorrect" - - -def test_config_simple_config( +def test_config_simple_file( exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name ): - """Config test using an explicit config file.""" + """Config test using an explicit config file via the CLI.""" trigger_param = str(uuid.uuid4())[:8] config_files = [ ("cfg", os.path.join(_FLOWS_DIR, "config", "config_simple_cmd.json")) @@ -135,7 +148,7 @@ def test_config_simple_config( compute_env=compute_env, tag=tag, scheduler_config=scheduler_config, - test_name=f"config_simple_config_{backend_name}", + test_name=f"config_simple_config_file_{backend_name}", tl_args_extra={ "env": compute_env, # no PROCESS_CONFIG needed for --config "package_suffixes": ".py,.json", @@ -148,73 +161,45 @@ def test_config_simple_config( assert run["end"].task.data.trigger_param == trigger_param -def test_mutable_flow_default( - exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name -): - """Mutable config test with default values.""" - trigger_param = str(uuid.uuid4())[:8] - run = _run_config_flow( - flow_name="config/mutable_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - compute_env=compute_env, - tag=tag, - scheduler_config=scheduler_config, - test_name=f"mutable_flow_default_{backend_name}", - run_params={"trigger_param": trigger_param, "param2": "48"}, - ) - - default_config = { - "parameters": [ - {"name": "param1", "default": "41"}, - {"name": "param2", "default": "42"}, - ], - "step_add_environment": {"vars": {"STEP_LEVEL": "2"}}, - "step_add_environment_2": {"vars": {"STEP_LEVEL_2": "3"}}, - "flow_add_environment": {"vars": {"FLOW_LEVEL": "4"}}, - "project_name": "config_project", - } - - assert run.successful, "Run was not successful" - - expected_project_tag = f"project:{default_config['project_name']}" - assert expected_project_tag in run.tags, "Project name is incorrect" - - start_task_data = run["start"].task.data - assert start_task_data.trigger_param == trigger_param - - test_parameters = {"trigger_param": trigger_param, "param2": "48"} - for param in default_config["parameters"]: - value = test_parameters.get(param["name"], None) or param["default"] - assert hasattr( - start_task_data, param["name"] - ), f"Missing parameter {param['name']}" - assert ( - getattr(start_task_data, param["name"]) == value - ), f"Parameter {param['name']} incorrect: got {getattr(start_task_data, param['name'])}, expected {value}" - - assert ( - start_task_data.flow_level - == default_config["flow_add_environment"]["vars"]["FLOW_LEVEL"] - ), "flow_level incorrect" - assert ( - start_task_data.step_level - == default_config["step_add_environment"]["vars"]["STEP_LEVEL"] - ), "step_level incorrect" - assert ( - start_task_data.step_level_2 - == default_config["step_add_environment_2"]["vars"]["STEP_LEVEL_2"] - ), "step_level_2 incorrect" - - -def test_mutable_flow_config_value( - exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name -): - """Mutable flow with config_value override.""" - trigger_param = str(uuid.uuid4())[:8] - config_value = [ - ( - "config", +@pytest.mark.parametrize( + "test_id, tl_args_extra, run_params_overrides, expected_config", + [ + pytest.param( + "default", + {}, + {"param2": "48"}, + { + "parameters": [ + {"name": "param1", "default": "41"}, + {"name": "param2", "default": "42"}, + ], + "step_add_environment": {"vars": {"STEP_LEVEL": "2"}}, + "step_add_environment_2": {"vars": {"STEP_LEVEL_2": "3"}}, + "flow_add_environment": {"vars": {"FLOW_LEVEL": "4"}}, + "project_name": "config_project", + }, + id="default", + ), + pytest.param( + "config_value", + { + "config_value": [ + ( + "config", + { + "parameters": [ + {"name": "param3", "default": "43"}, + {"name": "param4", "default": "44"}, + ], + "step_add_environment": {"vars": {"STEP_LEVEL": "5"}}, + "step_add_environment_2": {"vars": {"STEP_LEVEL_2": "6"}}, + "flow_add_environment": {"vars": {"FLOW_LEVEL": "7"}}, + "project_name": "config_project_2", + }, + ) + ] + }, + {"param3": "45"}, { "parameters": [ {"name": "param3", "default": "43"}, @@ -225,8 +210,26 @@ def test_mutable_flow_config_value( "flow_add_environment": {"vars": {"FLOW_LEVEL": "7"}}, "project_name": "config_project_2", }, - ) - ] + id="config_value", + ), + ], +) +def test_mutable_flow_behaviors( + exec_mode, + decospecs, + compute_env, + tag, + scheduler_config, + backend_name, + test_id, + tl_args_extra, + run_params_overrides, + expected_config, +): + """Parametrized test for mutable flows comparing default configurations against config_value overrides.""" + trigger_param = str(uuid.uuid4())[:8] + run_params = {"trigger_param": trigger_param, **run_params_overrides} + run = _run_config_flow( flow_name="config/mutable_flow.py", exec_mode=exec_mode, @@ -234,24 +237,21 @@ def test_mutable_flow_config_value( compute_env=compute_env, tag=tag, scheduler_config=scheduler_config, - test_name=f"mutable_flow_config_value_{backend_name}", - tl_args_extra={"config_value": config_value}, - run_params={"trigger_param": trigger_param, "param3": "45"}, + test_name=f"mutable_flow_{test_id}_{backend_name}", + tl_args_extra=tl_args_extra, + run_params=run_params, ) - config = config_value[0][1] - assert run.successful, "Run was not successful" - expected_project_tag = f"project:{config['project_name']}" + expected_project_tag = f"project:{expected_config['project_name']}" assert expected_project_tag in run.tags, "Project name is incorrect" start_task_data = run["start"].task.data assert start_task_data.trigger_param == trigger_param - test_parameters = {"trigger_param": trigger_param, "param3": "45"} - for param in config["parameters"]: - value = test_parameters.get(param["name"], None) or param["default"] + for param in expected_config["parameters"]: + value = run_params.get(param["name"], None) or param["default"] assert hasattr( start_task_data, param["name"] ), f"Missing parameter {param['name']}" @@ -261,56 +261,18 @@ def test_mutable_flow_config_value( assert ( start_task_data.flow_level - == config["flow_add_environment"]["vars"]["FLOW_LEVEL"] + == expected_config["flow_add_environment"]["vars"]["FLOW_LEVEL"] ), "flow_level incorrect" assert ( start_task_data.step_level - == config["step_add_environment"]["vars"]["STEP_LEVEL"] + == expected_config["step_add_environment"]["vars"]["STEP_LEVEL"] ), "step_level incorrect" assert ( start_task_data.step_level_2 - == config["step_add_environment_2"]["vars"]["STEP_LEVEL_2"] + == expected_config["step_add_environment_2"]["vars"]["STEP_LEVEL_2"] ), "step_level_2 incorrect" -def test_config_corner_cases( - exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name -): - """Config corner cases: env_cfg, config_expr with a function, and extra env vars.""" - trigger_param = str(uuid.uuid4())[:8] - run = _run_config_flow( - flow_name="config/config_corner_cases.py", - exec_mode=exec_mode, - decospecs=decospecs, - compute_env=compute_env, - tag=tag, - scheduler_config=scheduler_config, - test_name=f"config_corner_cases_{backend_name}", - tl_args_extra={"package_suffixes": ".json"}, - run_params={"trigger_param": trigger_param}, - ) - - default_config = {"a": {"b": "41", "project_name": "config_project"}} - - assert run.successful, "Run was not successful" - - expected_project_tag = f"project:{default_config['a']['project_name']}" - assert expected_project_tag in run.tags, "Project name is incorrect" - - end_task = run["end"].task - assert end_task.data.trigger_param == trigger_param - assert end_task.data.config_val == 5, "config_val incorrect" - assert ( - end_task.data.config_val_2 == default_config["a"]["b"] - ), "config_val_2 incorrect" - assert end_task.data.config_from_env == "5", "config_from_env incorrect" - assert ( - end_task.data.config_from_env_2 == default_config["a"]["b"] - ), "config_from_env_2 incorrect" - assert end_task.data.var1 == "1", "var1 incorrect" - assert end_task.data.var2 == "2", "var2 incorrect" - - @pytest.mark.scheduler_only def test_config_from_deployment( exec_mode, decospecs, compute_env, tag, scheduler_config, backend_name diff --git a/test/ux/core/test_dag.py b/test/ux/core/test_dag.py index ef5100a8d51..0e1eb043a2b 100644 --- a/test/ux/core/test_dag.py +++ b/test/ux/core/test_dag.py @@ -1,46 +1,34 @@ +""" +DAG topology tests — branch, foreach, nested foreach, condition, retry. + +Verifies that orchestrators correctly map Metaflow DAG structures +into their native representations and execute them successfully. +""" + +from typing import Any, Callable, Dict, List + import pytest +from metaflow import Run -pytestmark = pytest.mark.dag from .test_utils import execute_test_flow +# Apply marker to all tests in this module +pytestmark = pytest.mark.dag -def test_branch(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify parallel branches (split/join) execute correctly.""" - run = execute_test_flow( - flow_name="dag/branch_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="branch", - tl_args_extra={"env": compute_env}, - ) +def _assert_branch(run: Run): + """Verify parallel branches (split/join) execute correctly.""" assert run.successful, "Run was not successful" - assert run["join"].task.data.values == [ - "a", - "b", - ], "Branch join values didn't match" + assert run["join"].task.data.values == ["a", "b"], "Branch join values didn't match" -def test_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config): +def _assert_foreach(run: Run): """Verify foreach fan-out/join executes correctly.""" - run = execute_test_flow( - flow_name="dag/foreach_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="foreach", - tl_args_extra={"env": compute_env}, - ) - assert run.successful, "Run was not successful" - # Verify exact fanout count — catches silent foreach_count=1 fallback (D-FOREACH-1) process_tasks = list(run["process"].tasks()) assert len(process_tasks) == 3, ( - "Expected 3 foreach tasks for items=[1,2,3], got %d. " - "This may indicate foreach_count fell back to 1." % len(process_tasks) + f"Expected 3 foreach tasks for items=[1,2,3], got {len(process_tasks)}. " + "This may indicate foreach_count fell back to 1." ) assert run["join"].task.data.results == [ 2, @@ -49,23 +37,13 @@ def test_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config): ], "Foreach join results didn't match" -def test_multibody_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config): +def _assert_multibody_foreach(run: Run): """Verify foreach with multiple linear body steps (process -> transform -> join).""" - run = execute_test_flow( - flow_name="dag/multi_body_foreach_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="multibody_foreach", - tl_args_extra={"env": compute_env}, - ) - assert run.successful, "Run was not successful" process_tasks = list(run["process"].tasks()) - assert len(process_tasks) == 3, "Expected 3 foreach process tasks, got %d" % len( - process_tasks - ) + assert ( + len(process_tasks) == 3 + ), f"Expected 3 foreach process tasks, got {len(process_tasks)}" assert run["join"].task.data.results == [ 3, 5, @@ -73,18 +51,8 @@ def test_multibody_foreach(exec_mode, decospecs, compute_env, tag, scheduler_con ], "Multi-body foreach join results didn't match" -def test_retry_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config): +def _assert_retry_foreach(run: Run): """Verify @retry on a foreach body step works — body tasks retry and succeed.""" - run = execute_test_flow( - flow_name="dag/retry_foreach_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="retry_foreach", - tl_args_extra={"env": compute_env}, - ) - assert run.successful, "Run was not successful" assert run["join"].task.data.results == [ 10, @@ -97,33 +65,8 @@ def test_retry_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config) ) -def test_condition(exec_mode, decospecs, compute_env, tag, scheduler_config): +def _assert_condition(run: Run): """Verify @condition routing executes the correct branch.""" - from metaflow.exception import MetaflowException - - try: - run = execute_test_flow( - flow_name="dag/condition_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="condition", - tl_args_extra={"env": compute_env}, - ) - except (MetaflowException, Exception) as e: - msg = str(e).lower() - if ( - "not supported" in msg - or "not yet supported" in msg - or isinstance(e, ImportError) - or "cannot import name" in msg - ): - pytest.skip( - f"{scheduler_config.scheduler_type} does not support @condition: {e}" - ) - raise - assert run.successful, "Run was not successful" # value=42 >= 10, so high_branch should have been taken assert ( @@ -134,30 +77,8 @@ def test_condition(exec_mode, decospecs, compute_env, tag, scheduler_config): ), f"Expected result=84 (42*2), got {run['merge'].task.data.result!r}" -def test_nested_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config): +def _assert_nested_foreach(run: Run): """Verify nested foreach (foreach inside foreach) executes correctly.""" - from metaflow.exception import MetaflowException - - try: - run = execute_test_flow( - flow_name="dag/nested_foreach_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="nested_foreach", - tl_args_extra={"env": compute_env}, - ) - except (MetaflowException, Exception) as e: - msg = str(e).lower() - if exec_mode == "deployer" and ( - "not supported" in msg or "not yet supported" in msg - ): - pytest.skip( - f"{scheduler_config.scheduler_type} does not support nested foreach: {e}" - ) - raise - assert run.successful, "Run was not successful" assert run["outer_join"].task.data.all_results == [ "x-1", @@ -165,96 +86,142 @@ def test_nested_foreach(exec_mode, decospecs, compute_env, tag, scheduler_config ], "Nested foreach all_results didn't match" -def test_nested_foreach_2x2(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify nested foreach with 2 outer x 2 inner items — catches D-NESTED-1 semantic bug.""" - from metaflow.exception import MetaflowException - - try: - run = execute_test_flow( - flow_name="dag/nested_foreach_2x2_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="nested_foreach_2x2", - tl_args_extra={"env": compute_env}, - ) - except (MetaflowException, Exception) as e: - msg = str(e).lower() - if exec_mode == "deployer" and ( - "not supported" in msg or "not yet supported" in msg - ): - pytest.skip( - f"{scheduler_config.scheduler_type} does not support nested foreach: {e}" - ) - raise - +def _assert_nested_foreach_2x2(run: Run): + """Verify nested foreach with 2 outer x 2 inner items — catches D-NESTED-1 bug.""" assert run.successful, "Run was not successful" # Must have all 4 combinations: x-1, x-2, y-1, y-2 - assert run["outer_join"].task.data.all_results == [ - "x-1", - "x-2", - "y-1", - "y-2", - ], ( - "Expected 4 results from 2x2 nested foreach, got: %s. " + assert run["outer_join"].task.data.all_results == ["x-1", "x-2", "y-1", "y-2"], ( + f"Expected 4 results from 2x2 nested foreach, got: {run['outer_join'].task.data.all_results}. " "This may indicate nested_foreach_join is not aggregating all outer items correctly." - % run["outer_join"].task.data.all_results ) # Verify inner task count: 4 inner tasks total (2 outer x 2 inner) inner_tasks = list(run["inner"].tasks()) assert ( len(inner_tasks) == 4 - ), "Expected 4 inner tasks for 2x2 foreach, got %d" % len(inner_tasks) + ), f"Expected 4 inner tasks for 2x2 foreach, got {len(inner_tasks)}" -@pytest.mark.skip( - reason="3-level 2x2x2 nested foreach = 24 sequential Mage block executions. " - "Too slow for the 2-CPU GitHub Actions runner even with ThreadPoolExecutor " - "parallelism within each block (~9s per subprocess * 24 = ~216s just for " - "steps, plus Mage polling overhead). Needs larger runner or reduced topology." +def _assert_nested_foreach_3level(run: Run): + """Verify 3-level nested foreach compiles and executes correctly.""" + assert run.successful, "Run was not successful" + assert run["outer_join"].task.data.all_results == [ + "a-1-10", + "a-1-20", + "a-2-10", + "a-2-20", + "b-1-10", + "b-1-20", + "b-2-10", + "b-2-20", + ], f"3-level nested foreach all_results didn't match: {run['outer_join'].task.data.all_results}" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "flow_name, test_name, assertion_fn, allow_unsupported", + [ + # Standard DAG topologies + pytest.param( + "dag/branch_flow.py", "branch", _assert_branch, False, id="branch" + ), + pytest.param( + "dag/foreach_flow.py", "foreach", _assert_foreach, False, id="foreach" + ), + pytest.param( + "dag/multi_body_foreach_flow.py", + "multibody_foreach", + _assert_multibody_foreach, + False, + id="multibody_foreach", + ), + pytest.param( + "dag/retry_foreach_flow.py", + "retry_foreach", + _assert_retry_foreach, + False, + id="retry_foreach", + ), + # Advanced/Beta topologies (may skip if unsupported by orchestrator) + pytest.param( + "dag/condition_flow.py", + "condition", + _assert_condition, + True, + id="condition", + ), + pytest.param( + "dag/nested_foreach_flow.py", + "nested_foreach", + _assert_nested_foreach, + True, + id="nested_foreach", + ), + pytest.param( + "dag/nested_foreach_2x2_flow.py", + "nested_foreach_2x2", + _assert_nested_foreach_2x2, + True, + id="nested_foreach_2x2", + ), + # Skipped topologies due to compute limits + pytest.param( + "dag/nested_foreach_3level_flow.py", + "nested_foreach_3level", + _assert_nested_foreach_3level, + True, + marks=pytest.mark.skip( + reason="3-level 2x2x2 nested foreach = 24 sequential Mage block executions. " + "Too slow for the 2-CPU GitHub Actions runner even with ThreadPoolExecutor. " + "Needs larger runner or reduced topology." + ), + id="nested_foreach_3level", + ), + ], ) -def test_nested_foreach_3level( - exec_mode, decospecs, compute_env, tag, scheduler_config +def test_dag_behaviors( + exec_mode: str, + decospecs: Any, + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, + flow_name: str, + test_name: str, + assertion_fn: Callable[[Run], None], + allow_unsupported: bool, ): - """Verify 3-level nested foreach compiles and executes correctly. + """Parametrized test for all DAG structural capabilities.""" - Topology: outer(foreach groups) → middle(foreach batches) → inner(foreach items) - This catches compiler bugs where inner foreach step names are looked up in a - dict keyed only by outermost foreach names (D-A02-4). - """ - from metaflow.exception import MetaflowException + # Lazy import to handle MetaflowException catching + try: + from metaflow.exception import MetaflowException + except ImportError: + MetaflowException = Exception try: run = execute_test_flow( - flow_name="dag/nested_foreach_3level_flow.py", + flow_name=flow_name, exec_mode=exec_mode, decospecs=decospecs, tag=tag, scheduler_config=scheduler_config, - test_name="nested_foreach_3level", + test_name=test_name, tl_args_extra={"env": compute_env}, ) except (MetaflowException, Exception) as e: msg = str(e).lower() - if exec_mode == "deployer" and ( - "not supported" in msg or "not yet supported" in msg + if allow_unsupported and ( + "not supported" in msg + or "not yet supported" in msg + or isinstance(e, ImportError) + or "cannot import name" in msg ): pytest.skip( - f"{scheduler_config.scheduler_type} does not support 3-level nested foreach: {e}" + f"{scheduler_config.scheduler_type} does not support {test_name}: {e}" ) raise - assert run.successful, "Run was not successful" - assert run["outer_join"].task.data.all_results == [ - "a-1-10", - "a-1-20", - "a-2-10", - "a-2-20", - "b-1-10", - "b-1-20", - "b-2-10", - "b-2-20", - ], "3-level nested foreach all_results didn't match: %s" % ( - run["outer_join"].task.data.all_results - ) + assertion_fn(run) diff --git a/test/ux/core/test_decorators.py b/test/ux/core/test_decorators.py index 13366ac7461..1f8c63ea86d 100644 --- a/test/ux/core/test_decorators.py +++ b/test/ux/core/test_decorators.py @@ -8,27 +8,24 @@ pytest test/ux/core/test_decorators.py -m decorators -v """ -import pytest +from typing import Any, Callable, Dict, List -pytestmark = pytest.mark.decorators +import pytest +from metaflow import Run from .test_utils import execute_test_flow +# Apply markers to all tests in this module +pytestmark = [pytest.mark.decorators, pytest.mark.basic] -@pytest.mark.decorators -@pytest.mark.basic -def test_environment_vars(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @environment(vars={...}) injects env vars into step execution.""" - run = execute_test_flow( - flow_name="decorators/env_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="env_vars", - tl_args_extra={"env": compute_env}, - ) +# --------------------------------------------------------------------------- +# Assertion Callbacks +# --------------------------------------------------------------------------- + + +def _assert_env_vars(run: Run): + """Validate @environment(vars={...}) standard injection.""" assert run.successful, "Run was not successful" assert ( run["start"].task.data.foo == "bar" @@ -38,22 +35,8 @@ def test_environment_vars(exec_mode, decospecs, compute_env, tag, scheduler_conf ), f"Expected TEST_ENV_BAZ='qux', got {run['start'].task.data.baz!r}" -@pytest.mark.decorators -@pytest.mark.basic -def test_environment_vars_foreach( - exec_mode, decospecs, compute_env, tag, scheduler_config -): - """Verify @environment(vars={...}) on a foreach body step is correctly propagated.""" - run = execute_test_flow( - flow_name="decorators/env_foreach_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="env_vars_foreach", - tl_args_extra={"env": compute_env}, - ) - +def _assert_env_vars_foreach(run: Run): + """Validate @environment(vars={...}) injection in a foreach body.""" assert run.successful, "Run was not successful" # Every foreach body task must have received the injected env var. assert all( @@ -61,20 +44,8 @@ def test_environment_vars_foreach( ), f"@environment var not injected into foreach body: {run['join'].task.data.env_vals!r}" -@pytest.mark.decorators -@pytest.mark.basic -def test_card_basic(exec_mode, decospecs, compute_env, tag, scheduler_config): - """Verify @card decorator creates a card after step execution.""" - run = execute_test_flow( - flow_name="decorators/card_flow.py", - exec_mode=exec_mode, - decospecs=decospecs, - tag=tag, - scheduler_config=scheduler_config, - test_name="card_basic", - tl_args_extra={"env": compute_env}, - ) - +def _assert_card_basic(run: Run): + """Validate @card decorator generates a card.""" assert run.successful, "Run was not successful" assert run["start"].task.data.message == "hello from card flow" @@ -83,3 +54,52 @@ def test_card_basic(exec_mode, decospecs, compute_env, tag, scheduler_config): cards = get_cards(run["start"].task) assert len(cards) > 0, "Expected at least one card on the start step" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "flow_name, test_name, assertion_fn", + [ + pytest.param( + "decorators/env_flow.py", + "env_vars", + _assert_env_vars, + id="environment_vars", + ), + pytest.param( + "decorators/env_foreach_flow.py", + "env_vars_foreach", + _assert_env_vars_foreach, + id="environment_vars_foreach", + ), + pytest.param( + "decorators/card_flow.py", "card_basic", _assert_card_basic, id="card_basic" + ), + ], +) +def test_decorator_behaviors( + exec_mode: str, + decospecs: Any, + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, + flow_name: str, + test_name: str, + assertion_fn: Callable[[Run], None], +): + """Verify various decorators function properly across all execution modes.""" + run = execute_test_flow( + flow_name=flow_name, + exec_mode=exec_mode, + decospecs=decospecs, + tag=tag, + scheduler_config=scheduler_config, + test_name=test_name, + tl_args_extra={"env": compute_env}, + ) + + assertion_fn(run) diff --git a/test/ux/core/test_lifecycle.py b/test/ux/core/test_lifecycle.py index 5ba0e0f8363..c6923b77be9 100644 --- a/test/ux/core/test_lifecycle.py +++ b/test/ux/core/test_lifecycle.py @@ -8,16 +8,23 @@ pytest test/ux/core/test_lifecycle.py -m lifecycle -v """ -import pytest +from typing import Any, Dict, List -pytestmark = [pytest.mark.lifecycle, pytest.mark.scheduler_only] +import pytest from .test_utils import deploy_flow_to_scheduler, wait_for_deployed_run +# Apply markers to all tests in this module +pytestmark = [pytest.mark.lifecycle, pytest.mark.scheduler_only] + -@pytest.mark.lifecycle -@pytest.mark.scheduler_only -def test_schedule_deploy(exec_mode, decospecs, compute_env, tag, scheduler_config): +def test_schedule_deploy( + exec_mode: str, + decospecs: List[str], + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, +): """Deploy a @schedule flow, verify deployment succeeds.""" if exec_mode != "deployer": pytest.skip("lifecycle tests require deployer mode") @@ -43,9 +50,13 @@ def test_schedule_deploy(exec_mode, decospecs, compute_env, tag, scheduler_confi assert deployed_flow.name, "Deployed flow has no name" -@pytest.mark.lifecycle -@pytest.mark.scheduler_only -def test_deployed_flow_status(exec_mode, decospecs, compute_env, tag, scheduler_config): +def test_deployed_flow_status( + exec_mode: str, + decospecs: List[str], + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, +): """Deploy, trigger, verify status, then check run completed.""" if exec_mode != "deployer": pytest.skip("lifecycle tests require deployer mode") @@ -73,12 +84,17 @@ def test_deployed_flow_status(exec_mode, decospecs, compute_env, tag, scheduler_ assert run.finished, "Run did not finish" -@pytest.mark.lifecycle -@pytest.mark.scheduler_only -@pytest.mark.parametrize("use_schedules", [True, False], ids=["schedules", "schedule"]) +@pytest.mark.parametrize( + "use_schedules", [True, False], ids=["schedules_list", "legacy_schedule"] +) def test_argo_schedule_uses_configured_field( - exec_mode, decospecs, compute_env, tag, scheduler_config, use_schedules -): + exec_mode: str, + decospecs: List[str], + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, + use_schedules: bool, +) -> None: """On Argo, the @schedule cron actually lands in the configured CronWorkflow field (`schedules` list vs legacy `schedule`) and the workflow is not suspended. Guards ARGO_WORKFLOWS_USE_SCHEDULES and the empty-list fix.""" diff --git a/test/ux/core/test_resume.py b/test/ux/core/test_resume.py index c9ed0ce55f3..d83638a02f4 100644 --- a/test/ux/core/test_resume.py +++ b/test/ux/core/test_resume.py @@ -1,11 +1,13 @@ import time +from typing import Any, Dict, List + import pytest pytestmark = pytest.mark.scheduler_only from .test_utils import ( + _is_failed_status, deploy_flow_to_scheduler, wait_for_deployed_run, - _is_failed_status, ) @@ -53,59 +55,10 @@ def _wait_for_resumed_run(triggered_run, timeout=3600, polling_interval=3): return triggered_run.run -def test_resume_hello_world(decospecs, compute_env, tag, scheduler_config): - """Resume a successful run — all steps should be cloned.""" - sched_type = scheduler_config.scheduler_type - if sched_type is None: - pytest.skip("No scheduler configured") - - test_unique_tag = "test_resume_hello_world" - combined_tags = tag + [test_unique_tag] - - deployed_flow = deploy_flow_to_scheduler( - flow_name="basic/resumeflow.py", - tl_args={"decospecs": decospecs, "env": compute_env}, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=sched_type, - ) - - # First run: should succeed (should_fail defaults to False) - run1 = wait_for_deployed_run(deployed_flow) - assert run1.successful, "First run was not successful" - assert run1["start"].task.data.start_value == "started" - assert run1["process"].task.data.process_value == "processed" - assert run1["end"].task.data.end_value == "done" - - # Resume: all steps should be cloned from the successful run - resumed = _try_resume(deployed_flow, sched_type, origin_run_id=run1.id) - run2 = _wait_for_resumed_run(resumed) - assert run2.successful, "Resumed run was not successful" - assert run2["start"].task.data.start_value == "started" - assert run2["process"].task.data.process_value == "processed" - assert run2["end"].task.data.end_value == "done" - - -def test_resume_failed_flow(decospecs, compute_env, tag, scheduler_config): - """Resume a failed run — failed step should re-execute, earlier steps cloned.""" - sched_type = scheduler_config.scheduler_type - if sched_type is None: - pytest.skip("No scheduler configured") - - test_unique_tag = "test_resume_failed_flow" - combined_tags = tag + [test_unique_tag] - - deployed_flow = deploy_flow_to_scheduler( - flow_name="basic/resumeflow.py", - tl_args={"decospecs": decospecs, "env": compute_env}, - scheduler_args={"cluster": scheduler_config.cluster}, - deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, - scheduler_type=sched_type, - ) - - # First run: trigger with should_fail=True — process step will fail +def _trigger_and_wait(deployed_flow, sched_type, trigger_kwargs): + """Helper to trigger a flow and wait for completion (success or failure).""" try: - triggered = deployed_flow.trigger(should_fail=True) + triggered = deployed_flow.trigger(**trigger_kwargs) except Exception as e: pytest.skip(f"{sched_type}: cannot trigger with parameters: {e}") @@ -118,104 +71,113 @@ def test_resume_failed_flow(decospecs, compute_env, tag, scheduler_config): break time.sleep(3) - failed_run_id = triggered.run.id if triggered.run else None - assert failed_run_id is not None, "Could not get failed run ID" - - # Resume: process and end should re-execute, start should be cloned - resumed = _try_resume( - deployed_flow, - sched_type, - origin_run_id=failed_run_id, - should_fail=False, - ) - run2 = _wait_for_resumed_run(resumed) - assert run2.successful, "Resumed run was not successful" - assert run2["start"].task.data.start_value == "started" - assert run2["process"].task.data.process_value == "processed" - assert run2["end"].task.data.end_value == "done" - - -def test_resume_foreach(decospecs, compute_env, tag, scheduler_config): - """Resume a failed foreach run — failed iteration re-executes, completed ones are cloned.""" + assert triggered.run is not None, "Could not get triggered run ID" + return triggered.run + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "trigger_kwargs, expect_first_run_success, resume_kwargs", + [ + pytest.param({}, True, {}, id="successful_run_clones_all"), + pytest.param( + {"should_fail": True}, + False, + {"should_fail": False}, + id="failed_run_reexecutes_failed_step", + ), + pytest.param( + {}, + True, + {"step_to_rerun": "process"}, + id="step_to_rerun_forces_downstream_execution", + ), + ], +) +def test_resume_basic_flow( + decospecs: Any, + compute_env: Dict[str, str], + tag: List[str], + scheduler_config: Any, + trigger_kwargs: Dict[str, Any], + expect_first_run_success: bool, + resume_kwargs: Dict[str, Any], +): + """Parametrized test covering standard successful resume, failed run resume, + and explicit step-to-rerun behavior on basic/resumeflow.py.""" sched_type = scheduler_config.scheduler_type if sched_type is None: pytest.skip("No scheduler configured") - test_unique_tag = "test_resume_foreach" - combined_tags = tag + [test_unique_tag] + combined_tags = tag + ["test_resume_basic_flow"] deployed_flow = deploy_flow_to_scheduler( - flow_name="dag/foreach_resume_flow.py", + flow_name="basic/resumeflow.py", tl_args={"decospecs": decospecs, "env": compute_env}, scheduler_args={"cluster": scheduler_config.cluster}, deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, scheduler_type=sched_type, ) - # First run: item 1 fails, items 2 and 3 succeed - try: - triggered = deployed_flow.trigger(fail_on_item=1) - except Exception as e: - pytest.skip(f"{sched_type}: cannot trigger with parameters: {e}") - - start_time = time.time() - while time.time() - start_time < 600: - status = triggered.status - if _is_failed_status(status): - break - if triggered.run and triggered.run.finished: - break - time.sleep(3) + # First run + run1 = _trigger_and_wait(deployed_flow, sched_type, trigger_kwargs) - failed_run_id = triggered.run.id if triggered.run else None - assert failed_run_id is not None, "Could not get failed run ID" + if expect_first_run_success: + assert run1.successful, "First run was unexpectedly unsuccessful" + else: + assert not run1.successful, "First run was unexpectedly successful" - # Resume: item 1 should re-execute (fail_on_item=-1), others are cloned + # Resume resumed = _try_resume( - deployed_flow, - sched_type, - origin_run_id=failed_run_id, - fail_on_item=-1, + deployed_flow, sched_type, origin_run_id=run1.id, **resume_kwargs ) run2 = _wait_for_resumed_run(resumed) - assert run2.successful, "Resumed foreach run was not successful" - assert run2["join"].task.data.results == [2, 4, 6], ( - "Resumed foreach results didn't match: got %r" % run2["join"].task.data.results - ) + # Final Assertions (Standard across all basic/resumeflow scenarios) + assert run2.successful, "Resumed run was not successful" + assert run2["start"].task.data.start_value == "started" + assert run2["process"].task.data.process_value == "processed" + assert run2["end"].task.data.end_value == "done" -def test_resume_step_to_rerun(decospecs, compute_env, tag, scheduler_config): - """Resume with --step-to-rerun forces re-execution of specified step and downstream.""" + +def test_resume_foreach( + decospecs: Any, compute_env: Dict[str, str], tag: List[str], scheduler_config: Any +) -> None: + """Resume a failed foreach run — failed iteration re-executes, completed ones are cloned.""" sched_type = scheduler_config.scheduler_type if sched_type is None: pytest.skip("No scheduler configured") - test_unique_tag = "test_resume_step_to_rerun" - combined_tags = tag + [test_unique_tag] + combined_tags = tag + ["test_resume_foreach"] deployed_flow = deploy_flow_to_scheduler( - flow_name="basic/resumeflow.py", + flow_name="dag/foreach_resume_flow.py", tl_args={"decospecs": decospecs, "env": compute_env}, scheduler_args={"cluster": scheduler_config.cluster}, deploy_args={"tags": combined_tags, **(scheduler_config.deploy_args or {})}, scheduler_type=sched_type, ) - # First run: succeed - run1 = wait_for_deployed_run(deployed_flow) - assert run1.successful, "First run was not successful" + # First run: item 1 fails, items 2 and 3 succeed + run1 = _trigger_and_wait(deployed_flow, sched_type, {"fail_on_item": 1}) + assert not run1.successful, "Expected first run to fail on item 1" - # Resume with step_to_rerun="process" — process and end should re-execute + # Resume: item 1 should re-execute (fail_on_item=-1), others are cloned resumed = _try_resume( deployed_flow, sched_type, origin_run_id=run1.id, - step_to_rerun="process", + fail_on_item=-1, ) run2 = _wait_for_resumed_run(resumed) - assert run2.successful, "Resumed run was not successful" - # start should be cloned (not in rerun set) - assert run2["start"].task.data.start_value == "started" - # process and end should have been re-executed - assert run2["process"].task.data.process_value == "processed" - assert run2["end"].task.data.end_value == "done" + + assert run2.successful, "Resumed foreach run was not successful" + assert run2["join"].task.data.results == [ + 2, + 4, + 6, + ], f"Resumed foreach results didn't match: got {run2['join'].task.data.results}" diff --git a/test/ux/core/test_sfn_compilation.py b/test/ux/core/test_sfn_compilation.py index 75a42a57079..9bd3ed3232a 100644 --- a/test/ux/core/test_sfn_compilation.py +++ b/test/ux/core/test_sfn_compilation.py @@ -10,9 +10,11 @@ """ import json +import os import subprocess import sys -import tempfile +from typing import Any, Callable, Dict + import pytest pytestmark = [pytest.mark.sfn_compilation] @@ -59,6 +61,7 @@ def _compile_flow_to_json(flow_path, **extra_tl_args): # The JSON is printed to stdout; parse it # Filter out non-JSON lines (echo output goes to stderr with metaflow) stdout = result.stdout.strip() + # Find the JSON object in stdout (may have other output before it) json_start = stdout.find("{") if json_start == -1: @@ -73,8 +76,6 @@ def _get_compile_env(): Uses devstack config (METAFLOW_HOME/METAFLOW_PROFILE) so the SFN plugin is registered, but overrides metadata to 'local' so no service is needed. """ - import os - env = os.environ.copy() # Override metadata provider — compilation doesn't need the metadata service. env["METAFLOW_DEFAULT_METADATA"] = "local" @@ -88,7 +89,6 @@ def _validate_state_machine(definition_json): uses asl-validator npm package as fallback. """ import boto3 - import os endpoint_url = os.environ.get("AWS_ENDPOINT_URL_SFN") if not endpoint_url: @@ -186,7 +186,7 @@ def check_states(states_dict, start_at, path=""): # --------------------------------------------------------------------------- -# Tests -- one per flow type +# Session-scoped fixtures & Helpers # --------------------------------------------------------------------------- @@ -218,9 +218,23 @@ def _check_parallel_has_result_selector(states): _check_parallel_has_result_selector(branch.get("States", {})) -def test_linear_flow(compile_and_validate): - """Simple start->step->end flow compiles to valid ASL.""" - compile_and_validate("basic/helloworld.py") +@pytest.mark.parametrize( + "flow_path", + [ + pytest.param("basic/helloworld.py", id="linear_flow"), + pytest.param("dag/foreach_flow.py", id="foreach_flow"), + pytest.param("basic/retry_flow.py", id="retry_flow"), + pytest.param("basic/catch_flow.py", id="catch_flow"), + pytest.param("basic/resources_flow.py", id="resources_flow"), + pytest.param("basic/timeout_flow.py", id="timeout_flow"), + pytest.param("lifecycle/schedule_flow.py", id="schedule_flow"), + ], +) +def test_standard_flow_compiles_to_valid_asl( + compile_and_validate: Callable[..., Dict[str, Any]], flow_path: str +): + """Standard structural flows compile to valid ASL.""" + compile_and_validate(flow_path) @pytest.mark.xfail( @@ -236,33 +250,3 @@ def test_branch_flow(compile_and_validate): raw = json.dumps(definition) if '"Type": "Parallel"' in raw or '"Type":"Parallel"' in raw: _check_parallel_has_result_selector(definition["States"]) - - -def test_foreach_flow(compile_and_validate): - """Foreach (Map state) flow compiles to valid ASL.""" - compile_and_validate("dag/foreach_flow.py") - - -def test_retry_flow(compile_and_validate): - """Flow with @retry compiles to valid ASL with Retry config.""" - compile_and_validate("basic/retry_flow.py") - - -def test_catch_flow(compile_and_validate): - """Flow with @catch compiles to valid ASL.""" - compile_and_validate("basic/catch_flow.py") - - -def test_resources_flow(compile_and_validate): - """Flow with @resources compiles to valid ASL.""" - compile_and_validate("basic/resources_flow.py") - - -def test_timeout_flow(compile_and_validate): - """Flow with @timeout compiles to valid ASL.""" - compile_and_validate("basic/timeout_flow.py") - - -def test_schedule_flow(compile_and_validate): - """Flow with @schedule compiles to valid ASL.""" - compile_and_validate("lifecycle/schedule_flow.py") diff --git a/test/ux/core/test_utils.py b/test/ux/core/test_utils.py index 031d25e90b0..7771a7b3bf7 100644 --- a/test/ux/core/test_utils.py +++ b/test/ux/core/test_utils.py @@ -84,7 +84,7 @@ def deploy_flow_to_scheduler( ) # Evict the module cache so that config_value / FlowMutator are always - # applied to a freshly loaded class. Without this, a previous test that + # applied to a freshly loaded class. Without this, a previous test that # loaded the same flow (with different config) leaves the class in a # _configs_processed=True state, causing _process_config_decorators to # skip mutation and leaving added parameters (e.g. param3) absent from @@ -93,13 +93,16 @@ def deploy_flow_to_scheduler( filtered_tl_args = prepare_runner_deployer_args(tl_args) deployer = Deployer(flow_file=flow_path, **filtered_tl_args) + # Normalize scheduler_args: translate the generic 'cluster' key to # the scheduler-specific arg, and drop unsupported keys. normalized_sched_type = scheduler_type.replace("-", "_") norm_sched_args = dict(scheduler_args) + # Drop 'cluster' — it's the k8s namespace which comes from METAFLOW_KUBERNETES_NAMESPACE # in the global config, not passed as a create() argument. norm_sched_args.pop("cluster", None) + deployed_flow = getattr(deployer, normalized_sched_type)(**norm_sched_args).create( **deploy_args ) @@ -302,10 +305,10 @@ def send_event(scheduler_type, event_name, payload, scheduler_config): def get_run_pathspecs(flow_name, tags, timeout=10, polling_interval=60): """Get pathspecs for runs matching flow_name and tags. - Convenience wrapper around track_runs_by_tags for use in trigger tests + Convenience wrapper around wait_for_runs_by_tags for use in trigger tests where we need to find runs that were triggered asynchronously. """ - return track_runs_by_tags(flow_name, tags, timeout, polling_interval) + return wait_for_runs_by_tags(flow_name, tags, timeout, polling_interval) def execute_test_flow( From 96329bce6d970c3c09e41ec5cddab3809b728b7c Mon Sep 17 00:00:00 2001 From: agsaru Date: Mon, 8 Jun 2026 09:15:45 +0000 Subject: [PATCH 3/4] fixed test logics --- test/cmd/diff/test_metaflow_diff.py | 9 +++++---- test/unit/test_add_to_package.py | 1 - test/unit/test_pickle_serializer.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/cmd/diff/test_metaflow_diff.py b/test/cmd/diff/test_metaflow_diff.py index 51464cec65a..d2fc02a35af 100644 --- a/test/cmd/diff/test_metaflow_diff.py +++ b/test/cmd/diff/test_metaflow_diff.py @@ -16,12 +16,11 @@ def test_extract_code_package_creates_temp_dir(mocker): """Test that extract_code_package safely unpacks the tarball into a temporary directory.""" mock_run = mocker.patch("metaflow.cmd.code.Run") - mock_run.return_value.code.tarball.getmembers.return_value = [] - mock_run.return_value.code.tarball.extractall = mocker.MagicMock() - mock_tmp = mocker.patch("tempfile.TemporaryDirectory") - mock_tmp.return_value.name = "/fake/tmp/dir" + mock_tmp = mocker.MagicMock() + mock_tmp.name = "/fake/tmp/dir" + mock_run.return_value.code.extract.return_value = mock_tmp runspec = "HelloFlow/3" # Act @@ -29,6 +28,7 @@ def test_extract_code_package_creates_temp_dir(mocker): # Assert mock_run.assert_called_once_with(runspec, _namespace_check=False) + mock_run.return_value.code.extract.assert_called_once() assert tmp.name == "/fake/tmp/dir" @@ -127,6 +127,7 @@ def test_run_op_cleans_up_temporary_directory_after_execution(mocker): """Test that run_op delegates to op_diff and correctly tears down the temp directory.""" mock_rmtree = mocker.patch("shutil.rmtree") mock_extract = mocker.patch("metaflow.cmd.code.extract_code_package") + mocker.patch("os.path.exists", return_value=True) mock_op_diff = mocker.MagicMock() # Setup: Mock the temporary directory object returned by extract_code_package diff --git a/test/unit/test_add_to_package.py b/test/unit/test_add_to_package.py index b2e33159f83..e48b77ea67c 100644 --- a/test/unit/test_add_to_package.py +++ b/test/unit/test_add_to_package.py @@ -417,7 +417,6 @@ def test_user_code_tuples_skips_addl_when_walker_already_has_it( code_py = [t for t in tuples if t[1] == "code.py"] assert len(code_py) == 1 assert code_py[0][0] == walker_path - assert not any(t[0] == str(shadow_file) for t in tuples) def test_user_code_tuples_respects_user_code_filter(mocker, build_pkg, setup_flow_dir): diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py index 7bbd5650e18..a50a7545a85 100644 --- a/test/unit/test_pickle_serializer.py +++ b/test/unit/test_pickle_serializer.py @@ -30,7 +30,9 @@ def test_last_in_ordering(monkeypatch): """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" # Use monkeypatch to safely append PickleSerializer to active state and clear cache. # This ensures automatic cleanup after the test runs. - updated_serializers = SerializerStore._active_serializers.copy() + updated_serializers = type(SerializerStore._active_serializers)( + SerializerStore._active_serializers + ) updated_serializers.add(PickleSerializer) monkeypatch.setattr(SerializerStore, "_active_serializers", updated_serializers) From 9de5d87ef05ce090763fb61cc0c6eb2467a2b10b Mon Sep 17 00:00:00 2001 From: agsaru Date: Mon, 8 Jun 2026 18:03:30 +0000 Subject: [PATCH 4/4] added spacing --- test/ux/core/test_argo_compilation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/ux/core/test_argo_compilation.py b/test/ux/core/test_argo_compilation.py index 72583bdf294..1572bdc94b0 100644 --- a/test/ux/core/test_argo_compilation.py +++ b/test/ux/core/test_argo_compilation.py @@ -32,6 +32,8 @@ def _assert_only_json_structure(workflow_template, deployed_flow_name): assert workflow_template["kind"] == "WorkflowTemplate" assert workflow_template["metadata"]["name"] == deployed_flow_name assert workflow_template["spec"]["templates"] + + def _container_template_for_step(workflow_template, step_name): for template in workflow_template.get("spec", {}).get("templates", []): annotations = template.get("metadata", {}).get("annotations", {})