Skip to content
54 changes: 29 additions & 25 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,36 +255,40 @@ def __call__(self, data):
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)
try:
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
else:
nda_croppeds = d["nda_croppeds"]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)

report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]
report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report
finally:
torch.set_grad_enabled(restore_grad_state)

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time() - start}")
return d

Expand Down
30 changes: 30 additions & 0 deletions tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,36 @@ def test_seg_summarizer(self):
assert str(DataStatsKeys.FG_IMAGE_STATS) in report
assert str(DataStatsKeys.LABEL_STATS) in report

def test_image_stats_precomputed_nda_croppeds(self):
# Verify that ImageStats does not crash when nda_croppeds is pre-populated in the dict.
# Previously this caused UnboundLocalError because the variable was only assigned in
# the else branch but used unconditionally.
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
precomputed = [np.random.rand(8, 8, 8)] # simulated pre-cropped foreground
data = {"image": MetaTensor(image), "nda_croppeds": precomputed}
result = analyzer(data)
assert "image_stats" in result
assert verify_report_format(result["image_stats"], analyzer.get_report_format())

def test_analyzer_grad_state_restored_after_call(self):
# Verify that ImageStats.__call__ always restores the grad-enabled state it found
# on entry, regardless of which state that was.
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
data = {"image": MetaTensor(image)}

# grad enabled before call → must still be enabled after
torch.set_grad_enabled(True)
analyzer(data)
assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call"

# grad disabled before call → must still be disabled after
torch.set_grad_enabled(False)
analyzer(data)
assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
torch.set_grad_enabled(True) # restore for subsequent tests
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make grad-state cleanup exception-safe in the test.

The test mutates global grad mode; restore the original state in finally to prevent cross-test leakage on failures.

Proposed fix
     def test_analyzer_grad_state_restored_after_call(self):
         # Verify that ImageStats.__call__ always restores the grad-enabled state it found
         # on entry, regardless of which state that was.
         analyzer = ImageStats(image_key="image")
         image = torch.rand(1, 10, 10, 10)
         data = {"image": MetaTensor(image)}
-
-        # grad enabled before call → must still be enabled after
-        torch.set_grad_enabled(True)
-        analyzer(data)
-        assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
-
-        # grad disabled before call → must still be disabled after
-        torch.set_grad_enabled(False)
-        analyzer(data)
-        assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
-        torch.set_grad_enabled(True)  # restore for subsequent tests
+        original_grad_state = torch.is_grad_enabled()
+        try:
+            # grad enabled before call → must still be enabled after
+            torch.set_grad_enabled(True)
+            analyzer(data)
+            assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
+
+            # grad disabled before call → must still be disabled after
+            torch.set_grad_enabled(False)
+            analyzer(data)
+            assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
+        finally:
+            torch.set_grad_enabled(original_grad_state)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/apps/test_auto3dseg.py` around lines 554 - 570, The test
test_analyzer_grad_state_restored_after_call currently mutates global torch grad
mode without guaranteeing restoration on exceptions; wrap the two analyzer(data)
calls in a try/finally: capture the original state with orig =
torch.is_grad_enabled(), set the required state for each subcase with
torch.set_grad_enabled(True/False), call analyzer(data), assert the state, and
in the finally restore torch.set_grad_enabled(orig) so the global grad mode is
always returned (update/remove the trailing torch.set_grad_enabled(True) in
favor of the finally restore).


def tearDown(self) -> None:
self.test_dir.cleanup()

Expand Down
Loading