From 47edaafb7fe9e1991cb9cadad4391b523fcebd4b Mon Sep 17 00:00:00 2001 From: agsaru Date: Mon, 8 Jun 2026 19:14:55 +0000 Subject: [PATCH 1/2] refactor(tests): storage, serializer, and cloud based unit tests to follow test conventions --- test/unit/test_argo_workflows_cli.py | 194 ++-- test/unit/test_artifact_serializer.py | 459 ++++----- test/unit/test_aws_util.py | 46 +- test/unit/test_compute_resource_attributes.py | 160 +-- test/unit/test_content_addressed_store.py | 55 +- test/unit/test_kubernetes.py | 50 +- test/unit/test_local_metadata_provider.py | 61 +- test/unit/test_pickle_serializer.py | 32 +- test/unit/test_s3_empty_input.py | 244 ++--- test/unit/test_s3_storage.py | 37 +- test/unit/test_serializer_integration.py | 382 +++---- test/unit/test_serializer_lifecycle.py | 928 +++++++----------- test/unit/test_serializer_public_api.py | 92 +- test/unit/test_to_pod.py | 75 +- 14 files changed, 1373 insertions(+), 1442 deletions(-) diff --git a/test/unit/test_argo_workflows_cli.py b/test/unit/test_argo_workflows_cli.py index a9fa8379132..c91c782f379 100644 --- a/test/unit/test_argo_workflows_cli.py +++ b/test/unit/test_argo_workflows_cli.py @@ -8,37 +8,15 @@ ArgoWorkflowsDeployedFlow, ) - -@pytest.mark.parametrize( - "name, expected", - [ - ("a-valid-name", "a-valid-name"), - ("removing---@+_characters@_+", "removing---characters"), - ("numb3rs-4r3-0k-123", "numb3rs-4r3-0k-123"), - ("proj3ct.br4nch.flow_name", "proj3ct.br4nch.flowname"), - # should not break RFC 1123 subdomain requirements, - # though trailing characters do not need to be sanitized due to a hash being appended to them. - ( - "---1breaking1---.--2subdomain2--.-3rules3----", - "1breaking1.2subdomain2.3rules3----", - ), - ( - "1brea---king1.2sub---domain2.-3ru-les3--", - "1brea---king1.2sub---domain2.3ru-les3--", - ), - ("project.branch-cut-short-.flowname", "project.branch-cut-short.flowname"), - ("test...name", "test.name"), - ], -) -def test_sanitize_for_argo(name, expected): - sanitized = sanitize_for_argo(name) - assert sanitized == expected +# --------------------------------------------------------------------------- +# Shared Fixtures +# --------------------------------------------------------------------------- @pytest.fixture def make_argo_with_schedule(): - """ - Factory fixture: returns a callable that builds a minimal ArgoWorkflows-like + """Factory fixture: returns a callable that builds a minimal ArgoWorkflows-like + object whose _get_schedule() can be called without instantiating the full class (which requires a live graph, environment, datastore, etc.). @@ -67,6 +45,47 @@ def argo_without_schedule(): return instance +# --------------------------------------------------------------------------- +# Argo Sanitization and Scheduling Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, expected", + [ + ("a-valid-name", "a-valid-name"), + ("removing---@+_characters@_+", "removing---characters"), + ("numb3rs-4r3-0k-123", "numb3rs-4r3-0k-123"), + ("proj3ct.br4nch.flow_name", "proj3ct.br4nch.flowname"), + # should not break RFC 1123 subdomain requirements, + # though trailing characters do not need to be sanitized due to a hash being appended to them. + ( + "---1breaking1---.--2subdomain2--.-3rules3----", + "1breaking1.2subdomain2.3rules3----", + ), + ( + "1brea---king1.2sub---domain2.-3ru-les3--", + "1brea---king1.2sub---domain2.3ru-les3--", + ), + ("project.branch-cut-short-.flowname", "project.branch-cut-short.flowname"), + ("test...name", "test.name"), + ], + ids=[ + "valid-string", + "strip-special-chars", + "alphanumeric-with-dashes", + "strip-subdomain-underscores", + "rfc1123-subdomain-edge-dashes", + "rfc1123-subdomain-internal-dashes", + "strip-trailing-dash-before-dot", + "collapse-consecutive-dots", + ], +) +def test_sanitize_for_argo(name, expected): + """Verify string processing safely conforms to Argo resource naming limitations.""" + assert sanitize_for_argo(name) == expected + + def test_get_schedule_no_decorator_returns_none(argo_without_schedule): """No @schedule decorator → (None, None).""" assert argo_without_schedule._get_schedule() == (None, None) @@ -94,70 +113,95 @@ def test_get_schedule_no_decorator_returns_none(argo_without_schedule): def test_get_schedule( make_argo_with_schedule, schedule_value, timezone_value, expected ): + """Verify cron parsing extraction, payload truncation, and timezone options.""" argo = make_argo_with_schedule( schedule_value=schedule_value, timezone_value=timezone_value ) assert argo._get_schedule() == expected -def test_trigger_explanation_no_schedule_does_not_claim_cronworkflow( +@pytest.mark.parametrize( + "has_decorator, schedule_value, internal_schedule, flow_name, expected_substrings, unexpected_substrings", + [ + (False, None, None, None, [], ["CronWorkflow"]), + (True, None, None, None, [], ["CronWorkflow"]), + ( + True, + "0 0 * * ? *", + "0 0 * * ?", + "myflow", + ["CronWorkflow", "myflow"], + [], + ), + ], + ids=[ + "no-schedule-decorator", + "schedule-decorator-resolves-to-none", + "active-schedule-claims-cronworkflow", + ], +) +def test_trigger_explanation_behavior( + make_argo_with_schedule, argo_without_schedule, + has_decorator, + schedule_value, + internal_schedule, + flow_name, + expected_substrings, + unexpected_substrings, ): - """With no schedule, trigger_explanation() must not mention CronWorkflow.""" - argo_without_schedule._schedule = None - argo_without_schedule.triggers = [] - result = argo_without_schedule.trigger_explanation() - assert "CronWorkflow" not in result - + """Verify if trigger explanations correctly list or omit CronWorkflow rules based on state.""" + argo = ( + make_argo_with_schedule(schedule_value=schedule_value) + if has_decorator + else argo_without_schedule + ) -def test_trigger_explanation_schedule_none_does_not_claim_cronworkflow( - make_argo_with_schedule, -): - """ - When @schedule is present but resolved to None, trigger_explanation() - must not claim the workflow triggers via a CronWorkflow. - """ - argo = make_argo_with_schedule(schedule_value=None) - argo._schedule = None # mirrors what _get_schedule() would set + argo._schedule = internal_schedule argo.triggers = [] - result = argo.trigger_explanation() - assert "CronWorkflow" not in result + if flow_name: + argo.name = flow_name - -def test_trigger_explanation_active_schedule_claims_cronworkflow( - make_argo_with_schedule, -): - """When a real schedule is set, trigger_explanation() names the CronWorkflow.""" - argo = make_argo_with_schedule(schedule_value="0 0 * * ? *") - argo._schedule = "0 0 * * ?" - argo.name = "myflow" result = argo.trigger_explanation() - assert result is not None - assert "CronWorkflow" in result - assert "myflow" in result - - -def test_deployed_flow_workflow_template_returns_only_json_payload(): - workflow_template = {"kind": "WorkflowTemplate", "metadata": {"name": "myflow"}} - deployer = types.SimpleNamespace( - name="myflow", - flow_name="MyFlow", - metadata="local@user:test", - additional_info={"workflow_template": workflow_template}, - ) - deployed_flow = ArgoWorkflowsDeployedFlow(deployer) + for substring in expected_substrings: + assert substring in result + for substring in unexpected_substrings: + assert substring not in result - assert deployed_flow.workflow_template == workflow_template +# --------------------------------------------------------------------------- +# Deployed Flow Object Tests +# --------------------------------------------------------------------------- -def test_deployed_flow_workflow_template_returns_none_without_payload(): - deployer = types.SimpleNamespace( - name="myflow", - flow_name="MyFlow", - metadata="local@user:test", - ) +@pytest.mark.parametrize( + "additional_info, expected_template", + [ + ( + { + "workflow_template": { + "kind": "WorkflowTemplate", + "metadata": {"name": "myflow"}, + } + }, + {"kind": "WorkflowTemplate", "metadata": {"name": "myflow"}}, + ), + (None, None), + ], + ids=["with-json-payload", "without-payload"], +) +def test_deployed_flow_workflow_template_resolution(additional_info, expected_template): + """Verify workflow template extraction handles missing or present payloads cleanly.""" + fields = { + "name": "myflow", + "flow_name": "MyFlow", + "metadata": "local@user:test", + } + if additional_info is not None: + fields["additional_info"] = additional_info + + deployer = types.SimpleNamespace(**fields) deployed_flow = ArgoWorkflowsDeployedFlow(deployer) - assert deployed_flow.workflow_template is None + assert deployed_flow.workflow_template == expected_template diff --git a/test/unit/test_artifact_serializer.py b/test/unit/test_artifact_serializer.py index 80b2c15dcb0..6033c64ffca 100644 --- a/test/unit/test_artifact_serializer.py +++ b/test/unit/test_artifact_serializer.py @@ -8,17 +8,19 @@ SerializerStore, ) +# --------------------------------------------------------------------------- +# Registry Isolation Setup & Fixtures +# --------------------------------------------------------------------------- # Snapshot the registry before this module's classes are defined. Module-level -# test serializers (_HighPrioritySerializer, ...) self-register at class -# definition time; the module-scoped fixture below removes them at teardown so -# other test modules see an unpolluted registry. +# test serializers self-register at class definition time; the module-scoped +# fixture below removes them at teardown so other modules see an unpolluted registry. _PRE_IMPORT_SNAPSHOT = dict(SerializerStore._all_serializers) _PRE_IMPORT_ACTIVE_SNAPSHOT = set(SerializerStore._active_serializers) @pytest.fixture(scope="module", autouse=True) -def _restore_serializer_registry(): +def _restore_module_serializer_registry(): yield SerializerStore._all_serializers.clear() SerializerStore._all_serializers.update(_PRE_IMPORT_SNAPSHOT) @@ -27,8 +29,21 @@ def _restore_serializer_registry(): SerializerStore._ordered_cache = None +@pytest.fixture +def clean_store(): + """Fixture to cleanly revert mutations to the SerializerStore registry per test.""" + all_snapshot = dict(SerializerStore._all_serializers) + active_snapshot = set(SerializerStore._active_serializers) + yield + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(all_snapshot) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(active_snapshot) + SerializerStore._ordered_cache = None + + # --------------------------------------------------------------------------- -# Helpers — test serializer subclasses defined inside the test module +# Helpers — Shared Test Serializer Subclasses # --------------------------------------------------------------------------- @@ -105,9 +120,39 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError -# Dispatch is now driven by _active_serializers (post-Phase-6). The metaclass -# only populates _all_serializers; tests that assert against the ordered -# dispatch list must also mark their classes as active. +class _DualFormatSerializer(ArtifactSerializer): + """Toy serializer that implements both formats for str objects.""" + + TYPE = "test_dual_format" + PRIORITY = 40 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_dual_format" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return obj + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_dual_format", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return data + return data[0].decode("utf-8") + + +# Dispatch is driven by _active_serializers. The metaclass populates _all_serializers; +# tests asserting against the ordered dispatch list must also mark these classes active. SerializerStore._active_serializers.add(_HighPrioritySerializer) SerializerStore._active_serializers.add(_LowPrioritySerializer) SerializerStore._active_serializers.add(_SamePrioritySerializer) @@ -115,7 +160,7 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): # --------------------------------------------------------------------------- -# SerializerStore tests +# SerializerStore Tests # --------------------------------------------------------------------------- @@ -132,37 +177,30 @@ def test_base_class_not_registered(): assert None not in SerializerStore._all_serializers -def test_re_registration_overwrites(): +def test_re_registration_overwrites(clean_store): """A second class with the same TYPE overwrites the first (notebook-friendly).""" - original = SerializerStore._all_serializers["test_high"] - try: - class _ReplacementSerializer(ArtifactSerializer): - TYPE = "test_high" # same as _HighPrioritySerializer - PRIORITY = 1 + class _ReplacementSerializer(ArtifactSerializer): + TYPE = "test_high" # same as _HighPrioritySerializer + PRIORITY = 1 - @classmethod - def can_serialize(cls, obj): - return False + @classmethod + def can_serialize(cls, obj): + return False - @classmethod - def can_deserialize(cls, metadata): - return False + @classmethod + def can_deserialize(cls, metadata): + return False - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError - @classmethod - def deserialize( - cls, data, metadata=None, format=SerializationFormat.STORAGE - ): - raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError - assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer - finally: - SerializerStore._all_serializers["test_high"] = original - SerializerStore._ordered_cache = None + assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer def test_priority_ordering(): @@ -172,7 +210,7 @@ def test_priority_ordering(): assert priorities == sorted(priorities) -def test_priority_tie_last_wins(): +def test_priority_tie_last_wins(clean_store): """When PRIORITY is equal, last-registered wins the tie.""" class _TieFirst(ArtifactSerializer): @@ -215,24 +253,19 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - SerializerStore._active_serializers.add(_TieFirst) - SerializerStore._active_serializers.add(_TieSecond) - SerializerStore._ordered_cache = None - ordered = SerializerStore.get_ordered_serializers() - idx_first = ordered.index(_TieFirst) - idx_second = ordered.index(_TieSecond) - # _TieSecond was registered LAST, so it should appear BEFORE _TieFirst. - assert idx_second < idx_first, ( - "Expected last-registered (_TieSecond) to come first; got " - "_TieFirst at index %d, _TieSecond at index %d" % (idx_first, idx_second) - ) - finally: - SerializerStore._all_serializers.pop("test_tie_first", None) - SerializerStore._all_serializers.pop("test_tie_second", None) - SerializerStore._active_serializers.discard(_TieFirst) - SerializerStore._active_serializers.discard(_TieSecond) - SerializerStore._ordered_cache = None + SerializerStore._active_serializers.add(_TieFirst) + SerializerStore._active_serializers.add(_TieSecond) + SerializerStore._ordered_cache = None + + ordered = SerializerStore.get_ordered_serializers() + idx_first = ordered.index(_TieFirst) + idx_second = ordered.index(_TieSecond) + + # _TieSecond was registered LAST, so it should appear BEFORE _TieFirst. + assert idx_second < idx_first, ( + f"Expected last-registered (_TieSecond) to come first; got " + f"_TieFirst at index {idx_first}, _TieSecond at index {idx_second}" + ) def test_deterministic_ordering(): @@ -249,12 +282,39 @@ def test_high_priority_before_low(): assert types.index("test_high") < types.index("test_low") +def test_priority_tie_lexicographic_fallback(): + """When PRIORITY and registration index both tie (simulated), class_path lex-sort wins.""" + + # Within a single process, registration indices are unique. To test the tertiary + # key fallback, we manually simulate identical prefixes on the internal sort key. + class _AClass: + __module__ = "z.module" + __qualname__ = "AClass" + PRIORITY = 100 + + class _BClass: + __module__ = "a.module" + __qualname__ = "BClass" + PRIORITY = 100 + + keys = [ + (_AClass.PRIORITY, 0, f"{_AClass.__module__}.{_AClass.__qualname__}"), + (_BClass.PRIORITY, 0, f"{_BClass.__module__}.{_BClass.__qualname__}"), + ] + sorted_keys = sorted(keys) + + # "a.module.BClass" < "z.module.AClass" lexicographically + assert sorted_keys[0][2] == "a.module.BClass" + assert sorted_keys[1][2] == "z.module.AClass" + + # --------------------------------------------------------------------------- -# SerializationMetadata tests +# SerializationMetadata Tests # --------------------------------------------------------------------------- def test_metadata_fields(): + """Verify attributes are assigned and mapped properly on metadata container.""" meta = SerializationMetadata( obj_type="dict", size=1024, @@ -268,105 +328,78 @@ def test_metadata_fields(): def test_metadata_is_namedtuple(): + """Verify that SerializationMetadata preserves namedtuple traits.""" meta = SerializationMetadata("str", 10, "utf-8", {}) assert isinstance(meta, tuple) assert len(meta) == 4 # --------------------------------------------------------------------------- -# SerializedBlob tests +# SerializedBlob Tests # --------------------------------------------------------------------------- -def test_blob_bytes_auto_detect(): - """bytes value auto-detects as not a reference.""" - blob = SerializedBlob(b"hello") - assert blob.is_reference is False - assert blob.needs_save is True - - -def test_blob_str_auto_detect(): - """str value auto-detects as a reference.""" - blob = SerializedBlob("sha1_key_abc123") - assert blob.is_reference is True - assert blob.needs_save is False - - -def test_blob_explicit_is_reference_override(): - """Explicit is_reference overrides auto-detection.""" - # bytes but marked as reference (edge case) - blob = SerializedBlob(b"data", is_reference=True) - assert blob.is_reference is True - assert blob.needs_save is False - - # str but marked as not a reference (edge case) - blob = SerializedBlob("inline_data", is_reference=False) - assert blob.is_reference is False - assert blob.needs_save is True - - -def test_blob_value_preserved(): - data = b"\x00\x01\x02\x03" +@pytest.mark.parametrize( + "value, kwargs, expected_is_reference, expected_needs_save", + [ + (b"hello", {}, False, True), + ("sha1_key_abc123", {}, True, False), + (b"data", {"is_reference": True}, True, False), + ("inline_data", {"is_reference": False}, False, True), + ], + ids=[ + "bytes-auto-detect-payload", + "str-auto-detect-reference", + "bytes-explicit-reference-override", + "str-explicit-payload-override", + ], +) +def test_blob_reference_detection( + value, kwargs, expected_is_reference, expected_needs_save +): + """Verify SerializedBlob reference tracking and explicit override behaviors.""" + blob = SerializedBlob(value, **kwargs) + assert blob.is_reference is expected_is_reference + assert blob.needs_save is expected_needs_save + + +@pytest.mark.parametrize( + "data", + [b"\x00\x01\x02\x03", "abc123def456"], + ids=["bytes-payload", "str-reference"], +) +def test_blob_value_preserved(data): + """Verify values given to the SerializedBlob initialization are kept identical.""" blob = SerializedBlob(data) assert blob.value is data - key = "abc123def456" - blob = SerializedBlob(key) - assert blob.value is key - -def test_blob_rejects_invalid_types(): +@pytest.mark.parametrize( + "bad_value", + [123, 3.14, None, [], {}], + ids=["int", "float", "none", "list", "dict"], +) +def test_blob_rejects_invalid_types(bad_value): """SerializedBlob must be str or bytes — reject everything else.""" - for bad_value in [123, 3.14, None, [], {}]: - with pytest.raises(TypeError, match="must be str or bytes"): - SerializedBlob(bad_value) + with pytest.raises(TypeError, match="must be str or bytes"): + SerializedBlob(bad_value) # --------------------------------------------------------------------------- -# Wire vs storage format dispatch +# Wire vs Storage Format Dispatch # --------------------------------------------------------------------------- -class _DualFormatSerializer(ArtifactSerializer): - """Toy serializer that implements both formats for str objects.""" - - TYPE = "test_dual_format" - PRIORITY = 40 - - @classmethod - def can_serialize(cls, obj): - return isinstance(obj, str) - - @classmethod - def can_deserialize(cls, metadata): - return metadata.encoding == "test_dual_format" - - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - if format == SerializationFormat.WIRE: - return obj - blob = obj.encode("utf-8") - return ( - [SerializedBlob(blob)], - SerializationMetadata("str", len(blob), "test_dual_format", {}), - ) - - @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - if format == SerializationFormat.WIRE: - return data - return data[0].decode("utf-8") - - def test_format_enum_values(): + """Verify the string mapping properties of the SerializationFormat enum.""" assert SerializationFormat.STORAGE.value == "storage" assert SerializationFormat.WIRE.value == "wire" - # str-backed Enum, so direct string comparison still works. assert SerializationFormat.STORAGE == "storage" assert SerializationFormat.WIRE == "wire" def test_dual_format_storage_roundtrip(): + """Verify the basic storage format workflow serialization loop.""" blobs, meta = _DualFormatSerializer.serialize("hello") assert meta.encoding == "test_dual_format" assert ( @@ -376,6 +409,7 @@ def test_dual_format_storage_roundtrip(): def test_dual_format_wire_roundtrip(): + """Verify the wire format workflow serialization loop.""" wire = _DualFormatSerializer.serialize("hello", format=SerializationFormat.WIRE) assert isinstance(wire, str) assert ( @@ -385,6 +419,7 @@ def test_dual_format_wire_roundtrip(): def test_pickle_serializer_rejects_wire(): + """Verify standard PickleSerializer errors explicitly when utilizing the wire format.""" from metaflow.plugins.datastores.serializers.pickle_serializer import ( PickleSerializer, ) @@ -395,36 +430,12 @@ def test_pickle_serializer_rejects_wire(): PickleSerializer.deserialize("42", format=SerializationFormat.WIRE) -def test_priority_tie_lexicographic_fallback(): - """When PRIORITY and registration index both tie (simulated), class_path lex-sort wins.""" - - # Within a single process, registration indices are always unique, so - # to actually exercise the tertiary key we construct two classes with - # identical (PRIORITY, registration_index) by manipulating the combined - # list passed to the sort logic. We do this by calling the internal - # sort key directly. - class _AClass: - __module__ = "z.module" - __qualname__ = "AClass" - PRIORITY = 100 - - class _BClass: - __module__ = "a.module" - __qualname__ = "BClass" - PRIORITY = 100 - - # Same registration index (simulated): the (priority, -idx) prefix ties. - keys = [ - (_AClass.PRIORITY, 0, "%s.%s" % (_AClass.__module__, _AClass.__qualname__)), - (_BClass.PRIORITY, 0, "%s.%s" % (_BClass.__module__, _BClass.__qualname__)), - ] - sorted_keys = sorted(keys) - # "a.module.BClass" < "z.module.AClass" lexicographically - assert sorted_keys[0][2] == "a.module.BClass" - assert sorted_keys[1][2] == "z.module.AClass" +# --------------------------------------------------------------------------- +# Setup and Lazy Import Infrastructure Tests +# --------------------------------------------------------------------------- -def test_setup_imports_default_is_noop(): +def test_setup_imports_default_is_noop(clean_store): """Default setup_imports should be callable and do nothing.""" class _NoOverride(ArtifactSerializer): @@ -446,21 +457,23 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - result = _NoOverride.setup_imports() - assert result is None - result = _NoOverride.setup_imports(context="anything") - assert result is None - finally: - SerializerStore._all_serializers.pop("test_no_override", None) - SerializerStore._ordered_cache = None + assert _NoOverride.setup_imports() is None + assert _NoOverride.setup_imports(context="anything") is None -def test_lazy_import_happy_path(): - """lazy_import imports the module, stashes on cls at the leaf alias, and returns it.""" +@pytest.mark.parametrize( + "module_name, alias, target_attr", + [ + ("json", None, "json"), + ("json", "j", "j"), + ], + ids=["default-leaf-alias", "custom-alias"], +) +def test_lazy_import_success(clean_store, module_name, alias, target_attr): + """lazy_import imports the module, stashes on cls at the given alias, and returns it.""" - class _LazyOk(ArtifactSerializer): - TYPE = "test_lazy_ok" + class _LazyTarget(ArtifactSerializer): + TYPE = "test_lazy_target" @classmethod def can_serialize(cls, obj): @@ -478,54 +491,30 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - mod = _LazyOk.lazy_import("json") - import json as _json - - assert mod is _json - assert _LazyOk.json is _json - finally: - SerializerStore._all_serializers.pop("test_lazy_ok", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyOk, "json"): - delattr(_LazyOk, "json") + import json as _json + kwargs = {"alias": alias} if alias else {} + mod = _LazyTarget.lazy_import(module_name, **kwargs) -def test_lazy_import_custom_alias(): - """alias= overrides the default leaf-name stash key.""" + assert mod is _json + assert getattr(_LazyTarget, target_attr) is _json - class _LazyAlias(ArtifactSerializer): - TYPE = "test_lazy_alias" - @classmethod - def can_serialize(cls, obj): - return False - - @classmethod - def can_deserialize(cls, metadata): - return False - - @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError - - @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError - - try: - _LazyAlias.lazy_import("json", alias="j") - import json as _json - - assert _LazyAlias.j is _json - finally: - SerializerStore._all_serializers.pop("test_lazy_alias", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyAlias, "j"): - delattr(_LazyAlias, "j") - - -def test_lazy_import_rejects_reserved_names(): +@pytest.mark.parametrize( + "bad_alias", + [ + "TYPE", + "PRIORITY", + "serialize", + "deserialize", + "can_serialize", + "can_deserialize", + "setup_imports", + "lazy_import", + "_secret", + ], +) +def test_lazy_import_rejects_reserved_names(clean_store, bad_alias): """Attempting to shadow TYPE / PRIORITY / dispatch methods raises.""" class _LazyReserved(ArtifactSerializer): @@ -547,26 +536,11 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - for bad in [ - "TYPE", - "PRIORITY", - "serialize", - "deserialize", - "can_serialize", - "can_deserialize", - "setup_imports", - "lazy_import", - "_secret", - ]: - with pytest.raises(ValueError, match="reserved or invalid"): - _LazyReserved.lazy_import("json", alias=bad) - finally: - SerializerStore._all_serializers.pop("test_lazy_reserved", None) - SerializerStore._ordered_cache = None - - -def test_lazy_import_rejects_double_assignment(): + with pytest.raises(ValueError, match="reserved or invalid"): + _LazyReserved.lazy_import("json", alias=bad_alias) + + +def test_lazy_import_rejects_double_assignment(clean_store): """Calling lazy_import twice with the same alias on the same cls raises.""" class _LazyDup(ArtifactSerializer): @@ -588,20 +562,12 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - _LazyDup.lazy_import("json") - with pytest.raises(ValueError, match="already set"): - _LazyDup.lazy_import("sys", alias="json") - finally: - SerializerStore._all_serializers.pop("test_lazy_dup", None) - SerializerStore._ordered_cache = None - if hasattr(_LazyDup, "json"): - delattr(_LazyDup, "json") - if hasattr(_LazyDup, "_lazy_imported_names"): - delattr(_LazyDup, "_lazy_imported_names") + _LazyDup.lazy_import("json") + with pytest.raises(ValueError, match="already set"): + _LazyDup.lazy_import("sys", alias="json") -def test_setup_imports_accepts_both_signatures(): +def test_setup_imports_accepts_both_signatures(clean_store): """Bootstrap calls setup_imports correctly whether author writes (cls) or (cls, context=None).""" class _OneArg(ArtifactSerializer): @@ -654,13 +620,8 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): from metaflow.datastore.artifacts.serializer import _call_setup_imports - try: - _call_setup_imports(_OneArg, context=None) - assert _OneArg.called is True + _call_setup_imports(_OneArg, context=None) + assert _OneArg.called is True - _call_setup_imports(_TwoArg, context="some-ctx") - assert _TwoArg.called_with == "some-ctx" - finally: - SerializerStore._all_serializers.pop("test_setup_one_arg", None) - SerializerStore._all_serializers.pop("test_setup_two_arg", None) - SerializerStore._ordered_cache = None + _call_setup_imports(_TwoArg, context="some-ctx") + assert _TwoArg.called_with == "some-ctx" diff --git a/test/unit/test_aws_util.py b/test/unit/test_aws_util.py index a212996c4eb..f28ee1247d7 100644 --- a/test/unit/test_aws_util.py +++ b/test/unit/test_aws_util.py @@ -1,3 +1,4 @@ +import contextlib import pytest from metaflow.plugins.aws.aws_utils import validate_aws_tag @@ -8,32 +9,29 @@ [ ("test", "value", False), ("test-with@chars+ - = ._/", "value@with.chars-+ - = ._/", False), - ( - "a" * 128, - "ok", - False, - ), # <=128 char key should work. - ("a" * 129, "ok", True), # >128 char key should fail. - ( - "ok", - "a" * 256, - False, - ), # <=256 char value should work. - ("ok", "a" * 257, True), # >256 char value should fail. + ("a" * 128, "ok", False), # <=128 char key should work + ("a" * 129, "ok", True), # >128 char key should fail + ("ok", "a" * 256, False), # <=256 char value should work + ("ok", "a" * 257, True), # >256 char value should fail ("aWs:not-allowed", "ok", True), # 'aws:' prefix should not be allowed as key ("ok", "AWS:not-allowed", True), # 'aws:' prefix should not be allowed as value - ( - "ok-aws:", - "middleaWs:not-allowed", - False, - ), # 'aws:' itself is not a restricted pattern + ("ok-aws:", "middleaWs:not-allowed", False), # 'aws:' substring itself is fine + ], + ids=[ + "simple-valid", + "allowed-special-chars", + "key-max-length-128", + "key-length-exceeded-129", + "value-max-length-256", + "value-length-exceeded-257", + "aws-prefix-key-rejected", + "aws-prefix-value-rejected", + "aws-substring-allowed", ], ) -def test_validate_aws_tag(key, value, should_raise): - did_raise = False - try: - validate_aws_tag(key, value) - except Exception as e: - did_raise = True +def test_aws_tag_validation_rules(key, value, should_raise): + """Verify AWS tag validation enforces character sets, length limits, and prefix restrictions.""" + expectation = pytest.raises(Exception) if should_raise else contextlib.nullcontext() - assert did_raise == should_raise + with expectation: + validate_aws_tag(key, value) diff --git a/test/unit/test_compute_resource_attributes.py b/test/unit/test_compute_resource_attributes.py index adb21c521b5..6d052b4f6dd 100644 --- a/test/unit/test_compute_resource_attributes.py +++ b/test/unit/test_compute_resource_attributes.py @@ -1,77 +1,95 @@ from collections import namedtuple -from metaflow.plugins.aws.aws_utils import compute_resource_attributes - - -MockDeco = namedtuple("MockDeco", ["name", "attributes"]) - - -def test_compute_resource_attributes(): - - # use default if nothing is set - assert compute_resource_attributes([], MockDeco("batch", {}), {"cpu": "1"}) == { - "cpu": "1" - } - - # @batch overrides default and you can use ints as attributes - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": 1}), {"cpu": "2"} - ) == {"cpu": "1"} - # Same but value set as str not int - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": "1"}), {"cpu": "2"} - ) == {"cpu": "1"} +import pytest - # same but use default memory - assert compute_resource_attributes( - [], MockDeco("batch", {"cpu": "1"}), {"cpu": "2", "memory": "100"} - ) == {"cpu": "1", "memory": "100"} - - # same but cpu set via @resources - assert compute_resource_attributes( - [], MockDeco("resources", {"cpu": "1"}), {"cpu": "2", "memory": "100"} - ) == {"cpu": "1", "memory": "100"} - - # take largest of @resources and @batch if both are present - assert compute_resource_attributes( - [MockDeco("resources", {"cpu": "2"})], - MockDeco("batch", {"cpu": 1}), - {"cpu": "3"}, - ) == {"cpu": "2.0"} - - # take largest of @resources and @batch if both are present - assert compute_resource_attributes( - [MockDeco("resources", {"cpu": 0.83})], - MockDeco("batch", {"cpu": "0.5"}), - {"cpu": "1"}, - ) == {"cpu": "0.83"} - - -def test_compute_resource_attributes_string(): - """Test string-valued resource attributes""" - - # if default is None and the value is not set in @batch, the value is not included in computed attributes in the end - assert compute_resource_attributes( - [], MockDeco("batch", {}), {"cpu": "1", "instance_type": None} - ) == {"cpu": "1"} +from metaflow.plugins.aws.aws_utils import compute_resource_attributes - # use string value from deco if set (default is None) - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": "p3.xlarge"}), - {"cpu": "1", "instance_type": None}, - ) == {"cpu": "1", "instance_type": "p3.xlarge"} +MockDeco = namedtuple("MockDeco", ["name", "attributes"]) - # use string value from deco if set (default is not None) - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": "p3.xlarge"}), - {"cpu": "1", "instance_type": "p4.xlarge"}, - ) == {"cpu": "1", "instance_type": "p3.xlarge"} - # use string value from defaults if @batch has it set to None - assert compute_resource_attributes( - [], - MockDeco("batch", {"instance_type": None}), - {"cpu": "1", "instance_type": "p4.xlarge"}, - ) == {"cpu": "1", "instance_type": "p4.xlarge"} +@pytest.mark.parametrize( + "decorators, primary_deco, defaults, expected", + [ + # --- Numeric attribute resolution --- + # use default if nothing is set + ([], MockDeco("batch", {}), {"cpu": "1"}, {"cpu": "1"}), + # @batch overrides default and you can use ints as attributes + ([], MockDeco("batch", {"cpu": 1}), {"cpu": "2"}, {"cpu": "1"}), + # Same but value set as str not int + ([], MockDeco("batch", {"cpu": "1"}), {"cpu": "2"}, {"cpu": "1"}), + # same but use default memory + ( + [], + MockDeco("batch", {"cpu": "1"}), + {"cpu": "2", "memory": "100"}, + {"cpu": "1", "memory": "100"}, + ), + # same but cpu set via @resources + ( + [], + MockDeco("resources", {"cpu": "1"}), + {"cpu": "2", "memory": "100"}, + {"cpu": "1", "memory": "100"}, + ), + # --- Max/Largest resource resolution across decorators --- + # take largest of @resources and @batch if both are present + ( + [MockDeco("resources", {"cpu": "2"})], + MockDeco("batch", {"cpu": 1}), + {"cpu": "3"}, + {"cpu": "2.0"}, + ), + # take largest of @resources and @batch if both are present (floats) + ( + [MockDeco("resources", {"cpu": 0.83})], + MockDeco("batch", {"cpu": "0.5"}), + {"cpu": "1"}, + {"cpu": "0.83"}, + ), + # --- String attribute resolution --- + # if default is None and the value is not set in @batch, it is omitted + ( + [], + MockDeco("batch", {}), + {"cpu": "1", "instance_type": None}, + {"cpu": "1"}, + ), + # use string value from deco if set (default is None) + ( + [], + MockDeco("batch", {"instance_type": "p3.xlarge"}), + {"cpu": "1", "instance_type": None}, + {"cpu": "1", "instance_type": "p3.xlarge"}, + ), + # use string value from deco if set (default is not None) + ( + [], + MockDeco("batch", {"instance_type": "p3.xlarge"}), + {"cpu": "1", "instance_type": "p4.xlarge"}, + {"cpu": "1", "instance_type": "p3.xlarge"}, + ), + # use string value from defaults if @batch has it set to None + ( + [], + MockDeco("batch", {"instance_type": None}), + {"cpu": "1", "instance_type": "p4.xlarge"}, + {"cpu": "1", "instance_type": "p4.xlarge"}, + ), + ], + ids=[ + "fallback-to-default", + "batch-int-overrides-default", + "batch-str-overrides-default", + "merge-batch-with-default-memory", + "merge-resources-with-default-memory", + "resolve-max-between-resources-and-batch-int", + "resolve-max-between-resources-and-batch-float", + "omit-none-default-if-missing-in-batch", + "batch-str-overrides-none-default", + "batch-str-overrides-str-default", + "none-in-batch-falls-back-to-default-str", + ], +) +def test_compute_resource_attributes(decorators, primary_deco, defaults, expected): + """Verify that resource attributes are correctly merged and resolved from decorators and defaults.""" + assert compute_resource_attributes(decorators, primary_deco, defaults) == expected diff --git a/test/unit/test_content_addressed_store.py b/test/unit/test_content_addressed_store.py index 2042f8e4216..097a5971eed 100644 --- a/test/unit/test_content_addressed_store.py +++ b/test/unit/test_content_addressed_store.py @@ -1,17 +1,25 @@ -from contextlib import contextmanager +import contextlib +from pathlib import Path import pytest from metaflow.datastore.content_addressed_store import ContentAddressedStore from metaflow.datastore.exceptions import DataException +# --------------------------------------------------------------------------- +# Mocks & Helpers +# --------------------------------------------------------------------------- -@contextmanager + +@contextlib.contextmanager def _loaded_bytes(entries): + """Context manager to simulate loading bytes iteratively.""" yield iter(entries) -class _FakeStorageImpl(object): +class _FakeStorageImpl: + """A minimal fake storage implementation to support CAS loading tests.""" + TYPE = "fake" def __init__(self, entries): @@ -27,26 +35,33 @@ def path_split(path): @staticmethod def full_uri(path): - return "fake://" + path + return f"fake://{path}" def load_bytes(self, paths): - expected_paths = [entry[0] for entry in self._entries] - assert set(expected_paths).issubset( - set(paths) - ), "expected paths %s not all in %s" % (expected_paths, paths) + expected_paths = {entry[0] for entry in self._entries} + assert expected_paths.issubset( + paths + ), f"expected paths {expected_paths} not all in {paths}" return _loaded_bytes(self._entries) def _make_store(entries): + """Helper to initialize a ContentAddressedStore with fake storage.""" return ContentAddressedStore("prefix", _FakeStorageImpl(entries)) -def _write_blob_file(tmp_path, name="blob.bin", data=b"not-a-valid-gzip-stream"): +def _write_blob_file(tmp_path: Path, name="blob.bin", data=b"not-a-valid-gzip-stream"): + """Helper to write a temporary binary blob file.""" blob_file = tmp_path / name blob_file.write_bytes(data) return str(blob_file) +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "meta, unpack_error, expected_substrings", [ @@ -62,15 +77,20 @@ def _write_blob_file(tmp_path, name="blob.bin", data=b"not-a-valid-gzip-stream") ["Could not unpack artifact", "boom"], ), ], - ids=["missing_version", "unknown_version", "unpack_failure"], + ids=["missing-version", "unknown-version", "unpack-failure"], ) def test_load_blobs_error_message_uses_current_path_key( tmp_path, monkeypatch, meta, unpack_error, expected_substrings ): + """ + Verify that load_blobs error messages accurately reference the *current* + failing path, preventing regression of a bug where a stale outer-loop + variable caused misleading exception messages. + """ stale_key = "aaaaaaaaaa" current_key = "bbbbbbbbbb" - stale_path = "prefix/aa/%s" % stale_key - current_path = "prefix/bb/%s" % current_key + stale_path = f"prefix/aa/{stale_key}" + current_path = f"prefix/bb/{current_key}" file_path = _write_blob_file(tmp_path) store = _make_store([(current_path, file_path, meta)]) @@ -89,7 +109,14 @@ def _raise_unpack_error(_fileobj): list(store.load_blobs([current_key, stale_key])) message = str(exc.value) - assert current_path in message - assert stale_path not in message + + # Assertions + assert ( + current_path in message + ), f"Expected active path {current_path} in error message." + assert ( + stale_path not in message + ), f"Stale path {stale_path} leaked into the error message." + for expected in expected_substrings: assert expected in message diff --git a/test/unit/test_kubernetes.py b/test/unit/test_kubernetes.py index 6ede190d428..2eb8d1448e5 100644 --- a/test/unit/test_kubernetes.py +++ b/test/unit/test_kubernetes.py @@ -7,6 +7,11 @@ ) +# --------------------------------------------------------------------------- +# validate_kube_labels +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "labels", [ @@ -25,7 +30,7 @@ "1234567890" "1234567890" "123" - ) + ) # 63 characters (max valid length) }, { "label": ( @@ -39,8 +44,18 @@ ) }, ], + ids=[ + "none", + "single_label", + "multiple_labels", + "none_value", + "single_char", + "empty_string", + "max_length_63_chars", + "max_length_with_allowed_special_chars", + ], ) -def test_kubernetes_decorator_validate_kube_labels(labels): +def test_validate_kube_labels_accepts_valid_inputs(labels): assert validate_kube_labels(labels) @@ -59,18 +74,31 @@ def test_kubernetes_decorator_validate_kube_labels(labels): "1234567890" "1234567890" "1234" - ) + ) # 64 characters (exceeds max length) }, {"label": "(){}??"}, {"valid": "test", "invalid": "bißchen"}, ], + ids=[ + "ends_with_hyphen", + "starts_with_dot", + "invalid_chars_parentheses", + "exceeds_max_length_64_chars", + "only_invalid_chars", + "invalid_unicode_chars", + ], ) -def test_kubernetes_decorator_validate_kube_labels_fail(labels): +def test_validate_kube_labels_rejects_invalid_inputs(labels): """Fail if label contains invalid characters or is too long""" with pytest.raises(KubernetesException): validate_kube_labels(labels) +# --------------------------------------------------------------------------- +# parse_kube_keyvalue_list +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "items,requires_both,expected", [ @@ -79,8 +107,14 @@ def test_kubernetes_decorator_validate_kube_labels_fail(labels): (["key"], False, {"key": None}), (["key=value", "key2=value2"], True, {"key": "value", "key2": "value2"}), ], + ids=[ + "single_kv_requires_both", + "single_kv_optional_both", + "key_only_optional_both", + "multiple_kv_requires_both", + ], ) -def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): +def test_parse_kube_keyvalue_list_success(items, requires_both, expected): ret = parse_kube_keyvalue_list(items, requires_both) assert ret == expected @@ -91,7 +125,11 @@ def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): (["key=value", "key=value2"], True), (["key"], True), ], + ids=[ + "duplicate_keys_not_allowed", + "missing_value_when_requires_both", + ], ) -def test_kubernetes_parse_keyvalue_list(items, requires_both): +def test_parse_kube_keyvalue_list_raises_exception(items, requires_both): with pytest.raises(KubernetesException): parse_kube_keyvalue_list(items, requires_both) diff --git a/test/unit/test_local_metadata_provider.py b/test/unit/test_local_metadata_provider.py index be080c5f115..fb4de8c9000 100644 --- a/test/unit/test_local_metadata_provider.py +++ b/test/unit/test_local_metadata_provider.py @@ -1,31 +1,36 @@ +import pytest + from metaflow.plugins.metadata_providers.local import LocalMetadataProvider -def test_deduce_run_id_from_meta_dir(): - test_cases = [ - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/start/1/_meta", - "sub_type": "task", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/start/_meta", - "sub_type": "step", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/1652384326805262/_meta", - "sub_type": "run", - "expected_run_id": "1652384326805262", - }, - { - "meta_path": ".metaflow/BasicParameterTestFlow/_meta", - "sub_type": "flow", - "expected_run_id": None, - }, - ] - for case in test_cases: - actual_run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( - case["meta_path"], case["sub_type"] - ) - assert case["expected_run_id"] == actual_run_id +@pytest.mark.parametrize( + "meta_path, sub_type, expected_run_id", + [ + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/start/1/_meta", + "task", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/start/_meta", + "step", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/1652384326805262/_meta", + "run", + "1652384326805262", + ), + ( + ".metaflow/BasicParameterTestFlow/_meta", + "flow", + None, + ), + ], + ids=["task_level", "step_level", "run_level", "flow_level"], +) +def test_deduce_run_id_from_meta_dir(meta_path, sub_type, expected_run_id): + actual_run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( + meta_path, sub_type + ) + assert actual_run_id == expected_run_id diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py index d08206a88dd..a50a7545a85 100644 --- a/test/unit/test_pickle_serializer.py +++ b/test/unit/test_pickle_serializer.py @@ -1,5 +1,4 @@ import pickle - import pytest from metaflow.datastore.artifacts.serializer import ( @@ -27,21 +26,20 @@ def test_registered_in_store(): assert SerializerStore._all_serializers["pickle"] is PickleSerializer -def test_last_in_ordering(): +def test_last_in_ordering(monkeypatch): """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" - # Dispatch is driven by _active_serializers (post-Phase-6). Ensure Pickle - # is active for this test regardless of whether bootstrap() has already - # run in the current process. - was_active = PickleSerializer in SerializerStore._active_serializers - SerializerStore._active_serializers.add(PickleSerializer) - SerializerStore._ordered_cache = None - try: - ordered = SerializerStore.get_ordered_serializers() - assert ordered[-1] is PickleSerializer - finally: - if not was_active: - SerializerStore._active_serializers.discard(PickleSerializer) - SerializerStore._ordered_cache = None + # Use monkeypatch to safely append PickleSerializer to active state and clear cache. + # This ensures automatic cleanup after the test runs. + updated_serializers = type(SerializerStore._active_serializers)( + SerializerStore._active_serializers + ) + updated_serializers.add(PickleSerializer) + + monkeypatch.setattr(SerializerStore, "_active_serializers", updated_serializers) + monkeypatch.setattr(SerializerStore, "_ordered_cache", None) + + ordered = SerializerStore.get_ordered_serializers() + assert ordered[-1] is PickleSerializer # --------------------------------------------------------------------------- @@ -90,6 +88,7 @@ def test_can_serialize_any_object(obj): @pytest.mark.parametrize( "encoding", ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"], + ids=["pickle-v2", "pickle-v4", "gzip-pickle-v2", "gzip-pickle-v4"], ) def test_can_deserialize_valid_encodings(encoding): meta = SerializationMetadata("object", 100, encoding, {}) @@ -99,6 +98,7 @@ def test_can_deserialize_valid_encodings(encoding): @pytest.mark.parametrize( "encoding", ["json", "iotype:text", "msgpack", "unknown", ""], + ids=["json", "iotype-text", "msgpack", "unknown", "empty-string"], ) def test_cannot_deserialize_unknown_encodings(encoding): meta = SerializationMetadata("object", 100, encoding, {}) @@ -181,6 +181,8 @@ def test_round_trip(obj): class _CustomObj: + """Helper class to verify custom object serialization properties.""" + def __init__(self, x): self.x = x diff --git a/test/unit/test_s3_empty_input.py b/test/unit/test_s3_empty_input.py index 14b89722071..8344a185948 100644 --- a/test/unit/test_s3_empty_input.py +++ b/test/unit/test_s3_empty_input.py @@ -14,16 +14,17 @@ Fix: return early from both methods when the input list is empty. """ -import os -import tempfile -from unittest.mock import MagicMock, patch - import pytest from metaflow.plugins.datatools.s3.s3 import S3 +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + -def _make_s3(tmp_path): +@pytest.fixture +def s3_instance(tmp_path): """Create a minimal S3 instance without a real S3 connection.""" s3 = object.__new__(S3) s3._tmproot = str(tmp_path) @@ -36,111 +37,128 @@ def _make_s3(tmp_path): return s3 -class TestPutManyFilesEmptyInput: - def test_returns_empty_list(self, tmp_path): - """_put_many_files with no items returns [] without calling s3op.""" - s3 = _make_s3(tmp_path) - result = s3._put_many_files(iter([]), overwrite=True) - assert result == [] - - def test_does_not_call_s3op(self, tmp_path): - """s3op subprocess must not be spawned when there is nothing to upload.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - s3._put_many_files(iter([]), overwrite=True) - mock_op.assert_not_called() - - def test_error_handler_would_crash_without_guard(self): - """ - Regression: document the latent IndexError in the error handler. - - Before the guard was added, _put_many_files called s3op even with - empty url_info. s3op exits 0 but writes "Uploading 0 files.." to - stderr. The `if stderr:` block then executed: - url_info[0][2]["key"] # IndexError — url_info is [] - This test proves that the error handler is unsafe with an empty list, - which is why the early-return guard is necessary. - """ - url_info = [] # what packing_iter() produces on a CAS cache hit - stderr = b"Uploading 0 files..\nUploaded 0 files." # s3op info output - with pytest.raises(IndexError): - # Reproduce the exact expression from the old error handler: - _ = url_info[0][2]["key"] - - def test_non_empty_input_still_calls_s3op(self, tmp_path): - """Non-empty input should still go through s3op (not short-circuit).""" - s3 = _make_s3(tmp_path) - - with tempfile.NamedTemporaryFile( - dir=str(tmp_path), delete=False, mode="wb" - ) as tf: - tf.write(b"data") - local = tf.name - - try: - with patch.object( - s3, - "_s3op_with_retries", - return_value=([b"s3://bucket/key " + tf.name.encode() + b" "], b"", 0), - ) as mock_op: - - def _gen(): - yield local, "s3://bucket/key", {"key": "mykey"} - - s3._put_many_files(_gen(), overwrite=True) - mock_op.assert_called_once() - finally: - os.unlink(local) - - -class TestReadManyFilesEmptyInput: - def test_returns_empty_generator(self, tmp_path): - """_read_many_files with no prefixes yields nothing.""" - s3 = _make_s3(tmp_path) - result = list(s3._read_many_files("get", iter([]))) - assert result == [] - - def test_does_not_call_s3op(self, tmp_path): - """s3op subprocess must not be spawned when there is nothing to read.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - list(s3._read_many_files("get", iter([]))) - mock_op.assert_not_called() - - def test_error_handler_would_crash_without_guard(self): - """ - Regression: document the latent IndexError in the error handler. - - Before the guard was added, _read_many_files called s3op even with - empty prefixes_and_ranges. s3op exits 0 but writes - "Info_downloading 0 files.." to stderr. The `if stderr:` block then - executed: - prefixes_and_ranges[0] # IndexError — list is [] - This test proves that the error handler is unsafe with an empty list, - which is why the early-return guard is necessary. - """ - prefixes_and_ranges = [] # empty input materialised as a list - stderr = b"Info_downloading 0 files..\nInfo_downloaded 0 files." - with pytest.raises(IndexError): - # Reproduce the exact expression from the old error handler: - _ = prefixes_and_ranges[0] - - def test_empty_input_for_all_ops(self, tmp_path): - """All s3op modes (get, list, info, put) are safe with empty input.""" - s3 = _make_s3(tmp_path) - with patch.object(s3, "_s3op_with_retries") as mock_op: - for op in ("get", "list", "info", "put"): - result = list(s3._read_many_files(op, iter([]))) - assert result == [], f"op={op} should yield nothing for empty input" - mock_op.assert_not_called() - - def test_non_empty_input_still_calls_s3op(self, tmp_path): - """Non-empty input should still go through s3op (not short-circuit).""" - s3 = _make_s3(tmp_path) - with patch.object( - s3, - "_s3op_with_retries", - return_value=([b"s3://b/k /tmp/f 100"], b"", 0), - ) as mock_op: - list(s3._read_many_files("get", iter([("s3://b/k", None)]))) - mock_op.assert_called_once() +# --------------------------------------------------------------------------- +# Test Functions: Put Many Files +# --------------------------------------------------------------------------- + + +def test_put_many_files_empty_input_returns_empty_list(s3_instance): + """_put_many_files with no items returns [] without calling s3op.""" + result = s3_instance._put_many_files(iter([]), overwrite=True) + assert result == [] + + +def test_put_many_files_empty_input_does_not_call_s3op(mocker, s3_instance): + """s3op subprocess must not be spawned when there is nothing to upload.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + s3_instance._put_many_files(iter([]), overwrite=True) + + mock_op.assert_not_called() + + +def test_put_many_files_error_handler_would_crash_without_guard(): + """ + Regression: document the latent IndexError in the error handler. + + Before the guard was added, _put_many_files called s3op even with + empty url_info. s3op exits 0 but writes "Uploading 0 files.." to + stderr. The `if stderr:` block then executed: + url_info[0][2]["key"] # IndexError — url_info is [] + This test proves that the error handler is unsafe with an empty list, + which is why the early-return guard is necessary. + """ + url_info = [] # what packing_iter() produces on a CAS cache hit + stderr = b"Uploading 0 files..\nUploaded 0 files." # s3op info output + + with pytest.raises(IndexError): + # Reproduce the exact expression from the old error handler: + _ = url_info[0][2]["key"] + + +def test_put_many_files_non_empty_input_calls_s3op(mocker, s3_instance, tmp_path): + """Non-empty input should still go through s3op (not short-circuit).""" + # Use tmp_path instead of NamedTemporaryFile to avoid manual cleanup + local_file = tmp_path / "test_data.txt" + local_file.write_bytes(b"data") + + mock_op = mocker.patch.object( + s3_instance, + "_s3op_with_retries", + return_value=([b"s3://bucket/key " + str(local_file).encode() + b" "], b"", 0), + ) + + def _gen(): + yield str(local_file), "s3://bucket/key", {"key": "mykey"} + + s3_instance._put_many_files(_gen(), overwrite=True) + + mock_op.assert_called_once() + + +# --------------------------------------------------------------------------- +# Test Functions: Read Many Files +# --------------------------------------------------------------------------- + + +def test_read_many_files_empty_input_returns_empty_generator(s3_instance): + """_read_many_files with no prefixes yields nothing.""" + result = list(s3_instance._read_many_files("get", iter([]))) + assert result == [] + + +def test_read_many_files_empty_input_does_not_call_s3op(mocker, s3_instance): + """s3op subprocess must not be spawned when there is nothing to read.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + list(s3_instance._read_many_files("get", iter([]))) + + mock_op.assert_not_called() + + +def test_read_many_files_error_handler_would_crash_without_guard(): + """ + Regression: document the latent IndexError in the error handler. + + Before the guard was added, _read_many_files called s3op even with + empty prefixes_and_ranges. s3op exits 0 but writes + "Info_downloading 0 files.." to stderr. The `if stderr:` block then + executed: + prefixes_and_ranges[0] # IndexError — list is [] + This test proves that the error handler is unsafe with an empty list, + which is why the early-return guard is necessary. + """ + prefixes_and_ranges = [] # empty input materialised as a list + stderr = b"Info_downloading 0 files..\nInfo_downloaded 0 files." + + with pytest.raises(IndexError): + # Reproduce the exact expression from the old error handler: + _ = prefixes_and_ranges[0] + + +@pytest.mark.parametrize( + "s3_op", + ["get", "list", "info", "put"], + ids=["op_get", "op_list", "op_info", "op_put"], +) +def test_read_many_files_empty_input_safe_for_all_ops(mocker, s3_instance, s3_op): + """All s3op modes (get, list, info, put) are safe with empty input.""" + mock_op = mocker.patch.object(s3_instance, "_s3op_with_retries") + + result = list(s3_instance._read_many_files(s3_op, iter([]))) + + assert result == [], f"op={s3_op} should yield nothing for empty input" + mock_op.assert_not_called() + + +def test_read_many_files_non_empty_input_calls_s3op(mocker, s3_instance): + """Non-empty input should still go through s3op (not short-circuit).""" + mock_op = mocker.patch.object( + s3_instance, + "_s3op_with_retries", + return_value=([b"s3://b/k /tmp/f 100"], b"", 0), + ) + + list(s3_instance._read_many_files("get", iter([("s3://b/k", None)]))) + + mock_op.assert_called_once() diff --git a/test/unit/test_s3_storage.py b/test/unit/test_s3_storage.py index 6f762a9d87b..c88da4ab2aa 100644 --- a/test/unit/test_s3_storage.py +++ b/test/unit/test_s3_storage.py @@ -1,3 +1,5 @@ +"""Tests for S3 datastore byte saving and metadata mapping.""" + from io import BytesIO import pytest @@ -6,7 +8,12 @@ from metaflow.plugins.datastores.s3_storage import S3Storage -def _make_storage(): +# --- Fixtures --- + + +@pytest.fixture +def s3_storage(): + """Fixture providing a minimal, uninitialized S3Storage instance.""" storage = object.__new__(S3Storage) storage.datastore_root = "s3://unit-test-root" storage.s3_client = object() @@ -33,12 +40,21 @@ def patched_s3(mocker): return s3 -def test_save_bytes_put_many_preserves_metadata_slot(patched_s3, test_items): - storage = _make_storage() - storage.save_bytes(iter(test_items), overwrite=True, len_hint=11) +# --- Tests --- + + +def test_save_bytes_put_many_preserves_metadata_slot( + s3_storage, patched_s3, test_items +): + """ + When len_hint > 10, save_bytes optimizes by delegating to put_many. + This test ensures the metadata payload survives the batch translation. + """ + s3_storage.save_bytes(iter(test_items), overwrite=True, len_hint=11) put_objs, overwrite = patched_s3.put_many.call_args[0] put_objs = list(put_objs) + assert overwrite is True assert put_objs[0].encryption is None assert put_objs[0].metadata == {"k": "v"} @@ -46,13 +62,20 @@ def test_save_bytes_put_many_preserves_metadata_slot(patched_s3, test_items): assert put_objs[1].metadata is None -def test_save_bytes_sequential_preserves_metadata(patched_s3, test_items): - storage = _make_storage() - storage.save_bytes(iter(test_items), overwrite=False, len_hint=2) +def test_save_bytes_sequential_preserves_metadata(s3_storage, patched_s3, test_items): + """ + When len_hint <= 10, save_bytes falls back to sequential put() calls. + This test ensures the metadata kwargs are passed correctly per item. + """ + s3_storage.save_bytes(iter(test_items), overwrite=False, len_hint=2) put_calls = patched_s3.put.call_args_list assert len(put_calls) == 2 + + # Check first item ("a") assert put_calls[0][0][0] == "a" assert put_calls[0][1]["metadata"] == {"k": "v"} + + # Check second item ("b") assert put_calls[1][0][0] == "b" assert put_calls[1][1]["metadata"] is None diff --git a/test/unit/test_serializer_integration.py b/test/unit/test_serializer_integration.py index c8e09d46c00..d3885ffd0d5 100644 --- a/test/unit/test_serializer_integration.py +++ b/test/unit/test_serializer_integration.py @@ -9,12 +9,14 @@ """ import json -import os -import shutil -import tempfile +import threading import pytest +from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, +) from metaflow.datastore.artifacts.serializer import ( ArtifactSerializer, SerializationFormat, @@ -22,21 +24,51 @@ SerializedBlob, SerializerStore, ) +from metaflow.datastore.exceptions import ( + DataException, + UnpicklableArtifactException, +) +from metaflow.datastore.flow_datastore import FlowDataStore +from metaflow.exception import MetaflowException +from metaflow.plugins.datastores.local_storage import LocalStorage from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + # --------------------------------------------------------------------------- -# Test PickleSerializer round-trip through save/load artifacts +# Fixtures # --------------------------------------------------------------------------- +@pytest.fixture +def isolated_store(): + """ + Fixture to isolate SerializerStore global state per test. + Prevents tests from poisoning the registry if they fail mid-execution. + """ + saved_all = dict(SerializerStore._all_serializers) + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + saved_pending = dict(SerializerStore._pending_by_module) + saved_cache = SerializerStore._ordered_cache + + yield + + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(saved_all) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + SerializerStore._pending_by_module.clear() + SerializerStore._pending_by_module.update(saved_pending) + SerializerStore._ordered_cache = saved_cache + + @pytest.fixture def task_datastore(tmp_path): """Create a minimal TaskDataStore wired to a local storage backend.""" - from metaflow.datastore.flow_datastore import FlowDataStore - from metaflow.plugins.datastores.local_storage import LocalStorage - - storage_root = str(tmp_path / "datastore") - os.makedirs(storage_root, exist_ok=True) + storage_root = tmp_path / "datastore" + storage_root.mkdir() flow_ds = FlowDataStore( flow_name="TestFlow", @@ -45,7 +77,7 @@ def task_datastore(tmp_path): event_logger=None, monitor=None, storage_impl=LocalStorage, - ds_root=storage_root, + ds_root=str(storage_root), ) task_ds = flow_ds.get_task_datastore( @@ -62,30 +94,35 @@ def task_datastore(tmp_path): return task_ds -def test_save_load_pickle_round_trip(task_datastore): - """Standard Python objects go through PickleSerializer and round-trip.""" - artifacts = [ +# --------------------------------------------------------------------------- +# Test PickleSerializer round-trip through save/load artifacts +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, value", + [ ("my_dict", {"key": "value", "nested": [1, 2, 3]}), ("my_int", 42), ("my_str", "hello world"), ("my_none", None), - ] - task_datastore.save_artifacts(iter(artifacts)) + ], + ids=["dict", "int", "str", "none"], +) +def test_save_load_pickle_round_trip(task_datastore, name, value): + """Standard Python objects go through PickleSerializer and round-trip.""" + task_datastore.save_artifacts(iter([(name, value)])) # Verify metadata - for name, _ in artifacts: - info = task_datastore._info[name] - assert "encoding" in info - assert info["encoding"] == "pickle-v4" - assert info["size"] > 0 - assert "type" in info + info = task_datastore._info[name] + assert "encoding" in info + assert info["encoding"] == "pickle-v4" + assert info["size"] > 0 + assert "type" in info # Load and verify - loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) - assert loaded["my_dict"] == {"key": "value", "nested": [1, 2, 3]} - assert loaded["my_int"] == 42 - assert loaded["my_str"] == "hello world" - assert loaded["my_none"] is None + loaded = dict(task_datastore.load_artifacts([name])) + assert loaded[name] == value def test_distinct_objects_on_load(task_datastore): @@ -107,14 +144,9 @@ def test_metadata_auto_populates_source_for_pickle(task_datastore): assert info.get("serializer_info", {}).get("source") == "metaflow" -def test_author_source_is_not_overridden(task_datastore): +def test_author_source_is_not_overridden(task_datastore, isolated_store): """A serializer that sets its own ``source`` in serializer_info should not have it overridden by the auto-injected bootstrap source.""" - from metaflow.datastore.artifacts import SerializationFormat, SerializerStore - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) class _ExplicitSourceSerializer(ArtifactSerializer): TYPE = "test_explicit_source" @@ -159,15 +191,9 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): task_datastore._serializers = [_ExplicitSourceSerializer, PickleSerializer] - try: - task_datastore.save_artifacts(iter([("hello", "world")])) - info = task_datastore._info["hello"] - assert info["serializer_info"]["source"] == "i-picked-this-myself" - finally: - SerializerStore._records.pop("test_explicit_source", None) - SerializerStore._active_serializers.discard(_ExplicitSourceSerializer) - SerializerStore._all_serializers.pop("test_explicit_source", None) - SerializerStore._ordered_cache = None + task_datastore.save_artifacts(iter([("hello", "world")])) + info = task_datastore._info["hello"] + assert info["serializer_info"]["source"] == "i-picked-this-myself" # --------------------------------------------------------------------------- @@ -175,10 +201,9 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): # --------------------------------------------------------------------------- -def test_custom_serializer_takes_priority(task_datastore): +def test_custom_serializer_takes_priority(task_datastore, isolated_store): """A custom serializer with lower PRIORITY claims matching objects over pickle.""" - # Define and register a custom serializer inside the test class _JsonStringSerializer(ArtifactSerializer): TYPE = "test_json_str" PRIORITY = 50 @@ -208,29 +233,23 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format="storage"): return json.loads(data[0].decode("utf-8")) - # Explicitly set serializers: custom first, then pickle fallback. - # Don't use get_ordered_serializers() to avoid pollution from other test files. task_datastore._serializers = [_JsonStringSerializer, PickleSerializer] - try: - task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) + task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) - # "msg" should use our custom serializer (str → _JsonStringSerializer) - msg_info = task_datastore._info["msg"] - assert msg_info["encoding"] == "test_json_str" - assert msg_info["serializer_info"] == {"format": "json-utf8"} + # "msg" should use our custom serializer (str → _JsonStringSerializer) + msg_info = task_datastore._info["msg"] + assert msg_info["encoding"] == "test_json_str" + assert msg_info["serializer_info"] == {"format": "json-utf8"} - # "num" should fall through to PickleSerializer (int → not claimed by custom) - num_info = task_datastore._info["num"] - assert num_info["encoding"] == "pickle-v4" + # "num" should fall through to PickleSerializer (int → not claimed by custom) + num_info = task_datastore._info["num"] + assert num_info["encoding"] == "pickle-v4" - # Both round-trip correctly - loaded = dict(task_datastore.load_artifacts(["msg", "num"])) - assert loaded["msg"] == "hello" - assert loaded["num"] == 42 - finally: - SerializerStore._all_serializers.pop("test_json_str", None) - SerializerStore._ordered_cache = None + # Both round-trip correctly + loaded = dict(task_datastore.load_artifacts(["msg", "num"])) + assert loaded["msg"] == "hello" + assert loaded["num"] == 42 # --------------------------------------------------------------------------- @@ -240,7 +259,6 @@ def deserialize(cls, data, metadata=None, format="storage"): def test_backward_compat_old_metadata(task_datastore): """Artifacts saved with old metadata format (no serializer_info) still load.""" - # Save normally first task_datastore.save_artifacts(iter([("old_artifact", {"a": 1})])) # Simulate old metadata format: no serializer_info, old encoding @@ -248,27 +266,22 @@ def test_backward_compat_old_metadata(task_datastore): "size": 100, "type": "", "encoding": "gzip+pickle-v4", - # no "serializer_info" key } - # Should still load via PickleSerializer (can_deserialize handles gzip+pickle-v4) + # Should still load via PickleSerializer loaded = dict(task_datastore.load_artifacts(["old_artifact"])) assert loaded["old_artifact"] == {"a": 1} def test_backward_compat_no_encoding(task_datastore): """Very old artifacts without encoding field default to gzip+pickle-v2.""" - # Save an artifact task_datastore.save_artifacts(iter([("ancient", 99)])) - # Simulate very old metadata: no encoding, no serializer_info task_datastore._info["ancient"] = { "size": 10, "type": "", - # no "encoding" key — defaults to gzip+pickle-v2 } - # Should still load loaded = dict(task_datastore.load_artifacts(["ancient"])) assert loaded["ancient"] == 99 @@ -289,7 +302,6 @@ def test_missing_info_with_object_uses_pickle_defaults(task_datastore): del task_datastore._info["present"] loaded = dict(task_datastore.load_artifacts(["present"])) - assert loaded["present"] == {"value": 1} @@ -311,13 +323,12 @@ def test_info_without_object_raises_key_error(task_datastore): # --------------------------------------------------------------------------- -def test_post_init_registration_reaches_existing_datastore(task_datastore): +def test_post_init_registration_reaches_existing_datastore( + task_datastore, isolated_store +): """A serializer registered AFTER the datastore was constructed must still be visible. Without the dynamic ``_serializers`` property, lazy imports - (e.g. ``import torch`` after ``TaskDataStore.__init__``) would be silently - ignored for that instance. - """ - # Drop the test override so the property falls back to the live registry. + would be silently ignored for that instance.""" task_datastore._serializers = None class _PostInitSerializer(ArtifactSerializer): @@ -340,15 +351,10 @@ def serialize(cls, obj, format="storage"): def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError - # Dispatch reads from _active_serializers now (post-Phase-6). SerializerStore._active_serializers.add(_PostInitSerializer) SerializerStore._ordered_cache = None - try: - assert _PostInitSerializer in task_datastore._serializers - finally: - SerializerStore._all_serializers.pop("test_post_init_registration", None) - SerializerStore._active_serializers.discard(_PostInitSerializer) - SerializerStore._ordered_cache = None + + assert _PostInitSerializer in task_datastore._serializers # --------------------------------------------------------------------------- @@ -356,13 +362,11 @@ def deserialize(cls, data, metadata=None, format="storage"): # --------------------------------------------------------------------------- -def test_info_not_populated_when_serializer_returns_no_blobs(task_datastore): - """ - Regression for the "_info[name] poisoned on validation failure" bug: if a - serializer returns an empty blob list, ``save_artifacts`` must raise - without leaving partial metadata in ``_info``. - """ - from metaflow.datastore.exceptions import DataException +def test_info_not_populated_when_serializer_returns_no_blobs( + task_datastore, isolated_store +): + """If a serializer returns an empty blob list, ``save_artifacts`` must raise + without leaving partial metadata in ``_info``.""" class _EmptyBlobSerializer(ArtifactSerializer): TYPE = "test_empty_blob" @@ -378,28 +382,23 @@ def can_deserialize(cls, metadata): @classmethod def serialize(cls, obj, format="storage"): - return ( - [], - SerializationMetadata("x", 0, "test_empty_blob", {}), - ) + return ([], SerializationMetadata("x", 0, "test_empty_blob", {})) @classmethod def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError task_datastore._serializers = [_EmptyBlobSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="returned no blobs"): - task_datastore.save_artifacts(iter([("bad", object())])) - assert "bad" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_empty_blob", None) - SerializerStore._ordered_cache = None + + with pytest.raises(DataException, match="returned no blobs"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info -def test_info_not_populated_when_serializer_returns_multi_blob(task_datastore): +def test_info_not_populated_when_serializer_returns_multi_blob( + task_datastore, isolated_store +): """Same guarantee as above for the multi-blob rejection path.""" - from metaflow.datastore.exceptions import DataException class _MultiBlobSerializer(ArtifactSerializer): TYPE = "test_multi_blob" @@ -425,31 +424,20 @@ def deserialize(cls, data, metadata=None, format="storage"): raise NotImplementedError task_datastore._serializers = [_MultiBlobSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="single-blob serializers"): - task_datastore.save_artifacts(iter([("bad", object())])) - assert "bad" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_multi_blob", None) - SerializerStore._ordered_cache = None + + with pytest.raises(DataException, match="single-blob serializers"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info # --------------------------------------------------------------------------- -# Exception flow: PickleSerializer owns its own UnpicklableArtifactException; -# extension MetaflowExceptions pass through; everything else gets wrapped. +# Exception flow mapping # --------------------------------------------------------------------------- def test_pickle_serializer_raises_unpicklable_with_artifact_name(task_datastore): - """PickleSerializer raises ``UnpicklableArtifactException`` from inside - ``serialize()`` (no name); ``save_artifacts`` re-raises it with the - artifact name attached so users see the original "named X" message.""" - import threading - - from metaflow.datastore.exceptions import UnpicklableArtifactException - - # ``threading.Lock`` raises ``TypeError`` from ``pickle.dumps``, which is - # the path that turns into ``UnpicklableArtifactException``. + """PickleSerializer raises ``UnpicklableArtifactException``; + ``save_artifacts`` re-raises it with the artifact name attached.""" unpicklable = threading.Lock() with pytest.raises(UnpicklableArtifactException, match='named "bad_one"'): @@ -457,11 +445,9 @@ def test_pickle_serializer_raises_unpicklable_with_artifact_name(task_datastore) assert "bad_one" not in task_datastore._info -def test_extension_metaflow_exception_passes_through(task_datastore): +def test_extension_metaflow_exception_passes_through(task_datastore, isolated_store): """An extension serializer raising a ``MetaflowException`` subclass must - propagate as-is — wrapping it in ``DataException`` would obscure the - original headline/message that is already user-facing.""" - from metaflow.exception import MetaflowException + propagate as-is.""" class _ExtensionError(MetaflowException): headline = "Extension validation failed" @@ -487,22 +473,17 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError task_datastore._serializers = [_RaisingSerializer, PickleSerializer] - try: - with pytest.raises(_ExtensionError, match="schema mismatch on field 'foo'"): - task_datastore.save_artifacts(iter([("x", object())])) - assert "x" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_passthrough_ser", None) - SerializerStore._ordered_cache = None + + with pytest.raises(_ExtensionError, match="schema mismatch on field 'foo'"): + task_datastore.save_artifacts(iter([("x", object())])) + assert "x" not in task_datastore._info -def test_extension_type_error_is_not_mislabeled_unpicklable(task_datastore): +def test_extension_type_error_is_not_mislabeled_unpicklable( + task_datastore, isolated_store +): """A non-pickle serializer raising ``TypeError`` must NOT be reported as - ``UnpicklableArtifactException`` — that wrapper is pickle-specific now.""" - from metaflow.datastore.exceptions import ( - DataException, - UnpicklableArtifactException, - ) + ``UnpicklableArtifactException``.""" class _TypeErrorSerializer(ArtifactSerializer): TYPE = "test_typeerror_ser" @@ -525,31 +506,23 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError task_datastore._serializers = [_TypeErrorSerializer, PickleSerializer] - try: - with pytest.raises(DataException, match="_TypeErrorSerializer") as exc_info: - task_datastore.save_artifacts(iter([("x", object())])) - # Critically: NOT UnpicklableArtifactException — that name would lie - # to users about which serializer rejected the object. - assert not isinstance(exc_info.value, UnpicklableArtifactException) - assert "x" not in task_datastore._info - finally: - SerializerStore._all_serializers.pop("test_typeerror_ser", None) - SerializerStore._ordered_cache = None - - -def test_can_serialize_exception_falls_through_to_pickle(task_datastore): + + with pytest.raises(DataException, match="_TypeErrorSerializer") as exc_info: + task_datastore.save_artifacts(iter([("x", object())])) + + assert not isinstance(exc_info.value, UnpicklableArtifactException) + assert "x" not in task_datastore._info + + +def test_can_serialize_exception_falls_through_to_pickle( + task_datastore, isolated_store +): """A buggy custom serializer's can_serialize exception must NOT crash - save_artifacts. The buggy serializer is skipped; pickle fallback handles - the artifact; dispatch_error_count is incremented.""" - from metaflow.datastore.artifacts import SerializationFormat - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) + save_artifacts. Pickle fallback handles it; dispatch_error_count is incremented.""" class _BuggyCanSerialize(ArtifactSerializer): TYPE = "test_buggy_cs" - PRIORITY = 1 # tried first + PRIORITY = 1 @classmethod def can_serialize(cls, obj): @@ -567,7 +540,6 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - # Seed a diagnostic record so dispatch_error_count has somewhere to go. rec = SerializerRecord( name="test_buggy_cs", class_path="test.inline.BuggyCanSerialize", @@ -580,27 +552,15 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): task_datastore._serializers = [_BuggyCanSerialize, PickleSerializer] - try: - # Must NOT raise. - task_datastore.save_artifacts(iter([("x", 42)])) - assert task_datastore._info["x"]["encoding"] == "pickle-v4" - assert rec.dispatch_error_count == 1 - assert rec.last_error is not None - assert "RuntimeError" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_buggy_cs", None) - SerializerStore._active_serializers.discard(_BuggyCanSerialize) - SerializerStore._records.pop("test_buggy_cs", None) - SerializerStore._ordered_cache = None - - -def test_can_deserialize_exception_falls_through(task_datastore): + # Must NOT raise. + task_datastore.save_artifacts(iter([("x", 42)])) + assert task_datastore._info["x"]["encoding"] == "pickle-v4" + assert rec.dispatch_error_count == 1 + assert "RuntimeError" in rec.last_error + + +def test_can_deserialize_exception_falls_through(task_datastore, isolated_store): """Same guarantee for can_deserialize during load_artifacts.""" - from metaflow.datastore.artifacts import SerializationFormat - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, - ) class _BuggyCanDeserialize(ArtifactSerializer): TYPE = "test_buggy_cd" @@ -632,35 +592,22 @@ def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): SerializerStore._records["test_buggy_cd"] = rec SerializerStore._active_serializers.add(_BuggyCanDeserialize) - # First save an artifact normally via pickle so load has something to load. + # First save an artifact normally via pickle task_datastore._serializers = [PickleSerializer] task_datastore.save_artifacts(iter([("y", "hello")])) - # Now install the buggy serializer and try to load — buggy can_deserialize - # should be skipped and pickle should take over. + # Now install the buggy serializer and try to load task_datastore._serializers = [_BuggyCanDeserialize, PickleSerializer] - try: - loaded = dict(task_datastore.load_artifacts(["y"])) - assert loaded["y"] == "hello" - assert rec.dispatch_error_count == 1 - assert rec.last_error is not None - assert "RuntimeError" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_buggy_cd", None) - SerializerStore._active_serializers.discard(_BuggyCanDeserialize) - SerializerStore._records.pop("test_buggy_cd", None) - SerializerStore._ordered_cache = None + loaded = dict(task_datastore.load_artifacts(["y"])) + assert loaded["y"] == "hello" + assert rec.dispatch_error_count == 1 + assert "RuntimeError" in rec.last_error -def test_subclass_lazy_import_stashes_on_child_not_parent(): +def test_subclass_lazy_import_stashes_on_child_not_parent(isolated_store): """lazy_import on a subclass should set attrs on the subclass, not the parent. Parent and children should each have their own _lazy_imported_names set.""" - from metaflow.datastore.artifacts import ( - ArtifactSerializer, - SerializationFormat, - SerializerStore, - ) class _ParentSer(ArtifactSerializer): TYPE = "test_inherit_parent" @@ -692,29 +639,16 @@ class _ChildSer(_ParentSer): def setup_imports(cls, context=None): cls.lazy_import("sys") - try: - _ParentSer.setup_imports() - _ChildSer.setup_imports() - import json as _json - import sys as _sys - - # Parent has json; child has sys - assert _ParentSer.json is _json - assert _ChildSer.sys is _sys - - # Each class should have its OWN _lazy_imported_names set - # (not a shared inherited one) - parent_names = _ParentSer.__dict__.get("_lazy_imported_names", set()) - child_names = _ChildSer.__dict__.get("_lazy_imported_names", set()) - assert parent_names == {"json"} - assert child_names == {"sys"} - finally: - for t in ("test_inherit_parent", "test_inherit_child"): - SerializerStore._all_serializers.pop(t, None) - SerializerStore._ordered_cache = None - for c, attr in ((_ParentSer, "json"), (_ChildSer, "sys")): - if attr in c.__dict__: - delattr(c, attr) - for c in (_ParentSer, _ChildSer): - if "_lazy_imported_names" in c.__dict__: - delattr(c, "_lazy_imported_names") + _ParentSer.setup_imports() + _ChildSer.setup_imports() + + import json as _json + import sys as _sys + + assert _ParentSer.json is _json + assert _ChildSer.sys is _sys + + parent_names = _ParentSer.__dict__.get("_lazy_imported_names", set()) + child_names = _ChildSer.__dict__.get("_lazy_imported_names", set()) + assert parent_names == {"json"} + assert child_names == {"sys"} diff --git a/test/unit/test_serializer_lifecycle.py b/test/unit/test_serializer_lifecycle.py index d201f06587a..4e7fcef4e0b 100644 --- a/test/unit/test_serializer_lifecycle.py +++ b/test/unit/test_serializer_lifecycle.py @@ -9,6 +9,68 @@ SerializerRecord, SerializerState, ) +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializerStore, + SerializationFormat, +) +from metaflow import metaflow_config + + +# --- Fixtures --- + + +@pytest.fixture +def isolated_store(): + """ + Fixture to isolate SerializerStore global state per test. + Provides a clean slate and restores original state afterward. + """ + # Snapshot original state + saved_all = dict(SerializerStore._all_serializers) + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + saved_pending = dict(SerializerStore._pending_by_module) + saved_cache = SerializerStore._ordered_cache + + # Clear state for test isolation + SerializerStore._all_serializers.clear() + SerializerStore._active_serializers.clear() + SerializerStore._records.clear() + SerializerStore._pending_by_module.clear() + SerializerStore._ordered_cache = None + + yield + + # Restore original state + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(saved_all) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + SerializerStore._pending_by_module.clear() + SerializerStore._pending_by_module.update(saved_pending) + SerializerStore._ordered_cache = saved_cache + + +@pytest.fixture +def make_serializer_module(monkeypatch): + """ + Factory fixture to dynamically create a module from source code + and register it. Automatically cleans up sys.modules via monkeypatch. + """ + + def _make(module_name, source_code): + mod = types.ModuleType(module_name) + exec(source_code, mod.__dict__) + monkeypatch.setitem(sys.modules, module_name, mod) + return mod + + return _make + + +# --- Tests --- def test_serializer_record_default_fields(): @@ -45,14 +107,7 @@ def test_serializer_record_as_dict(): assert d["import_trigger"] == "eager" -from metaflow.datastore.artifacts.serializer import ( - ArtifactSerializer, - SerializerStore, - SerializationFormat, -) - - -def test_store_separates_all_vs_active(): +def test_store_separates_all_vs_active(isolated_store): """_all_serializers is the known-classes index; _active_serializers is dispatch pool.""" class _Known(ArtifactSerializer): @@ -74,23 +129,18 @@ def serialize(cls, obj, format=SerializationFormat.STORAGE): def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError - try: - # Metaclass registers on class body execution - assert _Known in SerializerStore._all_serializers.values() - # But without bootstrap, it is NOT in the dispatch pool - assert _Known not in SerializerStore._active_serializers - # _records is an empty dict initially (for entries from DESC tuples) - assert isinstance(SerializerStore._records, dict) - finally: - SerializerStore._all_serializers.pop("test_known_not_active", None) - SerializerStore._active_serializers.discard(_Known) - SerializerStore._ordered_cache = None + # Metaclass registers on class body execution + assert _Known in SerializerStore._all_serializers.values() + # But without bootstrap, it is NOT in the dispatch pool + assert _Known not in SerializerStore._active_serializers + # _records is an empty dict initially (for entries from DESC tuples) + assert isinstance(SerializerStore._records, dict) -def test_bootstrap_activates_dependency_free_serializer(): +def test_bootstrap_activates_dependency_free_serializer( + isolated_store, make_serializer_module +): """bootstrap_entries with an in-process serializer moves it to ACTIVE.""" - - mod = types.ModuleType("_test_bootstrap_mod") source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat @@ -103,40 +153,26 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError """ - exec(source, mod.__dict__) - sys.modules["_test_bootstrap_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_bootstrap_probe", "_test_bootstrap_mod._BootProbe"), - ] - ) - rec = SerializerStore._records["test_bootstrap_probe"] - assert rec.state == SerializerState.ACTIVE - assert rec.priority == 60 - assert rec.type == "test_bootstrap_probe" - assert rec.import_trigger == "eager" - assert mod._BootProbe in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_bootstrap_probe", None) - SerializerStore._active_serializers.discard(mod._BootProbe) - SerializerStore._records.pop("test_bootstrap_probe", None) - SerializerStore._ordered_cache = None - del sys.modules["_test_bootstrap_mod"] - - -def test_bootstrap_rejects_name_type_mismatch(): - """Tuple first element MUST equal class.TYPE.""" + mod = make_serializer_module("_test_bootstrap_mod", source) + SerializerStore.bootstrap_entries( + [("test_bootstrap_probe", "_test_bootstrap_mod._BootProbe")] + ) + + rec = SerializerStore._records["test_bootstrap_probe"] + assert rec.state == SerializerState.ACTIVE + assert rec.priority == 60 + assert rec.type == "test_bootstrap_probe" + assert rec.import_trigger == "eager" + assert mod._BootProbe in SerializerStore._active_serializers - mod = types.ModuleType("_test_mismatch_mod") - exec( - """ + +def test_bootstrap_rejects_name_type_mismatch(isolated_store, make_serializer_module): + """Tuple first element MUST equal class.TYPE.""" + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Mismatch(ArtifactSerializer): @@ -146,161 +182,113 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_mismatch_mod", source) + SerializerStore.bootstrap_entries( + [("declared_name", "_test_mismatch_mod._Mismatch")] ) - sys.modules["_test_mismatch_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("declared_name", "_test_mismatch_mod._Mismatch"), - ] - ) - rec = SerializerStore._records["declared_name"] - assert rec.state == SerializerState.BROKEN - assert "tuple name" in rec.last_error - assert "actual_type" in rec.last_error - finally: - SerializerStore._all_serializers.pop("actual_type", None) - SerializerStore._records.pop("declared_name", None) - del sys.modules["_test_mismatch_mod"] - - -def test_bootstrap_missing_module_parks_entry(): - """ModuleNotFoundError during import_module moves entry to PENDING_ON_IMPORTS.""" + rec = SerializerStore._records["declared_name"] + assert rec.state == SerializerState.BROKEN + assert "tuple name" in rec.last_error + assert "actual_type" in rec.last_error + + +def test_bootstrap_missing_module_parks_entry(isolated_store): + """ModuleNotFoundError during import_module moves entry to PENDING_ON_IMPORTS.""" SerializerStore.bootstrap_entries( - [ - ("test_absent", "_never_created_module._Absent"), - ] + [("test_absent", "_never_created_module._Absent")] ) - try: - rec = SerializerStore._records["test_absent"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "_never_created_module" in rec.awaiting_modules - # _pending_by_module should track it - assert "test_absent" in SerializerStore._pending_by_module.get( - "_never_created_module", [] - ) - finally: - SerializerStore._records.pop("test_absent", None) - SerializerStore._pending_by_module.pop("_never_created_module", None) - - -def test_bootstrap_missing_class_in_module_broken(): + + rec = SerializerStore._records["test_absent"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "_never_created_module" in rec.awaiting_modules + assert "test_absent" in SerializerStore._pending_by_module.get( + "_never_created_module", [] + ) + + +def test_bootstrap_missing_class_in_module_broken( + isolated_store, make_serializer_module +): """getattr failure after successful module import moves to BROKEN.""" - mod = types.ModuleType("_test_no_class_mod") - # Intentionally empty — no class inside - sys.modules["_test_no_class_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_no_class", "_test_no_class_mod._Missing"), - ] - ) - rec = SerializerStore._records["test_no_class"] - assert rec.state == SerializerState.BROKEN - assert "class" in rec.last_error.lower() - assert "_Missing" in rec.last_error - finally: - SerializerStore._records.pop("test_no_class", None) - del sys.modules["_test_no_class_mod"] - - -def test_bootstrap_setup_imports_missing_dep_parks_entry(): + # Intentionally empty module — no class inside + make_serializer_module("_test_no_class_mod", "") + SerializerStore.bootstrap_entries( + [("test_no_class", "_test_no_class_mod._Missing")] + ) + + rec = SerializerStore._records["test_no_class"] + assert rec.state == SerializerState.BROKEN + assert "class" in rec.last_error.lower() + assert "_Missing" in rec.last_error + + +def test_bootstrap_setup_imports_missing_dep_parks_entry( + isolated_store, make_serializer_module +): """ImportError inside setup_imports parks on the missing module name.""" - mod = types.ModuleType("_test_setup_missing_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _WantsMissing(ArtifactSerializer): TYPE = "test_setup_wants_missing" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("absent_at_setup_time_xyz") + def setup_imports(cls, context=None): cls.lazy_import("absent_at_setup_time_xyz") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_setup_missing_mod", source) + SerializerStore.bootstrap_entries( + [("test_setup_wants_missing", "_test_setup_missing_mod._WantsMissing")] ) - sys.modules["_test_setup_missing_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_setup_wants_missing", "_test_setup_missing_mod._WantsMissing"), - ] - ) - rec = SerializerStore._records["test_setup_wants_missing"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "absent_at_setup_time_xyz" in rec.awaiting_modules - finally: - SerializerStore._all_serializers.pop("test_setup_wants_missing", None) - SerializerStore._records.pop("test_setup_wants_missing", None) - SerializerStore._pending_by_module.pop("absent_at_setup_time_xyz", None) - del sys.modules["_test_setup_missing_mod"] - - -def test_bootstrap_setup_imports_other_exception_broken(): + + rec = SerializerStore._records["test_setup_wants_missing"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "absent_at_setup_time_xyz" in rec.awaiting_modules + + +def test_bootstrap_setup_imports_other_exception_broken( + isolated_store, make_serializer_module +): """Non-ImportError from setup_imports moves entry to BROKEN.""" - mod = types.ModuleType("_test_setup_boom_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Boom(ArtifactSerializer): TYPE = "test_boom" @classmethod - def setup_imports(cls, context=None): - raise RuntimeError("explicit boom from test") + def setup_imports(cls, context=None): raise RuntimeError("explicit boom from test") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, - ) - sys.modules["_test_setup_boom_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [ - ("test_boom", "_test_setup_boom_mod._Boom"), - ] - ) - rec = SerializerStore._records["test_boom"] - assert rec.state == SerializerState.BROKEN - assert "RuntimeError" in rec.last_error - assert "explicit boom" in rec.last_error - finally: - SerializerStore._all_serializers.pop("test_boom", None) - SerializerStore._records.pop("test_boom", None) - del sys.modules["_test_setup_boom_mod"] - - -def test_bootstrap_disabled_toggle(): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_setup_boom_mod", source) + SerializerStore.bootstrap_entries([("test_boom", "_test_setup_boom_mod._Boom")]) + + rec = SerializerStore._records["test_boom"] + assert rec.state == SerializerState.BROKEN + assert "RuntimeError" in rec.last_error + assert "explicit boom" in rec.last_error + + +def test_bootstrap_disabled_toggle(isolated_store, make_serializer_module): """Entries whose name is in disabled_names land in DISABLED state.""" - mod = types.ModuleType("_test_disable_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _DisableMe(ArtifactSerializer): @@ -310,280 +298,185 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + mod = make_serializer_module("_test_disable_mod", source) + SerializerStore.bootstrap_entries( + [("test_disable", "_test_disable_mod._DisableMe")], + disabled_names={"test_disable"}, ) - sys.modules["_test_disable_mod"] = mod - try: - SerializerStore.bootstrap_entries( - [("test_disable", "_test_disable_mod._DisableMe")], - disabled_names={"test_disable"}, - ) - rec = SerializerStore._records["test_disable"] - assert rec.state == SerializerState.DISABLED - # Class should NOT be in active pool - assert mod._DisableMe not in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_disable", None) - SerializerStore._active_serializers.discard(mod._DisableMe) - SerializerStore._records.pop("test_disable", None) - del sys.modules["_test_disable_mod"] - - -def test_retry_activates_pending_record_on_module_import(): + + rec = SerializerStore._records["test_disable"] + assert rec.state == SerializerState.DISABLED + assert mod._DisableMe not in SerializerStore._active_serializers + + +def test_retry_activates_pending_record_on_module_import( + isolated_store, make_serializer_module, monkeypatch +): """When a pending record's awaited module imports, the record retries to ACTIVE.""" - ser_mod = types.ModuleType("_test_retry_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Pending(ArtifactSerializer): TYPE = "test_retry_pending" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("retry_dep_mod_name") + def setup_imports(cls, context=None): cls.lazy_import("retry_dep_mod_name") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_retry_ser_mod", source) + monkeypatch.delitem(sys.modules, "retry_dep_mod_name", raising=False) + + SerializerStore.bootstrap_entries( + [("test_retry_pending", "_test_retry_ser_mod._Pending")] + ) + + rec = SerializerStore._records["test_retry_pending"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "retry_dep_mod_name" in rec.awaiting_modules + assert ( + "test_retry_pending" in SerializerStore._pending_by_module["retry_dep_mod_name"] + ) + + # Simulate the dep becoming available, then fire the retry hook. + dep_mod = types.ModuleType("retry_dep_mod_name") + monkeypatch.setitem(sys.modules, "retry_dep_mod_name", dep_mod) + SerializerStore._on_module_imported("retry_dep_mod_name", dep_mod) + + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._Pending in SerializerStore._active_serializers + assert "test_retry_pending" not in SerializerStore._pending_by_module.get( + "retry_dep_mod_name", [] ) - sys.modules["_test_retry_ser_mod"] = ser_mod - sys.modules.pop("retry_dep_mod_name", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_retry_pending", "_test_retry_ser_mod._Pending"), - ] - ) - rec = SerializerStore._records["test_retry_pending"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - assert "retry_dep_mod_name" in rec.awaiting_modules - assert ( - "test_retry_pending" - in SerializerStore._pending_by_module["retry_dep_mod_name"] - ) - - # Simulate the dep becoming available, then fire the retry hook. - dep_mod = types.ModuleType("retry_dep_mod_name") - sys.modules["retry_dep_mod_name"] = dep_mod - SerializerStore._on_module_imported("retry_dep_mod_name", dep_mod) - - assert rec.state == SerializerState.ACTIVE - assert rec.import_trigger == "hook" - assert ser_mod._Pending in SerializerStore._active_serializers - # _pending_by_module should no longer list this record under that module - assert "test_retry_pending" not in SerializerStore._pending_by_module.get( - "retry_dep_mod_name", [] - ) - finally: - SerializerStore._all_serializers.pop("test_retry_pending", None) - SerializerStore._active_serializers.discard(ser_mod._Pending) - SerializerStore._records.pop("test_retry_pending", None) - SerializerStore._pending_by_module.pop("retry_dep_mod_name", None) - SerializerStore._ordered_cache = None - sys.modules.pop("_test_retry_ser_mod", None) - sys.modules.pop("retry_dep_mod_name", None) - - -def test_retry_hits_loop_guard_after_repeated_failure(): + + +def test_retry_hits_loop_guard_after_repeated_failure( + isolated_store, make_serializer_module, monkeypatch +): """Calling _on_module_imported when setup_imports still fails on the same module name should transition to BROKEN via the loop guard.""" - ser_mod = types.ModuleType("_test_loop_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _Loopy(ArtifactSerializer): TYPE = "test_loopy" @classmethod - def setup_imports(cls, context=None): - # Always raises on the same module name even if retried. - cls.lazy_import("never_resolves_mod") + def setup_imports(cls, context=None): cls.lazy_import("never_resolves_mod") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, - ) - sys.modules["_test_loop_ser_mod"] = ser_mod - sys.modules.pop("never_resolves_mod", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_loopy", "_test_loop_ser_mod._Loopy"), - ] - ) - rec = SerializerStore._records["test_loopy"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - - # Fake the dep appearing but NOT actually installing it — the retry - # will re-run setup_imports, which will ImportError again on the same name. - dep_mod = types.ModuleType("never_resolves_mod") - # We deliberately DO NOT put dep_mod in sys.modules, so lazy_import - # inside setup_imports will raise ModuleNotFoundError again. - SerializerStore._on_module_imported("never_resolves_mod", dep_mod) - - assert rec.state == SerializerState.BROKEN - assert "repeated" in rec.last_error.lower() - finally: - SerializerStore._all_serializers.pop("test_loopy", None) - SerializerStore._records.pop("test_loopy", None) - SerializerStore._pending_by_module.pop("never_resolves_mod", None) - sys.modules.pop("_test_loop_ser_mod", None) - sys.modules.pop("never_resolves_mod", None) - - -def test_retry_fires_via_real_import_hook(tmp_path, monkeypatch): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + make_serializer_module("_test_loop_ser_mod", source) + monkeypatch.delitem(sys.modules, "never_resolves_mod", raising=False) + + SerializerStore.bootstrap_entries([("test_loopy", "_test_loop_ser_mod._Loopy")]) + + rec = SerializerStore._records["test_loopy"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Fake the dep appearing but DO NOT actually put it in sys.modules, + # so lazy_import will raise ModuleNotFoundError again. + dep_mod = types.ModuleType("never_resolves_mod") + SerializerStore._on_module_imported("never_resolves_mod", dep_mod) + + assert rec.state == SerializerState.BROKEN + assert "repeated" in rec.last_error.lower() + + +def test_retry_fires_via_real_import_hook( + tmp_path, monkeypatch, isolated_store, make_serializer_module +): """End-to-end: park a serializer on a missing module, install the hook, actually import the module, verify the serializer activates.""" + # Setup test package directory to act as an importable module pkg_dir = tmp_path / "fixture_retry_dep" pkg_dir.mkdir() (pkg_dir / "__init__.py").write_text("VALUE = 42\n") - # NOTE: we prepend syspath AFTER bootstrap_entries below so the first - # lazy_import() fails and the record parks on the missing module. - ser_mod = types.ModuleType("_test_e2e_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _E2E(ArtifactSerializer): TYPE = "test_e2e_retry" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("fixture_retry_dep") + def setup_imports(cls, context=None): cls.lazy_import("fixture_retry_dep") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, - ) - sys.modules["_test_e2e_ser_mod"] = ser_mod - # Make sure the dep module isn't pre-imported - sys.modules.pop("fixture_retry_dep", None) - - try: - SerializerStore.bootstrap_entries( - [ - ("test_e2e_retry", "_test_e2e_ser_mod._E2E"), - ] - ) - rec = SerializerStore._records["test_e2e_retry"] - assert rec.state == SerializerState.PENDING_ON_IMPORTS - - # Now make the dep discoverable on sys.path and actually import it. - monkeypatch.syspath_prepend(str(tmp_path)) - import fixture_retry_dep # noqa: F401 - - # After real import, the hook chain should have fired. - assert rec.state == SerializerState.ACTIVE - assert rec.import_trigger == "hook" - assert ser_mod._E2E in SerializerStore._active_serializers - finally: - SerializerStore._all_serializers.pop("test_e2e_retry", None) - SerializerStore._active_serializers.discard(ser_mod._E2E) - SerializerStore._records.pop("test_e2e_retry", None) - SerializerStore._pending_by_module.pop("fixture_retry_dep", None) - SerializerStore._ordered_cache = None - sys.modules.pop("_test_e2e_ser_mod", None) - sys.modules.pop("fixture_retry_dep", None) - # Clean up interceptor state - from metaflow.datastore.artifacts.lazy_registry import _interceptor - - _interceptor._watched.discard("fixture_retry_dep") - _interceptor._processed.discard("fixture_retry_dep") - - -def test_bootstrap_with_no_extensions_still_runs_core(): + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_e2e_ser_mod", source) + monkeypatch.delitem(sys.modules, "fixture_retry_dep", raising=False) + + SerializerStore.bootstrap_entries([("test_e2e_retry", "_test_e2e_ser_mod._E2E")]) + + rec = SerializerStore._records["test_e2e_retry"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Make the dep discoverable on sys.path and actually import it + monkeypatch.syspath_prepend(str(tmp_path)) + import fixture_retry_dep # noqa: F401 + + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._E2E in SerializerStore._active_serializers + + # Cleanup interceptor state to prevent leaking to other tests + from metaflow.datastore.artifacts.lazy_registry import _interceptor + + _interceptor._watched.discard("fixture_retry_dep") + _interceptor._processed.discard("fixture_retry_dep") + + +def test_bootstrap_with_no_extensions_still_runs_core(isolated_store): """bootstrap() reads core ARTIFACT_SERIALIZERS_DESC from metaflow.plugins and activates PickleSerializer.""" - from metaflow.plugins.datastores.serializers.pickle_serializer import ( - PickleSerializer, + SerializerStore.bootstrap() + + # PickleSerializer should be in active pool (core entry) + pickle_active = any( + r.type == "pickle" and r.state == SerializerState.ACTIVE + for r in SerializerStore._records.values() ) + assert pickle_active - # Snapshot and clear state - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() - try: - SerializerStore.bootstrap() - # PickleSerializer should be in active pool (core entry) - pickle_active = any( - r.type == "pickle" and r.state == SerializerState.ACTIVE - for r in SerializerStore._records.values() - ) - assert pickle_active - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_bootstrap_stamps_core_source_on_record(): +def test_bootstrap_stamps_core_source_on_record(isolated_store): """Core serializers bootstrap with source='metaflow' on their records.""" - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() + SerializerStore.bootstrap() - try: - SerializerStore.bootstrap() - pickle_rec = next( - (r for r in SerializerStore._records.values() if r.type == "pickle"), - None, - ) - assert pickle_rec is not None - assert pickle_rec.source == "metaflow" - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_bootstrap_entries_accepts_source_override(): - """bootstrap_entries accepts ``source`` and attaches it to each record.""" - import sys as _sys - import types as _types + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.type == "pickle"), + None, + ) + assert pickle_rec is not None + assert pickle_rec.source == "metaflow" - saved_records = dict(SerializerStore._records) - saved_active = set(SerializerStore._active_serializers) - ser_mod = _types.ModuleType("_test_source_ser_mod") - exec( - """ +def test_bootstrap_entries_accepts_source_override( + isolated_store, make_serializer_module +): + """bootstrap_entries accepts ``source`` and attaches it to each record.""" + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _SourceProbe(ArtifactSerializer): @@ -593,72 +486,40 @@ def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_source_ser_mod", source) + + SerializerStore.bootstrap_entries( + [("test_source_probe", "_test_source_ser_mod._SourceProbe")], + source="fake-extension", ) - _sys.modules["_test_source_ser_mod"] = ser_mod - - try: - SerializerStore.bootstrap_entries( - [("test_source_probe", "_test_source_ser_mod._SourceProbe")], - source="fake-extension", - ) - rec = SerializerStore._records["test_source_probe"] - assert rec.source == "fake-extension" - assert SerializerStore.get_source_for(ser_mod._SourceProbe) == "fake-extension" - finally: - SerializerStore._all_serializers.pop("test_source_probe", None) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._ordered_cache = None - _sys.modules.pop("_test_source_ser_mod", None) - - -def test_bootstrap_applies_disabled_toggle(monkeypatch): - """bootstrap() respects -name toggles in ENABLED_ARTIFACT_SERIALIZER config.""" - from metaflow import metaflow_config - saved_active = set(SerializerStore._active_serializers) - saved_records = dict(SerializerStore._records) - SerializerStore._active_serializers.clear() - SerializerStore._records.clear() + rec = SerializerStore._records["test_source_probe"] + assert rec.source == "fake-extension" + assert SerializerStore.get_source_for(ser_mod._SourceProbe) == "fake-extension" + +def test_bootstrap_applies_disabled_toggle(isolated_store, monkeypatch): + """bootstrap() respects -name toggles in ENABLED_ARTIFACT_SERIALIZER config.""" monkeypatch.setattr( - metaflow_config, - "ENABLED_ARTIFACT_SERIALIZER", - ["-pickle"], - raising=False, + metaflow_config, "ENABLED_ARTIFACT_SERIALIZER", ["-pickle"], raising=False ) - try: - SerializerStore.bootstrap() - pickle_rec = next( - (r for r in SerializerStore._records.values() if r.name == "pickle"), - None, - ) - assert pickle_rec is not None - assert pickle_rec.state == SerializerState.DISABLED - finally: - SerializerStore._active_serializers.clear() - SerializerStore._active_serializers.update(saved_active) - SerializerStore._records.clear() - SerializerStore._records.update(saved_records) - - -def test_list_serializer_status_returns_dicts(): - """list_serializer_status returns one dict per _records entry, with the - documented shape.""" - from metaflow.datastore.artifacts import list_serializer_status - from metaflow.datastore.artifacts.diagnostic import ( - SerializerRecord, - SerializerState, + SerializerStore.bootstrap() + + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.name == "pickle"), + None, ) + assert pickle_rec is not None + assert pickle_rec.state == SerializerState.DISABLED + + +def test_list_serializer_status_returns_dicts(isolated_store): + """list_serializer_status returns one dict per _records entry, with the documented shape.""" + from metaflow.datastore.artifacts import list_serializer_status # Seed a fake record rec = SerializerRecord( @@ -671,115 +532,80 @@ def test_list_serializer_status_returns_dicts(): ) SerializerStore._records["fake_test_serializer"] = rec - try: - status = list_serializer_status() - assert isinstance(status, list) - match = next((s for s in status if s["name"] == "fake_test_serializer"), None) - assert match is not None - assert match["state"] == "active" - assert match["priority"] == 42 - assert match["type"] == "fake_test_serializer" - assert match["import_trigger"] == "eager" - assert match["class_path"] == "inline.FakeSer" - for key in ( - "name", - "class_path", - "state", - "awaiting_modules", - "last_error", - "priority", - "type", - "import_trigger", - "dispatch_error_count", - ): - assert key in match, "missing key '%s' in status dict" % key - finally: - SerializerStore._records.pop("fake_test_serializer", None) - - -def test_reset_for_tests_clears_registry_state(): + status = list_serializer_status() + assert isinstance(status, list) + + match = next((s for s in status if s["name"] == "fake_test_serializer"), None) + assert match is not None + assert match["state"] == "active" + assert match["priority"] == 42 + assert match["type"] == "fake_test_serializer" + assert match["import_trigger"] == "eager" + assert match["class_path"] == "inline.FakeSer" + + expected_keys = [ + "name", + "class_path", + "state", + "awaiting_modules", + "last_error", + "priority", + "type", + "import_trigger", + "dispatch_error_count", + ] + for key in expected_keys: + assert key in match, f"missing key '{key}' in status dict" + + +def test_reset_for_tests_clears_registry_state(isolated_store, make_serializer_module): """SerializerStore._reset_for_tests clears _records, _active_serializers, _pending_by_module, _ordered_cache, and per-class _lazy_imported_names.""" - import sys as _sys - import types as _types from metaflow.datastore.artifacts.lazy_registry import _interceptor - # Seed state by bootstrapping an inline serializer - ser_mod = _types.ModuleType("_test_reset_ser_mod") - exec( - """ + source = """ from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat class _ResetProbe(ArtifactSerializer): TYPE = "test_reset_probe" @classmethod - def setup_imports(cls, context=None): - cls.lazy_import("json") + def setup_imports(cls, context=None): cls.lazy_import("json") @classmethod def can_serialize(cls, obj): return False @classmethod def can_deserialize(cls, metadata): return False @classmethod - def serialize(cls, obj, format=SerializationFormat.STORAGE): - raise NotImplementedError + def serialize(cls, obj, format=SerializationFormat.STORAGE): raise NotImplementedError @classmethod - def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): - raise NotImplementedError -""", - ser_mod.__dict__, + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): raise NotImplementedError +""" + ser_mod = make_serializer_module("_test_reset_ser_mod", source) + + # Seed state by bootstrapping inline serializers + SerializerStore.bootstrap_entries( + [("test_reset_probe", "_test_reset_ser_mod._ResetProbe")] + ) + SerializerStore.bootstrap_entries( + [("test_reset_pending", "_never_exists_mod._Absent")] ) - _sys.modules["_test_reset_ser_mod"] = ser_mod - - try: - SerializerStore.bootstrap_entries( - [ - ("test_reset_probe", "_test_reset_ser_mod._ResetProbe"), - ] - ) - # Also seed a pending record to exercise _pending_by_module - SerializerStore.bootstrap_entries( - [ - ("test_reset_pending", "_never_exists_mod._Absent"), - ] - ) - - # Snapshot pre-reset state - assert "test_reset_probe" in SerializerStore._records - assert ser_mod._ResetProbe in SerializerStore._active_serializers - assert "_never_exists_mod" in SerializerStore._pending_by_module - # ResetProbe should have _lazy_imported_names populated - assert "json" in ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) - - # Pre-reset, the interceptor should be watching _never_exists_mod. - assert "_never_exists_mod" in _interceptor._watched - - # Call reset - SerializerStore._reset_for_tests() - - # Post-reset: all registry state empty - assert SerializerStore._records == {} - assert len(SerializerStore._active_serializers) == 0 - assert SerializerStore._pending_by_module == {} - assert SerializerStore._ordered_cache is None - - # The probe class should no longer have stashed attrs - assert "json" not in ser_mod._ResetProbe.__dict__ - assert ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) == set() - - # Interceptor watches should also be cleared - assert "_never_exists_mod" not in _interceptor._watched - finally: - # In case reset didn't clean up (e.g., test failed mid-way) - SerializerStore._all_serializers.pop("test_reset_probe", None) - SerializerStore._records.pop("test_reset_probe", None) - SerializerStore._records.pop("test_reset_pending", None) - SerializerStore._active_serializers.discard(ser_mod._ResetProbe) - SerializerStore._pending_by_module.clear() - SerializerStore._ordered_cache = None - _sys.modules.pop("_test_reset_ser_mod", None) - for attr in ("json",): - if attr in ser_mod._ResetProbe.__dict__: - delattr(ser_mod._ResetProbe, attr) - # Re-bootstrap so subsequent tests see the normal active pool - # (e.g. PickleSerializer in _active_serializers + _records). - SerializerStore.bootstrap() + + # Snapshot pre-reset state assertion + assert "test_reset_probe" in SerializerStore._records + assert ser_mod._ResetProbe in SerializerStore._active_serializers + assert "_never_exists_mod" in SerializerStore._pending_by_module + assert "json" in getattr(ser_mod._ResetProbe, "_lazy_imported_names", set()) + assert "_never_exists_mod" in _interceptor._watched + + # Act + SerializerStore._reset_for_tests() + + # Assert post-reset state is empty + assert SerializerStore._records == {} + assert len(SerializerStore._active_serializers) == 0 + assert SerializerStore._pending_by_module == {} + assert SerializerStore._ordered_cache is None + + # The probe class should no longer have stashed attrs + assert "json" not in ser_mod._ResetProbe.__dict__ + assert getattr(ser_mod._ResetProbe, "_lazy_imported_names", set()) == set() + assert "_never_exists_mod" not in _interceptor._watched diff --git a/test/unit/test_serializer_public_api.py b/test/unit/test_serializer_public_api.py index 50a3e7e8634..5d8dc5f1e19 100644 --- a/test/unit/test_serializer_public_api.py +++ b/test/unit/test_serializer_public_api.py @@ -1,55 +1,55 @@ """Smoke tests guarding the public surface of metaflow.datastore.artifacts.""" - -def test_register_serializer_for_type_not_public(): - """Imperative per-type registration is not a public API.""" - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "register_serializer_for_type") - - -def test_serializer_config_not_public(): - """SerializerConfig is not a public export.""" - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "SerializerConfig") - - -def test_register_serializer_config_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "register_serializer_config") - - -def test_iter_registered_configs_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "iter_registered_configs") - - -def test_load_serializer_class_not_public(): - import metaflow.datastore.artifacts as mda - - assert not hasattr(mda, "load_serializer_class") - - -def test_plugins_has_no_artifact_serializers_global(): - """metaflow.plugins does not expose a resolved ARTIFACT_SERIALIZERS global. - Dispatch reads directly from SerializerStore.get_ordered_serializers().""" - import metaflow.plugins as mplugins - +import pytest + +import metaflow.datastore.artifacts as mda +import metaflow.plugins as mplugins +from metaflow.datastore.artifacts import list_serializer_status + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "internal_attribute", + [ + "register_serializer_for_type", + "SerializerConfig", + "register_serializer_config", + "iter_registered_configs", + "load_serializer_class", + ], + ids=[ + "register_serializer_for_type", + "serializer_config", + "register_serializer_config", + "iter_registered_configs", + "load_serializer_class", + ], +) +def test_datastore_artifacts_hides_internal_api(internal_attribute): + """Ensure internal serialization methods and configs are not exposed publicly.""" + assert not hasattr(mda, internal_attribute) + + +def test_plugins_hides_artifact_serializers_global(): + """ + metaflow.plugins does not expose a resolved ARTIFACT_SERIALIZERS global. + Dispatch reads directly from SerializerStore.get_ordered_serializers(). + """ assert not hasattr( mplugins, "ARTIFACT_SERIALIZERS" ), "Expected ARTIFACT_SERIALIZERS to be absent; still present" -def test_pickle_serializer_is_active_after_import(): - """After import metaflow, PickleSerializer should be in ACTIVE state.""" - from metaflow.datastore.artifacts import list_serializer_status - +def test_pickle_serializer_defaults_to_active_state(): + """After importing metaflow, PickleSerializer should be in the ACTIVE state.""" status = list_serializer_status() pickle_rec = next((r for r in status if r.get("type") == "pickle"), None) - assert pickle_rec is not None, "PickleSerializer record missing; status=%r" % status - assert pickle_rec["state"] == "active", ( - "Expected pickle active; got %r" % pickle_rec - ) + + assert pickle_rec is not None, f"PickleSerializer record missing; status={status!r}" + assert ( + pickle_rec["state"] == "active" + ), f"Expected pickle active; got {pickle_rec!r}" diff --git a/test/unit/test_to_pod.py b/test/unit/test_to_pod.py index 3968dc1fbbe..70d1f5c5439 100644 --- a/test/unit/test_to_pod.py +++ b/test/unit/test_to_pod.py @@ -4,8 +4,13 @@ (used by DAGNode.node_info serialization for extensions like FunctionSpec). """ +import pytest from metaflow.util import to_pod +# --------------------------------------------------------------------------- +# Dummy Callables for Testing +# --------------------------------------------------------------------------- + def _top_level_fn(): pass @@ -20,45 +25,77 @@ def instance_method(self): pass -def test_to_pod_primitives(): - assert to_pod("abc") == "abc" - assert to_pod(42) == 42 - assert to_pod(3.14) == 3.14 - - -def test_to_pod_list_set_tuple(): - assert to_pod([1, 2, 3]) == [1, 2, 3] - assert sorted(to_pod({1, 2, 3})) == [1, 2, 3] - assert to_pod((1, 2, 3)) == [1, 2, 3] - - -def test_to_pod_dict(): +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "input_val, expected", + [ + ("abc", "abc"), + (42, 42), + (3.14, 3.14), + ], + ids=["string", "integer", "float"], +) +def test_to_pod_converts_primitives_unchanged(input_val, expected): + """Test that basic primitives pass through POD conversion unchanged.""" + assert to_pod(input_val) == expected + + +@pytest.mark.parametrize( + "input_val, expected", + [ + ([1, 2, 3], [1, 2, 3]), + ((1, 2, 3), [1, 2, 3]), + ({1, 2, 3}, [1, 2, 3]), + ], + ids=["list", "tuple", "set"], +) +def test_to_pod_converts_iterable_collections_to_lists(input_val, expected): + """Test that lists, tuples, and sets are converted to standard lists.""" + result = to_pod(input_val) + + # Sets are unordered, so we must sort the result before comparing + if isinstance(input_val, set): + assert sorted(result) == expected + else: + assert result == expected + + +def test_to_pod_converts_dicts_unchanged(): + """Test that simple dictionaries pass through unchanged.""" assert to_pod({"a": 1, "b": 2}) == {"a": 1, "b": 2} -def test_to_pod_nested(): +def test_to_pod_handles_nested_structures(): + """Test that to_pod recursively converts nested collections.""" value = {"outer": [{"inner": (1, 2)}], "other": {"k": "v"}} - assert to_pod(value) == { + expected = { "outer": [{"inner": [1, 2]}], "other": {"k": "v"}, } + assert to_pod(value) == expected -def test_to_pod_callable_uses_qualname(): +def test_to_pod_converts_callable_to_qualname(): """Callables serialize to their __qualname__ for _graph_info persistence.""" result = to_pod(_top_level_fn) assert result == "_top_level_fn" -def test_to_pod_callable_in_dict(): +def test_to_pod_converts_callables_inside_dictionaries(): """Simulates DAGNode.node_info with function references (FunctionSpec use case).""" result = to_pod({"init_func": _top_level_fn, "call_func": _Wrapper.static_method}) + assert result["init_func"] == "_top_level_fn" assert result["call_func"] == "_Wrapper.static_method" -def test_to_pod_lambda_uses_qualname(): - """Lambdas have __qualname__ like 'test_to_pod_lambda_uses_qualname..'.""" +def test_to_pod_converts_lambda_to_qualname_string(): + """Lambdas have __qualname__ like 'test_to_pod_converts_lambda_to_qualname_string..'.""" fn = lambda x: x # noqa: E731 result = to_pod(fn) + assert "" in result From e44b129f708cfb3138488d299ccd7e88fbce03e6 Mon Sep 17 00:00:00 2001 From: agsaru Date: Tue, 9 Jun 2026 04:12:51 +0000 Subject: [PATCH 2/2] added localbatch tests --- test/unit/localbatch/test_localbatch.py | 561 ++++++++++++------------ 1 file changed, 281 insertions(+), 280 deletions(-) diff --git a/test/unit/localbatch/test_localbatch.py b/test/unit/localbatch/test_localbatch.py index de726a06b87..f9778c2bc8a 100644 --- a/test/unit/localbatch/test_localbatch.py +++ b/test/unit/localbatch/test_localbatch.py @@ -2,13 +2,13 @@ Integration tests for localbatch — the local AWS Batch emulator. Test layout mirrors the spin tests: - - TestBatchAPI : validates every Batch REST endpoint using boto3. - No Docker required — jobs fail gracefully when unavailable. - - TestDockerExecution : validates actual container execution. - Skipped when Docker is not running. - - TestMetaflowE2E : runs a real Metaflow flow via the @batch decorator - and verifies artifacts, mirroring the spin test style. - Requires Docker. + - Batch REST API Tests : validates every Batch REST endpoint using boto3. + No Docker required — jobs fail gracefully when unavailable. + - Docker Execution Tests: validates actual container execution. + Skipped when Docker is not running. + - Metaflow E2E Tests : runs a real Metaflow flow via the @batch decorator + and verifies artifacts, mirroring the spin test style. + Requires Docker. Run only the API tests (no Docker needed): pytest test/unit/localbatch/ -m "not docker" @@ -18,7 +18,6 @@ """ import time - import pytest import requests @@ -26,7 +25,7 @@ # --------------------------------------------------------------------------- -# Helpers +# Helpers & Shared Fixtures # --------------------------------------------------------------------------- @@ -58,278 +57,292 @@ def _register(batch_client, name, command=None): return resp["jobDefinitionArn"] +@pytest.fixture(autouse=True) +def _require_docker_if_marked(request): + """Automatically skip any test marked with @pytest.mark.docker if daemon is unavailable.""" + if "docker" in request.keywords: + try: + import docker + + docker.from_env().ping() + except Exception: + pytest.skip("Docker not available") + + # --------------------------------------------------------------------------- -# Batch REST API tests (no Docker required) +# Batch REST API Tests (No Docker Required) # --------------------------------------------------------------------------- -class TestBatchAPI: - def test_health(self, localbatch_server): - resp = requests.get(f"{localbatch_server.base_url}/health") - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - def test_describe_queues_returns_default(self, batch_client): - resp = batch_client.describe_job_queues() - queues = resp["jobQueues"] - assert len(queues) >= 1 - names = [q["jobQueueName"] for q in queues] - assert "localbatch-default" in names - - def test_default_queue_is_healthy(self, batch_client): - resp = batch_client.describe_job_queues(jobQueues=["localbatch-default"]) - q = resp["jobQueues"][0] - assert q["state"] == "ENABLED" - assert q["status"] == "VALID" - assert len(q["computeEnvironmentOrder"]) >= 1 - - def test_describe_compute_environments(self, batch_client): - resp = batch_client.describe_compute_environments() - envs = resp["computeEnvironments"] - assert len(envs) >= 1 - assert envs[0]["status"] == "VALID" - - def test_register_job_definition(self, batch_client): +def test_health(localbatch_server): + resp = requests.get(f"{localbatch_server.base_url}/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + +def test_describe_queues_returns_default(batch_client): + resp = batch_client.describe_job_queues() + queues = resp["jobQueues"] + assert len(queues) >= 1 + names = [q["jobQueueName"] for q in queues] + assert "localbatch-default" in names + + +def test_default_queue_is_healthy(batch_client): + resp = batch_client.describe_job_queues(jobQueues=["localbatch-default"]) + q = resp["jobQueues"][0] + assert q["state"] == "ENABLED" + assert q["status"] == "VALID" + assert len(q["computeEnvironmentOrder"]) >= 1 + + +def test_describe_compute_environments(batch_client): + resp = batch_client.describe_compute_environments() + envs = resp["computeEnvironments"] + assert len(envs) >= 1 + assert envs[0]["status"] == "VALID" + + +def test_register_job_definition(batch_client): + resp = batch_client.register_job_definition( + jobDefinitionName="api-test-jobdef", + type="container", + containerProperties={ + "image": "alpine:latest", + "command": ["echo", "hello"], + "resourceRequirements": [ + {"type": "VCPU", "value": "1"}, + {"type": "MEMORY", "value": "256"}, + ], + }, + ) + assert resp["jobDefinitionName"] == "api-test-jobdef" + assert resp["revision"] == 1 + assert resp["jobDefinitionArn"].startswith("arn:aws:batch:") + + +def test_revision_increments_on_re_register(batch_client): + name = "revision-increment-jobdef" + for expected in (1, 2, 3): resp = batch_client.register_job_definition( - jobDefinitionName="api-test-jobdef", + jobDefinitionName=name, type="container", containerProperties={ "image": "alpine:latest", - "command": ["echo", "hello"], + "command": ["echo", str(expected)], "resourceRequirements": [ {"type": "VCPU", "value": "1"}, {"type": "MEMORY", "value": "256"}, ], }, ) - assert resp["jobDefinitionName"] == "api-test-jobdef" - assert resp["revision"] == 1 - assert resp["jobDefinitionArn"].startswith("arn:aws:batch:") - - def test_revision_increments_on_re_register(self, batch_client): - name = "revision-increment-jobdef" - for expected in (1, 2, 3): - resp = batch_client.register_job_definition( - jobDefinitionName=name, - type="container", - containerProperties={ - "image": "alpine:latest", - "command": ["echo", str(expected)], - "resourceRequirements": [ - {"type": "VCPU", "value": "1"}, - {"type": "MEMORY", "value": "256"}, - ], - }, - ) - assert resp["revision"] == expected - - def test_describe_job_definitions_by_name(self, batch_client): - _register(batch_client, "describe-by-name-job") - resp = batch_client.describe_job_definitions( - jobDefinitionName="describe-by-name-job", status="ACTIVE" - ) - defs = resp["jobDefinitions"] - assert len(defs) >= 1 - assert defs[0]["jobDefinitionName"] == "describe-by-name-job" - - def test_submit_job_returns_id_and_name(self, batch_client): - arn = _register(batch_client, "submit-test-job") - resp = batch_client.submit_job( - jobName="submit-test-run", - jobQueue="localbatch-default", - jobDefinition=arn, - ) - assert resp["jobId"] - assert resp["jobName"] == "submit-test-run" - - def test_job_reaches_terminal_state(self, batch_client): - """Without Docker the job fails; with Docker it succeeds — both are valid.""" - arn = _register(batch_client, "terminal-test-job") - resp = batch_client.submit_job( - jobName="terminal-run", - jobQueue="localbatch-default", - jobDefinition=arn, - ) - job = _poll_job(batch_client, resp["jobId"]) - assert job["status"] in ("SUCCEEDED", "FAILED") - assert job["jobName"] == "terminal-run" - assert job["jobQueue"] == "localbatch-default" + assert resp["revision"] == expected + - def test_describe_jobs_returns_correct_shape(self, batch_client): - arn = _register(batch_client, "describe-shape-job") - job_id = batch_client.submit_job( - jobName="describe-shape-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] +def test_describe_job_definitions_by_name(batch_client): + _register(batch_client, "describe-by-name-job") + resp = batch_client.describe_job_definitions( + jobDefinitionName="describe-by-name-job", status="ACTIVE" + ) + defs = resp["jobDefinitions"] + assert len(defs) >= 1 + assert defs[0]["jobDefinitionName"] == "describe-by-name-job" - _poll_job(batch_client, job_id) - resp = batch_client.describe_jobs(jobs=[job_id]) - job = resp["jobs"][0] - for field in ( - "jobId", - "jobName", - "jobQueue", - "jobDefinition", - "status", - "createdAt", - ): - assert field in job, f"Missing field: {field}" - - def test_terminate_job_transitions_to_failed(self, batch_client): - arn = _register(batch_client, "terminate-test-job", command=["sleep", "999"]) - job_id = batch_client.submit_job( - jobName="terminate-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - batch_client.terminate_job(jobId=job_id, reason="cancelled by test") - - job = _poll_job(batch_client, job_id) - assert job["status"] == "FAILED" - assert "cancelled by test" in job.get("statusReason", "") - - def test_list_jobs_returns_summary_list(self, batch_client): - resp = batch_client.list_jobs(jobQueue="localbatch-default", jobStatus="FAILED") - assert "jobSummaryList" in resp - for entry in resp["jobSummaryList"]: - assert "jobId" in entry - assert "status" in entry - - def test_ecs_metadata_endpoint(self, localbatch_server): - job_id = "deadbeef-1234-5678-abcd-000000000000" - resp = requests.get(f"{localbatch_server.base_url}/metadata/{job_id}/task") - assert resp.status_code == 200 - data = resp.json() - container = data["Containers"][0] - opts = container["LogOptions"] - assert opts["awslogs-group"] == "/localbatch/batch/job" - assert job_id in opts["awslogs-stream"] - assert container["LogDriver"] == "awslogs" +def test_submit_job_returns_id_and_name(batch_client): + arn = _register(batch_client, "submit-test-job") + resp = batch_client.submit_job( + jobName="submit-test-run", + jobQueue="localbatch-default", + jobDefinition=arn, + ) + assert resp["jobId"] + assert resp["jobName"] == "submit-test-run" + + +def test_job_reaches_terminal_state(batch_client): + """Without Docker the job fails; with Docker it succeeds — both are valid.""" + arn = _register(batch_client, "terminal-test-job") + resp = batch_client.submit_job( + jobName="terminal-run", + jobQueue="localbatch-default", + jobDefinition=arn, + ) + job = _poll_job(batch_client, resp["jobId"]) + assert job["status"] in ("SUCCEEDED", "FAILED") + assert job["jobName"] == "terminal-run" + assert job["jobQueue"] == "localbatch-default" + + +def test_describe_jobs_returns_correct_shape(batch_client): + arn = _register(batch_client, "describe-shape-job") + job_id = batch_client.submit_job( + jobName="describe-shape-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + _poll_job(batch_client, job_id) + + resp = batch_client.describe_jobs(jobs=[job_id]) + job = resp["jobs"][0] + for field in ( + "jobId", + "jobName", + "jobQueue", + "jobDefinition", + "status", + "createdAt", + ): + assert field in job, f"Missing field: {field}" + + +def test_terminate_job_transitions_to_failed(batch_client): + arn = _register(batch_client, "terminate-test-job", command=["sleep", "999"]) + job_id = batch_client.submit_job( + jobName="terminate-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + batch_client.terminate_job(jobId=job_id, reason="cancelled by test") + + job = _poll_job(batch_client, job_id) + assert job["status"] == "FAILED" + assert "cancelled by test" in job.get("statusReason", "") + + +def test_list_jobs_returns_summary_list(batch_client): + resp = batch_client.list_jobs(jobQueue="localbatch-default", jobStatus="FAILED") + assert "jobSummaryList" in resp + for entry in resp["jobSummaryList"]: + assert "jobId" in entry + assert "status" in entry + + +def test_ecs_metadata_endpoint(localbatch_server): + job_id = "deadbeef-1234-5678-abcd-000000000000" + resp = requests.get(f"{localbatch_server.base_url}/metadata/{job_id}/task") + assert resp.status_code == 200 + data = resp.json() + container = data["Containers"][0] + opts = container["LogOptions"] + assert opts["awslogs-group"] == "/localbatch/batch/job" + assert job_id in opts["awslogs-stream"] + assert container["LogDriver"] == "awslogs" # --------------------------------------------------------------------------- -# Docker execution tests +# Docker Execution Tests # --------------------------------------------------------------------------- @pytest.mark.docker -class TestDockerExecution: - @pytest.fixture(autouse=True) - def _require_docker(self): - try: - import docker +def test_successful_container(batch_client): + arn = _register( + batch_client, + "docker-success-job", + command=["sh", "-c", "echo success && exit 0"], + ) + job_id = batch_client.submit_job( + jobName="docker-success-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] - docker.from_env().ping() - except Exception: - pytest.skip("Docker not available") + job = _poll_job(batch_client, job_id, timeout=60) + assert job["status"] == "SUCCEEDED" - def test_successful_container(self, batch_client): - arn = _register( - batch_client, - "docker-success-job", - command=["sh", "-c", "echo success && exit 0"], - ) - job_id = batch_client.submit_job( - jobName="docker-success-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - job = _poll_job(batch_client, job_id, timeout=60) - assert job["status"] == "SUCCEEDED" - - def test_failed_container_reports_exit_code(self, batch_client): - arn = _register( - batch_client, "docker-fail-job", command=["sh", "-c", "exit 42"] - ) - job_id = batch_client.submit_job( - jobName="docker-fail-run", - jobQueue="localbatch-default", - jobDefinition=arn, - )["jobId"] - - job = _poll_job(batch_client, job_id, timeout=60) - assert job["status"] == "FAILED" - assert job["container"]["exitCode"] == 42 - - def test_inject_env_is_visible_inside_container(self): - """ - Spin up a separate localbatch instance with inject_env set and confirm - the container can see the injected variable. - """ - import threading - - import boto3 - import uvicorn - from localbatch.runner import DockerRunner - from localbatch.server import create_app - from localbatch.store import Store - - port = 18766 - store = Store(queue_name="inject-queue") - runner = DockerRunner( - store, - host_addr="host.docker.internal", - port=port, - inject_env={"LOCALBATCH_CANARY": "canary-value"}, - ) - app = create_app(store, runner) - server = uvicorn.Server( - uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") - ) - t = threading.Thread(target=server.run, daemon=True) - t.start() - - for _ in range(100): - try: - requests.get(f"http://127.0.0.1:{port}/health", timeout=0.1) - break - except Exception: - time.sleep(0.05) - - client = boto3.client( - "batch", - endpoint_url=f"http://127.0.0.1:{port}", - region_name="us-east-1", - aws_access_key_id="test", - aws_secret_access_key="test", - ) - client.register_job_definition( - jobDefinitionName="inject-canary-job", - type="container", - containerProperties={ - "image": "alpine:latest", - "command": [ - "sh", - "-c", - 'test "$LOCALBATCH_CANARY" = "canary-value"', - ], - "resourceRequirements": [ - {"type": "VCPU", "value": "1"}, - {"type": "MEMORY", "value": "256"}, - ], - }, - ) - job_id = client.submit_job( - jobName="inject-canary-run", - jobQueue="inject-queue", - jobDefinition="inject-canary-job:1", - )["jobId"] - job = _poll_job(client, job_id, timeout=60) - assert ( - job["status"] == "SUCCEEDED" - ), "LOCALBATCH_CANARY was not visible or had wrong value inside container" +@pytest.mark.docker +def test_failed_container_reports_exit_code(batch_client): + arn = _register(batch_client, "docker-fail-job", command=["sh", "-c", "exit 42"]) + job_id = batch_client.submit_job( + jobName="docker-fail-run", + jobQueue="localbatch-default", + jobDefinition=arn, + )["jobId"] + + job = _poll_job(batch_client, job_id, timeout=60) + assert job["status"] == "FAILED" + assert job["container"]["exitCode"] == 42 - server.should_exit = True - t.join(timeout=5) + +@pytest.mark.docker +def test_inject_env_is_visible_inside_container(): + """Spin up a separate localbatch instance with inject_env set and confirm + + the container can see the injected variable. + """ + import threading + import boto3 + import uvicorn + from localbatch.runner import DockerRunner + from localbatch.server import create_app + from localbatch.store import Store + + port = 18766 + store = Store(queue_name="inject-queue") + runner = DockerRunner( + store, + host_addr="host.docker.internal", + port=port, + inject_env={"LOCALBATCH_CANARY": "canary-value"}, + ) + app = create_app(store, runner) + server = uvicorn.Server( + uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + ) + t = threading.Thread(target=server.run, daemon=True) + t.start() + + for _ in range(100): + try: + requests.get(f"http://127.0.0.1:{port}/health", timeout=0.1) + break + except Exception: + time.sleep(0.05) + + client = boto3.client( + "batch", + endpoint_url=f"http://127.0.0.1:{port}", + region_name="us-east-1", + aws_access_key_id="test", + aws_secret_access_key="test", + ) + client.register_job_definition( + jobDefinitionName="inject-canary-job", + type="container", + containerProperties={ + "image": "alpine:latest", + "command": [ + "sh", + "-c", + 'test "$LOCALBATCH_CANARY" = "canary-value"', + ], + "resourceRequirements": [ + {"type": "VCPU", "value": "1"}, + {"type": "MEMORY", "value": "256"}, + ], + }, + ) + job_id = client.submit_job( + jobName="inject-canary-run", + jobQueue="inject-queue", + jobDefinition="inject-canary-job:1", + )["jobId"] + + job = _poll_job(client, job_id, timeout=60) + assert ( + job["status"] == "SUCCEEDED" + ), "LOCALBATCH_CANARY was not visible or had wrong value inside container" + + server.should_exit = True + t.join(timeout=5) # --------------------------------------------------------------------------- -# Metaflow end-to-end test (mirrors spin test style) +# Metaflow End-to-End Tests (Mirrors Spin Test Style) # --------------------------------------------------------------------------- @@ -341,37 +354,25 @@ def test_inject_env_is_visible_inside_container(self): @pytest.mark.docker -class TestMetaflowE2E: - """ - Runs a real Metaflow flow against localbatch and verifies that artifacts - produced inside the @batch step are correctly persisted and readable - from the client side — the same contract the spin tests enforce. +@_NEEDS_CORE_BATCH_PARAMS +def test_batch_step_artifacts_are_persisted(simple_batch_run): + """The @batch step writes message='hello from localbatch' and value=42. + + Both must be readable via the Metaflow client after the run finishes. """ + task = simple_batch_run["start"].task + assert task["message"].data == "hello from localbatch" + assert task["value"].data == 42 - @pytest.fixture(autouse=True) - def _require_docker(self): - try: - import docker - docker.from_env().ping() - except Exception: - pytest.skip("Docker not available") +@pytest.mark.docker +@_NEEDS_CORE_BATCH_PARAMS +def test_run_succeeds(simple_batch_run): + assert simple_batch_run.successful - @_NEEDS_CORE_BATCH_PARAMS - def test_batch_step_artifacts_are_persisted(self, simple_batch_run): - """ - The @batch step writes message='hello from localbatch' and value=42. - Both must be readable via the Metaflow client after the run finishes. - """ - task = simple_batch_run["start"].task - assert task["message"].data == "hello from localbatch" - assert task["value"].data == 42 - - @_NEEDS_CORE_BATCH_PARAMS - def test_run_succeeds(self, simple_batch_run): - assert simple_batch_run.successful - - @_NEEDS_CORE_BATCH_PARAMS - def test_all_steps_have_tasks(self, simple_batch_run): - step_names = {step.id for step in simple_batch_run.steps()} - assert {"start", "end"} <= step_names + +@pytest.mark.docker +@_NEEDS_CORE_BATCH_PARAMS +def test_all_steps_have_tasks(simple_batch_run): + step_names = {step.id for step in simple_batch_run.steps()} + assert {"start", "end"} <= step_names