Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions test/nodes/test_get_worker_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Any, Dict

import torch
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
from torchdata.nodes import get_worker_info, IterableWrapper, ParallelMapper

from .utils import MockSource


def _capture_worker_info(item: Dict[str, Any]) -> Dict[str, Any]:
"""UDF that augments the item with the current worker's WorkerInfo fields."""
info = get_worker_info()
item = dict(item)
if info is None:
item["worker_id"] = None
item["num_workers"] = None
else:
item["worker_id"] = info.id
item["num_workers"] = info.num_workers
return item


def _capture_torch_worker_info(item: Dict[str, Any]) -> Dict[str, Any]:
"""UDF that reads worker info via torch.utils.data.get_worker_info() (process workers only)."""
info = torch.utils.data.get_worker_info()
item = dict(item)
item["worker_id"] = info.id if info is not None else None
item["num_workers"] = info.num_workers if info is not None else None
return item


class TestGetWorkerInfo(TestCase):
def test_none_outside_worker(self) -> None:
self.assertIsNone(get_worker_info())

def test_thread_workers(self) -> None:
num_workers = 4
src = MockSource(num_samples=40)
node = ParallelMapper(src, _capture_worker_info, num_workers=num_workers, method="thread", in_order=False)

results = list(node)
self.assertEqual(len(results), 40)

# All reported worker ids must be in [0, num_workers)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

@unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows")
def test_process_workers_get_worker_info(self) -> None:
"""torch.utils.data.get_worker_info() works correctly in process workers."""
num_workers = 2
src = MockSource(num_samples=8)
node = ParallelMapper(
src,
_capture_torch_worker_info,
num_workers=num_workers,
method="process",
multiprocessing_context="forkserver",
in_order=False,
)

results = list(node)
self.assertEqual(len(results), 8)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

@unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows")
def test_process_workers_torchdata_get_worker_info(self) -> None:
"""torchdata.nodes.get_worker_info() works correctly in process workers."""
num_workers = 2
src = MockSource(num_samples=8)
node = ParallelMapper(
src,
_capture_worker_info,
num_workers=num_workers,
method="process",
multiprocessing_context="forkserver",
in_order=False,
)

results = list(node)
self.assertEqual(len(results), 8)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

def test_num_workers_zero_no_worker_info(self) -> None:
"""With num_workers=0 (inline), get_worker_info() returns None inside UDF."""
src = MockSource(num_samples=5)
node = ParallelMapper(src, _capture_worker_info, num_workers=0)
results = list(node)
for r in results:
self.assertIsNone(r["worker_id"])
self.assertIsNone(r["num_workers"])
3 changes: 2 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._apply_udf import get_worker_info
from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher, Unbatcher
Expand All @@ -19,7 +20,6 @@
from .shuffler import Shuffler
from .types import Stateful


__all__ = [
"BaseNode",
"Batcher",
Expand All @@ -40,6 +40,7 @@
"StopCriteria",
"T",
"Unbatcher",
"get_worker_info",
]

assert sorted(__all__) == __all__
45 changes: 43 additions & 2 deletions torchdata/nodes/_apply_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,71 @@
import multiprocessing.synchronize as python_mp_synchronize
import queue
import threading
from typing import Callable, Union
from typing import Any, Callable, Optional, Union

import torch
import torch.multiprocessing as mp

import torch.utils.data._utils.worker as _worker_module # type: ignore[import]
from torch._utils import ExceptionWrapper

from .constants import QUEUE_TIMEOUT

_thread_local = threading.local()


def get_worker_info() -> Optional[Any]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't live inside _apply_udf. maybe a utils file in nodes/ ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Moved get_worker_info(), the thread-local storage, and the _set_worker_info() helper into a dedicated torchdata/nodes/_worker_info.py. _apply_udf.py now just calls _set_worker_info(worker_id, num_workers) and has no worker-info logic of its own.

"""Return a :class:`~torch.utils.data.WorkerInfo` for the current
:class:`~torchdata.nodes.ParallelMapper` worker, or ``None`` if called
from outside a worker context.

Unlike :func:`torch.utils.data.get_worker_info`, this function uses
thread-local storage and is therefore correct for both thread-based and
process-based :class:`~torchdata.nodes.ParallelMapper` workers.

The returned object has the following attributes:

* ``id`` (int): the worker index (0 to num_workers - 1)
* ``num_workers`` (int): total number of workers
* ``seed`` (int): per-worker seed derived from the initial RNG seed
* ``dataset``: always ``None`` for :class:`~torchdata.nodes.ParallelMapper`

Returns:
A ``WorkerInfo`` object, or ``None`` when called from outside a worker.
"""
return getattr(_thread_local, "worker_info", None)


def _apply_udf(
worker_id: int,
in_q: Union[queue.Queue, mp.Queue],
out_q: Union[queue.Queue, mp.Queue],
udf: Callable,
stop_event: Union[threading.Event, python_mp_synchronize.Event],
num_workers: int = 1,
):
"""_apply_udf assumes in_q emits tuples of (x, idx) where x is the
payload, idx is the index of the result, potentially used for maintaining
ordered outputs. For every input it pulls, a tuple (y, idx) is put on the out_q
where the output of udf(x), an ExceptionWrapper, or StopIteration (if it pulled
StopIteration from in_q).

Sets up worker info before entering the processing loop so that
:func:`torchdata.nodes.get_worker_info` returns a valid
:class:`~torch.utils.data.WorkerInfo` from inside the UDF. For process
workers, :func:`torch.utils.data.get_worker_info` also works because each
process has its own memory space. For thread workers, prefer
:func:`torchdata.nodes.get_worker_info` which uses thread-local storage.
"""
torch.set_num_threads(1)
seed = torch.initial_seed() + worker_id
worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type]
# Thread-local: always returns the correct info for this worker, regardless of
# whether other workers (threads) have set their own worker info concurrently.
_thread_local.worker_info = worker_info
# Module-level global: correct for process workers (isolated memory); for thread
# workers this may race, so callers should use torchdata.nodes.get_worker_info().
_worker_module._worker_info = worker_info # type: ignore[attr-defined]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure of this change. Overall it is safe, but I don't think its very clean to have this inside _apply_udf. Plus we are leaving dataset=None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on the placement, addressed in the same commit by moving everything to ``_worker_info.py.`

On dataset=None: ParallelMapper maps over arbitrary items rather than a PyTorch Dataset object, so there is no meaningful dataset to pass. We still construct a torch.utils.data.WorkerInfo (rather than a custom dataclass) so that torch.utils.data.get_worker_info() continues to work inside process workers, keeping compatibility with existing IterableDataset code that calls it. The # type: ignore[arg-type] on that line suppresses the mypy complaint. Happy to discuss if there's a cleaner alternative you have in mind.


while True:
if stop_event.is_set() and in_q.empty():
break
Expand Down
2 changes: 2 additions & 0 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
self._intermed_q,
self.map_fn,
self._stop,
self.num_workers,
)

elif self.method == "process":
Expand All @@ -213,6 +214,7 @@ def __init__(
self._intermed_q,
self.map_fn,
self._mp_stop,
self.num_workers,
)
self._workers.append(mp_context.Process(target=_apply_udf, args=_args, daemon=True))
for t in self._workers:
Expand Down
Loading