Skip to content

Commit 961080e

Browse files
committed
bugfix
1 parent 62f54fb commit 961080e

4 files changed

Lines changed: 6 additions & 6 deletions

File tree

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.6
1+
1.0.7

gs_divergence/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from gs_divergence.geodesical_skew_divergence import gs_div
22
from gs_divergence.symmetrized_geodesical_skew_divergence import symmetrized_gs_div
33

4-
__version__ = '1.0.6'
4+
__version__ = '1.0.7'

gs_divergence/alpha_geodesic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ def alpha_geodesic(
1616
if alpha == 1:
1717
return torch.exp((1 - lmd) * torch.log(a_) + lmd * torch.log(b_))
1818
elif alpha >= 1e+9:
19-
return torch.min(a, b)
19+
return torch.min(a_, b_)
2020
elif alpha <= -1e+9:
21-
return torch.max(a, b)
21+
return torch.max(a_, b_)
2222
else:
2323
p = (1 - alpha) / 2
2424
lhs = a_ ** p
2525
rhs = b_ ** p
2626
g = ((1 - lmd) * lhs + lmd * rhs) ** (1/p)
2727

2828
if alpha > 0 and (g == 0).sum() > 0:
29-
return torch.min(a, b)
29+
return torch.min(a_, b_)
3030

3131
return g

tests/test_alpha_geodesic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_value_inf(self):
7474
g = alpha_geodesic(a, b, alpha=100, lmd=0.5)
7575
res = torch.min(a, b)
7676

77-
self.assertTrue(torch.equal(g, res))
77+
self.assertTrue(torch.all(torch.isclose(g, res)))
7878

7979
def test_grad(self):
8080
a = torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]], requires_grad=True)

0 commit comments

Comments
 (0)