-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add parameter to DiceMetric and DiceHelper classes #8774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 4 commits
b5107da
ccca77a
c110e2a
41e52c1
34a6817
8d412a1
cb433a8
ba2e0b3
d9bfb5d
6f2155c
ba05438
4e6def7
925e431
28a2944
24a17e9
a74f2cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,10 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from scipy.ndimage import distance_transform_edt, generate_binary_structure | ||
| from scipy.ndimage import label as sn_label | ||
|
|
||
| from monai.metrics.utils import do_metric_reduction | ||
| from monai.utils import MetricReduction, deprecated_arg | ||
|
|
@@ -95,6 +98,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, | ||
| the index begins at "0", otherwise at "1". It can also take a list of label names. | ||
| The outcome will then be returned as a dictionary. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -106,6 +112,7 @@ def __init__( | |
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| return_with_label: bool | list[str] = False, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.include_background = include_background | ||
|
|
@@ -114,13 +121,15 @@ def __init__( | |
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.return_with_label = return_with_label | ||
| self.per_component = per_component | ||
| self.dice_helper = DiceHelper( | ||
| include_background=self.include_background, | ||
| reduction=MetricReduction.NONE, | ||
| get_not_nans=False, | ||
| apply_argmax=False, | ||
| ignore_empty=self.ignore_empty, | ||
| num_classes=self.num_classes, | ||
| per_component=self.per_component, | ||
| ) | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
|
|
@@ -175,6 +184,7 @@ def compute_dice( | |
| include_background: bool = True, | ||
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| per_component: bool = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Dice score metric for a batch of predictions. This performs the same computation as | ||
|
|
@@ -192,6 +202,9 @@ def compute_dice( | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| Returns: | ||
| Dice scores per batch and per class, (shape: [batch_size, num_classes]). | ||
|
|
@@ -204,6 +217,7 @@ def compute_dice( | |
| apply_argmax=False, | ||
| ignore_empty=ignore_empty, | ||
| num_classes=num_classes, | ||
| per_component=per_component, | ||
| )(y_pred=y_pred, y=y) | ||
|
|
||
|
|
||
|
|
@@ -246,6 +260,9 @@ class DiceHelper: | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
| """ | ||
|
|
||
| @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") | ||
|
|
@@ -262,6 +279,7 @@ def __init__( | |
| num_classes: int | None = None, | ||
| sigmoid: bool | None = None, | ||
| softmax: bool | None = None, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| # handling deprecated arguments | ||
| if sigmoid is not None: | ||
|
|
@@ -277,6 +295,73 @@ def __init__( | |
| self.activate = activate | ||
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.per_component = per_component | ||
|
|
||
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
|
|
||
| Args: | ||
| labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. | ||
| connectivity: 6/18/26 (3D) | ||
| sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) | ||
| """ | ||
|
|
||
| x = np.asarray(labels) | ||
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | ||
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | ||
| cc, num = sn_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = np.ones(cc.shape, dtype=np.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| return torch.from_numpy(voronoi) | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
|
|
||
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the above,
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right? |
||
| """ | ||
| Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately | ||
| for each batch item and for each channel of those items. | ||
|
|
||
| Args: | ||
| y_pred: input predictions with shape HW[D]. | ||
| y: ground truth with shape HW[D]. | ||
| """ | ||
| data = [] | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -322,15 +407,24 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = y_pred > 0.5 | ||
|
|
||
| first_ch = 0 if self.include_background else 1 | ||
| if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2): | ||
| raise ValueError( | ||
| f"per_component requires 5D binary segmentation with 2 channels (background + foreground). " | ||
| f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)." | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 | ||
|
VijayVignesh1 marked this conversation as resolved.
|
||
| data = [] | ||
| for b in range(y_pred.shape[0]): | ||
| c_list = [] | ||
| for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: | ||
| x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() | ||
| x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] | ||
| c_list.append(self.compute_channel(x_pred, x)) | ||
| if self.per_component: | ||
| c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] | ||
| data.append(torch.stack(c_list)) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| data = torch.stack(data, dim=0).contiguous() # type: ignore | ||
|
|
||
| f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.