Skip to content

Commit 1a2598b

Browse files
authored
Add mixed-device LabelStats handling and tests (#1)
* Add mixed-device LabelStats coverage * Autofix to mirror CI: Restrict versions of isort, ruff, and black. Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com> --------- Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent e93a911 commit 1a2598b

10 files changed

Lines changed: 116 additions & 31 deletions

File tree

monai/apps/detection/transforms/dictionary.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_miss
125125
super().__init__(box_keys, allow_missing_keys)
126126
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
127127
if len(box_ref_image_keys_tuple) > 1:
128-
raise ValueError("Please provide a single key for box_ref_image_keys.\
129-
All boxes of box_keys are attached to box_ref_image_keys.")
128+
raise ValueError(
129+
"Please provide a single key for box_ref_image_keys.\
130+
All boxes of box_keys are attached to box_ref_image_keys."
131+
)
130132
self.box_ref_image_keys = box_ref_image_keys
131133

132134
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
@@ -287,8 +289,10 @@ def __init__(
287289
super().__init__(box_keys, allow_missing_keys)
288290
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
289291
if len(box_ref_image_keys_tuple) > 1:
290-
raise ValueError("Please provide a single key for box_ref_image_keys.\
291-
All boxes of box_keys are attached to box_ref_image_keys.")
292+
raise ValueError(
293+
"Please provide a single key for box_ref_image_keys.\
294+
All boxes of box_keys are attached to box_ref_image_keys."
295+
)
292296
self.box_ref_image_keys = box_ref_image_keys
293297
self.image_meta_key = image_meta_key or f"{box_ref_image_keys}_{image_meta_key_postfix}"
294298
self.converter_to_image_coordinate = AffineBox()
@@ -306,8 +310,10 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[Ndarray
306310
else:
307311
raise ValueError(f"{meta_key} is not found. Please check whether it is the correct the image meta key.")
308312
if "affine" not in meta_dict:
309-
raise ValueError(f"'affine' is not found in {meta_key}. \
310-
Please check whether it is the correct the image meta key.")
313+
raise ValueError(
314+
f"'affine' is not found in {meta_key}. \
315+
Please check whether it is the correct the image meta key."
316+
)
311317
affine: NdarrayOrTensor = meta_dict["affine"]
312318

313319
if self.affine_lps_to_ras: # RAS affine
@@ -809,12 +815,16 @@ def __init__(
809815
) -> None:
810816
box_keys_tuple = ensure_tuple(box_keys)
811817
if len(box_keys_tuple) != 1:
812-
raise ValueError("Please provide a single key for box_keys.\
813-
All label_keys are attached to this box_keys.")
818+
raise ValueError(
819+
"Please provide a single key for box_keys.\
820+
All label_keys are attached to this box_keys."
821+
)
814822
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
815823
if len(box_ref_image_keys_tuple) != 1:
816-
raise ValueError("Please provide a single key for box_ref_image_keys.\
817-
All box_keys and label_keys are attached to this box_ref_image_keys.")
824+
raise ValueError(
825+
"Please provide a single key for box_ref_image_keys.\
826+
All box_keys and label_keys are attached to this box_ref_image_keys."
827+
)
818828
self.label_keys = ensure_tuple(label_keys)
819829
super().__init__(box_keys_tuple, allow_missing_keys)
820830

@@ -1081,8 +1091,10 @@ def __init__(
10811091

10821092
box_keys_tuple = ensure_tuple(box_keys)
10831093
if len(box_keys_tuple) != 1:
1084-
raise ValueError("Please provide a single key for box_keys.\
1085-
All label_keys are attached to this box_keys.")
1094+
raise ValueError(
1095+
"Please provide a single key for box_keys.\
1096+
All label_keys are attached to this box_keys."
1097+
)
10861098
self.box_keys = box_keys_tuple[0]
10871099
self.label_keys = ensure_tuple(label_keys)
10881100

monai/apps/detection/utils/anchor_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ def __init__(
124124
aspect_ratios = (aspect_ratios,) * len(self.sizes)
125125

126126
if len(self.sizes) != len(aspect_ratios):
127-
raise ValueError("len(sizes) and len(aspect_ratios) should be equal. \
128-
It represents the number of feature maps.")
127+
raise ValueError(
128+
"len(sizes) and len(aspect_ratios) should be equal. \
129+
It represents the number of feature maps."
130+
)
129131

130132
spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1
131133
spatial_dims = look_up_option(spatial_dims, [2, 3])
@@ -170,12 +172,16 @@ def generate_anchors(
170172
scales_t = torch.as_tensor(scales, dtype=dtype, device=device) # sized (N,)
171173
aspect_ratios_t = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) # sized (M,) or (M,2)
172174
if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):
173-
raise ValueError(f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
174-
{len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}.")
175+
raise ValueError(
176+
f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
177+
{len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
178+
)
175179

176180
if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):
177-
raise ValueError(f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
178-
shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}.")
181+
raise ValueError(
182+
f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
183+
shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
184+
)
179185

180186
# if 2d, w:h = 1:aspect_ratios
181187
if self.spatial_dims == 2:

monai/apps/reconstruction/transforms/array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def __init__(
6161
real/imaginary parts.
6262
"""
6363
if len(center_fractions) != len(accelerations):
64-
raise ValueError("Number of center fractions \
65-
should match number of accelerations")
64+
raise ValueError(
65+
"Number of center fractions \
66+
should match number of accelerations"
67+
)
6668

6769
self.center_fractions = center_fractions
6870
self.accelerations = accelerations

monai/auto3dseg/analyzer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
476476
restore_grad_state = torch.is_grad_enabled()
477477
torch.set_grad_enabled(False)
478478

479+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
480+
label_tensor, (MetaTensor, torch.Tensor)
481+
):
482+
if label_tensor.device != image_tensor.device:
483+
label_tensor = label_tensor.to(image_tensor.device)
484+
479485
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
480486
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
481487

monai/bundle/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@
124124
"run_name": None,
125125
# may fill it at runtime
126126
"save_execute_config": True,
127-
"is_not_rank0": ("$torch.distributed.is_available() \
128-
and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"),
127+
"is_not_rank0": (
128+
"$torch.distributed.is_available() \
129+
and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"
130+
),
129131
# MLFlowHandler config for the trainer
130132
"trainer": {
131133
"_target_": "MLFlowHandler",

monai/losses/dice.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
203203
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
204204
else:
205205
if self.class_weight.shape[0] != num_of_classes:
206-
raise ValueError("""the length of the `weight` sequence should be the same as the number of classes.
206+
raise ValueError(
207+
"""the length of the `weight` sequence should be the same as the number of classes.
207208
If `include_background=False`, the weight should not include
208-
the background category class 0.""")
209+
the background category class 0."""
210+
)
209211
if self.class_weight.min() < 0:
210212
raise ValueError("the value/values of the `weight` should be no less than 0.")
211213
# apply class_weight to loss

monai/losses/focal_loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
183183
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
184184
else:
185185
if self.class_weight.shape[0] != num_of_classes:
186-
raise ValueError("""the length of the `weight` sequence should be the same as the number of classes.
186+
raise ValueError(
187+
"""the length of the `weight` sequence should be the same as the number of classes.
187188
If `include_background=False`, the weight should not include
188-
the background category class 0.""")
189+
the background category class 0."""
190+
)
189191
if self.class_weight.min() < 0:
190192
raise ValueError("the value/values of the `weight` should be no less than 0.")
191193
# apply class_weight to loss

requirements-dev.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ mccabe
1414
pep8-naming
1515
pycodestyle
1616
pyflakes
17-
black>=25.1.0
18-
isort>=5.1, !=6.0.0
19-
ruff
17+
black==25.1.0
18+
isort>=5.1, <6, !=6.0.0
19+
ruff>=0.14.11,<0.15
2020
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
2121
types-setuptools
2222
mypy>=1.5.0, <1.12.0

tests/apps/test_auto3dseg.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SqueezeDimd,
5454
ToDeviced,
5555
)
56-
from monai.utils.enums import DataStatsKeys
56+
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
5757
from tests.test_utils import skip_if_no_cuda
5858

5959
device = "cpu"
@@ -78,6 +78,13 @@
7878

7979
SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]
8080

81+
LABEL_STATS_DEVICE_TEST_CASES = [
82+
[{"image_device": "cpu", "label_device": "cpu", "image_meta": False}],
83+
[{"image_device": "cuda", "label_device": "cuda", "image_meta": True}],
84+
[{"image_device": "cpu", "label_device": "cuda", "image_meta": True}],
85+
[{"image_device": "cuda", "label_device": "cpu", "image_meta": False}],
86+
]
87+
8188

8289
def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
8390
"""
@@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self):
360367
report_format = analyzer.get_report_format()
361368
assert verify_report_format(d["label_stats"], report_format)
362369

370+
@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
371+
def test_label_stats_mixed_device_analyzer(self, input_params):
372+
image_device = torch.device(input_params["image_device"])
373+
label_device = torch.device(input_params["label_device"])
374+
375+
if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available():
376+
self.skipTest("CUDA is not available for mixed-device LabelStats tests.")
377+
378+
analyzer = LabelStats(image_key="image", label_key="label")
379+
380+
image_tensor = torch.tensor(
381+
[
382+
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
383+
[[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]],
384+
],
385+
dtype=torch.float32,
386+
).to(image_device)
387+
label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device)
388+
389+
if input_params["image_meta"]:
390+
image_tensor = MetaTensor(image_tensor)
391+
label_tensor = MetaTensor(label_tensor)
392+
393+
result = analyzer({"image": image_tensor, "label": label_tensor})
394+
report = result["label_stats"]
395+
396+
assert verify_report_format(report, analyzer.get_report_format())
397+
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]
398+
399+
label_stats = report[LabelStatsKeys.LABEL]
400+
self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5)
401+
self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5)
402+
403+
label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST]
404+
label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST]
405+
self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25)
406+
self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75)
407+
self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25)
408+
self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75)
409+
410+
foreground_stats = report[LabelStatsKeys.IMAGE_INTST]
411+
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
412+
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)
413+
363414
def test_filename_case_analyzer(self):
364415
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
365416
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)

versioneer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
429429
return stdout, process.returncode
430430

431431

432-
LONG_VERSION_PY["git"] = r'''
432+
LONG_VERSION_PY[
433+
"git"
434+
] = r'''
433435
# This file helps to compute a version number in source trees obtained from
434436
# git-archive tarball (such as those provided by githubs download-from-tag
435437
# feature). Distribution tarballs (built by setup.py sdist) and build

0 commit comments

Comments
 (0)