Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
392 changes: 253 additions & 139 deletions MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py

Large diffs are not rendered by default.

200 changes: 104 additions & 96 deletions MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@
from itertools import islice
from typing import TYPE_CHECKING, Any

from matplotlib import rcParams
from more_itertools import flatten, ilen

from MDANSE.MLogging import LOG
from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter

if TYPE_CHECKING:
import numpy as np
from matplotlib.axes import Axes
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar
from matplotlib.figure import Figure
from matplotlib.lines import Line2D

from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext
from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext


@Plotter.register("Grid")
Expand All @@ -41,16 +45,15 @@ def __init__(self) -> None:
self._backup_limits = []
self._active_curves = []
self._backup_curves = []
self._plot_limit = 8
self._title_length_limit = 30
self._plot_limit = 9

def slider_labels(self) -> list[str]:
"""Return labels to show that sliders are not used."""
return ["Inactive", "Inactive"]

def slider_limits(self) -> list[str]:
def slider_limits(self) -> list[tuple[float, float, float]]:
"""Return generic slider limit values."""
return self._number_of_sliders * [[-1.0, 1.0, 0.01]]
return [(-1.0, 1.0, 0.01)] * self._number_of_sliders

def check_curve_lengths(self):
"""Find the maximum number of elements in the x axes of the plot data."""
Expand All @@ -70,148 +73,153 @@ def change_normalisation(self, new_value: dict[str, Any]):
"""
super().change_normalisation(new_value)
target = self._figure
if target is None:
return
if len(self._active_curves) == 0:
if target is None or not self._active_curves:
return

for curve_index, curve in enumerate(self._active_curves):
xdata = self._backup_curves[curve_index][0]
ydata = self._backup_curves[curve_index][1]
xdata, ydata = self._backup_curves[curve_index]
xdata, ydata = self.normalise_curve(xdata, ydata)
curve.set_xdata(xdata)
curve.set_ydata(ydata)

target.canvas.draw()

for axes in self._axes:
axes.relim()
axes.autoscale()

if self._toolbar is not None:
self._toolbar.update()
self._toolbar.push_current()

def title_fontsize(self, title_text: str) -> int:
normal_size = rcParams["font.size"]
new_size = (
normal_size
if len(title_text) < self._title_length_limit
else normal_size - round(len(title_text) / self._title_length_limit)
)
return new_size

def toggle_legend(self, enabled: bool) -> None:
if self._figure is None:
return
for plot_index, axes in enumerate(self._axes):

for axes, title in zip(self._axes, self._axes_titles, strict=True):
axes.set_title(
self._axes_titles[plot_index] if enabled else "",
fontsize=self.title_fontsize(self._axes_titles[plot_index]),
title if enabled else "",
fontsize=self.title_fontsize(title),
)

self._figure.canvas.draw()

def plot(
self,
plotting_context: PlottingContext,
figure: Figure = None,
update_only=False,
toolbar=None,
figure: Figure | None = None,
update_only: bool = False,
toolbar: Toolbar | None = None,
):
"""Plot datasets in separate subplots.

Parameters
----------
plotting_context : PlottingContext
Data model storing the data to be plotted
Data model storing the data to be plotted.
figure : Figure, optional
Matplotlib figure instance for plotting, by default None
Matplotlib figure instance for plotting, by default None.
update_only : bool, optional
If true, try to re-use zoom settings, by default False
toolbar : _type_, optional
GUI instance of the matplotlib toolbar, by default None
If true, try to re-use zoom settings, by default False.
toolbar : Toolbar, optional
GUI instance of the matplotlib toolbar, by default None.

