-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add spatial shape constraint docs and validation for DiNTS (#6771) #8827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
88ce791
c0cc47a
3420a85
214ab53
973662b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,8 +333,16 @@ class DiNTS(nn.Module): | |
| The architecture codes will be initialized as one. | ||
| - ``TopologyConstruction`` is the parent class which constructs the instance and search space. | ||
|
|
||
| To meet the requirements of the structure, the input size for each spatial dimension should be: | ||
| divisible by 2 ** (num_depths + 1). | ||
| Spatial Shape Constraints: | ||
| Each spatial dimension of the input must be divisible by ``2 ** (num_depths + int(use_downsample))``. | ||
|
|
||
| - With ``use_downsample=True`` (default) and ``num_depths=3`` (default): divisible by ``2 ** 4 = 16``. | ||
| - With ``use_downsample=False`` and ``num_depths=3``: divisible by ``2 ** 3 = 8``. | ||
|
|
||
| This requirement arises from the multi-resolution stem downsampling the input ``num_depths`` times | ||
| (each by a factor of 2), plus one additional factor of 2 when ``use_downsample=True``. | ||
|
|
||
| A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint. | ||
|
|
||
| Args: | ||
| dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`. | ||
|
|
@@ -346,6 +354,7 @@ class DiNTS(nn.Module): | |
| use_downsample: use downsample in the stem. | ||
| If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8], | ||
| if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. | ||
| Affects the input size divisibility requirement: ``2 ** (num_depths + int(use_downsample))``. | ||
| node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`. | ||
| +1 for multi-resolution inputs. | ||
| In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None. | ||
|
|
@@ -481,13 +490,40 @@ def __init__( | |
| def weight_parameters(self): | ||
| return [param for name, param in self.named_parameters()] | ||
|
|
||
| @torch.jit.unused | ||
| def _check_input_size(self, spatial_shape): | ||
| """ | ||
| Validate that input spatial dimensions satisfy the divisibility requirement. | ||
|
|
||
| Each spatial dimension must be divisible by ``2 ** (num_depths + int(use_downsample))``. | ||
|
|
||
| Args: | ||
| spatial_shape: spatial dimensions of the input tensor (excluding batch and channel dims). | ||
|
|
||
| Raises: | ||
| ValueError: if any spatial dimension is not divisible by the required factor. | ||
| """ | ||
| factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample)) | ||
| wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0] | ||
| if wrong_dims: | ||
| raise ValueError( | ||
| f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})" | ||
| f" must be divisible by 2 ** (num_depths + int(use_downsample)) = {factor}." | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| """ | ||
| Prediction based on dynamic arch_code. | ||
|
|
||
| Args: | ||
| x: input tensor. | ||
|
|
||
| Raises: | ||
| ValueError: if any spatial dimension of ``x`` is not divisible by | ||
| ``2 ** (num_depths + int(use_downsample))``. | ||
| """ | ||
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
| self._check_input_size(x.shape[2:]) | ||
|
Comment on lines
+531
to
+532
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check if the 3D cuda test case input shapes are compatible with new validation
rg -n "num_depths.*4" tests/networks/nets/test_dints_network.py -A 20 | head -60Repository: Project-MONAI/MONAI Length of output: 1924 🏁 Script executed: # Find _check_input_size method in dints.py
rg -n "_check_input_size" monai/networks/nets/dints.py -A 10Repository: Project-MONAI/MONAI Length of output: 1042 🏁 Script executed: # Get full _check_input_size implementation
sed -n '500,520p' monai/networks/nets/dints.pyRepository: Project-MONAI/MONAI Length of output: 993 🏁 Script executed: # Check if the guard (is_scripting/is_tracing) was in the original code
# Look at the full forward method to see if guard is new
sed -n '520,535p' monai/networks/nets/dints.pyRepository: Project-MONAI/MONAI Length of output: 597 🏁 Script executed: # Check how TEST_CASES_3D is used in the test file
rg -n "TEST_CASES_3D" tests/networks/nets/test_dints_network.py -A 5 -B 2Repository: Project-MONAI/MONAI Length of output: 1833 🏁 Script executed: # Check if test case dimensions have any comments explaining them
sed -n '46,68p' tests/networks/nets/test_dints_network.pyRepository: Project-MONAI/MONAI Length of output: 692 Test case dimensions incompatible with new validation. The cuda case (lines 47-67 in test file) uses 🤖 Prompt for AI Agents |
||
| inputs = [] | ||
| for d in range(self.num_depths): | ||
| # allow multi-resolution input | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.