Skip to content

Progress bar issue with sampling from MCMCPosterior #1838

@MolinAlexei

Description

@MolinAlexei

🐛 Bug Description

I get an error when trying to sample from nuts pymc with a trained NLE estimator.

potential_fn, parameter_transform = likelihood_estimator_based_potential(density_estimator, prior, x_obs)
posterior = MCMCPosterior(potential_fn, 
    proposal=prior, 
    theta_transform=parameter_transform, 
    warmup_steps=1000,
    thin = 1, 
    method = 'nuts_pymc', 
    num_chains = 20, 
    num_workers = 32,)

samples = posterior.sample((1000,20), x=x_obs)

gives the following error

Exception in thread Thread-7:
Traceback (most recent call last):
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/live.py", line 38, in run
    self.live.refresh()
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/live.py", line 247, in refresh
    self._live_render.set_renderable(self.renderable)
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/live.py", line 225, in renderable
    renderable = Group(*[live.get_renderable() for live in live_stack])
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/live.py", line 225, in <listcomp>
    renderable = Group(*[live.get_renderable() for live in live_stack])
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/live.py", line 105, in get_renderable
    self._get_renderable()
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/progress.py", line 1554, in 
get_renderable
    renderable = Group(*self.get_renderables())
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/progress.py", line 1559, in 
get_renderables
    table = self.make_tasks_table(self.tasks)
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/pymc/progress_bar.py", line 157, in 
make_tasks_table
    table.add_row(
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/pymc/progress_bar.py", line 162, in <genexpr>
    else call_column(column, task)
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/pymc/progress_bar.py", line 133, in 
call_column
    return column(task)
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/progress.py", line 540, in __call__
    renderable = self.render(task)
  File "/home/mola/miniconda3/envs/myenv/lib/python3.10/site-packages/rich/progress.py", line 636, in render
    _text = self.text_format.format(task=task)
TypeError: unsupported format string passed to numpy.ndarray.__format__

...

I have

  • sbi version 0.25.0
  • pymc version 5.25.1
  • numpy version 1.26.4
  • rich version 14.3.3

Do I just have bad dependencies ?
This error does not come up if I turn off the progress bar when sampling.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions