Skip to content

Commit a900675

Browse files
committed
Enable global coordinates in spatial crop transforms
Add convenience transforms for converting points between world and image coordinate spaces, and extend SpatialCropd to accept string dictionary keys for ROI parameters, enabling deferred coordinate resolution at call time. New transforms: - TransformPointsWorldToImaged: world-to-image coordinate conversion - TransformPointsImageToWorldd: image-to-world coordinate conversion SpatialCropd changes: - roi_center, roi_size, roi_start, roi_end now accept string keys - When strings are provided, coordinates are resolved from the data dictionary at __call__ time (zero overhead for existing usage) - Tensors from ApplyTransformToPoints are automatically flattened and rounded to integers
1 parent b3fff92 commit a900675

File tree

6 files changed

+515
-9
lines changed

6 files changed

+515
-9
lines changed

monai/transforms/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,12 @@
562562
ApplyTransformToPointsd,
563563
ApplyTransformToPointsD,
564564
ApplyTransformToPointsDict,
565+
TransformPointsWorldToImaged,
566+
TransformPointsWorldToImageD,
567+
TransformPointsWorldToImageDict,
568+
TransformPointsImageToWorldd,
569+
TransformPointsImageToWorldD,
570+
TransformPointsImageToWorldDict,
565571
AsChannelLastd,
566572
AsChannelLastD,
567573
AsChannelLastDict,

monai/transforms/croppad/dictionary.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431
- a spatial center and size
432432
- the start and end coordinates of the ROI
433433
434+
ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+
string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+
data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+
transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+
data dictionary under the given key.
439+
440+
Example::
441+
442+
from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+
pipeline = Compose([
445+
TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+
TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+
SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+
])
449+
434450
This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451
for more information.
436452
"""
437453

438454
def __init__(
439455
self,
440456
keys: KeysCollection,
441-
roi_center: Sequence[int] | int | None = None,
442-
roi_size: Sequence[int] | int | None = None,
443-
roi_start: Sequence[int] | int | None = None,
444-
roi_end: Sequence[int] | int | None = None,
457+
roi_center: Sequence[int] | int | str | None = None,
458+
roi_size: Sequence[int] | int | str | None = None,
459+
roi_start: Sequence[int] | int | str | None = None,
460+
roi_end: Sequence[int] | int | str | None = None,
445461
roi_slices: Sequence[slice] | None = None,
446462
allow_missing_keys: bool = False,
447463
lazy: bool = False,
@@ -450,19 +466,75 @@ def __init__(
450466
Args:
451467
keys: keys of the corresponding items to be transformed.
452468
See also: :py:class:`monai.transforms.compose.MapTransform`
453-
roi_center: voxel coordinates for center of the crop ROI.
469+
roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+
the coordinates from the data dictionary.
454471
roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455-
will not crop that dimension of the image.
456-
roi_start: voxel coordinates for start of the crop ROI.
472+
will not crop that dimension of the image. Can also be a string key.
473+
roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+
the coordinates from the data dictionary.
457475
roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458-
use the end coordinate of image.
476+
use the end coordinate of image. Can also be a string key.
459477
roi_slices: list of slices for each of the spatial dimensions.
460478
allow_missing_keys: don't raise exception if key is missing.
461479
lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480
"""
463-
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
481+
self._roi_center = roi_center
482+
self._roi_size = roi_size
483+
self._roi_start = roi_start
484+
self._roi_end = roi_end
485+
self._roi_slices = roi_slices
486+
self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end])
487+
488+
if not self._has_str_roi:
489+
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
490+
else:
491+
# Placeholder cropper for the string-key path. Actual cropping is done with a
492+
# local SpatialCrop created at __call__ time once string keys are resolved.
493+
# Crop.inverse() reads crop info from the MetaTensor's transform stack (not from
494+
# cropper state), so the placeholder still works correctly for inverse operations.
495+
cropper = SpatialCrop(roi_start=[0], roi_end=[1], lazy=lazy)
464496
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)
465497

498+
@staticmethod
499+
def _resolve_roi_param(val, d):
500+
"""Resolve an ROI parameter: if it's a string, look it up in the data dict and flatten."""
501+
if not isinstance(val, str):
502+
return val
503+
if val not in d:
504+
raise KeyError(f"ROI key '{val}' not found in the data dictionary.")
505+
resolved = d[val]
506+
# ApplyTransformToPoints outputs tensors of shape (C, N, dims).
507+
# A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
508+
# Flatten to 1-D and round to integers for compute_slices.
509+
# Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
510+
if isinstance(resolved, torch.Tensor):
511+
resolved = torch.round(resolved.flatten()).to(torch.int64)
512+
return resolved
513+
514+
@property
515+
def requires_current_data(self):
516+
return self._has_str_roi
517+
518+
def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:
519+
if not self._has_str_roi:
520+
return super().__call__(data, lazy=lazy)
521+
522+
d = dict(data)
523+
roi_center = self._resolve_roi_param(self._roi_center, d)
524+
roi_size = self._resolve_roi_param(self._roi_size, d)
525+
roi_start = self._resolve_roi_param(self._roi_start, d)
526+
roi_end = self._resolve_roi_param(self._roi_end, d)
527+
528+
lazy_ = self.lazy if lazy is None else lazy
529+
cropper = SpatialCrop(
530+
roi_center=roi_center, roi_size=roi_size,
531+
roi_start=roi_start, roi_end=roi_end,
532+
roi_slices=self._roi_slices, lazy=lazy_,
533+
)
534+
for key in self.key_iterator(d):
535+
d[key] = cropper(d[key], lazy=lazy_)
536+
return d
537+
466538

