Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,16 @@ class DiNTS(nn.Module):
The architecture codes will be initialized as one.
- ``TopologyConstruction`` is the parent class which constructs the instance and search space.

To meet the requirements of the structure, the input size for each spatial dimension should be:
divisible by 2 ** (num_depths + 1).
Spatial Shape Constraints:
Each spatial dimension of the input must be divisible by ``2 ** (num_depths + int(use_downsample))``.

- With ``use_downsample=True`` (default) and ``num_depths=3`` (default): divisible by ``2 ** 4 = 16``.
- With ``use_downsample=False`` and ``num_depths=3``: divisible by ``2 ** 3 = 8``.

This requirement arises from the multi-resolution stem downsampling the input ``num_depths`` times
(each by a factor of 2), plus one additional factor of 2 when ``use_downsample=True``.

A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.

Args:
dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.
Expand All @@ -346,6 +354,7 @@ class DiNTS(nn.Module):
use_downsample: use downsample in the stem.
If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],
if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
Affects the input size divisibility requirement: ``2 ** (num_depths + int(use_downsample))``.
node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.
+1 for multi-resolution inputs.
In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.
Expand Down Expand Up @@ -481,13 +490,40 @@ def __init__(
def weight_parameters(self):
return [param for name, param in self.named_parameters()]

@torch.jit.unused
def _check_input_size(self, spatial_shape):
"""
Validate that input spatial dimensions satisfy the divisibility requirement.

Each spatial dimension must be divisible by ``2 ** (num_depths + int(use_downsample))``.

Args:
spatial_shape: spatial dimensions of the input tensor (excluding batch and channel dims).

Raises:
ValueError: if any spatial dimension is not divisible by the required factor.
"""
factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0]
if wrong_dims:
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by 2 ** (num_depths + int(use_downsample)) = {factor}."
)

def forward(self, x: torch.Tensor):
"""
Prediction based on dynamic arch_code.

Args:
x: input tensor.

Raises:
ValueError: if any spatial dimension of ``x`` is not divisible by
``2 ** (num_depths + int(use_downsample))``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
self._check_input_size(x.shape[2:])
Comment on lines +531 to +532
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if the 3D cuda test case input shapes are compatible with new validation
rg -n "num_depths.*4" tests/networks/nets/test_dints_network.py -A 20 | head -60

Repository: Project-MONAI/MONAI

Length of output: 1924


🏁 Script executed:

# Find _check_input_size method in dints.py
rg -n "_check_input_size" monai/networks/nets/dints.py -A 10

Repository: Project-MONAI/MONAI

Length of output: 1042


🏁 Script executed:

# Get full _check_input_size implementation
sed -n '500,520p' monai/networks/nets/dints.py

Repository: Project-MONAI/MONAI

Length of output: 993


🏁 Script executed:

# Check if the guard (is_scripting/is_tracing) was in the original code
# Look at the full forward method to see if guard is new
sed -n '520,535p' monai/networks/nets/dints.py

Repository: Project-MONAI/MONAI

Length of output: 597


🏁 Script executed:

# Check how TEST_CASES_3D is used in the test file
rg -n "TEST_CASES_3D" tests/networks/nets/test_dints_network.py -A 5 -B 2

Repository: Project-MONAI/MONAI

Length of output: 1833


🏁 Script executed:

# Check if test case dimensions have any comments explaining them
sed -n '46,68p' tests/networks/nets/test_dints_network.py

Repository: Project-MONAI/MONAI

Length of output: 692


Test case dimensions incompatible with new validation. The cuda case (lines 47-67 in test file) uses num_depths=4, use_downsample=True with shape (3, 2, 32, 32, 16). The validation requires all spatial dims divisible by 2^(4+1)=32, but 16 % 32 ≠ 0. During normal test execution, the guard does not skip the check, so test_dints_inference and test_dints_search will fail with ValueError. Update the spatial dimension to 32 or replace with a compatible test case.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/nets/dints.py` around lines 531 - 532, The test uses
num_depths=4 and use_downsample=True which triggers Dints._check_input_size
(called in monai.networks.nets.dints when not scripting/tracing) and requires
spatial dims divisible by 2^(num_depths+1)=32; update the failing CUDA test
cases (test_dints_inference and test_dints_search) to provide compatible spatial
dimensions (e.g., change input shape from (3, 2, 32, 32, 16) to (3, 2, 32, 32,
32)) or alternatively reduce num_depths/use_downsample so that the existing
spatial size is valid.

inputs = []
for d in range(self.num_depths):
# allow multi-resolution input
Expand Down
22 changes: 20 additions & 2 deletions tests/networks/nets/test_dints_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"use_downsample": True,
"spatial_dims": 2,
},
(2, 2, 32, 16),
(2, 2, 32, 16),
(2, 2, 32, 32), # use_downsample=True, num_depths=4 -> factor=32; both dims must be divisible by 32
(2, 2, 32, 32),
]
]
if torch.cuda.is_available():
Expand Down Expand Up @@ -153,6 +153,24 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect
self.assertTrue(isinstance(net.weight_parameters(), list))


class TestDintsInputShape(unittest.TestCase):
def test_invalid_input_shape_3d(self):
# num_depths=3, use_downsample=True -> factor = 2**(3+1) = 16
# 33 is not divisible by 16
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=True, spatial_dims=3)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=True, spatial_dims=3)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32, 32))

def test_invalid_input_shape_2d(self):
# num_depths=3, use_downsample=False -> factor = 2**(3+0) = 8
# 33 is not divisible by 8
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=False, spatial_dims=2)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=False, spatial_dims=2)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32))


class TestDintsTS(unittest.TestCase):
@parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)
def test_script(self, dints_grid_params, dints_params, input_shape, _):
Expand Down
Loading