Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,16 @@ class DiNTS(nn.Module):
The architecture codes will be initialized as one.
- ``TopologyConstruction`` is the parent class which constructs the instance and search space.

To meet the requirements of the structure, the input size for each spatial dimension should be:
divisible by 2 ** (num_depths + 1).
Spatial Shape Constraints:
Each spatial dimension of the input must be divisible by ``2 ** (num_depths + int(use_downsample))``.

- With ``use_downsample=True`` (default) and ``num_depths=3`` (default): divisible by ``2 ** 4 = 16``.
- With ``use_downsample=False`` and ``num_depths=3``: divisible by ``2 ** 3 = 8``.

This requirement arises from the multi-resolution stem downsampling the input ``num_depths`` times
(each by a factor of 2), plus one additional factor of 2 when ``use_downsample=True``.

A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.

Args:
dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.
Expand All @@ -346,6 +354,7 @@ class DiNTS(nn.Module):
use_downsample: use downsample in the stem.
If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],
if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
Affects the input size divisibility requirement: ``2 ** (num_depths + int(use_downsample))``.
node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.
+1 for multi-resolution inputs.
In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.
Expand Down Expand Up @@ -481,13 +490,38 @@ def __init__(
def weight_parameters(self):
return [param for name, param in self.named_parameters()]

def _check_input_size(self, spatial_shape):
"""
Validate that input spatial dimensions satisfy the divisibility requirement.

Each spatial dimension must be divisible by ``2 ** (num_depths + int(use_downsample))``.

Args:
spatial_shape: spatial dimensions of the input tensor (excluding batch and channel dims).

Raises:
ValueError: if any spatial dimension is not divisible by the required factor.
"""
factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0]
if wrong_dims:
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by 2 ** (num_depths + int(use_downsample)) = {factor}."
)

def forward(self, x: torch.Tensor):
"""
Prediction based on dynamic arch_code.

Args:
x: input tensor.

Raises:
ValueError: if any spatial dimension of ``x`` is not divisible by
``2 ** (num_depths + int(use_downsample))``.
"""
self._check_input_size(x.shape[2:])
inputs = []
for d in range(self.num_depths):
# allow multi-resolution input
Expand Down
18 changes: 18 additions & 0 deletions tests/networks/nets/test_dints_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect
self.assertTrue(isinstance(net.weight_parameters(), list))


class TestDintsInputShape(unittest.TestCase):
def test_invalid_input_shape_3d(self):
# num_depths=3, use_downsample=True -> factor = 2**(3+1) = 16
# 33 is not divisible by 16
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=True, spatial_dims=3)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=True, spatial_dims=3)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32, 32))

def test_invalid_input_shape_2d(self):
# num_depths=3, use_downsample=False -> factor = 2**(3+0) = 8
# 33 is not divisible by 8
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=False, spatial_dims=2)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=False, spatial_dims=2)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32))


class TestDintsTS(unittest.TestCase):
@parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)
def test_script(self, dints_grid_params, dints_params, input_shape, _):
Expand Down
Loading