467539
class CenterSpatialCropd(Cropd):
468540
"""

monai/transforms/utility/dictionary.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@
192192
"ApplyTransformToPointsd",
193193
"ApplyTransformToPointsD",
194194
"ApplyTransformToPointsDict",
195+
"TransformPointsWorldToImaged",
196+
"TransformPointsWorldToImageD",
197+
"TransformPointsWorldToImageDict",
198+
"TransformPointsImageToWorldd",
199+
"TransformPointsImageToWorldD",
200+
"TransformPointsImageToWorldDict",
195201
"FlattenSequenced",
196202
"FlattenSequenceD",
197203
"FlattenSequenceDict",
@@ -1910,6 +1916,86 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19101916
return d
19111917

19121918

1919+
class TransformPointsWorldToImaged(ApplyTransformToPointsd):
1920+
"""
1921+
Dictionary-based transform to convert points from world coordinates to image coordinates.
1922+
1923+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1924+
``invert_affine=True``, which transforms world-space coordinates into the coordinate space of a
1925+
reference image by inverting the image's affine matrix.
1926+
1927+
Args:
1928+
keys: keys of the corresponding items to be transformed.
1929+
See also: monai.transforms.MapTransform
1930+
refer_keys: The key of the reference image used to derive the affine transformation.
1931+
This is required because the affine must come from a reference image.
1932+
It can also be a sequence of keys, in which case each refers to the affine applied
1933+
to the matching points in ``keys``.
1934+
dtype: The desired data type for the output.
1935+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1936+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1937+
allow_missing_keys: Don't raise exception if key is missing.
1938+
"""
1939+
1940+
def __init__(
1941+
self,
1942+
keys: KeysCollection,
1943+
refer_keys: KeysCollection,
1944+
dtype: DtypeLike | torch.dtype = torch.float64,
1945+
affine_lps_to_ras: bool = False,
1946+
allow_missing_keys: bool = False,
1947+
):
1948+
super().__init__(
1949+
keys=keys,
1950+
refer_keys=refer_keys,
1951+
dtype=dtype,
1952+
affine=None,
1953+
invert_affine=True,
1954+
affine_lps_to_ras=affine_lps_to_ras,
1955+
allow_missing_keys=allow_missing_keys,
1956+
)
1957+
1958+
1959+
class TransformPointsImageToWorldd(ApplyTransformToPointsd):
1960+
"""
1961+
Dictionary-based transform to convert points from image coordinates to world coordinates.
1962+
1963+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1964+
``invert_affine=False``, which transforms image-space coordinates into world-space coordinates
1965+
by applying the reference image's affine matrix directly.
1966+
1967+
Args:
1968+
keys: keys of the corresponding items to be transformed.
1969+
See also: monai.transforms.MapTransform
1970+
refer_keys: The key of the reference image used to derive the affine transformation.
1971+
This is required because the affine must come from a reference image.
1972+
It can also be a sequence of keys, in which case each refers to the affine applied
1973+
to the matching points in ``keys``.
1974+
dtype: The desired data type for the output.
1975+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1976+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1977+
allow_missing_keys: Don't raise exception if key is missing.
1978+
"""
1979+
1980+
def __init__(
1981+
self,
1982+
keys: KeysCollection,
1983+
refer_keys: KeysCollection,
1984+
dtype: DtypeLike | torch.dtype = torch.float64,
1985+
affine_lps_to_ras: bool = False,
1986+
allow_missing_keys: bool = False,
1987+
):
1988+
super().__init__(
1989+
keys=keys,
1990+
refer_keys=refer_keys,
1991+
dtype=dtype,
1992+
affine=None,
1993+
invert_affine=False,
1994+
affine_lps_to_ras=affine_lps_to_ras,
1995+
allow_missing_keys=allow_missing_keys,
1996+
)
1997+
1998+
19131999
class FlattenSequenced(MapTransform, ReduceTrait):
19142000
"""
19152001
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.
@@ -1975,4 +2061,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
19752061
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
19762062
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
19772063
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
2064+
TransformPointsWorldToImageD = TransformPointsWorldToImageDict = TransformPointsWorldToImaged
2065+
TransformPointsImageToWorldD = TransformPointsImageToWorldDict = TransformPointsImageToWorldd
19782066
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced

0 commit comments

Comments
 (0)