-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fixes #8697 GPU memory leak by checking both image and label tensors for CUDA device #8708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
aa93e6c
f8cc9ff
c532535
4367d32
546487d
a83b4bb
28539e2
a963ac8
2820618
9ba3044
e93a911
1a2598b
d60ffcd
2978972
105c1ca
f48d128
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None: | |
| raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2") | ||
| root: str | ||
| child_key: str | ||
| (root, _, child_key) = keys | ||
| root, _, child_key = keys | ||
| if root not in self.ops: | ||
| self.ops[root] = [{}] | ||
| self.ops[root][0].update({child_key: None}) | ||
|
|
@@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe | |
| """ | ||
| d: dict[Hashable, MetaTensor] = dict(data) | ||
| start = time.time() | ||
| if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda": | ||
| using_cuda = True | ||
| else: | ||
| using_cuda = False | ||
| image_tensor = d[self.image_key] | ||
| label_tensor = d[self.label_key] | ||
| using_cuda = any( | ||
| isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) | ||
| ) | ||
| restore_grad_state = torch.is_grad_enabled() | ||
| torch.set_grad_enabled(False) | ||
|
|
||
| ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore | ||
| ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D) | ||
| if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( | ||
| label_tensor, (MetaTensor, torch.Tensor) | ||
| ): | ||
| if label_tensor.device != image_tensor.device: | ||
| label_tensor = label_tensor.to(image_tensor.device) | ||
|
benediktjohannes marked this conversation as resolved.
Outdated
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just re-reviewed the merged commits and I think we might have a potential issue here. We set the label_tensor device to the device of the image_tensor here, but we already determined whether cuda is set or not (using_cuda) above by checking whether any of both tensors uses cuda. So in case we got label_tensor to use cuda and image_tensor not, then we would first set using_cuda to True, but then we check whether if label_tensor.device is != image_tensor.device which is True and therefore we set the device of label_tensor to non-cuda (e.g. cpu) which means that now both tensors use cpu while using_cuda is still wrongly set to True.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I see the issue here, the expectation was that if label_tensor.device != image_tensor.device:
# using_cuda would have to be True here, unless non-CUDA device types like mps have been mixed in
device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
label_tensor = label_tensor.to(device) # type: ignore
image_tensor= image_tensor.to(device) # type: ignoreThis would select the device for whichever is a CUDA tensor. CC @garciadias
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @benediktjohannes, I am not sure if I understand. Have you looked into the proposed tests? https://github.com/benediktjohannes/MONAI/blob/1a2598b5ab0624fce4e9ab74cd4eed9f922fb801/tests/apps/test_auto3dseg.py#L81 This is my understanding of how this should behave. If you think any of the test cases are wrong, then we will need to correct them. |
||
|
|
||
| ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore | ||
| ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| if ndas_label.shape != ndas[0].shape: | ||
| raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") | ||
|
|
||
| nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] | ||
| nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds] | ||
| nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] | ||
|
|
||
| unique_label = unique(ndas_label) | ||
| if isinstance(ndas_label, (MetaTensor, torch.Tensor)): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.