"""
self.enable_slider(allow_slider=False)
target = self.get_figure(figure)

if target is None:
return

if toolbar is not None:
self._toolbar = toolbar

if plotting_context.set_axes() is None:
LOG.debug("Axis check failed.")
return

self._figure = target
self._axes = []
self._axes_titles = []
self._backup_curves = []
self._active_curves = []
self._normalisation_errors = []
self.apply_settings(plotting_context)
nplots = 0
for databundle in plotting_context.datasets().values():
ds = databundle.dataset
try:
axis_info = ds._axes_units[databundle.main_axis], databundle.main_axis
except KeyError:
axis_info = ds.longest_axis()
curves = ds.curves_vs_axis(axis_info, max_limit=self._plot_limit)
nplots += len(curves)
nplots = min(nplots, self._plot_limit)
gridsize = math.ceil(nplots**0.5)
startnum = 1
counter = 0
for databundle in plotting_context.datasets().values():
dataset = databundle.dataset
try:
_, best_axis = (
dataset._axes_units[databundle.main_axis],
databundle.main_axis,
)
except KeyError:
_, best_axis = dataset.longest_axis()
for key, curve in islice(dataset._curves.items(), self._plot_limit):
counter += 1
if counter > self._plot_limit:
LOG.warning(
"Curves above the current limit of %s will be ignored",
self._plot_limit,
)
break
axes = target.add_subplot(gridsize, gridsize, startnum)
self._axes.append(axes)
plotlabel = databundle.legend_label
if dataset._curve_labels[key]:
plotlabel += ":" + dataset._curve_labels[key]
x_axis_label = dataset.x_axis_label(best_axis)
[temp_curve] = axes.plot(
dataset.x_axis(best_axis),
curve,
linestyle=databundle.line_style,
color=databundle.colour,
label=plotlabel,
)
try:
temp_curve.set_marker(databundle.marker)
except ValueError:
with contextlib.suppress(Exception):
temp_curve.set_marker(int(databundle.marker))
axes.set_xlabel(x_axis_label)
self._axes_titles.append(plotlabel)
axes.set_title(
plotlabel if plotting_context.use_legend else "",
fontsize=self.title_fontsize(plotlabel),
)
legend = axes.legend()
legend.set_visible(False)
axes.grid(plotting_context.use_grid)
startnum += 1
self._active_curves.append(temp_curve)
self._backup_curves.append(
[temp_curve.get_xdata(), temp_curve.get_ydata()],
)
if counter == 0:
self.plot_blank()
return

self._n_curves = min(
sum(db.dataset.n_curves for db in plotting_context.datasets().values()),
self._plot_limit,
)
grid_size = self.grid_size(self.n_curves)
gs = self._figure.add_gridspec(*grid_size)

for ind, (databundle, label, curve) in enumerate(
islice(flatten(plotting_context.curves()), self._plot_limit)
):
axes = target.add_subplot(gs[ind])
self._axes_titles.append(f"{databundle.dataset._name} {label}")

self._plot_single(
axes,
curve,
databundle,
label="",
colour=databundle.colour,
)
self._axes.append(axes)

self.toggle_legend(plotting_context.use_legend)
self.toggle_grid(plotting_context.use_grid)

self.apply_settings(plotting_context)
self.check_curve_lengths()
target.canvas.draw()

if self._toolbar is not None:
self._toolbar.update()
self._toolbar.push_current()

def _plot_single(
self,
axes: Axes,
curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray],
databundle: PlotArgs,
*,
label: str,
colour: tuple[float, float, float] | str,
):
"""Plot a single curve to axes.

Parameters
----------
axes : Axes
Axis to plot to.
curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray]
Curve to plot.
databundle : PlotArgs
Data to plot.
label : str
Plot label.
colour : tuple[float, float, float] | str
Curve colour.
"""
lines: list[Line2D] = axes.plot(
*curve,
linestyle=databundle.line_style,
label=label,
color=colour,
)

axes.set_xlabel(databundle.dataset.x_axis_label(databundle.main_axis))

for line in lines:
try:
line.set_marker(databundle.marker)
except ValueError:
with contextlib.suppress(Exception):
line.set_marker(int(databundle.marker))

self._active_curves.extend(lines)
self._backup_curves.extend(
(line.get_xdata(), line.get_ydata()) for line in lines
)
Loading
Loading