5050from monai .transforms .traits import LazyTrait , MultiSampleTrait
5151from monai .transforms .transform import LazyTransform , MapTransform , Randomizable
5252from monai .transforms .utils import is_positive
53- from monai .utils import MAX_SEED , Method , PytorchPadMode , ensure_tuple_rep
53+ from monai .utils import MAX_SEED , Method , PytorchPadMode , TraceKeys , ensure_tuple_rep
5454
5555__all__ = [
5656 "Padd" ,
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431 - a spatial center and size
432432 - the start and end coordinates of the ROI
433433
434+ ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+ string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+ data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+ transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+ data dictionary under the given key.
439+
440+ Example::
441+
442+ from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+ pipeline = Compose([
445+ TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+ TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+ SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+ ])
449+
434450 This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451 for more information.
436452 """
437453
438454 def __init__ (
439455 self ,
440456 keys : KeysCollection ,
441- roi_center : Sequence [int ] | int | None = None ,
442- roi_size : Sequence [int ] | int | None = None ,
443- roi_start : Sequence [int ] | int | None = None ,
444- roi_end : Sequence [int ] | int | None = None ,
457+ roi_center : Sequence [int ] | int | str | None = None ,
458+ roi_size : Sequence [int ] | int | str | None = None ,
459+ roi_start : Sequence [int ] | int | str | None = None ,
460+ roi_end : Sequence [int ] | int | str | None = None ,
445461 roi_slices : Sequence [slice ] | None = None ,
446462 allow_missing_keys : bool = False ,
447463 lazy : bool = False ,
@@ -450,19 +466,92 @@ def __init__(
450466 Args:
451467 keys: keys of the corresponding items to be transformed.
452468 See also: :py:class:`monai.transforms.compose.MapTransform`
453- roi_center: voxel coordinates for center of the crop ROI.
469+ roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+ the coordinates from the data dictionary.
454471 roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455- will not crop that dimension of the image.
456- roi_start: voxel coordinates for start of the crop ROI.
472+ will not crop that dimension of the image. Can also be a string key.
473+ roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+ the coordinates from the data dictionary.
457475 roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458- use the end coordinate of image.
476+ use the end coordinate of image. Can also be a string key.
459477 roi_slices: list of slices for each of the spatial dimensions.
460478 allow_missing_keys: don't raise exception if key is missing.
461479 lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480 """
463- cropper = SpatialCrop (roi_center , roi_size , roi_start , roi_end , roi_slices , lazy = lazy )
481+ self ._roi_center = roi_center
482+ self ._roi_size = roi_size
483+ self ._roi_start = roi_start
484+ self ._roi_end = roi_end
485+ self ._roi_slices = roi_slices
486+ self ._has_str_roi = any (isinstance (v , str ) for v in [roi_center , roi_size , roi_start , roi_end ])
487+
488+ if not self ._has_str_roi :
489+ cropper = SpatialCrop (roi_center , roi_size , roi_start , roi_end , roi_slices , lazy = lazy )
490+ else :
491+ # Placeholder cropper for the string-key path. Replaced on self.cropper at
492+ # __call__ time once string keys are resolved from the data dictionary.
493+ cropper = SpatialCrop (roi_start = [0 ], roi_end = [1 ], lazy = lazy )
464494 super ().__init__ (keys , cropper = cropper , allow_missing_keys = allow_missing_keys , lazy = lazy )
465495
496+ @staticmethod
497+ def _resolve_roi_param (val , d ):
498+ """Resolve an ROI parameter: if it's a string, look it up in the data dict and flatten."""
499+ if not isinstance (val , str ):
500+ return val
501+ if val not in d :
502+ raise KeyError (f"ROI key '{ val } ' not found in the data dictionary." )
503+ resolved = d [val ]
504+ # ApplyTransformToPoints outputs tensors of shape (C, N, dims).
505+ # A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
506+ # Flatten to 1-D and round to integers for compute_slices.
507+ # Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
508+ if isinstance (resolved , torch .Tensor ):
509+ resolved = torch .round (resolved .flatten ()).to (torch .int64 )
510+ return resolved
511+
512+ @property
513+ def requires_current_data (self ):
514+ return self ._has_str_roi
515+
516+ def __call__ (self , data : Mapping [Hashable , torch .Tensor ], lazy : bool | None = None ) -> dict [Hashable , torch .Tensor ]:
517+ if not self ._has_str_roi :
518+ return super ().__call__ (data , lazy = lazy )
519+
520+ d = dict (data )
521+ roi_center = self ._resolve_roi_param (self ._roi_center , d )
522+ roi_size = self ._resolve_roi_param (self ._roi_size , d )
523+ roi_start = self ._resolve_roi_param (self ._roi_start , d )
524+ roi_end = self ._resolve_roi_param (self ._roi_end , d )
525+
526+ lazy_ = self .lazy if lazy is None else lazy
527+ # Store on self.cropper so that Cropd.inverse() can find the matching
528+ # transform ID via check_transforms_match. This mirrors the pattern
529+ # used by CropForegroundd.
530+ self .cropper = SpatialCrop (
531+ roi_center = roi_center , roi_size = roi_size ,
532+ roi_start = roi_start , roi_end = roi_end ,
533+ roi_slices = self ._roi_slices , lazy = lazy_ ,
534+ )
535+ for key in self .key_iterator (d ):
536+ d [key ] = self .cropper (d [key ], lazy = lazy_ )
537+ return d
538+
539+ def inverse (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor ]:
540+ if not self ._has_str_roi :
541+ return super ().inverse (data )
542+ # For the string-key path, self.cropper is recreated on each __call__,
543+ # so its id() won't match the one stored in the MetaTensor's transform
544+ # stack. We bypass the ID check and apply the inverse directly using the
545+ # crop info stored in the MetaTensor.
546+ d = dict (data )
547+ for key in self .key_iterator (d ):
548+ transform = self .cropper .pop_transform (d [key ], check = False )
549+ cropped = transform [TraceKeys .EXTRA_INFO ]["cropped" ]
550+ inverse_transform = BorderPad (cropped )
551+ with inverse_transform .trace_transform (False ):
552+ d [key ] = inverse_transform (d [key ])
553+ return d
554+
466555
467556class CenterSpatialCropd (Cropd ):
468557 """
0 commit comments