Skip to content

Commit 232ebba

Browse files
committed
use repr_html
1 parent 2fbd0c5 commit 232ebba

5 files changed

Lines changed: 37 additions & 43 deletions

File tree

preliz/distributions/distributions.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,37 +39,42 @@ def __init__(self):
3939
self.is_frozen = False
4040
self.opt = None
4141

42-
def __repr__(self):
42+
def _get_name(self):
43+
"""Return the display name for this distribution."""
4344
name = self.__class__.__name__
4445
if name in ["Truncated", "Censored", "Hurdle"]:
4546
name += self.dist.__class__.__name__
4647
if name == "Mixture":
47-
name = (
48-
"Mixture"
49-
+ "".join(dict.fromkeys(dist.__class__.__name__ for dist in self.dist))
50-
+ "\n"
48+
name = "Mixture" + "".join(dict.fromkeys(dist.__class__.__name__ for dist in self.dist))
49+
return name
50+
51+
def _get_description(self):
52+
"""Return a string of parameters, or empty string if not frozen."""
53+
if not self.is_frozen:
54+
return ""
55+
return "".join(
56+
(
57+
f"{n}={v:.3g}, "
58+
if np.isscalar(v) or np.ndim(v) == 0
59+
else f"{n}=["
60+
+ "".join(f"{vi:.3g}, " for vi in np.atleast_1d(v)).strip(", ")
61+
+ "], "
5162
)
63+
for n, v in zip(self.param_names, self.params)
64+
).strip(", ")
5265

66+
def __repr__(self):
67+
name = self._get_name()
5368
if self.is_frozen:
54-
if "Mixture" in name:
55-
bolded_name = "\033[1m" + name.strip() + "\033[0m" + "\n"
56-
else:
57-
bolded_name = "\033[1m" + name + "\033[0m"
58-
59-
description = "".join(
60-
(
61-
f"{n}={v:.3g}, "
62-
if np.isscalar(v) or np.ndim(v) == 0
63-
else f"{n}=["
64-
+ "".join(f"{vi:.3g}, " for vi in np.atleast_1d(v)).strip(", ")
65-
+ "], "
66-
)
67-
for n, v in zip(self.param_names, self.params)
68-
).strip(", ")
69+
return f"{name}({self._get_description()})"
70+
return name
6971

70-
return f"{bolded_name}({description})"
71-
else:
72-
return name
72+
def _repr_html_(self):
73+
name = self._get_name()
74+
if self.is_frozen:
75+
desc = self._get_description()
76+
return f"<span style='font-weight:bold'>{name}</span><span'>({desc})</span>"
77+
return f"<span style='font-weight:bold'>{name}</span>"
7378

7479
@property
7580
def params_dict(self):

preliz/internal/plot_helper.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,6 @@ def side_legend(legend, ax):
311311
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
312312

313313

314-
def repr_to_matplotlib(distribution):
315-
string = repr(distribution)
316-
string = string.replace("\x1b[1m", r"$\bf{")
317-
string = string.replace("\x1b[0m", "}$")
318-
return string
319-
320-
321314
def get_moments(dist, moments):
322315
names = {
323316
"m": "μ",
@@ -727,7 +720,7 @@ def set_label(dist, legend, moments, ax):
727720
if isinstance(legend, str) and legend not in ["title", "legend"]:
728721
label = legend
729722
else:
730-
label = repr_to_matplotlib(dist)
723+
label = str(dist)
731724

732725
if moments is not None:
733726
label += get_moments(dist, moments)

preliz/internal/plot_helper_multivariate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
from matplotlib import tri
88

9-
from preliz.internal.plot_helper import repr_to_matplotlib
109
from preliz.internal.special import gammaln
1110

1211

@@ -194,7 +193,7 @@ def plot_dirichlet(
194193
ax.set_ylim(*ylim)
195194

196195
if legend == "title":
197-
fig.text(0.5, 1, repr_to_matplotlib(dist), ha="center", va="center")
196+
fig.text(0.5, 1, dist, ha="center", va="center")
198197

199198
elif dim == 3:
200199
dirichlet_ = DirichletOnSimplex(alpha)
@@ -203,7 +202,7 @@ def plot_dirichlet(
203202
_, axes = plt.subplots(1, 1)
204203
dirichlet_.plot(ax=axes)
205204
if legend == "title":
206-
axes.set_title(repr_to_matplotlib(dist))
205+
axes.set_title(dist)
207206
else:
208207
raise ValueError("joint only works for Dirichlet of dim=3")
209208

@@ -354,14 +353,14 @@ def plot_mvnormal(
354353
if xy_lim != "auto" and representation != "cdf":
355354
ax.set_ylim(*ylim)
356355
if legend == "title":
357-
fig.text(0.5, 1, repr_to_matplotlib(dist), ha="center", va="center")
356+
fig.text(0.5, 1, dist, ha="center", va="center")
358357

359358
elif dim == 2:
360359
if axes is None:
361360
_, axes = plt.subplots(1, 1)
362361
joint_normal(dist, axes)
363362
if legend == "title":
364-
axes.set_title(repr_to_matplotlib(dist))
363+
axes.set_title(dist)
365364
else:
366365
raise ValueError("joint only works for Multivariate Normal of dim=2")
367366

preliz/internal/predictive_helper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from preliz.internal.distribution_helper import get_distributions
4-
from preliz.internal.plot_helper import repr_to_matplotlib
54
from preliz.unidimensional import mle
65

76

@@ -11,7 +10,7 @@ def back_fitting_ppa(model, subset, new_families=True):
1110

1211
for name, dist in model.items():
1312
dist._fit_mle(subset[name])
14-
string += f"{name} = {repr_to_matplotlib(dist)}\n"
13+
string += f"{name} = {dist}\n"
1514

1615
if new_families:
1716
string += "\nYour selection is consistent with the priors (new families):\n"
@@ -26,7 +25,7 @@ def back_fitting_ppa(model, subset, new_families=True):
2625
elif dist.kind == "discrete":
2726
distributions = get_distributions(set([dist.__class__.__name__] + common_disc))
2827
idx, _ = mle(distributions, subset[name], plot=False)
29-
string += f"{name} = {repr_to_matplotlib(distributions[idx[0]])}\n"
28+
string += f"{name} = {distributions[idx[0]]}\n"
3029

3130
return string, np.concatenate([dist.params for dist in model.values()])
3231

preliz/tests/test_posterior_to_prior.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919

2020
def test_p2p_pymc():
2121
posterior_to_prior(model, idata)
22-
assert 'Gamma\x1b[0m("b", alpha=' in posterior_to_prior(model, idata, new_families="auto")
22+
assert 'Gamma("b", alpha=' in posterior_to_prior(model, idata, new_families="auto")
2323
posterior_to_prior(model, idata, new_families=[LogNormal()])
24-
assert 'Gamma\x1b[0m("b", mu=' in posterior_to_prior(
25-
model, idata, new_families={"b": [Gamma(mu=0)]}
26-
)
24+
assert 'Gamma("b", mu=' in posterior_to_prior(model, idata, new_families={"b": [Gamma(mu=0)]})
2725

2826

2927
# Temporarily disabled bambi test

0 commit comments

Comments
 (0)