Skip to content

Commit 57f72b0

Browse files
committed
bugfix
1 parent 961080e commit 57f72b0

5 files changed

Lines changed: 25 additions & 3 deletions

File tree

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.7
1+
1.0.8

gs_divergence/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from gs_divergence.geodesical_skew_divergence import gs_div
22
from gs_divergence.symmetrized_geodesical_skew_divergence import symmetrized_gs_div
33

4-
__version__ = '1.0.7'
4+
__version__ = '1.0.8'

gs_divergence/geodesical_skew_divergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def gs_div(
3030
assert lmd >= 0 and lmd <= 1
3131

3232
skew_target = alpha_geodesic(input, target, alpha=alpha, lmd=lmd)
33-
div = input * torch.log(input / skew_target)
33+
div = input * torch.log(input / skew_target + 1e-12)
3434
if reduction == 'batchmean':
3535
div = div.sum() / input.size()[0]
3636
elif reduction == 'sum':

tests/test_alpha_geodesic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_value_inf(self):
7272
b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])
7373

7474
g = alpha_geodesic(a, b, alpha=100, lmd=0.5)
75+
print(g)
7576
res = torch.min(a, b)
7677

7778
self.assertTrue(torch.all(torch.isclose(g, res)))

tests/test_gs_div.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
import torch
3+
4+
from gs_divergence import gs_div
5+
6+
7+
class TestGSDiv(unittest.TestCase):
8+
def test_alpha_minus_1(self):
9+
a = torch.Tensor([1, 2, 3])
10+
b = torch.Tensor([4, 5, 6])
11+
g = gs_div(a, b, alpha=-1, lmd=0.5)
12+
13+
self.assertIsNotNone(g)
14+
15+
def test_value_0_2d(self):
16+
a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]])
17+
b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])
18+
19+
g = gs_div(a, b, alpha=1, lmd=0.5)
20+
21+
self.assertTrue(torch.isinf(g).sum() == 0)

0 commit comments

Comments
 (0)