Skip to content

Commit 9c2f285

Browse files
committed
add mvnormal
1 parent 2140e14 commit 9c2f285

3 files changed

Lines changed: 219 additions & 12 deletions

File tree

preliz/distributions/continuous_multivariate.py

Lines changed: 198 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,17 @@ def plot_pdf(
144144
ax : matplotlib axis
145145
"""
146146
return plot_dirichlet(
147-
self, "pdf", marginals, pointinterval, interval, levels, support, baseline, legend, figsize, ax
147+
self,
148+
"pdf",
149+
marginals,
150+
pointinterval,
151+
interval,
152+
levels,
153+
support,
154+
baseline,
155+
legend,
156+
figsize,
157+
ax,
148158
)
149159

150160
def plot_cdf(
@@ -189,7 +199,17 @@ def plot_cdf(
189199
ax : matplotlib axis
190200
"""
191201
return plot_dirichlet(
192-
self, "cdf", "marginals", pointinterval, interval, levels, support, None, legend, figsize, ax
202+
self,
203+
"cdf",
204+
"marginals",
205+
pointinterval,
206+
interval,
207+
levels,
208+
support,
209+
None,
210+
legend,
211+
figsize,
212+
ax,
193213
)
194214

195215
def plot_ppf(
@@ -230,11 +250,19 @@ def plot_ppf(
230250
ax : matplotlib axis
231251
"""
232252
return plot_dirichlet(
233-
self, "ppf", "marginals", pointinterval, interval, levels, None, None, legend, figsize, ax
253+
self,
254+
"ppf",
255+
"marginals",
256+
pointinterval,
257+
interval,
258+
levels,
259+
None,
260+
None,
261+
legend,
262+
figsize,
263+
ax,
234264
)
235265

236-
237-
238266
def plot_sf(
239267
self,
240268
pointinterval=False,
@@ -277,7 +305,17 @@ def plot_sf(
277305
ax : matplotlib axis
278306
"""
279307
return plot_dirichlet(
280-
self, "sf", "marginals", pointinterval, interval, levels, support, None, legend, figsize, ax
308+
self,
309+
"sf",
310+
"marginals",
311+
pointinterval,
312+
interval,
313+
levels,
314+
support,
315+
None,
316+
legend,
317+
figsize,
318+
ax,
281319
)
282320

283321
def plot_isf(
@@ -318,11 +356,19 @@ def plot_isf(
318356
ax : matplotlib axis
319357
"""
320358
return plot_dirichlet(
321-
self, "isf", "marginals", pointinterval, interval, levels, None, None, legend, figsize, ax
359+
self,
360+
"isf",
361+
"marginals",
362+
pointinterval,
363+
interval,
364+
levels,
365+
None,
366+
None,
367+
legend,
368+
figsize,
369+
ax,
322370
)
323371

324-
325-
326372
def plot_interactive(
327373
self,
328374
kind="pdf",
@@ -534,6 +580,7 @@ def plot_pdf(
534580
interval=None,
535581
levels=None,
536582
support="full",
583+
baseline=True,
537584
legend="title",
538585
figsize=None,
539586
ax=None,
@@ -562,6 +609,9 @@ def plot_pdf(
562609
support : str:
563610
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
564611
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
612+
baseline : bool
613+
Whether to include a baseline in the plot. Defaults to True. Only used when
614+
``marginals=True``.
565615
legend : str
566616
Whether to include a string with the distribution and its parameter as a ``"title"``
567617
or not include them ``None``.
@@ -575,7 +625,17 @@ def plot_pdf(
575625
ax : matplotlib axis
576626
"""
577627
return plot_mvnormal(
578-
self, "pdf", marginals, pointinterval, interval, levels, support, legend, figsize, ax
628+
self,
629+
"pdf",
630+
marginals,
631+
pointinterval,
632+
interval,
633+
levels,
634+
support,
635+
baseline,
636+
legend,
637+
figsize,
638+
ax,
579639
)
580640

581641
def plot_cdf(
@@ -620,7 +680,17 @@ def plot_cdf(
620680
ax : matplotlib axis
621681
"""
622682
return plot_mvnormal(
623-
self, "cdf", "marginals", pointinterval, interval, levels, support, legend, figsize, ax
683+
self,
684+
"cdf",
685+
"marginals",
686+
pointinterval,
687+
interval,
688+
levels,
689+
support,
690+
None,
691+
legend,
692+
figsize,
693+
ax,
624694
)
625695

626696
def plot_ppf(
@@ -661,7 +731,123 @@ def plot_ppf(
661731
ax : matplotlib axis
662732
"""
663733
return plot_mvnormal(
664-
self, "ppf", "marginals", pointinterval, interval, levels, None, legend, figsize, ax
734+
self,
735+
"ppf",
736+
"marginals",
737+
pointinterval,
738+
interval,
739+
levels,
740+
None,
741+
None,
742+
legend,
743+
figsize,
744+
ax,
745+
)
746+
747+
def plot_sf(
748+
self,
749+
pointinterval=False,
750+
interval=None,
751+
levels=None,
752+
support="full",
753+
legend="title",
754+
figsize=None,
755+
ax=None,
756+
):
757+
"""
758+
Plot the survival function (1 - CDF).
759+
760+
Parameters
761+
----------
762+
pointinterval : bool
763+
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
764+
plot the median and two interquantiles ranges.
765+
interval : str
766+
Type of interval. Available options are highest density interval `"hdi"`,
767+
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
768+
Defaults to the value in rcParams["stats.ci_kind"].
769+
levels : list
770+
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
771+
For quantiles the number of elements should be 5, 3, 1 or 0
772+
(in this last case nothing will be plotted).
773+
support : str:
774+
If ``full`` use the finite end-points to set the limits of the plot. For unbounded
775+
end-points or if ``restricted`` use the 0.001 and 0.999 quantiles to set the limits.
776+
legend : str
777+
Whether to include a string with the distribution and its parameter as a ``"title"``
778+
or not include them ``None``.
779+
figsize : tuple
780+
Size of the figure
781+
ax : matplotlib axis
782+
Axis to plot on
783+
784+
Returns
785+
-------
786+
ax : matplotlib axis
787+
"""
788+
return plot_mvnormal(
789+
self,
790+
"sf",
791+
"marginals",
792+
pointinterval,
793+
interval,
794+
levels,
795+
support,
796+
None,
797+
legend,
798+
figsize,
799+
ax,
800+
)
801+
802+
def plot_isf(
803+
self,
804+
pointinterval=False,
805+
interval=None,
806+
levels=None,
807+
legend="title",
808+
figsize=None,
809+
ax=None,
810+
):
811+
"""
812+
Plot the inverse survival function.
813+
814+
Parameters
815+
----------
816+
pointinterval : bool
817+
Whether to include a plot of the quantiles. Defaults to False. If True the default is to
818+
plot the median and two interquantiles ranges.
819+
interval : str
820+
Type of interval. Available options are highest density interval `"hdi"`,
821+
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
822+
Defaults to the value in rcParams["stats.ci_kind"].
823+
levels : list
824+
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
825+
For quantiles the number of elements should be 5, 3, 1 or 0
826+
(in this last case nothing will be plotted).
827+
legend : str
828+
Whether to include a string with the distribution and its parameter as a ``"title"``
829+
or not include them ``None``.
830+
figsize : tuple
831+
Size of the figure
832+
ax : matplotlib axis
833+
Axis to plot on
834+
835+
Returns
836+
-------
837+
ax : matplotlib axis
838+
"""
839+
return plot_mvnormal(
840+
self,
841+
"isf",
842+
"marginals",
843+
pointinterval,
844+
interval,
845+
levels,
846+
None,
847+
None,
848+
legend,
849+
figsize,
850+
ax,
665851
)
666852

667853
def plot_interactive(

preliz/internal/plot_helper_multivariate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def plot_mvnormal(
266266
interval,
267267
levels,
268268
support,
269+
baseline,
269270
legend,
270271
figsize,
271272
axes,
@@ -310,6 +311,7 @@ def plot_mvnormal(
310311
interval=interval,
311312
levels=levels,
312313
support=support,
314+
baseline=baseline,
313315
legend=False,
314316
ax=ax,
315317
)
@@ -330,6 +332,23 @@ def plot_mvnormal(
330332
legend=False,
331333
ax=ax,
332334
)
335+
elif representation == "sf":
336+
marginal_dist.plot_sf(
337+
pointinterval=pointinterval,
338+
interval=interval,
339+
levels=levels,
340+
support=support,
341+
legend=False,
342+
ax=ax,
343+
)
344+
elif representation == "isf":
345+
marginal_dist.plot_isf(
346+
pointinterval=pointinterval,
347+
interval=interval,
348+
levels=levels,
349+
legend=False,
350+
ax=ax,
351+
)
333352
if xy_lim != "auto" and representation != "ppf":
334353
ax.set_xlim(*xlim)
335354
if xy_lim != "auto" and representation != "cdf":

preliz/ppls/pymc_io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def from_pymc(dist):
242242
PreliZ distribution
243243
"""
244244
name = dist.owner.op._print_name[0]
245+
if name == "MultivariateNormal":
246+
name = "MvNormal"
245247

246248
if name == "Censored":
247249
base_dist = dist.owner.inputs[0]

0 commit comments

Comments
 (0)