-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathtest_compute_frd_metric.py
More file actions
49 lines (38 loc) · 1.93 KB
/
test_compute_frd_metric.py
File metadata and controls
49 lines (38 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import numpy as np
import torch
from monai.metrics import FIDMetric, FrechetRadiomicsDistance
from monai.utils import optional_import
_, has_scipy = optional_import("scipy")
@unittest.skipUnless(has_scipy, "Requires scipy")
class TestFrechetRadiomicsDistance(unittest.TestCase):
def test_results(self):
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
results = FrechetRadiomicsDistance()(x, y)
np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4)
def test_frd_matches_fid_for_same_features(self):
"""FRD uses the same Fréchet formula as FID; same inputs give same value."""
y_pred = torch.Tensor([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]])
y = torch.Tensor([[2.0, 2.0], [1.0, 2.0], [1.0, 2.0]])
frd_score = FrechetRadiomicsDistance()(y_pred, y)
fid_score = FIDMetric()(y_pred, y)
np.testing.assert_allclose(frd_score.cpu().numpy(), fid_score.cpu().numpy(), atol=1e-6)
def test_rejects_high_dimensional_input(self):
"""FrechetRadiomicsDistance raises ValueError when inputs have ndimension() > 2."""
high_dim = torch.ones([3, 3, 144, 144])
with self.assertRaises(ValueError):
FrechetRadiomicsDistance()(high_dim, high_dim)
if __name__ == "__main__":
unittest.main()