-
Notifications
You must be signed in to change notification settings - Fork 178
Nodes/get worker info parallel mapper #1546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
d4eed6b
a88490b
dc28fcd
715db80
ff748df
b6f0f19
06d32ec
9ce8925
9eb1d04
659593e
b46af14
58779ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
| from typing import Any, Dict | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import IS_WINDOWS, TestCase | ||
| from torchdata.nodes import get_worker_info, IterableWrapper, ParallelMapper | ||
|
|
||
| from .utils import MockSource | ||
|
|
||
|
|
||
| def _capture_worker_info(item: Dict[str, Any]) -> Dict[str, Any]: | ||
| """UDF that augments the item with the current worker's WorkerInfo fields.""" | ||
| info = get_worker_info() | ||
| item = dict(item) | ||
| if info is None: | ||
| item["worker_id"] = None | ||
| item["num_workers"] = None | ||
| else: | ||
| item["worker_id"] = info.id | ||
| item["num_workers"] = info.num_workers | ||
| return item | ||
|
|
||
|
|
||
| def _capture_torch_worker_info(item: Dict[str, Any]) -> Dict[str, Any]: | ||
| """UDF that reads worker info via torch.utils.data.get_worker_info() (process workers only).""" | ||
| info = torch.utils.data.get_worker_info() | ||
| item = dict(item) | ||
| item["worker_id"] = info.id if info is not None else None | ||
| item["num_workers"] = info.num_workers if info is not None else None | ||
| return item | ||
|
|
||
|
|
||
| class TestGetWorkerInfo(TestCase): | ||
| def test_none_outside_worker(self) -> None: | ||
| self.assertIsNone(get_worker_info()) | ||
|
|
||
| def test_thread_workers(self) -> None: | ||
| num_workers = 4 | ||
| src = MockSource(num_samples=40) | ||
| node = ParallelMapper(src, _capture_worker_info, num_workers=num_workers, method="thread", in_order=False) | ||
|
|
||
| results = list(node) | ||
| self.assertEqual(len(results), 40) | ||
|
|
||
| # All reported worker ids must be in [0, num_workers) | ||
| for r in results: | ||
| self.assertIn(r["worker_id"], set(range(num_workers))) | ||
| self.assertEqual(r["num_workers"], num_workers) | ||
|
|
||
| @unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows") | ||
| def test_process_workers_get_worker_info(self) -> None: | ||
| """torch.utils.data.get_worker_info() works correctly in process workers.""" | ||
| num_workers = 2 | ||
| src = MockSource(num_samples=8) | ||
| node = ParallelMapper( | ||
| src, | ||
| _capture_torch_worker_info, | ||
| num_workers=num_workers, | ||
| method="process", | ||
| multiprocessing_context="forkserver", | ||
| in_order=False, | ||
| ) | ||
|
|
||
| results = list(node) | ||
| self.assertEqual(len(results), 8) | ||
| for r in results: | ||
| self.assertIn(r["worker_id"], set(range(num_workers))) | ||
| self.assertEqual(r["num_workers"], num_workers) | ||
|
|
||
| @unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows") | ||
| def test_process_workers_torchdata_get_worker_info(self) -> None: | ||
| """torchdata.nodes.get_worker_info() works correctly in process workers.""" | ||
| num_workers = 2 | ||
| src = MockSource(num_samples=8) | ||
| node = ParallelMapper( | ||
| src, | ||
| _capture_worker_info, | ||
| num_workers=num_workers, | ||
| method="process", | ||
| multiprocessing_context="forkserver", | ||
| in_order=False, | ||
| ) | ||
|
|
||
| results = list(node) | ||
| self.assertEqual(len(results), 8) | ||
| for r in results: | ||
| self.assertIn(r["worker_id"], set(range(num_workers))) | ||
| self.assertEqual(r["num_workers"], num_workers) | ||
|
|
||
| def test_num_workers_zero_no_worker_info(self) -> None: | ||
| """With num_workers=0 (inline), get_worker_info() returns None inside UDF.""" | ||
| src = MockSource(num_samples=5) | ||
| node = ParallelMapper(src, _capture_worker_info, num_workers=0) | ||
| results = list(node) | ||
| for r in results: | ||
| self.assertIsNone(r["worker_id"]) | ||
| self.assertIsNone(r["num_workers"]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,30 +7,71 @@ | |
| import multiprocessing.synchronize as python_mp_synchronize | ||
| import queue | ||
| import threading | ||
| from typing import Callable, Union | ||
| from typing import Any, Callable, Optional, Union | ||
|
|
||
| import torch | ||
| import torch.multiprocessing as mp | ||
|
|
||
| import torch.utils.data._utils.worker as _worker_module # type: ignore[import] | ||
| from torch._utils import ExceptionWrapper | ||
|
|
||
| from .constants import QUEUE_TIMEOUT | ||
|
|
||
| _thread_local = threading.local() | ||
|
|
||
|
|
||
| def get_worker_info() -> Optional[Any]: | ||
| """Return a :class:`~torch.utils.data.WorkerInfo` for the current | ||
| :class:`~torchdata.nodes.ParallelMapper` worker, or ``None`` if called | ||
| from outside a worker context. | ||
|
|
||
| Unlike :func:`torch.utils.data.get_worker_info`, this function uses | ||
| thread-local storage and is therefore correct for both thread-based and | ||
| process-based :class:`~torchdata.nodes.ParallelMapper` workers. | ||
|
|
||
| The returned object has the following attributes: | ||
|
|
||
| * ``id`` (int): the worker index (0 to num_workers - 1) | ||
| * ``num_workers`` (int): total number of workers | ||
| * ``seed`` (int): per-worker seed derived from the initial RNG seed | ||
| * ``dataset``: always ``None`` for :class:`~torchdata.nodes.ParallelMapper` | ||
|
|
||
| Returns: | ||
| A ``WorkerInfo`` object, or ``None`` when called from outside a worker. | ||
| """ | ||
| return getattr(_thread_local, "worker_info", None) | ||
|
|
||
|
|
||
| def _apply_udf( | ||
| worker_id: int, | ||
| in_q: Union[queue.Queue, mp.Queue], | ||
| out_q: Union[queue.Queue, mp.Queue], | ||
| udf: Callable, | ||
| stop_event: Union[threading.Event, python_mp_synchronize.Event], | ||
| num_workers: int = 1, | ||
| ): | ||
| """_apply_udf assumes in_q emits tuples of (x, idx) where x is the | ||
| payload, idx is the index of the result, potentially used for maintaining | ||
| ordered outputs. For every input it pulls, a tuple (y, idx) is put on the out_q | ||
| where the output of udf(x), an ExceptionWrapper, or StopIteration (if it pulled | ||
| StopIteration from in_q). | ||
|
|
||
| Sets up worker info before entering the processing loop so that | ||
| :func:`torchdata.nodes.get_worker_info` returns a valid | ||
| :class:`~torch.utils.data.WorkerInfo` from inside the UDF. For process | ||
| workers, :func:`torch.utils.data.get_worker_info` also works because each | ||
| process has its own memory space. For thread workers, prefer | ||
| :func:`torchdata.nodes.get_worker_info` which uses thread-local storage. | ||
| """ | ||
| torch.set_num_threads(1) | ||
| seed = torch.initial_seed() + worker_id | ||
| worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type] | ||
| # Thread-local: always returns the correct info for this worker, regardless of | ||
| # whether other workers (threads) have set their own worker info concurrently. | ||
| _thread_local.worker_info = worker_info | ||
| # Module-level global: correct for process workers (isolated memory); for thread | ||
| # workers this may race, so callers should use torchdata.nodes.get_worker_info(). | ||
| _worker_module._worker_info = worker_info # type: ignore[attr-defined] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure of this change. Overall it is safe, but I don't think its very clean to have this inside _apply_udf. Plus we are leaving dataset=None.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on the placement, addressed in the same commit by moving everything to ``_worker_info.py.` On |
||
|
|
||
| while True: | ||
| if stop_event.is_set() and in_q.empty(): | ||
| break | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't live inside _apply_udf. maybe a utils file in nodes/ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Moved
get_worker_info(), the thread-local storage, and the_set_worker_info()helper into a dedicatedtorchdata/nodes/_worker_info.py._apply_udf.pynow just calls_set_worker_info(worker_id, num_workers)and has no worker-info logic of its own.