-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add explicit spatial_ndim tracking to MetaTensor #8765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aymuos15
wants to merge
8
commits into
Project-MONAI:dev
Choose a base branch
from
aymuos15:fix/metatensor-einops
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
74e7ca4
Add explicit spatial_ndim tracking to MetaTensor (Fixes #6397)
aymuos15 c52a149
address coderabbit
aymuos15 ea915cb
ci: retrigger CI checks
aymuos15 e50ae41
Fix 2D inverse transform shape mismatch (4x4 vs 3x3)
aymuos15 787eef4
Fix spatial_ndim drift for sliced MetaTensor 2D paths
aymuos15 9f359b0
Fix MetaTensor spatial_ndim propagation regressions
aymuos15 dc16dda
Fix 2D inverse transform failures by removing no_channel from spatial…
aymuos15 ab2be3a
Fix get_spatial_ndim for 2D post-EnsureChannelFirst without regressin…
aymuos15 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,22 +13,59 @@ | |
|
|
||
| import functools | ||
| import warnings | ||
| from collections.abc import Sequence | ||
| from collections.abc import Mapping, Sequence | ||
| from copy import deepcopy | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| import monai | ||
| from monai.config.type_definitions import NdarrayTensor | ||
| from monai.data.meta_obj import MetaObj, get_track_meta | ||
| from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata | ||
| from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor | ||
| from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta | ||
| from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata | ||
| from monai.utils import look_up_option | ||
| from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys | ||
| from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor | ||
|
|
||
| __all__ = ["MetaTensor"] | ||
| __all__ = ["MetaTensor", "get_spatial_ndim"] | ||
|
|
||
|
|
||
| def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: | ||
| """Clamp spatial dims to a valid range for the current tensor shape.""" | ||
| limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) | ||
| return max(1, min(int(spatial_ndim), limit)) | ||
|
|
||
|
|
||
| def _has_explicit_no_channel(meta: Mapping | None) -> bool: | ||
| return ( | ||
| isinstance(meta, Mapping) | ||
| and MetaKeys.ORIGINAL_CHANNEL_DIM in meta | ||
| and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) | ||
| ) | ||
|
|
||
|
|
||
| def get_spatial_ndim(img: NdarrayOrTensor) -> int: | ||
| """Return the number of spatial dimensions assuming channel-first layout. | ||
|
|
||
| Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to | ||
| ``img.ndim - 1``. | ||
| """ | ||
| if isinstance(img, MetaTensor): | ||
| no_channel = _has_explicit_no_channel(img.meta) | ||
| return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel) | ||
| return img.ndim - 1 | ||
|
|
||
|
|
||
| def _is_batch_only_index(index: Any) -> bool: | ||
| """True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``).""" | ||
| if isinstance(index, (int, np.integer)): | ||
| return True | ||
| if not isinstance(index, Sequence) or not index: | ||
| return False | ||
| if not isinstance(index[0], (int, np.integer)): | ||
| return False | ||
| return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:]) | ||
|
|
||
|
|
||
| @functools.lru_cache(None) | ||
|
|
@@ -111,6 +148,7 @@ def __new__( | |
| meta: dict | None = None, | ||
| applied_operations: list | None = None, | ||
| *args, | ||
| spatial_ndim: int | None = None, | ||
| **kwargs, | ||
| ) -> MetaTensor: | ||
| _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {} | ||
|
|
@@ -123,6 +161,7 @@ def __init__( | |
| meta: dict | None = None, | ||
| applied_operations: list | None = None, | ||
| *_args, | ||
| spatial_ndim: int | None = None, | ||
| **_kwargs, | ||
| ) -> None: | ||
| """ | ||
|
|
@@ -134,6 +173,8 @@ def __init__( | |
| the list is typically maintained by `monai.transforms.TraceableTransform`. | ||
| See also: :py:class:`monai.transforms.TraceableTransform` | ||
| _args: additional args (currently not in use in this constructor). | ||
| spatial_ndim: optional number of spatial dimensions. If ``None``, derived | ||
| from the affine matrix clamped by the tensor shape. | ||
| _kwargs: additional kwargs (currently not in use in this constructor). | ||
|
|
||
| Note: | ||
|
|
@@ -158,6 +199,14 @@ def __init__( | |
| self.affine = self.meta[MetaKeys.AFFINE] | ||
| else: | ||
| self.affine = self.get_default_affine() | ||
| # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. | ||
| # This cached value is kept in sync via the affine setter for hot-path performance. | ||
| no_channel = _has_explicit_no_channel(self.meta) | ||
| if spatial_ndim is not None: | ||
| self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) | ||
| elif self.affine.ndim == 2: | ||
| self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) | ||
|
|
||
| # applied_operations | ||
| if applied_operations is not None: | ||
| self.applied_operations = applied_operations | ||
|
|
@@ -237,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): | |
| if func == torch.Tensor.__getitem__: | ||
| if idx > 0 or len(args) < 2 or len(args[0]) < 1: | ||
| return ret | ||
| full_idx = args[1] | ||
| batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] | ||
| # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the | ||
| # first element will be `slice(None, None, None)` and `Ellipsis`, | ||
|
|
@@ -258,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): | |
| ret_meta.is_batch = False | ||
| if hasattr(ret_meta, "__dict__"): | ||
| ret.__dict__ = ret_meta.__dict__.copy() | ||
| if _is_batch_only_index(full_idx): | ||
| ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False) | ||
| # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. | ||
| # But we only want to split the batch if the `unbind` is along the 0th dimension. | ||
| elif func == torch.Tensor.unbind: | ||
|
|
@@ -467,15 +519,40 @@ def affine(self) -> torch.Tensor: | |
|
|
||
| @affine.setter | ||
| def affine(self, d: NdarrayTensor) -> None: | ||
| """Set the affine.""" | ||
| self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) | ||
| """Set the affine. | ||
|
|
||
| When setting a non-batched affine matrix, automatically synchronizes the cached | ||
| spatial_ndim attribute to maintain consistency between the affine matrix (source of truth) | ||
| and the cached spatial dimension count. | ||
| """ | ||
| a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) | ||
| self.meta[MetaKeys.AFFINE] = a | ||
| if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) | ||
| no_channel = _has_explicit_no_channel(self.meta) | ||
| self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) | ||
|
|
||
| @property | ||
| def spatial_ndim(self) -> int: | ||
| """Get the number of spatial dimensions. | ||
|
|
||
| This value is cached for hot-path performance and is kept in sync with the affine matrix | ||
| via the affine setter. The affine matrix is the source of truth for spatial dimensions. | ||
| """ | ||
| return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM) | ||
|
|
||
| @spatial_ndim.setter | ||
| def spatial_ndim(self, val: int) -> None: | ||
| """Set the number of spatial dimensions.""" | ||
| if val < 1: | ||
| raise ValueError(f"spatial_ndim must be >= 1, got {val}") | ||
| self._spatial_ndim = val | ||
|
Comment on lines
+544
to
+548
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Harden Line 518 only checks Proposed fix+from numbers import Integral
...
`@spatial_ndim.setter`
def spatial_ndim(self, val: int) -> None:
"""Set the number of spatial dimensions."""
+ if not isinstance(val, Integral):
+ raise TypeError(f"spatial_ndim must be an integer, got {type(val).__name__}")
if val < 1:
raise ValueError(f"spatial_ndim must be >= 1, got {val}")
- self._spatial_ndim = val
+ self._spatial_ndim = _normalize_spatial_ndim(int(val), self.ndim)🧰 Tools🪛 Ruff (0.15.2)[warning] 519-519: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| @property | ||
| def pixdim(self): | ||
| """Get the spacing""" | ||
| if self.is_batch: | ||
| return [affine_to_spacing(a) for a in self.affine] | ||
| return affine_to_spacing(self.affine) | ||
| return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine] | ||
| return affine_to_spacing(self.affine, r=self.spatial_ndim) | ||
|
|
||
| def peek_pending_shape(self): | ||
| """ | ||
|
|
@@ -490,7 +567,7 @@ def peek_pending_shape(self): | |
|
|
||
| def peek_pending_affine(self): | ||
| res = self.affine | ||
| r = len(res) - 1 | ||
| r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim | ||
| if r not in (2, 3): | ||
| warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") | ||
| for p in self.pending_operations: | ||
|
|
@@ -503,8 +580,10 @@ def peek_pending_affine(self): | |
| return res | ||
|
|
||
| def peek_pending_rank(self): | ||
| a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine | ||
| return 1 if a is None else int(max(1, len(a) - 1)) | ||
| if self.pending_operations: | ||
| a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) | ||
| return 1 if a is None else int(max(1, len(a) - 1)) | ||
| return self.spatial_ndim | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override] | ||
| """ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.