Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[Union[Dict[str, Tuple[Type, Any]], OrderedDict[str, Type]]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
Expand Down
34 changes: 25 additions & 9 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import count
from typing import Any, Dict, List, Optional

from flytekit import ContainerTask
from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.base_task import PythonTask
Expand All @@ -35,7 +36,7 @@ class MapPythonTask(PythonTask):

def __init__(
self,
python_function_task: PythonFunctionTask,
python_function_task: typing.Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
**kwargs,
Expand All @@ -55,8 +56,11 @@ def __init__(

collection_interface = transform_interface_to_list_interface(python_function_task.python_interface)
instance = next(self._ids)
_, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
name = f"{mod}.mapper_{f}_{instance}"
if isinstance(python_function_task, ContainerTask):
name = f"raw_container_task.mapper_{python_function_task.name}_{instance}"
else:
_, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
name = f"{mod}.mapper_{f}_{instance}"

self._cmd_prefix = None
self._run_task = python_function_task
Expand Down Expand Up @@ -114,14 +118,20 @@ def prepare_target(self):
self._run_task.reset_command_fn()

def get_container(self, settings: SerializationSettings) -> Container:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_container(settings)
with self.prepare_target():
return self._run_task.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_k8s_pod(settings)
with self.prepare_target():
return self._run_task.get_k8s_pod(settings)

def get_sql(self, settings: SerializationSettings) -> Sql:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_sql(settings)
with self.prepare_target():
return self._run_task.get_sql(settings)

Expand Down Expand Up @@ -221,7 +231,12 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs):
def map_task(
task_function: typing.Union[PythonFunctionTask, ContainerTask],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Expand Down Expand Up @@ -267,8 +282,9 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes
successfully before terminating this task and marking it successful.

"""
if not isinstance(task_function, PythonFunctionTask):
raise ValueError(
f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}"
)
return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)
if isinstance(task_function, PythonFunctionTask) or isinstance(task_function, ContainerTask):
return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)

raise ValueError(
f"Only Flyte python-task, and raw-container types are supported in map tasks currently, received {type(task_function)}"
)
34 changes: 33 additions & 1 deletion tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import flytekit.configuration
from flytekit import LaunchPlan, map_task
from flytekit import ContainerTask, LaunchPlan, kwtypes, map_task
from flytekit.configuration import Image, ImageConfig
from flytekit.core.map_task import MapPythonTask
from flytekit.core.task import TaskMetadata, task
Expand All @@ -24,6 +24,22 @@ def serialization_settings():
)


raw_container = ContainerTask(
name="ellipse-area-metadata-python",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=int),
outputs=kwtypes(area=float),
image="flyte/raw-container:v1",
command=[
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
],
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -96,6 +112,22 @@ def test_serialization(serialization_settings):
]


def test_serialization_with_raw_container(serialization_settings):
maptask = map_task(raw_container, metadata=TaskMetadata(retries=1))
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.command == [
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
]


@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
Expand Down