Skip to content
Merged
26 changes: 18 additions & 8 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Hashable, Mapping
from copy import deepcopy
from typing import Any
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -408,9 +408,10 @@ def __init__(
}

if self.do_ccp:
report_format[LabelStatsKeys.LABEL][0].update(
{LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}
)
report_format[LabelStatsKeys.LABEL][0].update({
LabelStatsKeys.LABEL_SHAPE: None,
LabelStatsKeys.LABEL_NCOMP: None,
})

super().__init__(stats_name, report_format)
self.update_ops(LabelStatsKeys.IMAGE_INTST, SampleOperations())
Expand Down Expand Up @@ -470,6 +471,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
start = time.time()
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
Expand All @@ -479,8 +481,15 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
# If there's a device mismatch, move both to CUDA if either is on CUDA, otherwise sync to image device
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
if using_cuda:
# Prefer CUDA for performance when there's a mix
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
else:
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
Comment thread
ericspod marked this conversation as resolved.

Comment thread
ericspod marked this conversation as resolved.
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
Expand Down Expand Up @@ -724,9 +733,10 @@ def __init__(
LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}],
}
if self.do_ccp:
report_format[LabelStatsKeys.LABEL][0].update(
{LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}
)
report_format[LabelStatsKeys.LABEL][0].update({
LabelStatsKeys.LABEL_SHAPE: None,
LabelStatsKeys.LABEL_NCOMP: None,
})

super().__init__(stats_name, report_format)
self.update_ops(LabelStatsKeys.IMAGE_INTST, SummaryOperations())
Expand Down
41 changes: 19 additions & 22 deletions tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,14 @@ def test_transform_analyzer_class(self):

def test_image_stats_case_analyzer(self):
analyzer = ImageStats(image_key="image")
transform = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]), # this creates label to be (1,H,W,D)
ToDeviced(keys=["image"], device=device, non_blocking=True),
Orientationd(keys=["image"], axcodes="RAS"),
EnsureTyped(keys=["image"], data_type="tensor"),
analyzer,
]
)
transform = Compose([
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]), # this creates label to be (1,H,W,D)
ToDeviced(keys=["image"], device=device, non_blocking=True),
Orientationd(keys=["image"], axcodes="RAS"),
EnsureTyped(keys=["image"], data_type="tensor"),
analyzer,
])
create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
ds = Dataset(data=files)
Expand Down Expand Up @@ -346,18 +344,16 @@ def test_foreground_image_stats_cases_analyzer(self):

def test_label_stats_case_analyzer(self):
analyzer = LabelStats(image_key="image", label_key="label")
transform = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
Orientationd(keys=["image", "label"], axcodes="RAS"),
EnsureTyped(keys=["image", "label"], data_type="tensor"),
Lambdad(keys=["label"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
SqueezeDimd(keys=["label"], dim=0),
analyzer,
]
)
transform = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
Orientationd(keys=["image", "label"], axcodes="RAS"),
EnsureTyped(keys=["image", "label"], data_type="tensor"),
Lambdad(keys=["label"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
SqueezeDimd(keys=["label"], dim=0),
analyzer,
])
create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
ds = Dataset(data=files)
Expand Down Expand Up @@ -393,6 +389,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
result = analyzer({"image": image_tensor, "label": label_tensor})
report = result["label_stats"]

# Verify report format and computation succeeded despite mixed/unified devices
assert verify_report_format(report, analyzer.get_report_format())
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]

Expand Down
Loading