-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathinverse.py
More file actions
445 lines (382 loc) · 20 KB
/
inverse.py
File metadata and controls
445 lines (382 loc) · 20 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
import warnings
from collections.abc import Hashable, Mapping
from contextlib import contextmanager
from typing import Any
import torch
from monai import transforms
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import affine_to_spacing, to_affine_nd
from monai.transforms.traits import InvertibleTrait
from monai.transforms.transform import Transform
from monai.utils import (
LazyAttr,
MetaKeys,
TraceKeys,
TraceStatusKeys,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
)
from monai.utils.misc import MONAIEnvVars
__all__ = ["TraceableTransform", "InvertibleTransform"]
class TraceableTransform(Transform):
"""
Maintains a stack of applied transforms to data.
Data can be one of two types:
1. A `MetaTensor` (this is the preferred data type).
2. A dictionary of data containing arrays/tensors and auxiliary metadata. In
this case, a key must be supplied (this dictionary-based approach is deprecated).
If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``.
If `data` is a dictionary, then one of two things can happen:
1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``.
2. Else, the applied transform will be appended to an adjacent list using
`trace_key`. If, for example, the key is `image`, then the transform
will be appended to `image_transforms` (this dictionary-based approach is deprecated).
Hopefully it is clear that there are three total possibilities:
1. data is `MetaTensor`
2. data is dictionary, data[key] is `MetaTensor`
3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach).
The ``__call__`` method of this transform class must be implemented so
that the transformation information is stored during the data transformation.
The information in the stack of applied transforms must be compatible with the
default collate, by only storing strings, numbers and arrays.
`tracing` could be enabled by assigning to `self.tracing` or setting
`MONAI_TRACE_TRANSFORM` when initializing the class.
"""
def _init_trace_threadlocal(self):
"""Create a `_tracing` instance member to store the thread-local tracing state value."""
# needed since this class is meant to be a trait with no constructor
if not hasattr(self, "_tracing"):
self._tracing = threading.local()
# This is True while the above initialising _tracing is False when this is
# called from a different thread than the one initialising _tracing.
if not hasattr(self._tracing, "value"):
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
def __getstate__(self):
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
_slots = {k: getattr(self, k) for k in getattr(self, "__slots__", [])}
_dict.pop("_tracing", None) # remove tracing
return _dict if len(_slots) == 0 else (_dict, _slots)
@property
def tracing(self) -> bool:
"""
Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
"""
self._init_trace_threadlocal()
return bool(self._tracing.value)
@tracing.setter
def tracing(self, val: bool):
"""Sets the thread-local tracing state to `val`."""
self._init_trace_threadlocal()
self._tracing.value = val
@staticmethod
def trace_key(key: Hashable = None):
"""The key to store the stack of applied transforms."""
if key is None:
return f"{TraceKeys.KEY_SUFFIX}"
return f"{key}{TraceKeys.KEY_SUFFIX}"
@staticmethod
def transform_info_keys():
"""The keys to store necessary info of an applied transform."""
return (TraceKeys.CLASS_NAME, TraceKeys.ID, TraceKeys.TRACING, TraceKeys.DO_TRANSFORM)
def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
vals = (
self.__class__.__name__,
id(self),
self.tracing,
self._do_transform if hasattr(self, "_do_transform") else True,
)
return dict(zip(self.transform_info_keys(), vals))
def push_transform(self, data, *args, **kwargs):
"""
Push to a stack of applied transforms of ``data``.
Args:
data: dictionary of data or `MetaTensor`.
args: additional positional arguments to track_transform_meta.
kwargs: additional keyword arguments to track_transform_meta,
set ``replace=True`` (default False) to rewrite the last transform infor in
applied_operation/pending_operation based on ``self.get_transform_info()``.
"""
lazy_eval = kwargs.get("lazy", False)
transform_info = self.get_transform_info()
do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True)
kwargs = kwargs or {}
replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info
if replace and get_track_meta() and isinstance(data, MetaTensor):
if not lazy_eval:
xform = self.pop_transform(data, check=False) if do_transform else {}
meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform)
return data.copy_meta_from(meta_obj)
if do_transform:
xform = data.pending_operations.pop()
extra = xform.copy()
xform.update(transform_info)
else: # lazy, replace=True, do_transform=False
xform, extra = transform_info, {}
meta_obj = self.push_transform(data, transform_info=xform, lazy=True, extra_info=extra)
return data.copy_meta_from(meta_obj)
kwargs["lazy"] = lazy_eval
if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict):
kwargs["transform_info"].update(transform_info)
else:
kwargs["transform_info"] = transform_info
meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs)
return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data
@classmethod
def track_transform_meta(
cls,
data,
key: Hashable = None,
sp_size=None,
affine=None,
extra_info: dict | None = None,
orig_size: tuple | None = None,
transform_info=None,
lazy=False,
):
"""
Update a stack of applied/pending transforms metadata of ``data``.
Args:
data: dictionary of data or `MetaTensor`.
key: if data is a dictionary, data[key] will be modified.
sp_size: the expected output spatial size when the transform is applied.
it can be tensor or numpy, but will be converted to a list of integers.
affine: the affine representation of the (spatial) transform in the image space.
When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``.
extra_info: if desired, any extra information pertaining to the applied
transform can be stored in this dictionary. These are often needed for
computing the inverse transformation.
orig_size: sometimes during the inverse it is useful to know what the size
of the original image was, in which case it can be supplied here.
transform_info: info from self.get_transform_info().
lazy: whether to push the transform to pending_operations or applied_operations.
Returns:
For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with
updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata.
"""
data_t = data[key] if key is not None else data # compatible with the dict data representation
out_obj = MetaObj()
# after deprecating metadict, we should always convert data_t to metatensor here
if isinstance(data_t, MetaTensor):
out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys())
if lazy and (not get_track_meta()):
warnings.warn("metadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.")
if not lazy and affine is not None and isinstance(data_t, MetaTensor):
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
try:
affine = orig_affine @ to_affine_nd(orig_affine.shape[-1] - 1, affine, dtype=torch.float64)
except RuntimeError as e:
if orig_affine.ndim > 2:
if data_t.is_batch:
msg = "Transform applied to batched tensor, should be applied to instances only"
else:
msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation."
raise RuntimeError(msg) from e
else:
raise
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)
if MetaKeys.PIXDIM in out_obj.meta:
spacing = affine_to_spacing(out_obj.meta[MetaKeys.AFFINE])
out_obj.meta[MetaKeys.PIXDIM][1 : 1 + len(spacing)] = spacing
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t
return data
return out_obj # return with data_t as tensor if get_track_meta() is False
info = transform_info.copy()
# track the current spatial shape
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
elif isinstance(data_t, MetaTensor):
info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape()
elif hasattr(data_t, "shape"):
info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]
# add lazy status to the transform info
info[TraceKeys.LAZY] = lazy
# include extra_info
if extra_info is not None:
extra_info.pop(LazyAttr.SHAPE, None)
extra_info.pop(LazyAttr.AFFINE, None)
info[TraceKeys.EXTRA_INFO] = extra_info
# push the transform info to the applied_operation or pending_operation stack
if lazy:
if sp_size is None:
if LazyAttr.SHAPE not in info:
info[LazyAttr.SHAPE] = info.get(TraceKeys.ORIG_SIZE, [])
else:
info[LazyAttr.SHAPE] = sp_size
info[LazyAttr.SHAPE] = tuple(convert_to_numpy(info[LazyAttr.SHAPE], wrap_sequence=True).tolist())
if affine is None:
if LazyAttr.AFFINE not in info:
info[LazyAttr.AFFINE] = MetaTensor.get_default_affine()
else:
info[LazyAttr.AFFINE] = affine
info[LazyAttr.AFFINE] = convert_to_tensor(info[LazyAttr.AFFINE], device=torch.device("cpu"))
out_obj.push_pending_operation(info)
else:
if out_obj.pending_operations:
transform_name = info.get(TraceKeys.CLASS_NAME, "") if isinstance(info, dict) else ""
msg = (
f"Transform {transform_name} has been applied to a MetaTensor with pending operations: "
f"{[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}"
)
if key is not None:
msg += f" for key {key}"
pend = out_obj.pending_operations[-1]
statuses = pend.get(TraceKeys.STATUSES, dict())
messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, list())
messages.append(msg)
statuses[TraceStatusKeys.PENDING_DURING_APPLY] = messages
info[TraceKeys.STATUSES] = statuses
out_obj.push_applied_operation(info)
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
if isinstance(data_t, MetaTensor):
data[key] = data_t.copy_meta_from(out_obj)
else:
x_k = TraceableTransform.trace_key(key)
if x_k not in data:
data[x_k] = [] # If this is the first, create list
data[x_k].append(info)
return data
return out_obj
def check_transforms_match(self, transform: Mapping) -> None:
"""Check transforms are of same instance."""
xform_id = transform.get(TraceKeys.ID, "")
if xform_id == id(self):
return
# TraceKeys.NONE to skip the id check
if xform_id == TraceKeys.NONE:
return
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get("warn")
if warning_msg:
warnings.warn(warning_msg)
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
return
raise RuntimeError(
f"Error {self.__class__.__name__} getting the most recently "
f"applied invertible transform {xform_name} {xform_id} != {id(self)}."
)
def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
"""
Get most recent matching transform for the current class from the sequence of applied operations.
Args:
data: dictionary of data or `MetaTensor`.
key: if data is a dictionary, data[key] will be modified.
check: if true, check that `self` is the same type as the most recently-applied transform.
pop: if true, remove the transform as it is returned.
Returns:
Dictionary of most recently applied transform
Raises:
- RuntimeError: data is neither `MetaTensor` nor dictionary
"""
if not self.tracing:
raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.")
if isinstance(data, MetaTensor):
all_transforms = data.applied_operations
elif isinstance(data, Mapping):
if key in data and isinstance(data[key], MetaTensor):
all_transforms = data[key].applied_operations
else:
all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations())
else:
raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.")
if not all_transforms:
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
if check:
self.check_transforms_match(all_transforms[-1])
return all_transforms.pop(-1) if pop else all_transforms[-1]
def pop_transform(self, data, key: Hashable = None, check: bool = True):
"""
Return and pop the most recent transform.
Args:
data: dictionary of data or `MetaTensor`
key: if data is a dictionary, data[key] will be modified
check: if true, check that `self` is the same type as the most recently-applied transform.
Returns:
Dictionary of most recently applied transform
Raises:
- RuntimeError: data is neither `MetaTensor` nor dictionary
"""
return self.get_most_recent_transform(data, key, check, pop=True)
@contextmanager
def trace_transform(self, to_trace: bool):
"""Temporarily set the tracing status of a transform with a context manager."""
prev = self.tracing
self.tracing = to_trace
yield
self.tracing = prev
class InvertibleTransform(TraceableTransform, InvertibleTrait):
"""Classes for invertible transforms.
This class exists so that an ``invert`` method can be implemented. This allows, for
example, images to be cropped, rotated, padded, etc., during training and inference,
and after be returned to their original size before saving to file for comparison in
an external viewer.
When the ``inverse`` method is called:
- the inverse is called on each key individually, which allows for
different parameters being passed to each label (e.g., different
interpolation for image and label).
- the inverse transforms are applied in a last-in-first-out order. As
the inverse is applied, its entry is removed from the list detailing
the applied transformations. That is to say that during the forward
pass, the list of applied transforms grows, and then during the
inverse it shrinks back down to an empty list.
We currently check that the ``id()`` of the transform is the same in the forward and
inverse directions. This is a useful check to ensure that the inverses are being
processed in the correct order.
Note to developers: When converting a transform to an invertible transform, you need to:
#. Inherit from this class.
#. In ``__call__``, add a call to ``push_transform``.
#. Any extra information that might be needed for the inverse can be included with the
dictionary ``extra_info``. This dictionary should have the same keys regardless of
whether ``do_transform`` was `True` or `False` and can only contain objects that are
accepted in pytorch data loader's collate function (e.g., `None` is not allowed).
#. Implement an ``inverse`` method. Make sure that after performing the inverse,
``pop_transform`` is called.
"""
def inverse_update(self, data):
"""
This function is to be called before every `self.inverse(data)`,
update each MetaTensor `data[key]` using `data[key_transforms]` and `data[key_meta_dict]`,
for MetaTensor backward compatibility 0.9.0.
"""
if not isinstance(data, dict) or not isinstance(self, transforms.MapTransform):
return data
d = dict(data)
for k in self.key_iterator(data):
transform_key = transforms.TraceableTransform.trace_key(k)
if transform_key not in data or not data[transform_key]:
continue
d = transforms.sync_meta_info(k, data, t=False)
return d
def inverse(self, data: Any) -> Any:
"""
Inverse of ``__call__``.
Raises:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")