Skip to content

Commit 9f359b0

Browse files
committed
Fix MetaTensor spatial_ndim propagation regressions
- Clamp spatial_ndim only for true batch-only indexing - Handle explicit no-channel metadata when normalizing rank - Remove SplitDim double-decrement after affine sync - Align batch-slice tests with batched MetaTensor metadata - Extract DEFAULT_SPATIAL_NDIM constant to eliminate magic numbers - Add documentation explaining spatial_ndim caching and affine sync Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 787eef4 commit 9f359b0

File tree

7 files changed

+75
-28
lines changed

7 files changed

+75
-28
lines changed

monai/data/meta_obj.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
_TRACK_META = True
2626

27+
# Default number of spatial dimensions for medical imaging (3D volumetric data)
28+
_DEFAULT_SPATIAL_NDIM = 3
29+
2730
__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]
2831

2932

monai/data/meta_tensor.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import functools
1515
import warnings
16-
from collections.abc import Sequence
16+
from collections.abc import Mapping, Sequence
1717
from copy import deepcopy
1818
from typing import Any
1919

@@ -22,18 +22,27 @@
2222

2323
import monai
2424
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
25-
from monai.data.meta_obj import MetaObj, get_track_meta
26-
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
25+
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta
26+
from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata
2727
from monai.utils import look_up_option
2828
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
2929
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
3030

3131
__all__ = ["MetaTensor", "get_spatial_ndim"]
3232

3333

34-
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int:
34+
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int:
3535
"""Clamp spatial dims to a valid range for the current tensor shape."""
36-
return max(1, min(int(spatial_ndim), max(int(tensor_ndim) - 1, 1)))
36+
limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1)
37+
return max(1, min(int(spatial_ndim), limit))
38+
39+
40+
def _has_explicit_no_channel(meta: Mapping | None) -> bool:
41+
return (
42+
isinstance(meta, Mapping)
43+
and MetaKeys.ORIGINAL_CHANNEL_DIM in meta
44+
and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM])
45+
)
3746

3847

