Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from __future__ import annotations

import numpy as np
import torch
from scipy.ndimage import distance_transform_edt, generate_binary_structure
from scipy.ndimage import label as sn_label
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
Expand Down Expand Up @@ -95,6 +98,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +112,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +121,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +184,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +202,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +217,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +260,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +279,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +295,73 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
connectivity: 6/18/26 (3D)
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
"""

x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated

def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above, self is used to access ignore_empty which could be passed as a argument instead with this method turned into a function external to this class.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the compute_channel function, which is similar to compute_cc_dice, resides in the same class. To maintain consistency and encapsulate related functionality, it would make sense to keep both functions within the same class, right?

"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.

Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
data = []
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -322,15 +407,24 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2):
raise ValueError(
f"per_component requires 5D binary segmentation with 2 channels (background + foreground). "
f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

first_ch = 0 if self.include_background and not self.per_component else 1
Comment thread
VijayVignesh1 marked this conversation as resolved.
data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
if self.per_component:
c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))
Comment thread
coderabbitai[bot] marked this conversation as resolved.

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
37 changes: 37 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,31 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

TEST_CASE_16 = [
{"per_component": True},
{
"y": (
lambda: (
y := torch.zeros((5, 2, 64, 64, 64)),
y.__setitem__((0, 1, slice(20, 25), slice(20, 25), slice(20, 25)), 1),
y.__setitem__((0, 1, slice(40, 45), slice(40, 45), slice(40, 45)), 1),
y.__setitem__((0, 0), 1 - y[0, 1]),
y,
)[-1]
)(),
"y_pred": (
lambda: (
y_hat := torch.zeros((5, 2, 64, 64, 64)),
y_hat.__setitem__((0, 1, slice(21, 26), slice(21, 26), slice(21, 26)), 1),
y_hat.__setitem__((0, 1, slice(41, 46), slice(39, 44), slice(41, 46)), 1),
y_hat.__setitem__((0, 0), 1 - y_hat[0, 1]),
y_hat,
)[-1]
)(),
},
[[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]],
]


class TestComputeMeanDice(unittest.TestCase):

Expand Down Expand Up @@ -301,6 +326,18 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16])
def test_cc_dice_value(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

def test_input_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()
Loading