Skip to content

Commit 43f508b

Browse files
authored
Merge branch 'dev' into dev
2 parents 66af964 + 851054c commit 43f508b

File tree

2 files changed

+110
-35
lines changed

2 files changed

+110
-35
lines changed

monai/auto3dseg/analyzer.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216
super().__init__(stats_name, report_format)
217217
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
218218

219+
@torch.no_grad()
219220
def __call__(self, data):
220-
# Input Validation Addition
221-
if not isinstance(data, dict):
222-
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
223-
if self.image_key not in data:
224-
raise KeyError(f"Key '{self.image_key}' not found in input data.")
225-
image = data[self.image_key]
226-
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
227-
raise TypeError(
228-
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
229-
f"but got {type(image).__name__}."
230-
)
231-
if image.ndim < 3:
232-
raise ValueError(
233-
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
234-
)
235-
# --- End of validation ---
236221
"""
237-
Callable to execute the pre-defined functions
222+
Callable to execute the pre-defined functions.
238223
239224
Returns:
240225
A dictionary. The dict has the key in self.report_format. The value of
241226
ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227
has stats pre-defined by SampleOperations (max, min, ....).
243228
244229
Raises:
245-
RuntimeError if the stats report generated is not consistent with the pre-
230+
KeyError: if ``self.image_key`` is not present in the input data.
231+
TypeError: if the input data is not a dictionary, or if the image value is
232+
not a numpy array, torch.Tensor, or MetaTensor.
233+
ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+
``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+
RuntimeError: if the stats report generated is not consistent with the pre-
246236
defined report_format.
247237
248238
Note:
249239
The stats operation uses numpy and torch to compute max, min, and other
250240
functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242
"""
243+
if not isinstance(data, dict):
244+
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
245+
if self.image_key not in data:
246+
raise KeyError(f"Key '{self.image_key}' not found in input data.")
247+
image = data[self.image_key]
248+
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
249+
raise TypeError(
250+
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
251+
f"but got {type(image).__name__}."
252+
)
253+
if image.ndim < 3:
254+
raise ValueError(
255+
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
256+
)
257+
253258
d = dict(data)
254259
start = time.time()
255-
restore_grad_state = torch.is_grad_enabled()
256-
torch.set_grad_enabled(False)
257-
258260
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
259-
if "nda_croppeds" not in d:
261+
if "nda_croppeds" in d:
262+
nda_croppeds = d["nda_croppeds"]
263+
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
264+
raise ValueError(
265+
"Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+
f"(expected {len(ndas)})."
267+
)
268+
else:
260269
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
261270

262-
# perform calculation
263271
report = deepcopy(self.get_report_format())
264272

265273
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
@@ -284,7 +292,6 @@ def __call__(self, data):
284292

285293
d[self.stats_name] = report
286294

287-
torch.set_grad_enabled(restore_grad_state)
288295
logger.debug(f"Get image stats spent {time.time() - start}")
289296
return d
290297

@@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321328
super().__init__(stats_name, report_format)
322329
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
323330

331+
@torch.no_grad()
324332
def __call__(self, data: Mapping) -> dict:
325333
"""
326334
Callable to execute the pre-defined functions
@@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:
341349

342350
d = dict(data)
343351
start = time.time()
344-
restore_grad_state = torch.is_grad_enabled()
345-
torch.set_grad_enabled(False)
346-
347352
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348353
ndas_label = d[self.label_key] # (H,W,D)
349354

@@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
353358
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354359
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355360

356-
# perform calculation
357361
report = deepcopy(self.get_report_format())
358362

359363
report[ImageStatsKeys.INTENSITY] = [
@@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:
365369

366370
d[self.stats_name] = report
367371

368-
torch.set_grad_enabled(restore_grad_state)
369372
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370373
return d
371374

@@ -418,6 +421,7 @@ def __init__(
418421
id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
419422
self.update_ops_nested_label(id_seq, SampleOperations())
420423

424+
@torch.no_grad()
421425
def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:
422426
"""
423427
Callable to execute the pre-defined functions.
@@ -470,19 +474,15 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470474
start = time.time()
471475
image_tensor = d[self.image_key]
472476
label_tensor = d[self.label_key]
473-
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474477
using_cuda = any(
475478
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
476479
)
477-
restore_grad_state = torch.is_grad_enabled()
478-
torch.set_grad_enabled(False)
479480

