Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ def create_job(
]:
if key in attrs:
job.tag(key, attrs.get(key))
# As some values can be affected by users, sanitize them so they adhere to AWS tagging restrictions.
for key in [
"metaflow.version",
"metaflow.user",
Expand All @@ -398,10 +397,12 @@ def create_job(
k, v = sanitize_batch_tag(key, attrs.get(key))
job.tag(k, v)

if aws_batch_tags is not None:
for key, value in aws_batch_tags.items():
job.tag(key, value)

# User-defined tags (e.g. for cost attribution or compliance) should
# always be applied, regardless of BATCH_EMIT_TAGS which controls
# Metaflow's internal observability tags.
if aws_batch_tags is not None:
for key, value in aws_batch_tags.items():
job.tag(key, value)
return job

def launch_job(
Expand Down
61 changes: 61 additions & 0 deletions test/unit/test_batch_user_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Regression test for user-defined AWS Batch tags.

User-defined tags (aws_batch_tags) and METAFLOW_BATCH_DEFAULT_TAGS should
always be applied to Batch jobs, regardless of the BATCH_EMIT_TAGS setting.
BATCH_EMIT_TAGS controls only Metaflow's internal observability tags
(app=metaflow, metaflow.flow_name, etc.), not user-specified tags.

See: https://github.com/Netflix/metaflow/issues/3209
"""

import ast
import inspect
import textwrap

import pytest


def test_user_tags_not_inside_emit_tags_guard():
"""
Verify that the aws_batch_tags block in batch.py is NOT inside
the BATCH_EMIT_TAGS conditional. This is a structural test that
checks the source code's AST to ensure user tags are always applied.
"""
from metaflow.plugins.aws.batch.batch import Batch

source = inspect.getsource(Batch.create_job)
# dedent because inspect.getsource preserves class indentation
source = textwrap.dedent(source)
tree = ast.parse(source)

# Find all If nodes that test BATCH_EMIT_TAGS
def find_emit_tags_guards(node):
"""Find all 'if BATCH_EMIT_TAGS:' blocks and return their AST nodes."""
guards = []
for child in ast.walk(node):
if isinstance(child, ast.If):
# Check if the test is just 'BATCH_EMIT_TAGS'
test = child.test
if isinstance(test, ast.Name) and test.id == "BATCH_EMIT_TAGS":
guards.append(child)
return guards

def contains_aws_batch_tags_usage(node):
"""Check if an AST node contains reference to 'aws_batch_tags'."""
for child in ast.walk(node):
if isinstance(child, ast.Name) and child.id == "aws_batch_tags":
return True
return False

guards = find_emit_tags_guards(tree)
assert guards, "Could not find 'if BATCH_EMIT_TAGS:' block in create_job"
Comment thread
greptile-apps[bot] marked this conversation as resolved.

for guard in guards:
# Check the body of each BATCH_EMIT_TAGS guard
for stmt in guard.body:
assert not contains_aws_batch_tags_usage(stmt), (
"aws_batch_tags is used inside 'if BATCH_EMIT_TAGS:' block. "
"User-defined tags should be applied unconditionally, outside "
"the BATCH_EMIT_TAGS guard."
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Outdated