Skip to content

Commit 62f54fb

Browse files
committed
enable gradient
1 parent d069e62 commit 62f54fb

4 files changed

Lines changed: 15 additions & 7 deletions

File tree

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.5
1+
1.0.6

gs_divergence/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from gs_divergence.geodesical_skew_divergence import gs_div
22
from gs_divergence.symmetrized_geodesical_skew_divergence import symmetrized_gs_div
33

4-
__version__ = '1.0.5'
4+
__version__ = '1.0.6'

gs_divergence/alpha_geodesic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ def alpha_geodesic(
1111
$\alpha$-geodesic between two probability distributions
1212
"""
1313

14-
a += 1e-12
15-
b += 1e-12
14+
a_ = a + 1e-12
15+
b_ = b + 1e-12
1616
if alpha == 1:
17-
return torch.exp((1 - lmd) * torch.log(a) + lmd * torch.log(b))
17+
return torch.exp((1 - lmd) * torch.log(a_) + lmd * torch.log(b_))
1818
elif alpha >= 1e+9:
1919
return torch.min(a, b)
2020
elif alpha <= -1e+9:
2121
return torch.max(a, b)
2222
else:
2323
p = (1 - alpha) / 2
24-
lhs = a ** p
25-
rhs = b ** p
24+
lhs = a_ ** p
25+
rhs = b_ ** p
2626
g = ((1 - lmd) * lhs + lmd * rhs) ** (1/p)
2727

2828
if alpha > 0 and (g == 0).sum() > 0:

tests/test_alpha_geodesic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,11 @@ def test_value_inf(self):
7575
res = torch.min(a, b)
7676

7777
self.assertTrue(torch.equal(g, res))
78+
79+
def test_grad(self):
80+
a = torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]], requires_grad=True)
81+
b = torch.tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])
82+
83+
g = alpha_geodesic(a, b, alpha=1, lmd=0.5)
84+
85+
self.assertIsNotNone(g.grad_fn)

0 commit comments

Comments
 (0)