480481
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
481482
label_tensor, (MetaTensor, torch.Tensor)
482483
):
483484
if label_tensor.device != image_tensor.device:
484485
if using_cuda:
485-
# Move both tensors to CUDA when mixing devices
486486
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487487
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488488
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
@@ -548,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
548548

549549
d[self.stats_name] = report # type: ignore[assignment]
550550

551-
torch.set_grad_enabled(restore_grad_state)
552551
logger.debug(f"Get label stats spent {time.time() - start}")
553552
return d # type: ignore[return-value]
554553

tests/apps/test_auto3dseg.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SqueezeDimd,
5454
ToDeviced,
5555
)
56-
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
56+
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
5757
from tests.test_utils import skip_if_no_cuda
5858

5959
device = "cpu"
@@ -322,6 +322,47 @@ def test_image_stats_case_analyzer(self):
322322
report_format = analyzer.get_report_format()
323323
assert verify_report_format(d["image_stats"], report_format)
324324

325+
def test_image_stats_uses_precomputed_nda_croppeds(self):
326+
"""Verify ImageStats uses valid pre-computed foreground crops."""
327+
analyzer = ImageStats(image_key="image")
328+
image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4)
329+
nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)]
330+
331+
result = analyzer({"image": image, "nda_croppeds": nda_croppeds})
332+
report = result["image_stats"]
333+
334+
assert verify_report_format(report, analyzer.get_report_format())
335+
assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]]
336+
self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0)
337+
338+
def test_image_stats_validates_precomputed_nda_croppeds(self):
339+
"""Verify ImageStats rejects malformed pre-computed foreground crops."""
340+
analyzer = ImageStats(image_key="image")
341+
image = torch.ones((2, 4, 4, 4), dtype=torch.float32)
342+
invalid_cases = [
343+
("wrong_type", torch.ones((2, 2, 2), dtype=torch.float32)),
344+
("wrong_length", [torch.ones((2, 2, 2), dtype=torch.float32)]),
345+
]
346+
347+
for name, nda_croppeds in invalid_cases:
348+
with self.subTest(case=name):
349+
with self.assertRaisesRegex(ValueError, "one entry per image channel"):
350+
analyzer({"image": image, "nda_croppeds": nda_croppeds})
351+
352+
def test_image_stats_preserves_grad_state_after_call(self):
353+
"""Verify ImageStats preserves caller grad state on successful execution."""
354+
analyzer = ImageStats(image_key="image")
355+
data = {"image": MetaTensor(torch.rand(1, 10, 10, 10))}
356+
original_grad_state = torch.is_grad_enabled()
357+
try:
358+
for grad_enabled in (True, False):
359+
with self.subTest(grad_enabled=grad_enabled):
360+
torch.set_grad_enabled(grad_enabled)
361+
analyzer(data)
362+
self.assertEqual(torch.is_grad_enabled(), grad_enabled)
363+
finally:
364+
torch.set_grad_enabled(original_grad_state)
365+
325366
def test_foreground_image_stats_cases_analyzer(self):
326367
analyzer = FgImageStats(image_key="image", label_key="label")
327368
transform_list = [
@@ -412,6 +453,41 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
412453
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
413454
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)
414455

456+
def test_case_analyzers_restore_grad_state_on_exception(self):
457+
"""Verify analyzer calls restore caller grad state after exceptions."""
458+
cases = [
459+
(
460+
"image_stats",
461+
ImageStats(image_key="image"),
462+
{"image": torch.randn(2, 4, 4, 4), "nda_croppeds": [torch.ones((2, 2, 2))]},
463+
ValueError,
464+
),
465+
(
466+
"fg_image_stats",
467+
FgImageStats(image_key="image", label_key="label"),
468+
{"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)},
469+
ValueError,
470+
),
471+
(
472+
"label_stats",
473+
LabelStats(image_key="image", label_key="label"),
474+
{"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))},
475+
ValueError,
476+
),
477+
]
478+
479+
original_grad_state = torch.is_grad_enabled()
480+
try:
481+
for name, analyzer, data, error in cases:
482+
for grad_enabled in (True, False):
483+
with self.subTest(analyzer=name, grad_enabled=grad_enabled):
484+
torch.set_grad_enabled(grad_enabled)
485+
with self.assertRaises(error):
486+
analyzer(data)
487+
self.assertEqual(torch.is_grad_enabled(), grad_enabled)
488+
finally:
489+
torch.set_grad_enabled(original_grad_state)
490+
415491
def test_filename_case_analyzer(self):
416492
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
417493
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)

0 commit comments

Comments
 (0)