1313
1414import functools
1515import warnings
16- from collections .abc import Sequence
16+ from collections .abc import Mapping , Sequence
1717from copy import deepcopy
1818from typing import Any
1919
2222
2323import monai
2424from monai .config .type_definitions import NdarrayOrTensor , NdarrayTensor
25- from monai .data .meta_obj import MetaObj , get_track_meta
26- from monai .data .utils import affine_to_spacing , decollate_batch , list_data_collate , remove_extra_metadata
25+ from monai .data .meta_obj import _DEFAULT_SPATIAL_NDIM , MetaObj , get_track_meta
26+ from monai .data .utils import affine_to_spacing , decollate_batch , is_no_channel , list_data_collate , remove_extra_metadata
2727from monai .utils import look_up_option
2828from monai .utils .enums import LazyAttr , MetaKeys , PostFix , SpaceKeys
2929from monai .utils .type_conversion import convert_data_type , convert_to_dst_type , convert_to_numpy , convert_to_tensor
3030
3131__all__ = ["MetaTensor" , "get_spatial_ndim" ]
3232
3333
34- def _normalize_spatial_ndim (spatial_ndim : int , tensor_ndim : int ) -> int :
34+ def _normalize_spatial_ndim (spatial_ndim : int , tensor_ndim : int , no_channel : bool = False ) -> int :
3535 """Clamp spatial dims to a valid range for the current tensor shape."""
36- return max (1 , min (int (spatial_ndim ), max (int (tensor_ndim ) - 1 , 1 )))
36+ limit = max (int (tensor_ndim ), 1 ) if no_channel else max (int (tensor_ndim ) - 1 , 1 )
37+ return max (1 , min (int (spatial_ndim ), limit ))
38+
39+
40+ def _has_explicit_no_channel (meta : Mapping | None ) -> bool :
41+ return (
42+ isinstance (meta , Mapping )
43+ and MetaKeys .ORIGINAL_CHANNEL_DIM in meta
44+ and is_no_channel (meta [MetaKeys .ORIGINAL_CHANNEL_DIM ])
45+ )
3746
3847
3948def get_spatial_ndim (img : NdarrayOrTensor ) -> int :
@@ -43,16 +52,22 @@ def get_spatial_ndim(img: NdarrayOrTensor) -> int:
4352 ``img.ndim - 1``.
4453 """
4554 if isinstance (img , MetaTensor ):
46- inferred = _normalize_spatial_ndim (img .spatial_ndim , img .ndim )
47- shape_spatial = max (img .ndim - 1 , 1 )
48- # For non-batched tensors, preserve explicit higher-rank shape information
49- # (e.g., invalid 4D spatial inputs should still be reported as rank 4).
50- if not img .is_batch and shape_spatial > inferred :
51- return shape_spatial
52- return inferred
55+ no_channel = _has_explicit_no_channel (img .meta )
56+ return _normalize_spatial_ndim (img .spatial_ndim , img .ndim , no_channel = no_channel )
5357 return img .ndim - 1
5458
5559
60+ def _is_batch_only_index (index : Any ) -> bool :
61+ """True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``)."""
62+ if isinstance (index , (int , np .integer )):
63+ return True
64+ if not isinstance (index , Sequence ) or not index :
65+ return False
66+ if not isinstance (index [0 ], (int , np .integer )):
67+ return False
68+ return all (i in (slice (None , None , None ), Ellipsis , None ) for i in index [1 :])
69+
70+
5671@functools .lru_cache (None )
5772def _get_named_tuple_like_type (func ):
5873 if (
@@ -184,11 +199,13 @@ def __init__(
184199 self .affine = self .meta [MetaKeys .AFFINE ]
185200 else :
186201 self .affine = self .get_default_affine ()
187- # derive spatial_ndim from affine, clamped by tensor shape
202+ # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape.
203+ # This cached value is kept in sync via the affine setter for hot-path performance.
204+ no_channel = _has_explicit_no_channel (self .meta )
188205 if spatial_ndim is not None :
189- self .spatial_ndim = _normalize_spatial_ndim (spatial_ndim , self .ndim )
206+ self .spatial_ndim = _normalize_spatial_ndim (spatial_ndim , self .ndim , no_channel = no_channel )
190207 elif self .affine .ndim == 2 :
191- self .spatial_ndim = _normalize_spatial_ndim (self .affine .shape [- 1 ] - 1 , self .ndim )
208+ self .spatial_ndim = _normalize_spatial_ndim (self .affine .shape [- 1 ] - 1 , self .ndim , no_channel = no_channel )
192209
193210 # applied_operations
194211 if applied_operations is not None :
@@ -254,8 +271,6 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
254271 # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
255272 if is_batch :
256273 ret = MetaTensor ._handle_batched (ret , idx , metas , func , args , kwargs )
257- if func == torch .Tensor .__getitem__ :
258- ret .spatial_ndim = _normalize_spatial_ndim (ret .spatial_ndim , ret .ndim )
259274 out .append (ret )
260275 # if the input was a tuple, then return it as a tuple
261276 return tuple (out ) if isinstance (rets , tuple ) else out
@@ -271,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
271286 if func == torch .Tensor .__getitem__ :
272287 if idx > 0 or len (args ) < 2 or len (args [0 ]) < 1 :
273288 return ret
289+ full_idx = args [1 ]
274290 batch_idx = args [1 ][0 ] if isinstance (args [1 ], Sequence ) else args [1 ]
275291 # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
276292 # first element will be `slice(None, None, None)` and `Ellipsis`,
@@ -292,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
292308 ret_meta .is_batch = False
293309 if hasattr (ret_meta , "__dict__" ):
294310 ret .__dict__ = ret_meta .__dict__ .copy ()
311+ if _is_batch_only_index (full_idx ):
312+ ret .spatial_ndim = _normalize_spatial_ndim (ret .spatial_ndim , ret .ndim , no_channel = False )
295313 # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
296314 # But we only want to split the batch if the `unbind` is along the 0th dimension.
297315 elif func == torch .Tensor .unbind :
@@ -501,16 +519,26 @@ def affine(self) -> torch.Tensor:
501519
502520 @affine .setter
503521 def affine (self , d : NdarrayTensor ) -> None :
504- """Set the affine."""
522+ """Set the affine.
523+
524+ When setting a non-batched affine matrix, automatically synchronizes the cached
525+ spatial_ndim attribute to maintain consistency between the affine matrix (source of truth)
526+ and the cached spatial dimension count.
527+ """
505528 a = torch .as_tensor (d , device = torch .device ("cpu" ), dtype = torch .float64 )
506529 self .meta [MetaKeys .AFFINE ] = a
507- if a .ndim == 2 : # non-batched: sync spatial_ndim
508- self .spatial_ndim = _normalize_spatial_ndim (a .shape [- 1 ] - 1 , self .ndim )
530+ if a .ndim == 2 : # non-batched: sync spatial_ndim from affine (source of truth)
531+ no_channel = _has_explicit_no_channel (self .meta )
532+ self .spatial_ndim = _normalize_spatial_ndim (a .shape [- 1 ] - 1 , self .ndim , no_channel = no_channel )
509533
510534 @property
511535 def spatial_ndim (self ) -> int :
512- """Get the number of spatial dimensions."""
513- return getattr (self , "_spatial_ndim" , 3 )
536+ """Get the number of spatial dimensions.
537+
538+ This value is cached for hot-path performance and is kept in sync with the affine matrix
539+ via the affine setter. The affine matrix is the source of truth for spatial dimensions.
540+ """
541+ return getattr (self , "_spatial_ndim" , _DEFAULT_SPATIAL_NDIM )
514542
515543 @spatial_ndim .setter
516544 def spatial_ndim (self , val : int ) -> None :
0 commit comments