|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +from collections.abc import Callable |
| 15 | + |
| 16 | +from monai.handlers.ignite_metric import IgniteMetricHandler |
| 17 | +from monai.metrics import CalibrationErrorMetric, CalibrationReduction |
| 18 | +from monai.utils import MetricReduction |
| 19 | + |
| 20 | +__all__ = ["CalibrationError"] |
| 21 | + |
| 22 | + |
| 23 | +class CalibrationError(IgniteMetricHandler): |
| 24 | + """ |
| 25 | + Ignite handler to compute Calibration Error during training or evaluation. |
| 26 | +
|
| 27 | + **Why Calibration Matters:** |
| 28 | +
|
| 29 | + A well-calibrated model produces probability estimates that match the true likelihood of correctness. |
| 30 | + For example, predictions with 80% confidence should be correct approximately 80% of the time. |
| 31 | + Modern neural networks often exhibit poor calibration (typically overconfident), which can be |
| 32 | + problematic in medical imaging where probability estimates may inform clinical decisions. |
| 33 | +
|
| 34 | + This handler wraps :py:class:`~monai.metrics.CalibrationErrorMetric` for use with PyTorch Ignite |
| 35 | + engines, automatically computing and aggregating calibration errors across iterations. |
| 36 | +
|
| 37 | + **Supported Calibration Metrics:** |
| 38 | +
|
| 39 | + - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors (most common). |
| 40 | + - **Average Calibration Error (ACE)**: Unweighted average across bins. |
| 41 | + - **Maximum Calibration Error (MCE)**: Worst-case calibration error. |
| 42 | +
|
| 43 | + Args: |
| 44 | + num_bins: Number of equally-spaced bins for calibration computation. Defaults to 20. |
| 45 | + include_background: Whether to include the first channel (index 0) in computation. |
| 46 | + Set to ``False`` to exclude background in segmentation tasks. Defaults to ``True``. |
| 47 | + calibration_reduction: Calibration error reduction mode. Options: ``"expected"`` (ECE), |
| 48 | + ``"average"`` (ACE), ``"maximum"`` (MCE). Defaults to ``"expected"``. |
| 49 | + metric_reduction: Reduction across batch/channel after computing per-sample errors. |
| 50 | + Options: ``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, |
| 51 | + ``"mean_channel"``, ``"sum_channel"``. Defaults to ``"mean"``. |
| 52 | + output_transform: Callable to extract ``(y_pred, y)`` from ``engine.state.output``. |
| 53 | + See `Ignite concepts <https://pytorch.org/ignite/concepts.html#state>`_ and |
| 54 | + the batch output transform tutorial in the MONAI tutorials repository. |
| 55 | + save_details: If ``True``, saves per-sample/per-channel metric values to |
| 56 | + ``engine.state.metric_details[name]``. Defaults to ``True``. |
| 57 | +
|
| 58 | + References: |
| 59 | + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. |
| 60 | + https://proceedings.mlr.press/v70/guo17a.html |
| 61 | + - Barfoot, T., et al. "Average Calibration Losses for Reliable Uncertainty in |
| 62 | + Medical Image Segmentation." arXiv:2506.03942v3, 2025. |
| 63 | + https://arxiv.org/abs/2506.03942v3 |
| 64 | +
|
| 65 | + See Also: |
| 66 | + - :py:class:`~monai.metrics.CalibrationErrorMetric`: The underlying metric class. |
| 67 | + - :py:func:`~monai.metrics.calibration_binning`: Low-level binning for reliability diagrams. |
| 68 | +
|
| 69 | + Example: |
| 70 | + >>> from monai.handlers import CalibrationError, from_engine |
| 71 | + >>> from ignite.engine import Engine |
| 72 | + >>> |
| 73 | + >>> def evaluation_step(engine, batch): |
| 74 | + ... # Returns dict with "pred" (probabilities) and "label" (one-hot) |
| 75 | + ... return {"pred": model(batch["image"]), "label": batch["label"]} |
| 76 | + >>> |
| 77 | + >>> evaluator = Engine(evaluation_step) |
| 78 | + >>> |
| 79 | + >>> # Attach calibration error handler |
| 80 | + >>> CalibrationError( |
| 81 | + ... num_bins=15, |
| 82 | + ... include_background=False, |
| 83 | + ... calibration_reduction="expected", |
| 84 | + ... output_transform=from_engine(["pred", "label"]), |
| 85 | + ... ).attach(evaluator, name="ECE") |
| 86 | + >>> |
| 87 | + >>> # After evaluation, access results |
| 88 | + >>> evaluator.run(val_loader) |
| 89 | + >>> ece = evaluator.state.metrics["ECE"] |
| 90 | + >>> print(f"Expected Calibration Error: {ece:.4f}") |
| 91 | + """ |
| 92 | + |
| 93 | + def __init__( |
| 94 | + self, |
| 95 | + num_bins: int = 20, |
| 96 | + include_background: bool = True, |
| 97 | + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, |
| 98 | + metric_reduction: MetricReduction | str = MetricReduction.MEAN, |
| 99 | + output_transform: Callable = lambda x: x, |
| 100 | + save_details: bool = True, |
| 101 | + ) -> None: |
| 102 | + metric_fn = CalibrationErrorMetric( |
| 103 | + num_bins=num_bins, |
| 104 | + include_background=include_background, |
| 105 | + calibration_reduction=calibration_reduction, |
| 106 | + metric_reduction=metric_reduction, |
| 107 | + ) |
| 108 | + |
| 109 | + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) |
0 commit comments