Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
"""
d: dict[Hashable, MetaTensor] = dict(data)
start = time.time()
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
using_cuda = True
else:
using_cuda = False
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device)
Comment thread
benediktjohannes marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just re-reviewed the merged commits and I think we might have a potential issue here.

We set the label_tensor device to the device of the image_tensor here, but we already determined whether cuda is set or not (using_cuda) above by checking whether any of both tensors uses cuda. So in case we got label_tensor to use cuda and image_tensor not, then we would first set using_cuda to True, but then we check whether if label_tensor.device is != image_tensor.device which is True and therefore we set the device of label_tensor to non-cuda (e.g. cpu) which means that now both tensors use cpu while using_cuda is still wrongly set to True.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I see the issue here, the expectation was that image_tensor would always be a CUDA tensor if either were. Instead something like:

if label_tensor.device != image_tensor.device:
    # using_cuda would have to be True here, unless non-CUDA device types like mps have been mixed in
    device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
    label_tensor = label_tensor.to(device)  # type: ignore
    image_tensor= image_tensor.to(device)  # type: ignore

This would select the device for whichever is a CUDA tensor. CC @garciadias

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @benediktjohannes, I am not sure if I understand. Have you looked into the proposed tests? https://github.com/benediktjohannes/MONAI/blob/1a2598b5ab0624fce4e9ab74cd4eed9f922fb801/tests/apps/test_auto3dseg.py#L81

This is my understanding of how this should behave. If you think any of the test cases are wrong, then we will need to correct them.


ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mccabe
pep8-naming
pycodestyle
pyflakes
black>=25.1.0
isort>=5.1, !=6.0.0
ruff
black==25.1.0
isort>=5.1, <6, !=6.0.0
ruff>=0.14.11,<0.15
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
Expand Down
53 changes: 52 additions & 1 deletion tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SqueezeDimd,
ToDeviced,
)
from monai.utils.enums import DataStatsKeys
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
from tests.test_utils import skip_if_no_cuda

device = "cpu"
Expand All @@ -78,6 +78,13 @@

SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]

LABEL_STATS_DEVICE_TEST_CASES = [
[{"image_device": "cpu", "label_device": "cpu", "image_meta": False}],
[{"image_device": "cuda", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cpu", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cuda", "label_device": "cpu", "image_meta": False}],
]


def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
"""
Expand Down Expand Up @@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self):
report_format = analyzer.get_report_format()
assert verify_report_format(d["label_stats"], report_format)

@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
def test_label_stats_mixed_device_analyzer(self, input_params):
image_device = torch.device(input_params["image_device"])
label_device = torch.device(input_params["label_device"])

if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available():
self.skipTest("CUDA is not available for mixed-device LabelStats tests.")

analyzer = LabelStats(image_key="image", label_key="label")

image_tensor = torch.tensor(
[
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
[[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]],
],
dtype=torch.float32,
).to(image_device)
label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device)

if input_params["image_meta"]:
image_tensor = MetaTensor(image_tensor)
label_tensor = MetaTensor(label_tensor)

result = analyzer({"image": image_tensor, "label": label_tensor})
report = result["label_stats"]

assert verify_report_format(report, analyzer.get_report_format())
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]

label_stats = report[LabelStatsKeys.LABEL]
self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5)
self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5)

label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST]
label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25)
self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75)
self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25)
self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75)

foreground_stats = report[LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_filename_case_analyzer(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
Expand Down
Loading