3948
def get_spatial_ndim(img: NdarrayOrTensor) -> int:
@@ -43,16 +52,22 @@ def get_spatial_ndim(img: NdarrayOrTensor) -> int:
4352
``img.ndim - 1``.
4453
"""
4554
if isinstance(img, MetaTensor):
46-
inferred = _normalize_spatial_ndim(img.spatial_ndim, img.ndim)
47-
shape_spatial = max(img.ndim - 1, 1)
48-
# For non-batched tensors, preserve explicit higher-rank shape information
49-
# (e.g., invalid 4D spatial inputs should still be reported as rank 4).
50-
if not img.is_batch and shape_spatial > inferred:
51-
return shape_spatial
52-
return inferred
55+
no_channel = _has_explicit_no_channel(img.meta)
56+
return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel)
5357
return img.ndim - 1
5458

5559

60+
def _is_batch_only_index(index: Any) -> bool:
61+
"""True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``)."""
62+
if isinstance(index, (int, np.integer)):
63+
return True
64+
if not isinstance(index, Sequence) or not index:
65+
return False
66+
if not isinstance(index[0], (int, np.integer)):
67+
return False
68+
return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:])
69+
70+
5671
@functools.lru_cache(None)
5772
def _get_named_tuple_like_type(func):
5873
if (
@@ -184,11 +199,13 @@ def __init__(
184199
self.affine = self.meta[MetaKeys.AFFINE]
185200
else:
186201
self.affine = self.get_default_affine()
187-
# derive spatial_ndim from affine, clamped by tensor shape
202+
# Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape.
203+
# This cached value is kept in sync via the affine setter for hot-path performance.
204+
no_channel = _has_explicit_no_channel(self.meta)
188205
if spatial_ndim is not None:
189-
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim)
206+
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel)
190207
elif self.affine.ndim == 2:
191-
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim)
208+
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel)
192209

193210
# applied_operations
194211
if applied_operations is not None:
@@ -254,8 +271,6 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
254271
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
255272
if is_batch:
256273
ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
257-
if func == torch.Tensor.__getitem__:
258-
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim)
259274
out.append(ret)
260275
# if the input was a tuple, then return it as a tuple
261276
return tuple(out) if isinstance(rets, tuple) else out
@@ -271,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
271286
if func == torch.Tensor.__getitem__:
272287
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
273288
return ret
289+
full_idx = args[1]
274290
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
275291
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
276292
# first element will be `slice(None, None, None)` and `Ellipsis`,
@@ -292,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
292308
ret_meta.is_batch = False
293309
if hasattr(ret_meta, "__dict__"):
294310
ret.__dict__ = ret_meta.__dict__.copy()
311+
if _is_batch_only_index(full_idx):
312+
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False)
295313
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
296314
# But we only want to split the batch if the `unbind` is along the 0th dimension.
297315
elif func == torch.Tensor.unbind:
@@ -501,16 +519,26 @@ def affine(self) -> torch.Tensor:
501519

502520
@affine.setter
503521
def affine(self, d: NdarrayTensor) -> None:
504-
"""Set the affine."""
522+
"""Set the affine.
523+
524+
When setting a non-batched affine matrix, automatically synchronizes the cached
525+
spatial_ndim attribute to maintain consistency between the affine matrix (source of truth)
526+
and the cached spatial dimension count.
527+
"""
505528
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
506529
self.meta[MetaKeys.AFFINE] = a
507-
if a.ndim == 2: # non-batched: sync spatial_ndim
508-
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim)
530+
if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth)
531+
no_channel = _has_explicit_no_channel(self.meta)
532+
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel)
509533

510534
@property
511535
def spatial_ndim(self) -> int:
512-
"""Get the number of spatial dimensions."""
513-
return getattr(self, "_spatial_ndim", 3)
536+
"""Get the number of spatial dimensions.
537+
538+
This value is cached for hot-path performance and is kept in sync with the affine matrix
539+
via the affine setter. The affine matrix is the source of truth for spatial dimensions.
540+
"""
541+
return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM)
514542

515543
@spatial_ndim.setter
516544
def spatial_ndim(self, val: int) -> None:

monai/data/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.utils.data._utils.collate import default_collate
3232

3333
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
34-
from monai.data.meta_obj import MetaObj
34+
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj
3535
from monai.utils import (
3636
MAX_SEED,
3737
BlendMode,
@@ -432,7 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
432432
collated.meta = default_collate(meta_dicts)
433433
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
434434
collated.is_batch = True
435-
collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", 3), max(collated.ndim - 1, 1))
435+
collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", _DEFAULT_SPATIAL_NDIM), max(collated.ndim - 1, 1))
436436
return collated
437437

438438

monai/transforms/post/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,10 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
625625
img = convert_to_tensor(img, track_meta=get_track_meta())
626626
img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
627627
spatial_dims = get_spatial_ndim(img)
628+
# Validate actual tensor shape against tracked spatial_ndim
629+
actual_spatial = img_.ndim - 1 # channel-first layout
630+
if actual_spatial != spatial_dims:
631+
spatial_dims = actual_spatial
628632
img_ = img_.unsqueeze(0) # adds a batch dim
629633
if spatial_dims == 2:
630634
kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)

monai/transforms/spatial/array.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,9 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
10361036
out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0]
10371037
if isinstance(out, MetaTensor):
10381038
affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False)
1039+
# Use affine matrix shape directly (not spatial_ndim) because the affine may be
1040+
# larger than the spatial dimensions (e.g., 4x4 for 2D data), and we need to match
1041+
# the actual affine matrix rank being composed
10391042
mat = to_affine_nd(len(affine) - 1, transform_t)
10401043
out.affine @= convert_to_dst_type(mat, affine)[0]
10411044
return out
@@ -2352,6 +2355,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
23522355
out = MetaTensor(out)
23532356
out.meta = data.meta # type: ignore
23542357
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
2358+
# Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility
2359+
# when affine is larger than spatial dimensions (e.g., 4x4 for 2D data)
23552360
xform, *_ = convert_to_dst_type(
23562361
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
23572362
)
@@ -2621,6 +2626,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
26212626
out = MetaTensor(out)
26222627
out.meta = data.meta # type: ignore
26232628
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
2629+
# Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility
2630+
# when affine is larger than spatial dimensions (e.g., 4x4 for 2D data)
26242631
xform, *_ = convert_to_dst_type(
26252632
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
26262633
)
@@ -3035,10 +3042,11 @@ def __call__(
30353042
raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`")
30363043

30373044
all_ranges = []
3038-
num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img))
3045+
_sp = get_spatial_ndim(img)
3046+
num_cells = ensure_tuple_rep(self.num_cells, _sp)
30393047
if isinstance(img, MetaTensor) and img.pending_operations:
30403048
warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.")
3041-
for dim_idx, dim_size in enumerate(img.shape[1:]):
3049+
for dim_idx, dim_size in enumerate(img.shape[1 : 1 + _sp]):
30423050
dim_distort_steps = distort_steps[dim_idx]
30433051
ranges = torch.zeros(dim_size, dtype=torch.float32)
30443052
cell_size = dim_size // num_cells[dim_idx]

monai/transforms/utility/array.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]:
334334
shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype)
335335
shift[dim - 1, -1] = idx
336336
out.affine = out.affine @ shift
337-
if not self.keepdim:
338-
out.spatial_ndim = max(1, out.spatial_ndim - 1)
339337
return outputs
340338

341339

tests/data/meta_tensor/test_spatial_ndim.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def test_lazy_apply_pending_2d(self):
123123

124124
def test_batch_slice_clamps_spatial_ndim(self):
125125
t = MetaTensor(torch.randn(10, 6, 5, 7), affine=torch.eye(4))
126+
t.is_batch = True
127+
t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1)
126128
self.assertEqual(t.spatial_ndim, 3)
127129
sliced = t[0]
128130
self.assertEqual(sliced.shape, (6, 5, 7))
@@ -131,12 +133,16 @@ def test_batch_slice_clamps_spatial_ndim(self):
131133

132134
def test_label_to_contour_batch_slice_2d(self):
133135
t = MetaTensor(torch.randint(0, 2, (10, 6, 5, 7)).float(), affine=torch.eye(4))
136+
t.is_batch = True
137+
t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1)
134138
sliced = t[0]
135139
out = LabelToContour()(sliced)
136140
self.assertEqual(out.shape, sliced.shape)
137141

138142
def test_rand_zoom_batch_slice_2d(self):
139143
t = MetaTensor(torch.randn(10, 1, 64, 64), affine=torch.eye(4))
144+
t.is_batch = True
145+
t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1)
140146
sliced = t[0]
141147
zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=1.2)
142148
zoom.set_random_state(seed=0)

0 commit comments

Comments
 (0)