Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def convert_box_to_mask(
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
boxes_mask_np[tuple(slicing)] = boxes_only_mask
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]

Expand Down
23 changes: 15 additions & 8 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
root: str
child_key: str
(root, _, child_key) = keys
root, _, child_key = keys
if root not in self.ops:
self.ops[root] = [{}]
self.ops[root][0].update({child_key: None})
Expand Down Expand Up @@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
"""
d: dict[Hashable, MetaTensor] = dict(data)
start = time.time()
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
using_cuda = True
else:
using_cuda = False
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device)
Comment thread
benediktjohannes marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just re-reviewed the merged commits and I think we might have a potential issue here.

We set the label_tensor device to the device of the image_tensor here, but we already determined whether cuda is set or not (using_cuda) above by checking whether any of both tensors uses cuda. So in case we got label_tensor to use cuda and image_tensor not, then we would first set using_cuda to True, but then we check whether if label_tensor.device is != image_tensor.device which is True and therefore we set the device of label_tensor to non-cuda (e.g. cpu) which means that now both tensors use cpu while using_cuda is still wrongly set to True.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I see the issue here, the expectation was that image_tensor would always be a CUDA tensor if either were. Instead something like:

if label_tensor.device != image_tensor.device:
    # using_cuda would have to be True here, unless non-CUDA device types like mps have been mixed in
    device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
    label_tensor = label_tensor.to(device)  # type: ignore
    image_tensor= image_tensor.to(device)  # type: ignore

This would select the device for whichever is a CUDA tensor. CC @garciadias

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @benediktjohannes, I am not sure if I understand. Have you looked into the proposed tests? https://github.com/benediktjohannes/MONAI/blob/1a2598b5ab0624fce4e9ab74cd4eed9f922fb801/tests/apps/test_auto3dseg.py#L81

This is my understanding of how this should behave. If you think any of the test cases are wrong, then we will need to correct them.


ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def create_workflow(

"""
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
(workflow_name, config_file) = _pop_args(
workflow_name, config_file = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
) # the default workflow name is "ConfigWorkflow"
if isinstance(workflow_name, str):
Expand Down
4 changes: 2 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class DatasetFunc(Dataset):
"""

def __init__(self, data: Any, func: Callable, **kwargs) -> None:
super().__init__(data=None, transform=None) # type:ignore
super().__init__(data=None, transform=None) # type: ignore
self.src = data
self.func = func
self.kwargs = kwargs
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]:
"""

def stopping_fn(engine: Engine) -> Any:
return -engine.state.output # type:ignore
return -engine.state.output # type: ignore

return stopping_fn

Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def get_edge_surface_distance(
edges_spacing = None
if use_subvoxels:
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
(edges_pred, edges_gt, *areas) = get_mask_edges(
edges_pred, edges_gt, *areas = get_mask_edges(
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
)
if not edges_gt.any():
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
A collection of "vanilla" transforms for IO functions.
"""

from __future__ import annotations

import inspect
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def __init__(
# if the root log level is higher than INFO, set a separate stream handler to record
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
console.is_data_stats_handler = True # type:ignore[attr-defined]
console.is_data_stats_handler = True # type: ignore[attr-defined]
_logger.addHandler(console)

def __call__(
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mccabe
pep8-naming
pycodestyle
pyflakes
black>=25.1.0
isort>=5.1, !=6.0.0
ruff
black==25.1.0
isort>=5.1, <6, !=6.0.0
ruff>=0.14.11,<0.15
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
Expand Down
53 changes: 52 additions & 1 deletion tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SqueezeDimd,
ToDeviced,
)
from monai.utils.enums import DataStatsKeys
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
from tests.test_utils import skip_if_no_cuda

device = "cpu"
Expand All @@ -78,6 +78,13 @@

SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]

LABEL_STATS_DEVICE_TEST_CASES = [
[{"image_device": "cpu", "label_device": "cpu", "image_meta": False}],
[{"image_device": "cuda", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cpu", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cuda", "label_device": "cpu", "image_meta": False}],
]


def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
"""
Expand Down Expand Up @@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self):
report_format = analyzer.get_report_format()
assert verify_report_format(d["label_stats"], report_format)

@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
def test_label_stats_mixed_device_analyzer(self, input_params):
image_device = torch.device(input_params["image_device"])
label_device = torch.device(input_params["label_device"])

if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available():
self.skipTest("CUDA is not available for mixed-device LabelStats tests.")

analyzer = LabelStats(image_key="image", label_key="label")

image_tensor = torch.tensor(
[
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
[[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]],
],
dtype=torch.float32,
).to(image_device)
label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device)

if input_params["image_meta"]:
image_tensor = MetaTensor(image_tensor)
label_tensor = MetaTensor(label_tensor)

result = analyzer({"image": image_tensor, "label": label_tensor})
report = result["label_stats"]

assert verify_report_format(report, analyzer.get_report_format())
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]

label_stats = report[LabelStatsKeys.LABEL]
self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5)
self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5)

label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST]
label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25)
self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75)
self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25)
self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75)

foreground_stats = report[LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_filename_case_analyzer(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_loader_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
"""this test should not generate errors or
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""

from __future__ import annotations

import multiprocessing as mp
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/pyspy_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@
[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer

"""

# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring
# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements
# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error
Expand Down
Loading