Skip to content

Commit 26defc1

Browse files
authored
Feature/seg eigen cam (#580)
* new branch * Add SegEigenCAM implementation and update README for new method * Add SegEigenCAM to test suite for all CAM models * Trigger CI
1 parent 781dbc0 commit 26defc1

File tree

6 files changed

+124
-9
lines changed

6 files changed

+124
-9
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
4949
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
5050
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
5151
| FinerCAM | Improves fine-grained classification by comparing similar classes, suppressing shared features and highlighting discriminative details. |
52+
| SegEigenCAM | Like EigenCAM but with gradient weighting (absolute gradients ⊙ activations) before SVD and sign correction to fix SVD sign ambiguity; designed for semantic segmentation |
5253
## Visual Examples
5354

5455
| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
@@ -291,7 +292,7 @@ To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:
291292

292293
You can choose between:
293294

294-
`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, and `FinerCAM`.
295+
`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, `FinerCAM` and `SegEigenCAM`.
295296

296297
Some methods like ScoreCAM and AblationCAM require a large number of forward passes,
297298
and have a batched implementation.
@@ -374,3 +375,8 @@ Huaiguang Cai`
374375
https://arxiv.org/pdf/2501.11309 <br>
375376
`Finer-CAM : Spotting the Difference Reveals Finer Details for Visual Explanation`
376377
`Ziheng Zhang*, Jianyang Gu*, Arpita Chowdhury, Zheda Mai, David Carlyn,Tanya Berger-Wolf, Yu Su, Wei-Lun Chao`
378+
379+
380+
https://doi.org/10.3390/app15137562 <br>
381+
`Seg-Eigen-CAM: Eigen-Value-Based Visual Explanations for Semantic Segmentation Models
382+
Ching-Ting Chung, Josh Jia-Ching Ying`

cam.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
99
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
1010
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM,
11-
FinerCAM
11+
FinerCAM, SegEigenCAM
1212
)
1313
from pytorch_grad_cam import GuidedBackpropReLUModel
1414
from pytorch_grad_cam.utils.image import (
@@ -39,7 +39,7 @@ def get_args():
3939
'scorecam', 'xgradcam', 'ablationcam',
4040
'eigencam', 'eigengradcam', 'layercam',
4141
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam',
42-
'finercam'
42+
'finercam', 'segeigencam'
4343
],
4444
help='CAM method')
4545

@@ -79,7 +79,8 @@ def get_args():
7979
"gradcamelementwise": GradCAMElementWise,
8080
'kpcacam': KPCA_CAM,
8181
'shapleycam': ShapleyCAM,
82-
'finercam': FinerCAM
82+
'finercam': FinerCAM,
83+
'segeigencam': SegEigenCAM,
8384
}
8485

8586
if args.device=='hpu':

pytorch_grad_cam/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_grad_cam.random_cam import RandomCAM
1717
from pytorch_grad_cam.fullgrad_cam import FullGrad
1818
from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
19+
from pytorch_grad_cam.seg_eigen_cam import SegEigenCAM
1920
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
2021
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image
2122
import pytorch_grad_cam.utils.model_targets

pytorch_grad_cam/seg_eigen_cam.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
from pytorch_grad_cam.base_cam import BaseCAM
3+
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection_with_sign_correction
4+
5+
# Based on this paper:
6+
# https://doi.org/10.3390/app15137562
7+
# Chung, C.-T.; Ying, J.J.-C.
8+
# Seg-Eigen-CAM: Eigen-Value-Based Visual Explanations for Semantic Segmentation Models.
9+
# Applied Sciences, 2025, 15(13), 7562.
10+
11+
12+
class SegEigenCAM(BaseCAM):
13+
"""
14+
Seg-Eigen-CAM: Eigen-Value-Based Visual Explanations for Semantic
15+
Segmentation Models.
16+
17+
Extends Eigen-CAM with two contributions tailored for segmentation:
18+
19+
1. **Gradient Weighting** (Section 3.2.1): Element-wise product between
20+
absolute gradients and activations (Eq. 10), providing local pixel-wise
21+
spatial information instead of a global average.
22+
23+
2. **Sign Correction** (Section 3.2.2): Dynamically corrects the sign
24+
ambiguity from SVD by comparing |max| vs |min| of the projection
25+
(Eq. 13), ensuring salient regions are always positive.
26+
27+
Reference:
28+
Chung, C.-T.; Ying, J.J.-C. Seg-Eigen-CAM: Eigen-Value-Based Visual
29+
Explanations for Semantic Segmentation Models. Appl. Sci. 2025,
30+
15(13), 7562. https://doi.org/10.3390/app15137562
31+
32+
Args:
33+
model: The neural network model to explain.
34+
target_layers: List of convolutional layers to extract activations from.
35+
reshape_transform: Optional callable for non-standard activation shapes.
36+
"""
37+
38+
def __init__(self, model, target_layers, reshape_transform=None):
39+
super(SegEigenCAM, self).__init__(
40+
model,
41+
target_layers,
42+
reshape_transform,
43+
uses_gradients=True,
44+
)
45+
46+
def get_cam_image(
47+
self,
48+
input_tensor,
49+
target_layer,
50+
target_category,
51+
activations,
52+
grads,
53+
eigen_smooth,
54+
):
55+
# Step 1 — Gradient Weighting (Eq. 10):
56+
# |grads| ⊙ activations captures pixel-wise spatial importance,
57+
# using absolute values to include both positive and negative gradient
58+
# contributions (unlike Grad-CAM which discards negative gradients).
59+
weighted_activations = np.abs(grads) * activations
60+
61+
# Steps 2 & 3 — SVD + Sign Correction (Eq. 11-13)
62+
return get_2d_projection_with_sign_correction(weighted_activations)

pytorch_grad_cam/utils/svd_on_activations.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,59 @@ def get_2d_projection(activation_batch):
2020
return np.float32(projections)
2121

2222

23-
2423
def get_2d_projection_kernel(activation_batch, kernel='sigmoid', gamma=None):
2524
activation_batch[np.isnan(activation_batch)] = 0
2625
projections = []
2726
for activations in activation_batch:
28-
reshaped_activations = activations.reshape(activations.shape[0], -1).transpose()
29-
reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
27+
reshaped_activations = activations.reshape(
28+
activations.shape[0], -1).transpose()
29+
reshaped_activations = reshaped_activations - \
30+
reshaped_activations.mean(axis=0)
3031
# Apply Kernel PCA
3132
kpca = KernelPCA(n_components=1, kernel=kernel, gamma=gamma)
3233
projection = kpca.fit_transform(reshaped_activations)
3334
projection = projection.reshape(activations.shape[1:])
3435
projections.append(projection)
3536
return np.float32(projections)
37+
38+
39+
def get_2d_projection_with_sign_correction(activation_batch: np.ndarray) -> np.ndarray:
40+
"""
41+
Perform SVD on a batch of activation maps, project onto the first
42+
principal component, and apply sign correction.
43+
44+
Sign correction addresses the inherent sign ambiguity of SVD:
45+
decomposing A = U Σ Vᵀ is equivalent to (-U) Σ (-Vᵀ), so the sign
46+
of the resulting projection is arbitrary. The correction ensures that
47+
class-discriminative information aligns with the positive direction by
48+
flipping the map when |min| > |max| (Eq. 13 in the paper).
49+
50+
Reference:
51+
Chung, C.-T.; Ying, J.J.-C. Seg-Eigen-CAM. Appl. Sci. 2025,
52+
15(13), 7562. https://doi.org/10.3390/app15137562
53+
54+
Args:
55+
activation_batch: Array of shape (B, C, H, W).
56+
57+
Returns:
58+
np.ndarray of shape (B, H, W) with dtype float32.
59+
"""
60+
activation_batch[np.isnan(activation_batch)] = 0
61+
projections = []
62+
63+
for activations in activation_batch:
64+
reshaped = activations.reshape(activations.shape[0], -1).transpose()
65+
reshaped = reshaped - reshaped.mean(axis=0)
66+
67+
_, _, VT = np.linalg.svd(reshaped, full_matrices=True)
68+
69+
projection = reshaped @ VT[0, :]
70+
projection = projection.reshape(activations.shape[1:])
71+
72+
# Sign correction (Eq. 13): ensure salient regions are positive
73+
if abs(projection.min()) > abs(projection.max()):
74+
projection = -projection
75+
76+
projections.append(projection)
77+
78+
return np.float32(projections)

tests/test_run_all_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
EigenGradCAM, \
1212
LayerCAM, \
1313
FullGrad, \
14-
KPCA_CAM
14+
KPCA_CAM, \
15+
SegEigenCAM
1516
from pytorch_grad_cam.utils.image import show_cam_on_image, \
1617
preprocess_image
1718
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
@@ -52,7 +53,8 @@ def numpy_image():
5253
EigenGradCAM,
5354
LayerCAM,
5455
FullGrad,
55-
KPCA_CAM])
56+
KPCA_CAM,
57+
SegEigenCAM])
5658

5759
def test_all_cam_models_can_run(numpy_image, batch_size, width, height,
5860
cnn_model, target_layer_names, cam_method,

0 commit comments

Comments
 (0)