@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216 super ().__init__ (stats_name , report_format )
217217 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
218218
219+ @torch .no_grad ()
219220 def __call__ (self , data ):
220- # Input Validation Addition
221- if not isinstance (data , dict ):
222- raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
223- if self .image_key not in data :
224- raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
225- image = data [self .image_key ]
226- if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
227- raise TypeError (
228- f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
229- f"but got { type (image ).__name__ } ."
230- )
231- if image .ndim < 3 :
232- raise ValueError (
233- f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
234- )
235- # --- End of validation ---
236221 """
237- Callable to execute the pre-defined functions
222+ Callable to execute the pre-defined functions.
238223
239224 Returns:
240225 A dictionary. The dict has the key in self.report_format. The value of
241226 ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227 has stats pre-defined by SampleOperations (max, min, ....).
243228
244229 Raises:
245- RuntimeError if the stats report generated is not consistent with the pre-
230+ KeyError: if ``self.image_key`` is not present in the input data.
231+ TypeError: if the input data is not a dictionary, or if the image value is
232+ not a numpy array, torch.Tensor, or MetaTensor.
233+ ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+ ``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+ RuntimeError: if the stats report generated is not consistent with the pre-
246236 defined report_format.
247237
248238 Note:
249239 The stats operation uses numpy and torch to compute max, min, and other
250240 functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242 """
243+ if not isinstance (data , dict ):
244+ raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
245+ if self .image_key not in data :
246+ raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
247+ image = data [self .image_key ]
248+ if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
249+ raise TypeError (
250+ f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
251+ f"but got { type (image ).__name__ } ."
252+ )
253+ if image .ndim < 3 :
254+ raise ValueError (
255+ f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
256+ )
257+
253258 d = dict (data )
254259 start = time .time ()
255- restore_grad_state = torch .is_grad_enabled ()
256- torch .set_grad_enabled (False )
257-
258260 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
259- if "nda_croppeds" not in d :
261+ if "nda_croppeds" in d :
262+ nda_croppeds = d ["nda_croppeds" ]
263+ if not isinstance (nda_croppeds , (list , tuple )) or len (nda_croppeds ) != len (ndas ):
264+ raise ValueError (
265+ "Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+ f"(expected { len (ndas )} )."
267+ )
268+ else :
260269 nda_croppeds = [get_foreground_image (nda ) for nda in ndas ]
261270
262- # perform calculation
263271 report = deepcopy (self .get_report_format ())
264272
265273 report [ImageStatsKeys .SHAPE ] = [list (nda .shape ) for nda in ndas ]
@@ -284,7 +292,6 @@ def __call__(self, data):
284292
285293 d [self .stats_name ] = report
286294
287- torch .set_grad_enabled (restore_grad_state )
288295 logger .debug (f"Get image stats spent { time .time () - start } " )
289296 return d
290297
@@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321328 super ().__init__ (stats_name , report_format )
322329 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
323330
331+ @torch .no_grad ()
324332 def __call__ (self , data : Mapping ) -> dict :
325333 """
326334 Callable to execute the pre-defined functions
@@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:
341349
342350 d = dict (data )
343351 start = time .time ()
344- restore_grad_state = torch .is_grad_enabled ()
345- torch .set_grad_enabled (False )
346-
347352 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
348353 ndas_label = d [self .label_key ] # (H,W,D)
349354
@@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
353358 nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
354359 nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
355360
356- # perform calculation
357361 report = deepcopy (self .get_report_format ())
358362
359363 report [ImageStatsKeys .INTENSITY ] = [
@@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:
365369
366370 d [self .stats_name ] = report
367371
368- torch .set_grad_enabled (restore_grad_state )
369372 logger .debug (f"Get foreground image stats spent { time .time () - start } " )
370373 return d
371374
@@ -418,6 +421,7 @@ def __init__(
418421 id_seq = ID_SEP_KEY .join ([LabelStatsKeys .LABEL , "0" , LabelStatsKeys .IMAGE_INTST ])
419422 self .update_ops_nested_label (id_seq , SampleOperations ())
420423
424+ @torch .no_grad ()
421425 def __call__ (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor | dict ]:
422426 """
423427 Callable to execute the pre-defined functions.
@@ -470,19 +474,15 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470474 start = time .time ()
471475 image_tensor = d [self .image_key ]
472476 label_tensor = d [self .label_key ]
473- # Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474477 using_cuda = any (
475478 isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda" for t in (image_tensor , label_tensor )
476479 )
477- restore_grad_state = torch .is_grad_enabled ()
478- torch .set_grad_enabled (False )
479480
480481 if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
481482 label_tensor , (MetaTensor , torch .Tensor )
482483 ):
483484 if label_tensor .device != image_tensor .device :
484485 if using_cuda :
485- # Move both tensors to CUDA when mixing devices
486486 cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
487487 image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
488488 label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
@@ -548,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
548548
549549 d [self .stats_name ] = report # type: ignore[assignment]
550550
551- torch .set_grad_enabled (restore_grad_state )
552551 logger .debug (f"Get label stats spent { time .time () - start } " )
553552 return d # type: ignore[return-value]
554553
0 commit comments