From 681acafb0ac20a89f4b4952b8861dc8d9ee44a78 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Tue, 4 Nov 2025 15:22:36 +0000 Subject: [PATCH 1/9] General cleanup around plotting reducing duplication and fixing typing --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 2 +- .../Src/MDANSE_GUI/Tabs/Plotters/Plotter.py | 50 ++++---- .../Src/MDANSE_GUI/Tabs/Plotters/Single.py | 108 ++++++++++++++---- 3 files changed, 118 insertions(+), 42 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index b5beb9eaf1..90874427c6 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -52,7 +52,7 @@ class PlotArgs(NamedTuple): """Arguments for plotting data.""" - dataset: FloatArray | ComplexArray + dataset: SingleDataset colour: str line_style: str marker: str diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index e8a50a48c9..54ebfdfc68 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -25,7 +25,7 @@ from more_itertools import consumer from MDANSE.Core.RegisterFactory import RegisterFactory -from MDANSE.IO.IOUtils import UCDict +from MDANSE.IO.IOUtils import UCDict, UCEnum from MDANSE.MLogging import LOG from MDANSE.util_types import FloatArray @@ -39,13 +39,19 @@ from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext -class NormOperations(enum.Enum): +class NormOperations(UCEnum): """Enum for selecting mathematical operations when calculating norms.""" AVERAGE = enum.auto() SUM = enum.auto() NOT_IMPLEMENTED = enum.auto() + @classmethod + def _missing_(cls, value: str) -> NormOperations: + if res := super()._missing_(value): + return res + return cls.NOT_IMPLEMENTED + def str_to_enum(operation: str) -> NormOperations: """Get the right enum from the input text string. @@ -53,19 +59,15 @@ def str_to_enum(operation: str) -> NormOperations: Parameters ---------- operation : str - name of the mathematical operation as string. + Name of the mathematical operation as string. Returns ------- NormOperations - enum value of the operation. + Enum value of the operation. """ - if operation == "average": - return NormOperations.AVERAGE - if operation == "sum": - return NormOperations.SUM - return NormOperations.NOT_IMPLEMENTED + return NormOperations(operation) def enum_to_str(operation: NormOperations) -> str: @@ -74,19 +76,15 @@ def enum_to_str(operation: NormOperations) -> str: Parameters ---------- operation : NormOperations - Enum of the mathematical operation + Enum of the mathematical operation. Returns ------- str - name of the operation as string + Name of the operation as string. """ - if operation == NormOperations.AVERAGE: - return "average" - if NormOperations.SUM: - return "sum" - return "not implemented" + return operation.name.lower() NORMALISATION_DEFAULTS = { @@ -96,6 +94,8 @@ def enum_to_str(operation: NormOperations) -> str: "operation": NormOperations.AVERAGE, } +ValidPlotters = Literal["Single", "Vectors", "Text", "Heatmap", "Grid"] + class Plotter(RegisterFactory): """Parent class to all classes used for displaying data.""" @@ -214,25 +214,30 @@ def normalise_curve( """ apply = self._normalisation_values["apply"] operation = self._normalisation_values["operation"] - if not apply or operation == NormOperations.NOT_IMPLEMENTED: + if not apply or operation is NormOperations.NOT_IMPLEMENTED: return xdata, ydata + min_index = self._normalisation_values["min_index"] max_index = self._normalisation_values["max_index"] ref_values = ydata[min_index:max_index] + if len(ref_values) < 1: self._normalisation_errors.append( "No points within the specified index range" ) return xdata, ydata - if operation == NormOperations.AVERAGE: + + if operation is NormOperations.AVERAGE: scale_factor = np.nanmean(ref_values) - elif operation == NormOperations.SUM: + elif operation is NormOperations.SUM: scale_factor = np.sum(np.nan_to_num(ref_values)) + if np.isclose(scale_factor, 0.0): self._normalisation_errors.append( "Normalisation factor is 0 and will not be applied." ) return xdata, ydata + return xdata, ydata / scale_factor def normalise_array(self, data_array: FloatArray) -> FloatArray: @@ -258,15 +263,18 @@ def normalise_array(self, data_array: FloatArray) -> FloatArray: ref_column = data_array[:, min_index:max_index] if ref_column.shape[1] < 1: return data_array - if operation == NormOperations.AVERAGE: + + if operation is NormOperations.AVERAGE: scale_column = np.nanmean(ref_column, axis=1) - elif operation == NormOperations.SUM: + elif operation is NormOperations.SUM: scale_column = np.sum(np.nan_to_num(ref_column), axis=1) + if np.any(np.isclose(scale_column, 0.0)): self._normalisation_errors.append( "Normalisation factor is 0 for some rows of the 2D array." ) return data_array + return data_array / scale_column.reshape((len(scale_column), 1)) def change_normalisation(self, new_value: dict[str, Any]): diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 065ca8d26b..2731783976 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -21,14 +21,17 @@ import numpy as np from matplotlib.colors import to_rgb +from more_itertools import one from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.axes import Axes from matplotlib.figure import Figure + from matplotlib.lines import Line2D - from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext @Plotter.register("Single") @@ -39,8 +42,8 @@ def __init__(self) -> None: """Initialise all ploting parameters to default values.""" super().__init__() self._figure = None - self._active_curves = [] - self._backup_curves = [] + self._active_curves: list[Line2D] = [] + self._backup_curves: list[Line2D] = [] self._backup_limits = [] self._curve_limit_per_dataset = 12 self.height_max, self.length_max = 0.0, 0.0 @@ -80,7 +83,7 @@ def change_normalisation(self, new_value: dict[str, Any]): Parameters ---------- new_value : dict[str, Any] - parameters as in NORMALISATION_DEFAULTS + Parameters as in NORMALISATION_DEFAULTS. """ super().change_normalisation(new_value) @@ -89,12 +92,12 @@ def change_normalisation(self, new_value: dict[str, Any]): def offset_curves(self): """Offset curves against each other based on slider settings.""" target = self._figure - if target is None: - return - if len(self._active_curves) == 0: + if target is None or not self._active_curves: return + new_value = self._slider_values saved_xmin, saved_xmax, saved_ymin, saved_ymax = self._backup_limits + for num, curve in enumerate(self._active_curves): xdata = self._backup_curves[num][0] ydata = self._backup_curves[num][1] @@ -112,15 +115,18 @@ def offset_curves(self): self._backup_limits = [saved_xmin, saved_xmax, saved_ymin, saved_ymax] self._axes[0].relim() self._axes[0].autoscale() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() + try: self._axes[0].set_xlim(saved_xmin, saved_xmax) except ValueError: LOG.error( f"Matplotlib could not set x limits to {saved_xmin}, {saved_xmax}", ) + try: self._axes[0].set_ylim(saved_ymin, saved_ymax) except ValueError: @@ -140,7 +146,7 @@ def plot( self, plotting_context: PlottingContext, figure: Figure = None, - update_only=False, + update_only: bool = False, toolbar=None, ): """Plot all datasets in the same figure. @@ -158,29 +164,39 @@ def plot( """ self.enable_slider(allow_slider=False) + target = self.get_figure(figure) if target is None: return + if toolbar is not None: self._toolbar = toolbar + self._figure = target self._figure.set_layout_engine("none") self._active_curves = [] self._backup_curves = [] self._normalisation_errors = [] + axes = target.add_subplot(111) self._axes = [axes] self.apply_settings(plotting_context) x_axis_labels = [] + self.height_max, self.length_max = 0.0, 0.0 + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return + if len(plotting_context.datasets()) == 0: target.clear() target.canvas.draw() + for databundle in plotting_context.datasets().values(): dataset = databundle.dataset + plotlabel = databundle.legend_label + try: best_unit, best_axis = ( dataset._axes_units[databundle.main_axis], @@ -188,15 +204,16 @@ def plot( ) except KeyError: best_unit, best_axis = dataset.longest_axis() - plotlabel = databundle.legend_label + x_axis_labels.append(dataset.x_axis_label(best_axis)) + if dataset._n_dim == 1: - [temp] = axes.plot( - dataset.x_axis(best_axis), - dataset.data, - linestyle=databundle.line_style, + self._plot_single( + axes, + databundle, + best_axis, label=plotlabel, - color=databundle.colour, + colour=databundle.colour, ) try: temp.set_marker(databundle.marker) @@ -219,16 +236,17 @@ def plot( colour_increment = (0.5 - main_colour) / min( self._curve_limit_per_dataset, len(multi_curves) ) + for key, value in islice( multi_curves.items(), self._curve_limit_per_dataset ): try: - [temp] = axes.plot( - dataset.x_axis(best_axis), - value, - label=plotlabel + ":" + dataset._curve_labels[key], - linestyle=databundle.line_style, - color=tuple(main_colour), + self._plot_single( + axes, + databundle, + best_axis, + label=f"{plotlabel}:{dataset._curve_labels[key]}", + colour=tuple(main_colour), ) try: temp.set_marker(databundle.marker) @@ -249,11 +267,14 @@ def plot( LOG.error(f"values={value}") return main_colour += colour_increment + if len(self._backup_curves) > 1: self.enable_slider(allow_slider=True) + elif not self._backup_curves: self.plot_blank() return + if update_only: try: axes.set_xlim((self._backup_limits[0], self._backup_limits[1])) @@ -261,6 +282,7 @@ def plot( LOG.error( f"Matplotlib could not set x limits to {self._backup_limits[0]}, {self._backup_limits[1]}" ) + try: axes.set_ylim((self._backup_limits[2], self._backup_limits[3])) except ValueError: @@ -270,9 +292,55 @@ def plot( else: xlimits, ylimits = axes.get_xlim(), axes.get_ylim() self._backup_limits = [xlimits[0], xlimits[1], ylimits[0], ylimits[1]] + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) self.check_curve_lengths() self.offset_curves() + + def _plot_single( + self, + axes: Axes, + databundle: PlotArgs, + best_axis: str, + *, + label: str, + colour: tuple[float, float, float], + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + databundle : FIXME: Add type. + FIXME: Add docs. + best_axis : str + Axis label of X-axis. + label : str + Plot label. + colour : FIXME: Add type. + Curve colour. + """ + temp: Line2D = one( + axes.plot( + databundle.dataset.x_axis(best_axis), + databundle.dataset.data, + linestyle=databundle.line_style, + label=databundle.legend_label, + color=databundle.colour, + ) + ) + + try: + temp.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + temp.set_marker(int(databundle.marker)) + + self._active_curves.append(temp) + self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) + self.height_max = max(self.height_max, temp.get_ydata().max()) + self.length_max = max(self.length_max, temp.get_xdata().max()) From 348bf08f808203ffbba6f580286ca83650980d85 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Wed, 5 Nov 2025 16:43:18 +0000 Subject: [PATCH 2/9] Refactor heatmap, partial refactor of single --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 128 ++++++++----- .../Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 168 +++++++++--------- .../Src/MDANSE_GUI/Tabs/Plotters/Single.py | 83 +++++---- 3 files changed, 215 insertions(+), 164 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index 90874427c6..7a9b0c2936 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -17,6 +17,7 @@ import copy import functools +from collections.abc import Generator, Iterable, Sequence from contextlib import suppress from itertools import islice from math import prod @@ -31,7 +32,7 @@ from matplotlib.colors import to_hex as mpl_to_hex from matplotlib.lines import lineStyles from matplotlib.markers import MarkerStyle -from more_itertools import nth_product +from more_itertools import first, locate, nth, nth_product, sort_together, unzip from qtpy.QtCore import QModelIndex, Qt, Signal, Slot from qtpy.QtGui import QColor, QStandardItem, QStandardItemModel @@ -99,9 +100,7 @@ def __init__( self._curve_labels: dict[tuple[int, ...], str] = {} self._linestyle = linestyle self._marker = marker - self._planes: dict[int, FloatArray] = {} - self._plane_labels: dict[int, str] = {} - self._data_limits: list[int] | None = None + self._data_limits = None self._imaginary_data = None self._valid = True self._scaling_factor = 1.0 @@ -237,11 +236,11 @@ def _( self._axes_order.append(axis_key) self._axes_scaling[axis_key] = 1.0 self._current_units[axis_key] = self._axes_units[axis_key] - self._axes_tag = "|".join([str(x) for x in self._axes]) + self._axes_tag = "|".join(str(x) for x in self._axes) return for axis_key, axis_array in plot_axes.items(): - self._axes_tag = "|".join([str(x) for x in plot_axes]) + self._axes_tag = "|".join(str(x) for x in plot_axes) self._axes[axis_key] = axis_array self._axes_units[axis_key] = ( "N/A" if axes_units is None else axes_units[axis_key] @@ -266,9 +265,11 @@ def create_axes_tags(self, axes_tag: str, source: h5py.File): self._axes[f"index{dim_number}"] = np.arange(dim_length) self._axes_units[f"index{dim_number}"] = "N/A" return + self._current_units = {} self._axes_scaling = {} self._axes_order = [] + for ax_number, axis_name in enumerate(axes_tag.split("|")): aname = axis_name.strip() if aname == "index": @@ -326,6 +327,7 @@ def set_current_units(self, unit_lookup): """Update the unit based on the unit lookup of the PlottingContext.""" if unit_lookup is None: return + for axis_name, axis_unit in self._axes_units.items(): factor, new_unit = unit_lookup.conversion_factor(axis_unit) self._axes_scaling[axis_name] = factor @@ -633,7 +635,13 @@ def curves_vs_axis( return self._curves def curve_ind(self, limits: int, /) -> Iterator[int]: - """Return a generator of indices indexing only the curves within the limits.""" + """Return a generator of indices indexing only the curves within the limits. + + Parameters + ---------- + limits : int + Max number of curves to return. + """ return ( islice(self._data_limits, limits) if self._data_limits is not None @@ -644,7 +652,7 @@ def planes_vs_axis( self, axis_number: int, max_limit: int = 1, - ) -> list[FloatArray] | FloatArray | None: + ) -> Generator[tuple[str, FloatArray]]: """Prepare for plotting 2D subsets of an ND array. Parameters @@ -652,49 +660,81 @@ def planes_vs_axis( axis_number : int index of the axis perpendicular to the plotted array max_limit : int, optional - maximum number of planes allowed by plotter, by default 1 + Maximum number of curves allowed by plotter, by default 1 + + Yields + ------ + str + Grid label. + np.ndarray + 2D array. + + """ + match self._data.ndim: + case 1: + pass + case 2: + if axis_number == 1: + yield self._labels["medium"], self.data.T + else: + yield self._labels["medium"], self.data + case 3: + perpendicular_axis_name, perpendicular_axis = nth( + self._axes.items(), axis_number + ) + + reordered_view = np.moveaxis(self.data, axis_number, 0) + + for plane_number in self.curve_ind(max_limit): + yield ( + f"{self._labels['minimal']}:{perpendicular_axis_name}={perpendicular_axis[plane_number]}", + reordered_view[plane_number], + ) + case _: + raise NotImplementedError( + f"Cannot handle {self._data.ndim}-dimensional data." + ) + + def main_axis_index(self, main_axis: str, *, default: int) -> int: + """Find index of main axis. + + Parameters + ---------- + main_axis : str + Main axis name to search for. + default : int + Index if ``main_axis`` not found. Returns ------- - list[FloatArray] - List of 2D arrays for heatmap plots - + int + Index of main axis. """ - self._planes = {} - self._plane_labels = {} - _found = -1 - total_ndim = self._data.ndim + return first(locate(self._axes, pred=lambda x: x == main_axis), default) - if total_ndim == 1: - return None - if total_ndim == 2: - return self.data + def axes_main_order( + self, main_axis: str | None = None, ind: int | None = None + ) -> Sequence[str]: + """Return axis keys with ``main_axis`` first then the others. - data_shape = self._data.shape - number_of_planes = data_shape[axis_number] - perpendicular_axis = None - perpendicular_axis_name = "" - slice_def = [] - - for number, (axis_name, axis_array) in enumerate(self._axes.items()): - if number == axis_number: - slice_def.append(0) - perpendicular_axis = axis_array - perpendicular_axis_name = self.axis_true_name(axis_name) - else: - slice_def.append(slice(None)) - - for plane_number in self.curve_ind(max_limit): - if plane_number >= number_of_planes: - break - fixed_argument = perpendicular_axis[plane_number] - slice_def[axis_number] = plane_number - self._planes[plane_number] = self.data[tuple(slice_def)] - self._plane_labels[plane_number] = ( - f"{perpendicular_axis_name}={fixed_argument}" - ) + Parameters + ---------- + main_axis : str, optional + Name of main axis to move to front. + ind : int, optional + Main axis by index (if `main_axis` not found). - return None + Returns + ------- + Sequence[str] + Reordered axes. + """ + main_ind = self.main_axis_index(main_axis, default=ind) + return sort_together( + unzip(enumerate(self._axes)), + key=lambda x: x == main_ind, + reverse=True, + )[1] plotting_column_labels = [ diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 0e51b13ebc..2faf2413fa 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -18,43 +18,60 @@ import csv import math from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, TextIO +from dataclasses import dataclass +from itertools import islice +from typing import TYPE_CHECKING, Any, NamedTuple, TextIO import numpy as np from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.pyplot import colorbar as mpl_colorbar +from more_itertools import first, ilen, locate from scipy.interpolate import interp1d from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: from matplotlib.figure import Figure + from matplotlib.image import AxesImage from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext +GRID_SIZES = { + 2: (2, 1), + 5: (2, 3), + 6: (2, 3), +} + + @Plotter.register("Heatmap") class Heatmap(Plotter): """Creates a 2D heatmap plot.""" + @dataclass + class BackupInfo: + ind: int + image: AxesImage + array: np.ndarray + minmax: tuple[float, float] + limits: tuple[float, float, float, float] + interp: interp1d + def __init__(self) -> None: """Initialise all plotting parameters to defaults.""" super().__init__() self._figure = None - self._backup_images = {} - self._backup_arrays = {} - self._backup_minmax = {} - self._backup_scale_interpolators = {} + self._backup: dict[int, Heatmap.BackupInfo] = {} self._current_x_axes = [] - self._backup_limits = {} self._initial_values = [0.0, 100.0] self._slider_values = [0.0, 100.0] self._slice_axis = 2 - self._plot_limit = 1 + self._plot_limit = 9 - def clear(self, figure: Figure = None): + def clear(self, figure: Figure | None = None): """Clear the figure.""" target = self._figure if figure is None else figure if target is None: @@ -73,7 +90,7 @@ def sliders_coupled(self) -> bool: """Confirm that sliders are coupled in heatmap mode.""" return True - def get_figure(self, figure: Figure = None): + def get_figure(self, figure: Figure | None = None): """Return current figure which will be used for plotting.""" target = self._figure if figure is None else figure if target is None: @@ -92,10 +109,10 @@ def change_normalisation(self, new_value: dict[str, Any]): """ super().change_normalisation(new_value) - for ds_num, image in self._backup_images.items(): - data = self._backup_arrays[ds_num] + for backup in self._backup.values(): + data = backup.array new_data = self.normalise_array(data) - image.set_data(new_data) + backup.image.set_data(new_data) percentiles = np.linspace(0, 100.0, 21) results = np.percentile(np.nan_to_num(new_data), percentiles) self._backup_scale_interpolators[ds_num] = interp1d( @@ -108,20 +125,23 @@ def handle_slider(self, new_value: list[float]): """Adjust colormap values based on slider values.""" super().handle_slider(new_value) target = self._figure - if target is None: - return - if new_value[1] <= new_value[0]: + + if target is None or new_value[1] <= new_value[0]: return + self._slider_values = [new_value[0], new_value[1]] - for ds_num, image in self._backup_images.items(): + + for backup in self._backup.values(): try: - last_minmax = self._backup_minmax[ds_num] + last_minmax = backup.minmax except KeyError: - self._backup_minmax[ds_num] = [-1, -1] + backup.minmax = (-1, -1) last_minmax = [-1, -1] - interpolator = self._backup_scale_interpolators[ds_num] + + interpolator = backup.interp newmax = interpolator(new_value[1]) newmin = interpolator(new_value[0]) + if newmax < newmin: if newmax == last_minmax[1]: newmin = float(newmax) @@ -131,27 +151,25 @@ def handle_slider(self, new_value: list[float]): return if newmax >= newmin: try: - image.set_clim([newmin, newmax]) + backup.image.set_clim([newmin, newmax]) except ValueError: LOG.error( f"Matplotlib could not set colorbar limits to {newmin}, {newmax}" ) else: self._figure.canvas.draw_idle() - self._backup_minmax[ds_num] = [newmin, newmax] + backup.minmax = [newmin, newmax] target.canvas.draw() def check_curve_lengths(self): """Find the maximum number of elements in the x axes of the plot data.""" - self.curve_length_limit = 0 - for xdata in self._current_x_axes: - self.curve_length_limit = max(self.curve_length_limit, len(xdata)) + self.curve_length_limit = max(map(len, self._current_x_axes), default=0) def plot( self, plotting_context: PlottingContext, - figure: Figure = None, - update_only=False, + figure: Figure | None = None, + update_only: bool = False, toolbar=None, ): """Plot the first dataset as a heatmap. @@ -166,48 +184,41 @@ def plot( If true, try to re-use zoom settings, by default False toolbar : _type_, optional GUI instance of the matplotlib toolbar, by default None - """ self.enable_slider(allow_slider=True) target = self.get_figure(figure) if target is None: return + if toolbar is not None: self._toolbar = toolbar + self._figure = target self._current_x_axes = [] - self._normalisation_errors = [] - self._backup_images = {} - self._backup_arrays = {} - self._backup_scale_interpolators = {} + minmax_bak = {key: val.minmax for key, val in self._backup.items()} + scale_interpolators = {val.ind: val.interp for val in self._backup.values()} + self._backup = {} self._axes = [] + self.apply_settings(plotting_context) if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return - nplots = 0 + + def get_planes() -> Iterator[tuple[PlotArgs, str, np.ndarray]]: + for databundle in plotting_context.datasets().values(): + ds = databundle.dataset + + for label, plane in ds.planes_vs_axis( + ds.main_axis_index(databundle.main_axis, default=self._slice_axis), + max_limit=self._plot_limit, + ): + yield databundle, label, plane + + # Check interpolators for databundle in plotting_context.datasets().values(): - if nplots >= self._plot_limit: - break - ds = databundle.dataset - if ds._n_dim == 1: - continue - elif ds._n_dim == 3: - replacement_axis_number = None - for number, axis_name in enumerate(ds._axes.keys()): - if axis_name == databundle.main_axis: - replacement_axis_number = number - if replacement_axis_number is None: - ds.planes_vs_axis(self._slice_axis, max_limit=self._plot_limit) - else: - ds.planes_vs_axis( - replacement_axis_number, max_limit=self._plot_limit - ) - nplots += len(ds._planes) - else: - nplots += 1 try: - self._backup_scale_interpolators[databundle.row](51.2) + scale_interpolators[databundle.row](51.2) except Exception: percentiles = np.linspace(0, 100.0, 21) results = [ @@ -217,45 +228,25 @@ def plot( percentiles, results, ) - nplots = min(nplots, self._plot_limit) - gridsize = math.ceil(nplots**0.5) - startnum = 1 - for ds_index, databundle in enumerate(plotting_context.datasets().values()): - if ds_index >= self._plot_limit: - break + + nplots = min(ilen(get_planes()), self._plot_limit) + gridsize = GRID_SIZES.get(nplots, (math.ceil(nplots**0.5),) * 2) + + for ind, (databundle, label, plane) in enumerate( + islice(get_planes(), self._plot_limit), + ): dataset = databundle.dataset - transposed = False - primary_axis_number = 0 limits = [] x_axis_labels, y_axis_labels = [], [] - for number, axis_name in enumerate(ds._axes.keys()): - if axis_name == databundle.main_axis: - primary_axis_number = number - if dataset._n_dim == 1: - continue - if dataset._n_dim == 3: - all_numbers, all_datasets = ( - list(dataset._planes.keys()), - list(dataset._planes.values()), - ) - all_labels = [dataset._plane_labels[number] for number in all_numbers] - for counter, name in enumerate(dataset._axes.keys()): - if counter == primary_axis_number: - continue - axis_array = dataset.x_axis(name) - limits += [ - axis_array[0], - axis_array[-1], - ] - if not x_axis_labels: - x_axis_labels.append(dataset.x_axis_label(name)) - self._current_x_axes.append(axis_array) - else: - y_axis_labels.append(dataset.x_axis_label(name)) - else: - all_numbers = [0] - if primary_axis_number == 0: - all_datasets = [dataset._data.T] + + for name in dataset.axes_main_order( + databundle.main_axis, ind=self._slice_axis + ): + axis_array = dataset.x_axis(name) + limits += [axis_array[0], axis_array[-1]] + if not x_axis_labels: + x_axis_labels.append(dataset.x_axis_label(name)) + self._current_x_axes.append(axis_array) else: all_datasets = [dataset._data] transposed = True @@ -354,6 +345,7 @@ def plot( legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) + self.check_curve_lengths() self.request_slider_values() target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 2731783976..50ed2b5728 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -16,12 +16,12 @@ from __future__ import annotations import contextlib +from collections.abc import Generator from itertools import islice from typing import TYPE_CHECKING, Any import numpy as np from matplotlib.colors import to_rgb -from more_itertools import one from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter @@ -43,19 +43,19 @@ def __init__(self) -> None: super().__init__() self._figure = None self._active_curves: list[Line2D] = [] - self._backup_curves: list[Line2D] = [] + self._backup_curves: list[tuple[np.ndarray, np.ndarray]] = [] self._backup_limits = [] self._curve_limit_per_dataset = 12 self.height_max, self.length_max = 0.0, 0.0 - def clear(self, figure: Figure = None): + def clear(self, figure: Figure | None = None): """Clear the figure.""" target = self._figure if figure is None else figure if target is None: return target.clear() - def get_figure(self, figure: Figure = None): + def get_figure(self, figure: Figure | None = None): """Return the figure instance used for plotting.""" target = self._figure if figure is None else figure if target is None: @@ -84,7 +84,6 @@ def change_normalisation(self, new_value: dict[str, Any]): ---------- new_value : dict[str, Any] Parameters as in NORMALISATION_DEFAULTS. - """ super().change_normalisation(new_value) self.offset_curves() @@ -99,8 +98,7 @@ def offset_curves(self): saved_xmin, saved_xmax, saved_ymin, saved_ymax = self._backup_limits for num, curve in enumerate(self._active_curves): - xdata = self._backup_curves[num][0] - ydata = self._backup_curves[num][1] + xdata, ydata = self._backup_curves[num] xdata, ydata = self.normalise_curve(xdata, ydata) new_xdata = xdata + num * self.length_max * new_value[1] new_ydata = ydata + num * self.height_max * new_value[0] @@ -145,7 +143,7 @@ def check_curve_lengths(self): def plot( self, plotting_context: PlottingContext, - figure: Figure = None, + figure: Figure | None = None, update_only: bool = False, toolbar=None, ): @@ -232,9 +230,9 @@ def plot( multi_curves = dataset.curves_vs_axis( (best_unit, best_axis), max_limit=self._curve_limit_per_dataset ) - main_colour = np.array(to_rgb(databundle.colour)) - colour_increment = (0.5 - main_colour) / min( - self._curve_limit_per_dataset, len(multi_curves) + colours = self.colours( + databundle.colour, + min(self._curve_limit_per_dataset, len(multi_curves)), ) for key, value in islice( @@ -246,7 +244,7 @@ def plot( databundle, best_axis, label=f"{plotlabel}:{dataset._curve_labels[key]}", - colour=tuple(main_colour), + colour=next(colours), ) try: temp.set_marker(databundle.marker) @@ -266,7 +264,6 @@ def plot( LOG.error(f"x_axis={dataset._axes[best_axis]}") LOG.error(f"values={value}") return - main_colour += colour_increment if len(self._backup_curves) > 1: self.enable_slider(allow_slider=True) @@ -300,6 +297,26 @@ def plot( self.check_curve_lengths() self.offset_curves() + @staticmethod + def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: + """Generate colours from root colour. + + Parameters + ---------- + colour : str + Root colour. + + Returns + ------- + Generator[tuple[float, float, float]] + Next colour in sequence. + """ + main_colour = np.array(to_rgb(colour)) + colour_increment = (0.5 - main_colour) / n_curves + for _ in range(n_curves): + main_colour += colour_increment + yield tuple(main_colour) + def _plot_single( self, axes: Axes, @@ -315,8 +332,8 @@ def _plot_single( ---------- axes : Axes Axis to plot to. - databundle : FIXME: Add type. - FIXME: Add docs. + databundle : PlotArgs + Data to plot. best_axis : str Axis label of X-axis. label : str @@ -324,23 +341,25 @@ def _plot_single( colour : FIXME: Add type. Curve colour. """ - temp: Line2D = one( - axes.plot( - databundle.dataset.x_axis(best_axis), - databundle.dataset.data, - linestyle=databundle.line_style, - label=databundle.legend_label, - color=databundle.colour, - ) + lines: list[Line2D] = axes.plot( + databundle.dataset.x_axis(best_axis), + databundle.dataset.data, + linestyle=databundle.line_style, + label=label, + color=colour, ) - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self.height_max = max(self.height_max, line.get_ydata().max()) + self.length_max = max(self.length_max, line.get_xdata().max()) - self._active_curves.append(temp) - self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) - self.height_max = max(self.height_max, temp.get_ydata().max()) - self.length_max = max(self.length_max, temp.get_xdata().max()) + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) From e3e55e5ebc346087ed95c39e78878665fb4abd20 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Wed, 5 Nov 2025 23:43:58 +0000 Subject: [PATCH 3/9] Further refactor of heatmap and single and grid --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 95 ++++++--- .../Src/MDANSE_GUI/Tabs/Plotters/Grid.py | 104 ++++++++-- .../Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 45 ++--- .../Src/MDANSE_GUI/Tabs/Plotters/Plotter.py | 36 +++- .../Src/MDANSE_GUI/Tabs/Plotters/Single.py | 188 ++++++------------ .../Src/MDANSE_GUI/Tabs/Plotters/Text.py | 17 +- .../Src/MDANSE_GUI/Tabs/Plotters/Vectors.py | 23 ++- .../MDANSE_GUI/Tabs/Visualisers/PlotWidget.py | 12 +- 8 files changed, 312 insertions(+), 208 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index 7a9b0c2936..c3dbe2c108 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -96,8 +96,6 @@ def __init__( ): self._name = name self._use_scaling = True - self._curves: dict[tuple[int, ...], FloatArray] = {} - self._curve_labels: dict[tuple[int, ...], str] = {} self._linestyle = linestyle self._marker = marker self._data_limits = None @@ -527,7 +525,9 @@ def generate_curve_label( """ if self._n_dim < 2: return "" + label = "at " + for axis_index, axis_name in enumerate(axis_lookup): axis_label = self.axis_true_name(axis_name) @@ -557,6 +557,16 @@ def generate_curve_label( return float(picked_value) return label.rstrip(", ") + def curve_ind(self, limits: int | None = None, /): + """Get indices of valid axes. + + Parameters + ---------- + limits : int + Max number of curves to return. + """ + return islice(self._data_limits, limits) + def curves_vs_axis( self, x_axis_details: tuple[str, str], @@ -568,8 +578,8 @@ def curves_vs_axis( Parameters ---------- - x_axis_details : tuple[str, str] - Name and original unit of the primary plotting axis + main_axis : str + Name and original unit of the primary plotting axis. max_limit : int, optional Maximum number of curves allowed by plotter, by default 1 skip_label_text: bool, optional @@ -580,17 +590,22 @@ def curves_vs_axis( dict[int, FloatArray] List of data arrays ready for plotting + Yields + ------ + str + Plot label. + np.ndarray + x-axis. + np.ndarray + Curve to plot. """ - self._curves = {} - self._curve_labels = {} + x_axis = self.x_axis(main_axis) if self._data.ndim == 1: - self._curves[(0,)] = self.data - self._curve_labels[(0,)] = "" - return self.data + yield None, (x_axis, self.data) + return data_shape = self._data.shape - x_axis_unit, x_axis_name = x_axis_details slicer = [] indexer = [] label_lookup = [] @@ -600,8 +615,7 @@ def curves_vs_axis( raise ValueError("Array shape does not match the order of the axes") for current_dim, axis_name in enumerate(self._axes_order): - axis_unit = self._axes_units[axis_name] - if axis_unit == x_axis_unit and axis_name == x_axis_name: + if axis_name == main_axis: slicer.append([slice(None)]) continue @@ -612,13 +626,16 @@ def curves_vs_axis( if not indexer: LOG.warning("Empty selection for data set %s", self._name) - return self._curves + return for index in self.curve_ind(max_limit): try: index_tuple = nth_product(index, *indexer) index_slicer = nth_product(index, *slicer) - self._curves[index_tuple] = self.data[index_slicer].squeeze() + yield ( + self.generate_curve_label(index_tuple, label_lookup), + (x_axis, self.data[index_slicer].squeeze()), + ) except IndexError: LOG.warning( "Skipping: in dataset %s, index %s is out of bounds", @@ -659,8 +676,6 @@ def planes_vs_axis( ---------- axis_number : int index of the axis perpendicular to the plotted array - max_limit : int, optional - Maximum number of curves allowed by plotter, by default 1 Yields ------ @@ -760,7 +775,9 @@ class PlottingContext(QStandardItemModel): needs_an_update = Signal("quint64") - def __init__(self, *args, unit_lookup=None, **kwargs): + def __init__( + self, *args, unit_lookup: int | None = None, colormap: str = "viridis", **kwargs + ): super().__init__(*args, **kwargs) self._datasets = {} self._current_axis = [None, None, None] @@ -771,13 +788,14 @@ def __init__(self, *args, unit_lookup=None, **kwargs): self._best_xunits = [] self._colour_list = get_mpl_colours() self._last_colour_list = get_mpl_colours() - self._colour_map = kwargs.get("colormap", "viridis") + self._colour_map = colormap self._last_colour = 0 self._unit_lookup = unit_lookup self.plot_widget_id = -1 self.use_legend = True self.use_grid = True self.setHorizontalHeaderLabels(plotting_column_labels) + self.itemChanged.connect(self.ask_for_update) def generate_colour(self, number: int) -> str: """Get the matplotlib colour string for the nth curve. @@ -923,7 +941,15 @@ def datasets(self) -> dict[str, PlotArgs]: self._datasets[key].set_data_limits(data_number_string, main_axis=main_axis) self._datasets[key].set_current_units(self._unit_lookup) - result[key] = PlotArgs(self._datasets[key], **plot_args) + result[key] = PlotArgs( + dataset=self._datasets[key], + colour=row_data["Colour"].text(), + line_style=row_data["Line style"].text(), + marker=row_data["Marker"].text(), + row=row, + main_axis=row_data["Main axis"].text(), + legend_label=row_data["Legend label"].text(), + ) return result @@ -950,7 +976,7 @@ def add_dataset( self._datasets[newkey] = new_dataset items = [ QStandardItem(str(x)) - for x in [ + for x in ( new_dataset._name, getattr(optional_values, "legend_label", new_dataset._labels["medium"]), new_dataset._data_shape, @@ -966,7 +992,7 @@ def add_dataset( new_dataset._scaling_factor, show=1, arr_fmt=SCALE_FACTOR_FORMAT ), new_dataset._filename, - ] + ) ] fixed = {"Dataset", "Trajectory", "Size", "Unit", "Apply scaling?"} @@ -988,8 +1014,6 @@ def add_dataset( f"0:{prod(len(arr) for arr in new_dataset.dep_axes.values())}:1", ) - self.itemChanged.connect(self.ask_for_update) - temp = items[plotting_column_index["Colour"]] temp.setData(QColor(temp.text()), role=Qt.ItemDataRole.BackgroundRole) @@ -1043,3 +1067,28 @@ def delete_dataset(self, index: QModelIndex): dkey = index.data(role=Qt.ItemDataRole.UserRole) self.removeRow(index.row()) self._datasets.pop(dkey, None) + + def planes( + self, default_axis: int = 0, planes_per_dataset: int | None = None + ) -> Generator[tuple[PlotArgs, str, np.ndarray]]: + for databundle in self.datasets().values(): + ds = databundle.dataset + + for label, plane in islice( + ds.planes_vs_axis( + ds.main_axis_index(databundle.main_axis, default=default_axis) + ), + planes_per_dataset, + ): + yield databundle, label, plane + + def curves( + self, curves_per_dataset: int | None = None + ) -> Generator[tuple[PlotArgs, str, np.ndarray]]: + for databundle in self.datasets().values(): + ds = databundle.dataset + + for label, curve in islice( + ds.curves_vs_axis(databundle.main_axis), curves_per_dataset + ): + yield databundle, label, curve diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index b44f3cd34e..dfd97cc26a 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -21,14 +21,19 @@ from typing import TYPE_CHECKING, Any from matplotlib import rcParams +from more_itertools import ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + import numpy as np + from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure + from matplotlib.lines import Line2D - from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext @Plotter.register("Grid") @@ -70,20 +75,21 @@ def change_normalisation(self, new_value: dict[str, Any]): """ super().change_normalisation(new_value) target = self._figure - if target is None: - return - if len(self._active_curves) == 0: + if target is None or not self._active_curves: return + for curve_index, curve in enumerate(self._active_curves): - xdata = self._backup_curves[curve_index][0] - ydata = self._backup_curves[curve_index][1] + xdata, ydata = self._backup_curves[curve_index] xdata, ydata = self.normalise_curve(xdata, ydata) curve.set_xdata(xdata) curve.set_ydata(ydata) + target.canvas.draw() + for axes in self._axes: axes.relim() axes.autoscale() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() @@ -110,33 +116,37 @@ def toggle_legend(self, enabled: bool) -> None: def plot( self, plotting_context: PlottingContext, - figure: Figure = None, - update_only=False, - toolbar=None, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Plot datasets in separate subplots. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) target = self.get_figure(figure) + if target is None: return + if toolbar is not None: self._toolbar = toolbar + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return + self._figure = target self._axes = [] self._axes_titles = [] @@ -209,9 +219,75 @@ def plot( if counter == 0: self.plot_blank() return + + gridsize = self.grid_size(nplots) + + for ind, (databundle, label, curve) in enumerate( + islice(plotting_context.curves(), self._plot_limit), 1 + ): + axes = target.add_subplot(*gridsize, ind) + self._plot_single( + axes, + curve, + databundle, + label=label, + colour=databundle.colour, + ) + + if plotting_context.use_legend: + axes.legend() + axes.grid(plotting_context.use_grid) + self.apply_settings(plotting_context) self.check_curve_lengths() target.canvas.draw() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + label: str, + colour: tuple[float, float, float], + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + label : str + Plot label. + colour : tuple[float, float, float] + Curve colour. + """ + lines: list[Line2D] = axes.plot( + *curve, + linestyle=databundle.line_style, + label=label, + color=colour, + ) + + axes.set_xlabel(databundle.dataset.x_axis_label(databundle.main_axis)) + + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self._axes.append(axes) + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 2faf2413fa..f4c09ac562 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -26,27 +26,20 @@ from matplotlib.axes import Axes from matplotlib.image import AxesImage from matplotlib.pyplot import colorbar as mpl_colorbar -from more_itertools import first, ilen, locate +from more_itertools import ilen from scipy.interpolate import interp1d from MDANSE.MLogging import LOG -from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.image import AxesImage from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext -GRID_SIZES = { - 2: (2, 1), - 5: (2, 3), - 6: (2, 3), -} - - @Plotter.register("Heatmap") class Heatmap(Plotter): """Creates a 2D heatmap plot.""" @@ -170,20 +163,20 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar=None, + toolbar: Toolbar | None = None, ): """Plot the first dataset as a heatmap. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=True) target = self.get_figure(figure) @@ -194,6 +187,7 @@ def plot( self._toolbar = toolbar self._figure = target + self._figure.set_layout_engine(layout="constrained") self._current_x_axes = [] minmax_bak = {key: val.minmax for key, val in self._backup.items()} scale_interpolators = {val.ind: val.interp for val in self._backup.values()} @@ -205,15 +199,11 @@ def plot( LOG.debug("Axis check failed.") return - def get_planes() -> Iterator[tuple[PlotArgs, str, np.ndarray]]: - for databundle in plotting_context.datasets().values(): - ds = databundle.dataset + nplots = min(ilen(plotting_context.planes(self._slice_axis)), self._plot_limit) - for label, plane in ds.planes_vs_axis( - ds.main_axis_index(databundle.main_axis, default=self._slice_axis), - max_limit=self._plot_limit, - ): - yield databundle, label, plane + if not nplots: + self.plot_blank() + return # Check interpolators for databundle in plotting_context.datasets().values(): @@ -229,11 +219,11 @@ def get_planes() -> Iterator[tuple[PlotArgs, str, np.ndarray]]: results, ) - nplots = min(ilen(get_planes()), self._plot_limit) - gridsize = GRID_SIZES.get(nplots, (math.ceil(nplots**0.5),) * 2) + grid_size = self.grid_size(nplots) + gs = self._figure.add_gridspec(*grid_size) for ind, (databundle, label, plane) in enumerate( - islice(get_planes(), self._plot_limit), + islice(plotting_context.planes(self._slice_axis), self._plot_limit), ): dataset = databundle.dataset limits = [] @@ -346,6 +336,9 @@ def get_planes() -> Iterator[tuple[PlotArgs, str, np.ndarray]]: legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) + if nplots == 1: # Exploit label from loop for one plot + self._figure.suptitle(label) + self.check_curve_lengths() self.request_slider_values() target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index 54ebfdfc68..94f2f2c48a 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -18,6 +18,7 @@ import copy import csv import enum +import math from itertools import count from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO @@ -33,6 +34,7 @@ from collections.abc import Iterator from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.lines import Line2D @@ -102,6 +104,12 @@ class Plotter(RegisterFactory): registry: ClassVar[UCDict[str, type[Plotter]]] = UCDict() + GRID_SIZES = { + 2: (2, 1), + 5: (2, 3), + 6: (2, 3), + } + def __init__(self) -> None: """Create defaults common to all plotters.""" self._figure = None @@ -287,20 +295,20 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar=None, + toolbar: Toolbar | None = None, ): """Plot the selected data in the figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ LOG.info(f"normalisation errors {self._normalisation_errors}, setting to []") @@ -456,3 +464,19 @@ def _get_datasets(axis: Axes) -> Iterator[Line2D]: Each line in dataset. """ yield from axis.get_lines() + + @classmethod + def grid_size(cls, n_plots: int) -> tuple[int, int]: + """Get a good grid layout for plotting. + + Parameters + ---------- + n_plots : int + Number of expected plots. + + Returns + ------- + tuple[int, int] + Grid size. + """ + return cls.GRID_SIZES.get(n_plots, (math.ceil(n_plots**0.5),) * 2) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 50ed2b5728..6e8e777e79 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -22,12 +22,14 @@ import numpy as np from matplotlib.colors import to_rgb +from more_itertools import ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.lines import Line2D @@ -140,25 +142,45 @@ def check_curve_lengths(self): xdata = self._backup_curves[num][0] self.curve_length_limit = max(self.curve_length_limit, len(xdata)) + @staticmethod + def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: + """Generate colours from root colour. + + Parameters + ---------- + colour : str + Root colour. + + Returns + ------- + Generator[tuple[float, float, float]] + Next colour in sequence. + """ + main_colour = np.array(to_rgb(colour)) + colour_increment = (0.5 - main_colour) / n_curves + for _ in range(n_curves): + yield tuple(main_colour) + main_colour += colour_increment + def plot( self, plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar=None, + toolbar: Toolbar | None = None, ): """Plot all datasets in the same figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) @@ -171,15 +193,14 @@ def plot( self._toolbar = toolbar self._figure = target - self._figure.set_layout_engine("none") self._active_curves = [] self._backup_curves = [] self._normalisation_errors = [] + x_axis_labels = [] axes = target.add_subplot(111) self._axes = [axes] self.apply_settings(plotting_context) - x_axis_labels = [] self.height_max, self.length_max = 0.0, 0.0 @@ -187,90 +208,36 @@ def plot( LOG.debug("Axis check failed.") return - if len(plotting_context.datasets()) == 0: - target.clear() - target.canvas.draw() + total_n_curves = sum( + map(ilen, plotting_context.curves(self._curve_limit_per_dataset)) + ) - for databundle in plotting_context.datasets().values(): - dataset = databundle.dataset - plotlabel = databundle.legend_label + if not total_n_curves: + self.plot_blank() + return - try: - best_unit, best_axis = ( - dataset._axes_units[databundle.main_axis], + self.enable_slider(allow_slider=total_n_curves > 1) + + colours = {} + for databundle in plotting_context.datasets().values(): + n_curves = ilen( + databundle.dataset.curves_vs_axis( databundle.main_axis, + self._curve_limit_per_dataset, ) - except KeyError: - best_unit, best_axis = dataset.longest_axis() - - x_axis_labels.append(dataset.x_axis_label(best_axis)) - - if dataset._n_dim == 1: - self._plot_single( - axes, - databundle, - best_axis, - label=plotlabel, - colour=databundle.colour, - ) - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._active_curves.append(temp) - self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) - self.height_max = np.nanmax( - [self.height_max, np.nanmax(temp.get_ydata())] - ) - self.length_max = np.nanmax( - [self.length_max, np.nanmax(temp.get_xdata())] - ) - else: - multi_curves = dataset.curves_vs_axis( - (best_unit, best_axis), max_limit=self._curve_limit_per_dataset - ) - colours = self.colours( - databundle.colour, - min(self._curve_limit_per_dataset, len(multi_curves)), - ) - - for key, value in islice( - multi_curves.items(), self._curve_limit_per_dataset - ): - try: - self._plot_single( - axes, - databundle, - best_axis, - label=f"{plotlabel}:{dataset._curve_labels[key]}", - colour=next(colours), - ) - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._active_curves.append(temp) - self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) - self.height_max = np.nanmax( - [self.height_max, np.nanmax(temp.get_ydata())] - ) - self.length_max = np.nanmax( - [self.length_max, np.nanmax(temp.get_xdata())] - ) - except ValueError: - LOG.error(f"Plotting failed for {plotlabel} using {best_axis}") - LOG.error(f"x_axis={dataset._axes[best_axis]}") - LOG.error(f"values={value}") - return - - if len(self._backup_curves) > 1: - self.enable_slider(allow_slider=True) - - elif not self._backup_curves: - self.plot_blank() - return + ) + colours[databundle.row] = self.colours(databundle.colour, n_curves) + x_axis_labels.append(databundle.dataset.x_axis_label(databundle.main_axis)) + + for databundle, _, curve in plotting_context.curves( + self._curve_limit_per_dataset + ): + self._plot_single( + axes, + curve, + databundle, + colour=next(colours[databundle.row]), + ) if update_only: try: @@ -287,43 +254,23 @@ def plot( f"Matplotlib could not set y limits to {self._backup_limits[2]}, {self._backup_limits[3]}" ) else: - xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - self._backup_limits = [xlimits[0], xlimits[1], ylimits[0], ylimits[1]] + self._backup_limits = [*axes.get_xlim(), *axes.get_ylim()] axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - legend = axes.legend() - legend.set_visible(plotting_context.use_legend) + + if plotting_context.use_legend: + axes.legend() + axes.grid(plotting_context.use_grid) self.check_curve_lengths() self.offset_curves() - @staticmethod - def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: - """Generate colours from root colour. - - Parameters - ---------- - colour : str - Root colour. - - Returns - ------- - Generator[tuple[float, float, float]] - Next colour in sequence. - """ - main_colour = np.array(to_rgb(colour)) - colour_increment = (0.5 - main_colour) / n_curves - for _ in range(n_curves): - main_colour += colour_increment - yield tuple(main_colour) - def _plot_single( self, axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], databundle: PlotArgs, - best_axis: str, *, - label: str, colour: tuple[float, float, float], ): """Plot a single curve to axes. @@ -332,20 +279,17 @@ def _plot_single( ---------- axes : Axes Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. databundle : PlotArgs Data to plot. - best_axis : str - Axis label of X-axis. - label : str - Plot label. - colour : FIXME: Add type. + colour : tuple[float, float, float] Curve colour. """ lines: list[Line2D] = axes.plot( - databundle.dataset.x_axis(best_axis), - databundle.dataset.data, + *curve, linestyle=databundle.line_style, - label=label, + label=databundle.legend_label, color=colour, ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py index 057296d2a4..542b55de07 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from qtpy.QtWidgets import QTextBrowser from MDANSE_GUI.Tabs.Models.PlottingContext import ( @@ -501,9 +502,9 @@ def plot( self, plotting_context: PlottingContext, figure: QTextBrowser = None, - colours=None, - update_only=False, - toolbar=None, + colours: None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Show data as text. @@ -516,12 +517,12 @@ def plot( Data model containing the data sets to be shown figure : QTextBrowser, optional Target widget, an instance of QTextBrowser - colours : _type_, optional - ignored here + colours : None, optional + Ignored here update_only : bool, optional - ignored - toolbar : _type_, optional - ignored + Ignored + toolbar : Toolbar, optional + Ignored. """ target = self.get_figure(figure) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py index a41b6beca0..0e17dac393 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py @@ -25,6 +25,7 @@ from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from MDANSE.util_types import FloatArray @@ -78,7 +79,7 @@ def change_normalisation(self, new_value: dict[str, Any]): Parameters ---------- new_value : dict[str, Any] - parameters as in NORMALISATION_DEFAULTS + Parameters as in NORMALISATION_DEFAULTS. """ super().change_normalisation(new_value) @@ -88,20 +89,20 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar: type | None = None, + toolbar: Toolbar | None = None, ): """Plot all datasets in the same figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) @@ -109,6 +110,7 @@ def plot( if target is None: return + if toolbar is not None: self._toolbar = toolbar @@ -116,6 +118,7 @@ def plot( self._normalisation_errors = [] self._axes = [] self.apply_settings(plotting_context) + x_axis_labels = [] if plotting_context.set_axes() is None: @@ -142,8 +145,10 @@ def plot( ) except KeyError: best_unit, best_axis = dataset.longest_axis() + plotlabel = databundle.legend_label x_axis_labels.append(dataset.x_axis_label(best_axis)) + if dataset._name == "Available vectors": axes = target.add_subplot(single_plot_stack.pop()) if dataset._n_dim == 2: @@ -185,6 +190,7 @@ def plot( self._axes.append(axes) axes.set_xlabel(", ".join(np.unique(x_axis_labels))) axes.set_title(dataset._name) + elif dataset._name == "Shell population": axes = target.add_subplot(212) multi_curves = dataset.curves_vs_axis( @@ -271,12 +277,15 @@ def plot( ylimits[0], ylimits[1], ] + for axes in self._axes: legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) axes.relim() axes.autoscale() + if self._toolbar is not None: self._toolbar.update() + target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py index b9183e1a5a..df11224fd2 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py @@ -45,6 +45,10 @@ from MDANSE_GUI.Widgets.RestrictedSlider import RestrictedSlider if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import ( + NavigationToolbar2QT as Toolbar, + ) + from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -207,9 +211,8 @@ class PlotWidget(QWidget): reset_slider_values = Signal(bool) change_slider_coupling = Signal(bool) - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, plotter_type: str = "Single", **kwargs) -> None: """Create an empty plot with the default plotter.""" - plotter_type = kwargs.pop("plotter_type", "Single") super().__init__(*args, **kwargs) self._plotter = None self._sliderpack = None @@ -245,10 +248,12 @@ def set_plotter(self, plotter_option: str): except Exception: self._plotter = Plotter() self._plotter._figure = self._figure + self.change_slider_labels.emit(self._plotter.slider_labels()) self.change_slider_limits.emit(self._plotter.slider_limits()) self.change_slider_coupling.emit(self._plotter.sliders_coupled()) self.reset_slider_values.emit(self._plotter._value_reset_needed) + self._plotter._slider_reference = self._sliderpack self._sliderpack.setEnabled(False) self.plot_data() @@ -333,13 +338,16 @@ def plot_data(self, update_only=False): return if self._plotting_context is None: return + self._figure.set_layout_engine("tight") + self._plotter.plot( self._plotting_context, self._figure, update_only=update_only, toolbar=self._toolbar, ) + self._normaliser.update_spinbox_limits(self._plotter.curve_length_limit) self._normaliser.collect_values() self._sliderpack.collect_values() From 6a2076819f1fc6f9173eb60cd2e60014147d5af0 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Fri, 7 Nov 2025 14:23:03 +0000 Subject: [PATCH 4/9] Refactor Vectors type and remove from PlotWidget options --- .../Src/MDANSE_GUI/Tabs/Plotters/Vectors.py | 141 +++++++++++++----- 1 file changed, 103 insertions(+), 38 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py index 0e17dac393..b02ea2d909 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py @@ -20,8 +20,12 @@ import numpy as np import numpy.typing as npt +from matplotlib.axes import Axes +from matplotlib.lines import Line2D +from more_itertools import ilen from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: @@ -39,7 +43,7 @@ def violin_plot_width(positions: FloatArray) -> float: @Plotter.register("Vectors") class Vectors(Plotter): - """Plots all the datasets in the same figure.""" + """Plots summarised Q-Vectors to one figure.""" def __init__(self) -> None: """Initialise all ploting parameters to default values.""" @@ -117,6 +121,8 @@ def plot( self._figure = target self._normalisation_errors = [] self._axes = [] + self._active_curves = [] + self._backup_curves = [] self.apply_settings(plotting_context) x_axis_labels = [] @@ -127,37 +133,30 @@ def plot( if not plotting_context.datasets(): target.clear() - target.canvas.draw() + self.plot_blank() + return - single_plot_stack = [222, 221] - label_stack = [ - "Used", - "Found", - ] + gs = self._figure.add_gridspec(2, 2) + labels = iter(("Total available", "Requested", "Selected")) for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - try: - best_unit, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - best_unit, best_axis = dataset.longest_axis() - plotlabel = databundle.legend_label - x_axis_labels.append(dataset.x_axis_label(best_axis)) + x_axis_labels = [dataset.x_axis_label(databundle.main_axis)] - if dataset._name == "Available vectors": - axes = target.add_subplot(single_plot_stack.pop()) - if dataset._n_dim == 2: - temp_curves = [] - for value in dataset.data.T: - [temp] = axes.plot( - dataset.x_axis(best_axis), - value, - label=label_stack.pop(), + match dataset._name: + case "Available vectors": + axes = self._figure.add_subplot(gs[0]) + + lab = plotlabel if dataset._n_dim == 2 else next(labels) + + for _, curve in dataset.curves_vs_axis(databundle.main_axis): + self._plot_single( + axes, + curve, + databundle, + label=lab, ) temp_curves.append(temp) if not label_stack: @@ -220,19 +219,20 @@ def plot( width=width, edgecolor="black", ) - bottom += value - except ValueError: - LOG.error(f"Plotting failed for {plotlabel} using {best_axis}") - LOG.error(f"x_axis={dataset._axes[best_axis]}") - LOG.error(f"values={value}") - return - else: - if ( - add_legend_placeholder - and bar_index == len(multi_curves) - 2 - ): - add_legend_placeholder = False - add_last_entry = True + ) + + x_axis = dataset.x_axis(databundle.main_axis) + bottom = np.zeros_like(x_axis) + width = 0.8 * abs(np.mean(x_axis[1:] - x_axis[:-1])) + + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + for ind, (label, curve) in enumerate( + dataset.curves_vs_axis( + databundle.main_axis, + max_limit=self._curve_limit_per_dataset, + ) + ): + try: axes.bar( x_axis, 0, @@ -289,3 +289,68 @@ def plot( self._toolbar.update() target.canvas.draw() + + @staticmethod + def get_label(ind: int, n_curves: int, limit: int, label: str): + """Get label for legend. + + For the abbreviated legend return None for those which are + between ``limit`` and ``n_curves`` (skipping them), and "..." + when it's at the limit. + + Parameters + ---------- + ind : int + Current index. + n_curves : int + Total number of "curves" to plot. + limit : int + Max number of entries in legend. + label : str + Current labe. + """ + if ind == limit: + return "..." + if limit < ind < n_curves - 1: + return None + return label + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + yerr: np.ndarray | None = None, + label: str, + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + yerr : ndarray, optional + Error bars to add. + label : str + Plot label. + """ + line, _caps, _bars = axes.errorbar( + *curve, + yerr=yerr, + linestyle=databundle.line_style, + label=label, + color=databundle.colour, + ) + + axes.set_title(databundle.dataset._name) + + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) From 41e361258b30ea29404394e75be218a41c2eeff2 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Fri, 7 Nov 2025 14:36:05 +0000 Subject: [PATCH 5/9] Fix tests --- .../Tests/UnitTests/test_PlottingContext.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py b/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py index 5b729618b2..383a04aff5 100644 --- a/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py +++ b/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py @@ -1,12 +1,12 @@ -import pytest -import tempfile -import os +from __future__ import annotations + from pathlib import Path import h5py import numpy as np - -from qtpy import QtGui, QtCore, QtWidgets +import pytest +from more_itertools import ilen, nth, first +from qtpy import QtCore, QtGui, QtWidgets from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext, SingleDataset @@ -62,18 +62,16 @@ def test_available_x_axes_2d(file_2d): def test_curves_vs_axis_2d_long_axis(file_2d): temp = SingleDataset("f(q,t)_total", file_2d) temp.set_data_limits("0;1;2;3") - curves = temp.curves_vs_axis(("ps", "time"), max_limit=12) - print(len(curves)) - assert len(curves) == 4 - print(curves.keys()) - assert len(curves[(0,)]) == 501 + curves = temp.curves_vs_axis("time", max_limit=12) + assert ilen(curves) == 4 + label, (x, y) = first(temp.curves_vs_axis("time", max_limit=12)) + assert len(x) == 501 def test_curves_vs_axis_2d_short_axis(file_2d): temp = SingleDataset("f(q,t)_total", file_2d) temp.set_data_limits("2;3;4;5") - curves = temp.curves_vs_axis(("1/nm", "q"), max_limit=12) - print(len(curves)) - assert len(curves) == 4 - print(curves.keys()) - assert len(curves[(2,)]) == 10 + curves = temp.curves_vs_axis("q", max_limit=12) + assert ilen(curves) == 4 + label, (x, y) = nth(temp.curves_vs_axis("q", max_limit=12), 3) + assert len(x) == 10 From a9c827867da192122481b97b1b2690355b077967 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Tue, 18 Nov 2025 17:17:21 +0000 Subject: [PATCH 6/9] Respond to comments --- MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py | 7 ++++--- MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 4 +--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index dfd97cc26a..2b67252ef7 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -46,7 +46,7 @@ def __init__(self) -> None: self._backup_limits = [] self._active_curves = [] self._backup_curves = [] - self._plot_limit = 8 + self._plot_limit = 9 self._title_length_limit = 30 def slider_labels(self) -> list[str]: @@ -234,10 +234,11 @@ def plot( colour=databundle.colour, ) - if plotting_context.use_legend: - axes.legend() + axes.legend() axes.grid(plotting_context.use_grid) + axes.get_legend().set_visible(plotting_context.use_legend) + self.apply_settings(plotting_context) self.check_curve_lengths() target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index f4c09ac562..36720b4768 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -229,9 +229,7 @@ def plot( limits = [] x_axis_labels, y_axis_labels = [], [] - for name in dataset.axes_main_order( - databundle.main_axis, ind=self._slice_axis - ): + for name in dataset._axes: axis_array = dataset.x_axis(name) limits += [axis_array[0], axis_array[-1]] if not x_axis_labels: From eb3fad3c1039771d21288d0801c52fbf7ec3f4fa Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Mon, 1 Jun 2026 15:33:11 +0100 Subject: [PATCH 7/9] Tmp --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 218 +++++++++--------- .../Src/MDANSE_GUI/Tabs/Plotters/Grid.py | 62 +---- .../Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 196 +++++++--------- .../Src/MDANSE_GUI/Tabs/Plotters/Plotter.py | 25 ++ .../Src/MDANSE_GUI/Tabs/Plotters/Vectors.py | 201 +++++----------- .../Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py | 126 +++++----- .../Src/MDANSE_GUI/Tabs/Views/PlotDataView.py | 2 + .../MDANSE_GUI/Tabs/Visualisers/PlotWidget.py | 2 +- 8 files changed, 340 insertions(+), 492 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index c3dbe2c108..139e455206 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -22,7 +22,7 @@ from itertools import islice from math import prod from pathlib import Path -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Literal, NamedTuple, overload import h5py import matplotlib.pyplot as mpl @@ -484,7 +484,7 @@ def dep_axes(self) -> dict[str, FloatArray]: return {aname: axis for aname, axis in self._axes.items() if aname != la} @property - def data(self): + def data(self) -> npt.NDArray[np.floating]: """Data array, scaled if requested in the GUI table. Returns @@ -497,22 +497,32 @@ def data(self): return self._data * self._scaling_factor return self._data + @overload def generate_curve_label( self, - index_tuple: list[int], - axis_lookup: list[str], + index_tuple: Sequence[int], + axis_lookup: Iterable[str], *, - skip_text: bool = False, - ) -> str | float: + skip_text: Literal[False] = False, + ) -> str: ... + @overload + def generate_curve_label( + self, + index_tuple: Sequence[int], + axis_lookup: Iterable[str], + *, + skip_text: Literal[True], + ) -> float: ... + def generate_curve_label(self, index_tuple, axis_lookup, *, skip_text=False): """Get a meaningful label for a subset of data. Used when plotting 1D arrays out of a multidimensional array. Parameters ---------- - index_tuple : list[int] + index_tuple : Sequence[int] indices of the 1D data array position in the ND array. - axis_lookup : list[str] + axis_lookup : Iterable[str] Names of the axes to use. skip_text : bool, optional If set to true, omits the text parts of the label. By default False. @@ -536,15 +546,15 @@ def generate_curve_label( picked_value = axis_values[index_tuple[axis_index]] if len(axis_values) > 1: - significant_digit = np.floor( - np.log10(abs(np.mean(axis_values[1:] - axis_values[:-1]))), - ).astype(int) + data = np.mean(np.diff(axis_values)) elif len(axis_values) == 1: - significant_digit = np.floor(np.log10(abs(axis_values[0]))).astype(int) + data: np.floating = axis_values[0] else: label += f"{axis_label} has no values, unit {axis_unit}" continue + significant_digit: np.integer = np.floor(np.log10(np.abs(data))).astype(int) + if significant_digit < -20: picked_value = 0 elif significant_digit < 0: @@ -569,37 +579,49 @@ def curve_ind(self, limits: int | None = None, /): def curves_vs_axis( self, - x_axis_details: tuple[str, str], + axis_label: tuple[str, str] | str, max_limit: int = 1, *, + axis_unit: str | None = None, skip_label_text: bool = False, - ) -> dict[int, FloatArray]: + ) -> Generator[ + tuple[ + str | float | None, + tuple[FloatArray, FloatArray], + ] + ]: """Prepare a set of curves for plotting. Parameters ---------- - main_axis : str - Name and original unit of the primary plotting axis. + axis_label : str + Name of the primary plotting axis. + axis_unit : str, optional + Unit of the primary plotting axis. max_limit : int, optional Maximum number of curves allowed by plotter, by default 1 skip_label_text: bool, optional Whether to skip the axis name and unit in the curve label, by default False. - Returns - ------- - dict[int, FloatArray] - List of data arrays ready for plotting - Yields ------ str Plot label. - np.ndarray + npt.NDArray[np.floating] x-axis. - np.ndarray + npt.NDArray[np.floating] Curve to plot. """ - x_axis = self.x_axis(main_axis) + match axis_label: + case (str(unit), str(label)): + axis_unit = unit + axis_label = label + case str(): + pass + case _: + raise ValueError(f"Cannot handle {axis_label} as axis label") + + x_axis = self.x_axis(axis_label) if self._data.ndim == 1: yield None, (x_axis, self.data) @@ -607,63 +629,64 @@ def curves_vs_axis( data_shape = self._data.shape slicer = [] - indexer = [] + indexer: list[Sequence[int]] = [] label_lookup = [] axis_lengths = [len(self._axes[name]) for name in self._axes_order] + match_unit = axis_unit is not None + + if np.allclose(data_shape, axis_lengths): + for current_dim, axis_name in enumerate(self._axes_order): + curr_unit = self._axes_units[axis_name] + + if axis_name == axis_label and ( + not match_unit or (axis_unit == curr_unit) + ): + slicer.append([slice(None)]) + continue + + indices: npt.NDArray[np.integer] = np.arange(data_shape[current_dim]) + slicer.append(indices) + indexer.append(indices) + label_lookup.append(axis_name) + + if not indexer: + LOG.warning("Empty selection for data set %s", self._name) + return + + for index in self.curve_ind(max_limit): + try: + index_tuple = nth_product(index, *indexer) + index_slicer = nth_product(index, *slicer) + yield ( + self.generate_curve_label( + index_tuple, label_lookup, skip_text=skip_label_text + ), + (x_axis, self.data[index_slicer].squeeze()), + ) + except IndexError: + LOG.warning( + "Skipping: in dataset %s, index %s is out of bounds", + self._name, + index, + ) - if not np.allclose(data_shape, axis_lengths): - raise ValueError("Array shape does not match the order of the axes") - - for current_dim, axis_name in enumerate(self._axes_order): - if axis_name == main_axis: - slicer.append([slice(None)]) - continue - - indices = np.arange(data_shape[current_dim]) - slicer.append(indices) - indexer.append(indices) - label_lookup.append(axis_name) + elif ( + len(axis_lengths) == 1 + and len(data_shape) == 2 + and data_shape[0] == axis_lengths[0] + ): + # Assume multiple lines in block - if not indexer: - LOG.warning("Empty selection for data set %s", self._name) - return + axis_name = first(self._axes_order) - for index in self.curve_ind(max_limit): - try: - index_tuple = nth_product(index, *indexer) - index_slicer = nth_product(index, *slicer) + for current_dim in range(data_shape[1]): yield ( - self.generate_curve_label(index_tuple, label_lookup), - (x_axis, self.data[index_slicer].squeeze()), - ) - except IndexError: - LOG.warning( - "Skipping: in dataset %s, index %s is out of bounds", - self._name, - index, + axis_name, + (x_axis, self.data[:, current_dim]), ) - else: - self._curve_labels[index_tuple] = self.generate_curve_label( - index_tuple, - label_lookup, - skip_text=skip_label_text, - ) - - return self._curves - def curve_ind(self, limits: int, /) -> Iterator[int]: - """Return a generator of indices indexing only the curves within the limits. - - Parameters - ---------- - limits : int - Max number of curves to return. - """ - return ( - islice(self._data_limits, limits) - if self._data_limits is not None - else range(limits) - ) + else: + raise ValueError("Array shape does not match the order of the axes") def planes_vs_axis( self, @@ -681,23 +704,25 @@ def planes_vs_axis( ------ str Grid label. - np.ndarray + npt.NDArray[np.floating] 2D array. """ match self._data.ndim: case 1: pass + case 2 if axis_number == 1: + yield self._labels["medium"], self.data.T case 2: - if axis_number == 1: - yield self._labels["medium"], self.data.T - else: - yield self._labels["medium"], self.data + yield self._labels["medium"], self.data case 3: perpendicular_axis_name, perpendicular_axis = nth( - self._axes.items(), axis_number + self._axes.items(), axis_number, default=(None, None) ) + if perpendicular_axis is None: + return + reordered_view = np.moveaxis(self.data, axis_number, 0) for plane_number in self.curve_ind(max_limit): @@ -710,7 +735,7 @@ def planes_vs_axis( f"Cannot handle {self._data.ndim}-dimensional data." ) - def main_axis_index(self, main_axis: str, *, default: int) -> int: + def main_axis_index(self, main_axis: str | None, *, default: int) -> int: """Find index of main axis. Parameters @@ -727,30 +752,6 @@ def main_axis_index(self, main_axis: str, *, default: int) -> int: """ return first(locate(self._axes, pred=lambda x: x == main_axis), default) - def axes_main_order( - self, main_axis: str | None = None, ind: int | None = None - ) -> Sequence[str]: - """Return axis keys with ``main_axis`` first then the others. - - Parameters - ---------- - main_axis : str, optional - Name of main axis to move to front. - ind : int, optional - Main axis by index (if `main_axis` not found). - - Returns - ------- - Sequence[str] - Reordered axes. - """ - main_ind = self.main_axis_index(main_axis, default=ind) - return sort_together( - unzip(enumerate(self._axes)), - key=lambda x: x == main_ind, - reverse=True, - )[1] - plotting_column_labels = [ "Dataset", @@ -930,15 +931,6 @@ def datasets(self) -> dict[str, PlotArgs]: data_number_string = row_data["Use it?"].text() main_axis = row_data["Main axis"].text() - plot_args = { - "colour": row_data["Colour"].text(), - "line_style": row_data["Line style"].text(), - "marker": row_data["Marker"].text(), - "row": row, - "main_axis": main_axis, - "legend_label": row_data["Legend label"].text(), - } - self._datasets[key].set_data_limits(data_number_string, main_axis=main_axis) self._datasets[key].set_current_units(self._unit_lookup) result[key] = PlotArgs( diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index 2b67252ef7..756c149a3e 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -154,6 +154,7 @@ def plot( self._active_curves = [] self._normalisation_errors = [] self.apply_settings(plotting_context) + nplots = 0 for databundle in plotting_context.datasets().values(): ds = databundle.dataset @@ -162,64 +163,9 @@ def plot( except KeyError: axis_info = ds.longest_axis() curves = ds.curves_vs_axis(axis_info, max_limit=self._plot_limit) - nplots += len(curves) - nplots = min(nplots, self._plot_limit) - gridsize = math.ceil(nplots**0.5) - startnum = 1 - counter = 0 - for databundle in plotting_context.datasets().values(): - dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() - for key, curve in islice(dataset._curves.items(), self._plot_limit): - counter += 1 - if counter > self._plot_limit: - LOG.warning( - "Curves above the current limit of %s will be ignored", - self._plot_limit, - ) - break - axes = target.add_subplot(gridsize, gridsize, startnum) - self._axes.append(axes) - plotlabel = databundle.legend_label - if dataset._curve_labels[key]: - plotlabel += ":" + dataset._curve_labels[key] - x_axis_label = dataset.x_axis_label(best_axis) - [temp_curve] = axes.plot( - dataset.x_axis(best_axis), - curve, - linestyle=databundle.line_style, - color=databundle.colour, - label=plotlabel, - ) - try: - temp_curve.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp_curve.set_marker(int(databundle.marker)) - axes.set_xlabel(x_axis_label) - self._axes_titles.append(plotlabel) - axes.set_title( - plotlabel if plotting_context.use_legend else "", - fontsize=self.title_fontsize(plotlabel), - ) - legend = axes.legend() - legend.set_visible(False) - axes.grid(plotting_context.use_grid) - startnum += 1 - self._active_curves.append(temp_curve) - self._backup_curves.append( - [temp_curve.get_xdata(), temp_curve.get_ydata()], - ) - if counter == 0: - self.plot_blank() - return + nplots += ilen(curves) + nplots = min(nplots, self._plot_limit) gridsize = self.grid_size(nplots) for ind, (databundle, label, curve) in enumerate( @@ -228,7 +174,7 @@ def plot( axes = target.add_subplot(*gridsize, ind) self._plot_single( axes, - curve, + (curve,), databundle, label=label, colour=databundle.colour, diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 36720b4768..26ff84637d 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -18,15 +18,14 @@ import csv import math from collections.abc import Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import islice from typing import TYPE_CHECKING, Any, NamedTuple, TextIO import numpy as np from matplotlib.axes import Axes -from matplotlib.image import AxesImage from matplotlib.pyplot import colorbar as mpl_colorbar -from more_itertools import ilen +from more_itertools import ilen, nth from scipy.interpolate import interp1d from MDANSE.MLogging import LOG @@ -47,11 +46,11 @@ class Heatmap(Plotter): @dataclass class BackupInfo: ind: int - image: AxesImage - array: np.ndarray - minmax: tuple[float, float] - limits: tuple[float, float, float, float] - interp: interp1d + image: AxesImage | None = None + array: np.ndarray = field(default_factory=lambda: np.empty((0,), dtype=float)) + minmax: tuple[float, float] = (-np.inf, np.inf) + limits: tuple[float, float, float, float] = (-np.inf, np.inf, -np.inf, np.inf) + interp: interp1d = None def __init__(self) -> None: """Initialise all plotting parameters to defaults.""" @@ -75,15 +74,15 @@ def slider_labels(self) -> list[str]: """Return labels for the sliders in heatmap mode.""" return ["Minimum (percentile)", "Maximum (percentile)"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return slider limits for the colormap, in percent.""" - return self._number_of_sliders * [[0.0, 100.0, 0.01]] + return [(0.0, 100.0, 0.01)] * self._number_of_sliders def sliders_coupled(self) -> bool: """Confirm that sliders are coupled in heatmap mode.""" return True - def get_figure(self, figure: Figure | None = None): + def get_figure(self, figure: Figure | None = None) -> Figure | None: """Return current figure which will be used for plotting.""" target = self._figure if figure is None else figure if target is None: @@ -108,10 +107,8 @@ def change_normalisation(self, new_value: dict[str, Any]): backup.image.set_data(new_data) percentiles = np.linspace(0, 100.0, 21) results = np.percentile(np.nan_to_num(new_data), percentiles) - self._backup_scale_interpolators[ds_num] = interp1d( - percentiles, - results, - ) + backup.interp = interp1d(percentiles, results) + self.request_slider_values() def handle_slider(self, new_value: list[float]): @@ -140,18 +137,20 @@ def handle_slider(self, new_value: list[float]): newmin = float(newmax) else: newmax = float(newmin) + if newmin == last_minmax[0] and newmax == last_minmax[1]: return + if newmax >= newmin: try: - backup.image.set_clim([newmin, newmax]) + backup.image.set_clim((newmin, newmax)) except ValueError: LOG.error( f"Matplotlib could not set colorbar limits to {newmin}, {newmax}" ) else: self._figure.canvas.draw_idle() - backup.minmax = [newmin, newmax] + backup.minmax = (newmin, newmax) target.canvas.draw() def check_curve_lengths(self): @@ -180,6 +179,7 @@ def plot( """ self.enable_slider(allow_slider=True) target = self.get_figure(figure) + if target is None: return @@ -189,9 +189,13 @@ def plot( self._figure = target self._figure.set_layout_engine(layout="constrained") self._current_x_axes = [] - minmax_bak = {key: val.minmax for key, val in self._backup.items()} + + # minmax_bak = {key: val.minmax for key, val in self._backup.items()} scale_interpolators = {val.ind: val.interp for val in self._backup.values()} - self._backup = {} + self._backup = { + databundle.row: self.BackupInfo(ind=databundle.row) + for databundle in plotting_context.datasets().values() + } self._axes = [] self.apply_settings(plotting_context) @@ -212,12 +216,17 @@ def plot( except Exception: percentiles = np.linspace(0, 100.0, 21) results = [ - np.percentile(np.nan_to_num(ds._data), perc) for perc in percentiles + np.percentile(np.nan_to_num(databundle.dataset._data), perc) + for perc in percentiles ] - self._backup_scale_interpolators[databundle.row] = interp1d( + self._backup[databundle.row].interp = interp1d( percentiles, results, ) + else: + self._backup[databundle.row].interp = scale_interpolators[ + databundle.row + ] grid_size = self.grid_size(nplots) gs = self._figure.add_gridspec(*grid_size) @@ -226,110 +235,71 @@ def plot( islice(plotting_context.planes(self._slice_axis), self._plot_limit), ): dataset = databundle.dataset - limits = [] - x_axis_labels, y_axis_labels = [], [] - - for name in dataset._axes: - axis_array = dataset.x_axis(name) - limits += [axis_array[0], axis_array[-1]] - if not x_axis_labels: - x_axis_labels.append(dataset.x_axis_label(name)) - self._current_x_axes.append(axis_array) - else: - all_datasets = [dataset._data] - transposed = True - all_labels = [dataset._name] - for counter, name in enumerate(dataset._axes.keys()): - axis_array = dataset.x_axis(name) - limits += [ - axis_array[0], - axis_array[-1], - ] - if counter == primary_axis_number: - x_axis_labels.append(dataset.x_axis_label(name)) - self._current_x_axes.append(axis_array) - else: - y_axis_labels.append(dataset.x_axis_label(name)) - if transposed: - limits = limits[2:] + limits[:2] - for xnum in range(len(all_datasets)): - if startnum > self._plot_limit: - LOG.warning( - "Datasets above the current limit of %s will be ignored", - self._plot_limit, - ) - break - axes = target.add_subplot(gridsize, gridsize, startnum) - startnum += 1 - self._axes.append(axes) - image = axes.imshow( - all_datasets[xnum][::-1, :], - extent=limits, - aspect="auto", - interpolation=None, - cmap=plotting_context.colormap, + + axes = self._figure.add_subplot(gs[ind]) + + x_label = databundle.main_axis + y_label = nth(dataset._axes, self._slice_axis) + if y_label is None: + y_label = nth(dataset._axes, 1) + + x_axis = dataset.x_axis(x_label) + y_axis = dataset.x_axis(y_label) + + limits = (x_axis[0], x_axis[-1], y_axis[0], y_axis[-1]) + axes.set_xlabel(x_label) + axes.set_ylabel(y_label) + + image = axes.imshow( + plane, + extent=limits, + aspect="auto", + interpolation=None, + cmap=plotting_context.colormap, + ) + axes.set_title(label) + colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) + colorbar.set_label(dataset._data_unit) + xlimits, ylimits = axes.get_xlim(), axes.get_ylim() + self._axes.append(axes) + self._backup[databundle.row].array = plane + self._backup[databundle.row].image = image + + interpolator = self._backup[databundle.row].interp + last_minmax = ( + interpolator(self._slider_values[0]), + interpolator(self._slider_values[1]), + ) + + try: + image.set_clim(last_minmax) + except ValueError: + LOG.error( + f"Matplotlib could not set colorbar limits to {last_minmax}", ) - axes.set_title(all_labels[xnum]) - colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) - colorbar.set_label(dataset._data_unit) - xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - self._backup_arrays[databundle.row] = all_datasets[xnum][::-1, :] + if update_only: - interpolator = self._backup_scale_interpolators[databundle.row] - last_minmax = [ - interpolator(self._slider_values[0]), - interpolator(self._slider_values[1]), - ] - try: - image.set_clim(last_minmax) - except ValueError: - LOG.error( - f"Matplotlib could not set colorbar limits to {last_minmax}", - ) - self._backup_limits[databundle.row] = [ + xlimits = axes.get_xlim() + ylimits = axes.get_ylim() + self._backup[databundle.row].limits = ( xlimits[0], xlimits[1], ylimits[0], ylimits[1], - ] - xlim = axes.get_xlim() - self._backup_limits[databundle.row][0] = xlim[0] - self._backup_limits[databundle.row][1] = xlim[1] - ylim = axes.get_ylim() - self._backup_limits[databundle.row][2] = ylim[0] - self._backup_limits[databundle.row][3] = ylim[1] + ) else: - self._backup_limits[databundle.row] = [ - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ] - interpolator = self._backup_scale_interpolators[databundle.row] - last_minmax = [ - interpolator(self._slider_values[0]), - interpolator(self._slider_values[1]), - ] - try: - image.set_clim(last_minmax) - except ValueError: - LOG.error( - f"Matplotlib could not set colorbar limits to {last_minmax}", - ) - self._backup_minmax[databundle.row] = [ + self._backup[databundle.row].minmax = ( np.nanmin(dataset._data), np.nanmax(dataset._data), - ] - self._backup_limits[databundle.row] = [ + ) + self._backup[databundle.row].limits = ( xlimits[0], xlimits[1], ylimits[0], ylimits[1], - ] - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_ylabel(", ".join(np.unique(y_axis_labels))) - self._backup_images[databundle.row] = image - if startnum > 1: + ) + + if ind > 1: legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) @@ -339,7 +309,7 @@ def plot( self.check_curve_lengths() self.request_slider_values() - target.canvas.draw() + self._figure.canvas.draw() @staticmethod def _write_save_data( diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index 94f2f2c48a..4dd76d2fc1 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -480,3 +480,28 @@ def grid_size(cls, n_plots: int) -> tuple[int, int]: Grid size. """ return cls.GRID_SIZES.get(n_plots, (math.ceil(n_plots**0.5),) * 2) + + @staticmethod + def get_label(ind: int, n_curves: int, limit: int, label: str): + """Get label for legend. + + For the abbreviated legend return None for those which are + between ``limit`` and ``n_curves`` (skipping them), and "..." + when it's at the limit. + + Parameters + ---------- + ind : int + Current index. + n_curves : int + Total number of "curves" to plot. + limit : int + Max number of entries in legend. + label : str + Current label. + """ + if ind == limit: + return "..." + if limit < ind < n_curves - 1: + return None + return label diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py index b02ea2d909..307ba1f70c 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py @@ -57,9 +57,9 @@ def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def clear(self, figure: Figure | None = None): """Clear the figure.""" @@ -125,8 +125,6 @@ def plot( self._backup_curves = [] self.apply_settings(plotting_context) - x_axis_labels = [] - if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return @@ -137,7 +135,7 @@ def plot( return gs = self._figure.add_gridspec(2, 2) - labels = iter(("Total available", "Requested", "Selected")) + labels = iter(("Used", "Found")) for databundle in plotting_context.datasets().values(): dataset = databundle.dataset @@ -147,83 +145,47 @@ def plot( match dataset._name: case "Available vectors": - axes = self._figure.add_subplot(gs[0]) - - lab = plotlabel if dataset._n_dim == 2 else next(labels) + axes = self._figure.add_subplot(gs[0, 0]) for _, curve in dataset.curves_vs_axis(databundle.main_axis): - self._plot_single( - axes, - curve, - databundle, + lab = plotlabel if dataset._n_dim != 2 else next(labels) + + axes.bar( + *curve, label=lab, ) - temp_curves.append(temp) - if not label_stack: - break - else: - temp_curves = axes.plot( - dataset.x_axis(best_axis), + + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + axes.set_title(dataset._name) + self._axes.append(axes) + + case r"<|q|> - q$_{target}$": + axes = self._figure.add_subplot(gs[0, 1]) + xvals = dataset.x_axis(databundle.main_axis) + axes.violinplot( dataset.data, - linestyle=databundle.line_style, - label=plotlabel, - color=databundle.colour, + positions=xvals, + widths=violin_plot_width(xvals), ) - for temp in temp_curves: - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._axes.append(axes) - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_title(dataset._name) - elif dataset._name == r"<|q|> - q$_{target}$": - axes = target.add_subplot(single_plot_stack.pop()) - xvals = dataset.x_axis(best_axis) - axes.violinplot( - dataset.data, - positions=xvals, - widths=violin_plot_width(xvals), - ) - self._axes.append(axes) - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_title(dataset._name) - - elif dataset._name == "Shell population": - axes = target.add_subplot(212) - multi_curves = dataset.curves_vs_axis( - (best_unit, best_axis), - max_limit=self._curve_limit_per_dataset, - ) - x_axis = dataset.x_axis(best_axis) - bottom = np.zeros(len(x_axis)) - width = 0.8 * abs(np.mean(x_axis[1:] - x_axis[:-1])) - self._axes.append(axes) - add_legend_placeholder = ( - len(multi_curves) > self._legend_limit_for_histogram - ) - add_last_entry = False - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - for bar_index, (key, value) in enumerate(multi_curves.items()): - legend_label = plotlabel + ":" + dataset._curve_labels[key] - try: - axes.bar( - x_axis, - value, - label=legend_label - if bar_index < self._legend_limit_for_histogram - or add_last_entry - else None, - bottom=bottom, - width=width, - edgecolor="black", + self._axes.append(axes) + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + axes.set_title(dataset._name) + + case "Shell population": + axes = self._figure.add_subplot(gs[1, :]) + + n_curves = ilen( + dataset.curves_vs_axis( + databundle.main_axis, + max_limit=self._curve_limit_per_dataset, ) ) x_axis = dataset.x_axis(databundle.main_axis) bottom = np.zeros_like(x_axis) - width = 0.8 * abs(np.mean(x_axis[1:] - x_axis[:-1])) + width = 0.8 * abs(np.mean(np.diff(x_axis))) + + self._axes.append(axes) axes.set_xlabel(", ".join(np.unique(x_axis_labels))) for ind, (label, curve) in enumerate( @@ -234,11 +196,28 @@ def plot( ): try: axes.bar( - x_axis, - 0, - label="...", - color=target.get_facecolor(), + *curve, + label=self.get_label( + ind=ind, + n_curves=n_curves, + limit=self._legend_limit_for_histogram, + label=f"{plotlabel}:{label}", + ), + bottom=bottom, + width=width, + edgecolor="black", ) + bottom += curve[1] + + except ValueError: + x, y = curve + LOG.error( + f"Plotting failed for {plotlabel} using {databundle.main_axis}" + ) + LOG.error(f"x_axis={x}") + LOG.error(f"values={y}") + raise + for axindex, axes in enumerate(self._axes): if update_only: plot_limits = self._backup_limits[axindex] @@ -252,7 +231,7 @@ def plot( else: for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - best_unit, best_axis = dataset.longest_axis() + _, best_axis = dataset.longest_axis() if dataset._name == r"<|q|> - q$_{target}$": axes.clear() xvals = dataset.x_axis(best_axis) @@ -277,80 +256,12 @@ def plot( ylimits[0], ylimits[1], ] - for axes in self._axes: legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) axes.relim() axes.autoscale() - if self._toolbar is not None: self._toolbar.update() - - target.canvas.draw() - - @staticmethod - def get_label(ind: int, n_curves: int, limit: int, label: str): - """Get label for legend. - - For the abbreviated legend return None for those which are - between ``limit`` and ``n_curves`` (skipping them), and "..." - when it's at the limit. - - Parameters - ---------- - ind : int - Current index. - n_curves : int - Total number of "curves" to plot. - limit : int - Max number of entries in legend. - label : str - Current labe. - """ - if ind == limit: - return "..." - if limit < ind < n_curves - 1: - return None - return label - - def _plot_single( - self, - axes: Axes, - curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], - databundle: PlotArgs, - *, - yerr: np.ndarray | None = None, - label: str, - ): - """Plot a single curve to axes. - - Parameters - ---------- - axes : Axes - Axis to plot to. - curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] - Curve to plot. - databundle : PlotArgs - Data to plot. - yerr : ndarray, optional - Error bars to add. - label : str - Plot label. - """ - line, _caps, _bars = axes.errorbar( - *curve, - yerr=yerr, - linestyle=databundle.line_style, - label=label, - color=databundle.colour, - ) - - axes.set_title(databundle.dataset._name) - - try: - line.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - line.set_marker(int(databundle.marker)) + self._figure.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py index a067f8381d..c0100bcc3b 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py @@ -24,6 +24,7 @@ from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt import NavigationToolbar2QT from matplotlib.figure import Figure from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -45,9 +46,9 @@ def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def clear(self, figure: Figure | None = None): """Clear the figure.""" @@ -81,7 +82,7 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar: type | None = None, + toolbar: NavigationToolbar2QT | None = None, ): """Plot all datasets in the same figure. @@ -99,60 +100,88 @@ def plot( """ self.enable_slider(allow_slider=False) target = self.get_figure(figure) + if target is None: return if toolbar is not None: self._toolbar = toolbar + self._figure = target self._normalisation_errors = [] self._axes = [] self.apply_settings(plotting_context) + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return + if not plotting_context.datasets(): target.clear() target.canvas.draw() - single_plot_stack = [221, 222, 223, 224] - first_shared_axes = target.add_subplot(single_plot_stack[0]) - for ds_index, databundle in enumerate(plotting_context.datasets().values()): + + gs = self._figure.add_gridspec(2, 2) + shared_axes = figure.add_subplot(gs[0]) + self._axes = [shared_axes] + + for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() plotlabel = databundle.legend_label - if ds_index < 3: - axes = first_shared_axes - temp_curves = axes.plot( - dataset.x_axis(best_axis), - dataset.data, - linestyle=databundle.line_style, + + x_axis_labels = [dataset.x_axis_label(databundle.main_axis)] + + bins = dataset.x_axis(databundle.main_axis) + + if "angle" in plotlabel: + axes = self._figure.add_subplot( + gs[2 + plotlabel.startswith("Azimuthal")] + ) + + axes.bar( + bins, + dataset.data[0, :], + linestyle="none", label=plotlabel, color=databundle.colour, + width=np.mean(np.diff(bins)), + ) + axes.bar( + bins, + dataset.data[1, :], + linestyle=databundle.line_style, + label="Unique counts", + color="none", + edgecolor="black", + width=np.mean(np.diff(bins)), ) - for temp in temp_curves: - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) self._axes.append(axes) - axes.set_xlabel(dataset.x_axis_label(best_axis)) + axes.set_xlabel(", ".join(x_axis_labels)) axes.set_title(dataset._name) - elif ds_index < 4: - axes = target.add_subplot( - single_plot_stack[ds_index - 2], projection="3d" + elif "vs" in plotlabel: + [curve] = shared_axes.plot( + bins, + dataset.data, + linestyle=databundle.line_style, + label=plotlabel, + color=databundle.colour, ) + try: + curve.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + curve.set_marker(int(databundle.marker)) + + shared_axes.set_xlabel(", ".join(x_axis_labels)) + + elif "in 3D" in plotlabel: + axes = self._figure.add_subplot(gs[1], projection="3d") + all_coords = np.concatenate( [dataset.x_axis("q_x"), dataset.x_axis("q_y"), dataset.data] ) self.axis3d_min = np.min(all_coords) self.axis3d_max = np.max(all_coords) - temp_curves = axes.scatter( + + axes.scatter( dataset.x_axis("q_x"), dataset.x_axis("q_y"), dataset.data, @@ -165,38 +194,7 @@ def plot( axes.set_xlabel(dataset.x_axis_label("q_x")) axes.set_ylabel(dataset.x_axis_label("q_y")) axes.set_title(dataset._name) - else: - dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() - plotlabel = databundle.legend_label - axes = target.add_subplot(single_plot_stack[ds_index - 2]) - bins = dataset.x_axis(best_axis) - temp_curves = axes.bar( - bins, - dataset.data[0, :], - linestyle="none", - label=plotlabel, - color=databundle.colour, - width=np.mean(np.diff(bins)), - ) - _ = axes.bar( - bins, - dataset.data[1, :], - linestyle=databundle.line_style, - label="Unique counts", - color="none", - edgecolor="black", - width=np.mean(np.diff(bins)), - ) - self._axes.append(axes) - axes.set_xlabel(dataset.x_axis_label(best_axis)) - axes.set_title(dataset._name) + for axes in self._axes: if update_only: try: @@ -214,15 +212,19 @@ def plot( else: xlimits, ylimits = axes.get_xlim(), axes.get_ylim() self._backup_limits = [xlimits[0], xlimits[1], ylimits[0], ylimits[1]] + for axes in self._axes: legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) axes.relim() axes.autoscale() + self.axes3d.set_xlim([self.axis3d_min, self.axis3d_max]) self.axes3d.set_ylim([self.axis3d_min, self.axis3d_max]) self.axes3d.set_zlim([self.axis3d_min, self.axis3d_max]) + if self._toolbar is not None: self._toolbar.update() + target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py index e8d56d0428..9ca5741840 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py @@ -436,7 +436,9 @@ def q_data(elem: str) -> Generator: common_bins = common_bins[start_index:] qmod_histograms = [np.histogram(qmods, common_bins)[0] for qmods in modq_per_shell] stacked_histograms = np.vstack(qmod_histograms) + xvals = common_bins[1:] - np.diff(common_bins) / 2 + nvec_per_q = SingleDataset( "Available vectors", None, diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py index df11224fd2..fc4f264f77 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py @@ -339,7 +339,7 @@ def plot_data(self, update_only=False): if self._plotting_context is None: return - self._figure.set_layout_engine("tight") + # self._figure.set_layout_engine("tight") self._plotter.plot( self._plotting_context, From ab30f1d8be8e722423007040ffc15969f6874f41 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Mon, 8 Jun 2026 17:05:57 +0100 Subject: [PATCH 8/9] Rework to use generators, add grouped option, have coloured output --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 53 +-- .../Src/MDANSE_GUI/Tabs/Plotters/Grid.py | 62 ++-- .../Src/MDANSE_GUI/Tabs/Plotters/Grouped.py | 326 ++++++++++++++++++ .../Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 73 ++-- .../Src/MDANSE_GUI/Tabs/Plotters/Plotter.py | 33 ++ .../Src/MDANSE_GUI/Tabs/Plotters/Single.py | 50 +-- .../Src/MDANSE_GUI/Tabs/Plotters/Text.py | 17 +- .../Src/MDANSE_GUI/Tabs/Plotters/__init__.py | 1 + 8 files changed, 480 insertions(+), 135 deletions(-) create mode 100644 MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index 139e455206..c2fc5c7aa3 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -497,23 +497,21 @@ def data(self) -> npt.NDArray[np.floating]: return self._data * self._scaling_factor return self._data - @overload - def generate_curve_label( - self, - index_tuple: Sequence[int], - axis_lookup: Iterable[str], - *, - skip_text: Literal[False] = False, - ) -> str: ... - @overload + @property + def n_curves(self) -> int: + """Number of curves in dataset.""" + if self._data_limits is None: + return 0 + + return len(self._data_limits) + def generate_curve_label( self, index_tuple: Sequence[int], axis_lookup: Iterable[str], *, - skip_text: Literal[True], - ) -> float: ... - def generate_curve_label(self, index_tuple, axis_lookup, *, skip_text=False): + skip_text: bool = False, + ) -> str: """Get a meaningful label for a subset of data. Used when plotting 1D arrays out of a multidimensional array. @@ -529,7 +527,7 @@ def generate_curve_label(self, index_tuple, axis_lookup, *, skip_text=False): Returns ------- - str | float + str A string label for the plot legend or a number for Text plotter. """ @@ -564,7 +562,7 @@ def generate_curve_label(self, index_tuple, axis_lookup, *, skip_text=False): label += f"{axis_label}={picked_value} {axis_unit}, " if skip_text: - return float(picked_value) + return str(float(picked_value)) return label.rstrip(", ") def curve_ind(self, limits: int | None = None, /): @@ -580,7 +578,7 @@ def curve_ind(self, limits: int | None = None, /): def curves_vs_axis( self, axis_label: tuple[str, str] | str, - max_limit: int = 1, + max_limit: int | None = None, *, axis_unit: str | None = None, skip_label_text: bool = False, @@ -624,7 +622,7 @@ def curves_vs_axis( x_axis = self.x_axis(axis_label) if self._data.ndim == 1: - yield None, (x_axis, self.data) + yield axis_label, (x_axis, self.data) return data_shape = self._data.shape @@ -657,6 +655,7 @@ def curves_vs_axis( try: index_tuple = nth_product(index, *indexer) index_slicer = nth_product(index, *slicer) + yield ( self.generate_curve_label( index_tuple, label_lookup, skip_text=skip_label_text @@ -1062,7 +1061,7 @@ def delete_dataset(self, index: QModelIndex): def planes( self, default_axis: int = 0, planes_per_dataset: int | None = None - ) -> Generator[tuple[PlotArgs, str, np.ndarray]]: + ) -> Generator[tuple[PlotArgs, str, npt.NDArray[np.floating]]]: for databundle in self.datasets().values(): ds = databundle.dataset @@ -1076,11 +1075,21 @@ def planes( def curves( self, curves_per_dataset: int | None = None - ) -> Generator[tuple[PlotArgs, str, np.ndarray]]: + ) -> Generator[ + Generator[ + tuple[ + PlotArgs, + str, + tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]], + ] + ] + ]: for databundle in self.datasets().values(): ds = databundle.dataset - for label, curve in islice( - ds.curves_vs_axis(databundle.main_axis), curves_per_dataset - ): - yield databundle, label, curve + yield ( + (databundle, label, curve) + for label, curve in islice( + ds.curves_vs_axis(databundle.main_axis), curves_per_dataset + ) + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index 756c149a3e..c348ac25b9 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -20,8 +20,7 @@ from itertools import islice from typing import TYPE_CHECKING, Any -from matplotlib import rcParams -from more_itertools import ilen +from more_itertools import flatten, ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter @@ -47,15 +46,14 @@ def __init__(self) -> None: self._active_curves = [] self._backup_curves = [] self._plot_limit = 9 - self._title_length_limit = 30 def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def check_curve_lengths(self): """Find the maximum number of elements in the x axes of the plot data.""" @@ -94,23 +92,16 @@ def change_normalisation(self, new_value: dict[str, Any]): self._toolbar.update() self._toolbar.push_current() - def title_fontsize(self, title_text: str) -> int: - normal_size = rcParams["font.size"] - new_size = ( - normal_size - if len(title_text) < self._title_length_limit - else normal_size - round(len(title_text) / self._title_length_limit) - ) - return new_size - def toggle_legend(self, enabled: bool) -> None: if self._figure is None: return - for plot_index, axes in enumerate(self._axes): + + for axes, title in zip(self._axes, self._axes_titles, strict=True): axes.set_title( - self._axes_titles[plot_index] if enabled else "", - fontsize=self.title_fontsize(self._axes_titles[plot_index]), + title if enabled else "", + fontsize=self.title_fontsize(title), ) + self._figure.canvas.draw() def plot( @@ -155,35 +146,31 @@ def plot( self._normalisation_errors = [] self.apply_settings(plotting_context) - nplots = 0 - for databundle in plotting_context.datasets().values(): - ds = databundle.dataset - try: - axis_info = ds._axes_units[databundle.main_axis], databundle.main_axis - except KeyError: - axis_info = ds.longest_axis() - curves = ds.curves_vs_axis(axis_info, max_limit=self._plot_limit) - nplots += ilen(curves) - - nplots = min(nplots, self._plot_limit) - gridsize = self.grid_size(nplots) + nplots = min( + sum(db.dataset.n_curves for db in plotting_context.datasets().values()), + self._plot_limit, + ) + grid_size = self.grid_size(nplots) + gs = self._figure.add_gridspec(*grid_size) for ind, (databundle, label, curve) in enumerate( - islice(plotting_context.curves(), self._plot_limit), 1 + islice(flatten(plotting_context.curves()), self._plot_limit) ): - axes = target.add_subplot(*gridsize, ind) + axes = target.add_subplot(gs[ind]) + self._axes_titles.append(databundle.dataset._name) + self._plot_single( axes, - (curve,), + curve, databundle, label=label, colour=databundle.colour, ) - axes.legend() - axes.grid(plotting_context.use_grid) + self._axes.append(axes) - axes.get_legend().set_visible(plotting_context.use_legend) + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) self.apply_settings(plotting_context) self.check_curve_lengths() @@ -200,7 +187,7 @@ def _plot_single( databundle: PlotArgs, *, label: str, - colour: tuple[float, float, float], + colour: tuple[float, float, float] | str, ): """Plot a single curve to axes. @@ -214,7 +201,7 @@ def _plot_single( Data to plot. label : str Plot label. - colour : tuple[float, float, float] + colour : tuple[float, float, float] | str Curve colour. """ lines: list[Line2D] = axes.plot( @@ -233,7 +220,6 @@ def _plot_single( with contextlib.suppress(Exception): line.set_marker(int(databundle.marker)) - self._axes.append(axes) self._active_curves.extend(lines) self._backup_curves.extend( (line.get_xdata(), line.get_ydata()) for line in lines diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py new file mode 100644 index 0000000000..2eadc81b97 --- /dev/null +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py @@ -0,0 +1,326 @@ +# This file is part of MDANSE_GUI. +# +# MDANSE_GUI is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +from __future__ import annotations + +import contextlib +import math +from itertools import islice +from typing import TYPE_CHECKING, Any + +import numpy as np +from more_itertools import ilen + +from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter + +if TYPE_CHECKING: + import numpy.typing as npt + from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar + from matplotlib.figure import Figure + from matplotlib.lines import Line2D + + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext + + +@Plotter.register("Grouped") +class Grouped(Plotter): + """Plots each dataset in its own subplot.""" + + def __init__(self) -> None: + super().__init__() + self._figure = None + self._backup_limits: list[tuple[float, float, float, float]] = [] + self._active_curves: list[Line2D] = [] + self._backup_curves: list[ + tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]] + ] = [] + self._plot_limit = 9 + self.height_max, self.length_max = -np.inf, -np.inf + + def slider_labels(self) -> list[str]: + """Return labels to show that sliders are not used.""" + return ["Y offset", "X offset"] + + def slider_limits(self) -> list[tuple[float, float, float]]: + """Return generic slider limit values.""" + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders + + def handle_slider(self, new_value: list[float]): + """Save slider values and call offset_curves.""" + super().handle_slider(new_value) + self.offset_curves() + + def offset_curves(self): + """Offset curves against each other based on slider settings.""" + target = self._figure + if target is None: + return + + new_value = self._slider_values + + backup = iter(self._backup_curves) + for ind, axes in enumerate(self._axes): + saved_xmin, saved_xmax, saved_ymin, saved_ymax = self._backup_limits[ind] + + for num, curve in enumerate(axes.get_lines()): + xdata, ydata = next(backup) + xdata, ydata = self.normalise_curve(xdata, ydata) + new_xdata = xdata + num * self.length_max * new_value[1] + new_ydata = ydata + num * self.height_max * new_value[0] + curve.set_xdata(new_xdata) + curve.set_ydata(new_ydata) + xmin, xmax = np.nanmin(new_xdata), np.nanmax(new_xdata) + ymin, ymax = np.nanmin(new_ydata), np.nanmax(new_ydata) + saved_xmin = np.nanmin([xmin, saved_xmin]) + saved_xmax = np.nanmax([xmax, saved_xmax]) + saved_ymin = np.nanmin([ymin, saved_ymin]) + saved_ymax = np.nanmax([ymax, saved_ymax]) + + axes.relim() + axes.autoscale() + + self._backup_limits[ind] = (saved_xmin, saved_xmax, saved_ymin, saved_ymax) + + try: + axes.set_xlim(saved_xmin, saved_xmax) + except ValueError: + LOG.error( + f"Matplotlib could not set x limits to {saved_xmin}, {saved_xmax}", + ) + + try: + axes.set_ylim(saved_ymin, saved_ymax) + except ValueError: + LOG.error( + f"Matplotlib could not set y limits to {saved_ymin}, {saved_ymax}", + ) + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + target.canvas.draw() + + def check_curve_lengths(self): + """Find the maximum number of elements in the x axes of the plot data.""" + self.curve_length_limit = 0 + for num, _ in enumerate(self._active_curves): + xdata = self._backup_curves[num][0] + self.curve_length_limit = max(self.curve_length_limit, len(xdata)) + + def change_normalisation(self, new_value: dict[str, Any]): + """Normalise the data based on the new parameters. + + Parameters + ---------- + new_value : dict[str, Any] + parameters as in NORMALISATION_DEFAULTS + + """ + super().change_normalisation(new_value) + target = self._figure + if target is None or not self._active_curves: + return + + for curve_index, curve in enumerate(self._active_curves): + xdata, ydata = self._backup_curves[curve_index] + xdata, ydata = self.normalise_curve(xdata, ydata) + curve.set_xdata(xdata) + curve.set_ydata(ydata) + + target.canvas.draw() + + for axes in self._axes: + axes.relim() + axes.autoscale() + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + def toggle_legend(self, enabled: bool) -> None: + if self._figure is None: + return + + for axes, title in zip(self._axes, self._axes_titles, strict=True): + axes.set_title( + title if enabled else "", + fontsize=self.title_fontsize(title), + ) + axes.get_legend().set_visible(enabled) + + self._figure.canvas.draw() + + def plot( + self, + plotting_context: PlottingContext, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, + ): + """Plot datasets in separate subplots. + + Parameters + ---------- + plotting_context : PlottingContext + Data model storing the data to be plotted. + figure : Figure, optional + Matplotlib figure instance for plotting, by default None. + update_only : bool, optional + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. + + """ + self.enable_slider(allow_slider=False) + target = self.get_figure(figure) + + if target is None: + return + + if toolbar is not None: + self._toolbar = toolbar + + if plotting_context.set_axes() is None: + LOG.debug("Axis check failed.") + return + + self.height_max, self.length_max = 0.0, 0.0 + self._figure = target + self._axes = [] + self._axes_titles = [] + self._backup_curves = [] + self._active_curves = [] + self._normalisation_errors = [] + self.apply_settings(plotting_context) + + nplots = min(ilen(plotting_context.curves()), self._plot_limit) + grid_size = self.grid_size(nplots) + gs = self._figure.add_gridspec(*grid_size) + limits = [(0.0, 0.0, 0.0, 0.0) for _ in range(nplots)] + + for ind, (db, dataclump) in enumerate( + islice( + zip( + plotting_context.datasets().values(), + plotting_context.curves(), + strict=True, + ), + self._plot_limit, + ) + ): + ds = db.dataset + axes = target.add_subplot(gs[ind]) + self._axes_titles.append(ds._name) + + colours = self.colours(db.colour, ds.n_curves) + + for (databundle, label, curve), colour in zip( + dataclump, colours, strict=True + ): + self._plot_single( + axes, + curve, + databundle, + label=label, + colour=colour, + ) + axes.legend() + self._axes.append(axes) + limits[ind] = (*axes.get_xlim(), *axes.get_ylim()) + + if update_only: + try: + axes.set_xlim(self._backup_limits[ind][:2]) + except ValueError: + LOG.error( + f"Matplotlib could not set x limits to {self._backup_limits[ind][:2]}" + ) + + try: + axes.set_ylim(self._backup_limits[ind][2:]) + except ValueError: + LOG.error( + f"Matplotlib could not set y limits to {self._backup_limits[ind][2:]}" + ) + + if not update_only: + self._backup_limits = limits + + self.enable_slider( + allow_slider=any( + db.dataset.n_curves > 1 for db in plotting_context.datasets().values() + ) + ) + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) + + self.apply_settings(plotting_context) + self.check_curve_lengths() + target.canvas.draw() + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + label: str, + colour: tuple[float, float, float] | str, + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + label : str + Plot label. + colour : tuple[float, float, float] | str + Curve colour. + """ + lines: list[Line2D] = axes.plot( + *curve, + linestyle=databundle.line_style, + label=label, + color=colour, + ) + + axes.set_xlabel(databundle.dataset.x_axis_label(databundle.main_axis)) + + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self.height_max = np.nanmax([self.height_max, np.nanmax(line.get_ydata())]) + self.length_max = np.nanmax([self.length_max, np.nanmax(line.get_xdata())]) + + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 26ff84637d..7574f3231f 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -29,9 +29,11 @@ from scipy.interpolate import interp1d from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + import numpy.typing as npt from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.image import AxesImage @@ -238,33 +240,11 @@ def plot( axes = self._figure.add_subplot(gs[ind]) - x_label = databundle.main_axis - y_label = nth(dataset._axes, self._slice_axis) - if y_label is None: - y_label = nth(dataset._axes, 1) - - x_axis = dataset.x_axis(x_label) - y_axis = dataset.x_axis(y_label) - - limits = (x_axis[0], x_axis[-1], y_axis[0], y_axis[-1]) - axes.set_xlabel(x_label) - axes.set_ylabel(y_label) - - image = axes.imshow( - plane, - extent=limits, - aspect="auto", - interpolation=None, - cmap=plotting_context.colormap, + image = self._plot_single( + axes, plane, databundle, label=label, colour=plotting_context.colormap ) - axes.set_title(label) - colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) - colorbar.set_label(dataset._data_unit) - xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - self._axes.append(axes) - self._backup[databundle.row].array = plane - self._backup[databundle.row].image = image + xlimits, ylimits = axes.get_xlim(), axes.get_ylim() interpolator = self._backup[databundle.row].interp last_minmax = ( interpolator(self._slider_values[0]), @@ -279,8 +259,6 @@ def plot( ) if update_only: - xlimits = axes.get_xlim() - ylimits = axes.get_ylim() self._backup[databundle.row].limits = ( xlimits[0], xlimits[1], @@ -372,3 +350,44 @@ def _get_datasets(axis: Axes) -> Iterator[AxesImage]: Each image in dataset. """ yield from axis.get_images() + + def _plot_single( + self, + axes: Axes, + plane: npt.NDArray[np.floating], + databundle: PlotArgs, + *, + label: str, + colour: str, + ) -> AxesImage: + + dataset = databundle.dataset + + x_label = databundle.main_axis + y_label = nth(dataset._axes, self._slice_axis) + if y_label is None: + y_label = nth(dataset._axes, 1) + + x_axis = dataset.x_axis(x_label) + y_axis = dataset.x_axis(y_label) + + limits = (x_axis[0], x_axis[-1], y_axis[0], y_axis[-1]) + axes.set_xlabel(x_label) + axes.set_ylabel(y_label) + + image = axes.imshow( + plane, + extent=limits, + aspect="auto", + interpolation=None, + cmap=colour, + ) + axes.set_title(label) + colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) + colorbar.set_label(dataset._data_unit) + + self._axes.append(axes) + self._backup[databundle.row].array = plane + self._backup[databundle.row].image = image + + return image diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index 4dd76d2fc1..ea797dc916 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -19,10 +19,13 @@ import csv import enum import math +from collections.abc import Generator from itertools import count from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO import numpy as np +from matplotlib import rcParams +from matplotlib.colors import to_rgb from more_itertools import consumer from MDANSE.Core.RegisterFactory import RegisterFactory @@ -109,6 +112,7 @@ class Plotter(RegisterFactory): 5: (2, 3), 6: (2, 3), } + _title_length_limit: ClassVar[int] = 30 def __init__(self) -> None: """Create defaults common to all plotters.""" @@ -505,3 +509,32 @@ def get_label(ind: int, n_curves: int, limit: int, label: str): if limit < ind < n_curves - 1: return None return label + + @staticmethod + def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: + """Generate colours from root colour. + + Parameters + ---------- + colour : str + Root colour. + + Returns + ------- + Generator[tuple[float, float, float]] + Next colour in sequence. + """ + main_colour = np.array(to_rgb(colour)) + colour_increment = (0.5 - main_colour) / n_curves + for _ in range(n_curves): + yield tuple(main_colour) + main_colour += colour_increment + + def title_fontsize(self, title_text: str) -> int: + normal_size = rcParams["font.size"] + new_size = ( + normal_size + if len(title_text) < self._title_length_limit + else normal_size - round(len(title_text) / self._title_length_limit) + ) + return new_size diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 6e8e777e79..709d6fbdc6 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -21,8 +21,7 @@ from typing import TYPE_CHECKING, Any import numpy as np -from matplotlib.colors import to_rgb -from more_itertools import ilen +from more_itertools import flatten, ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter @@ -48,6 +47,7 @@ def __init__(self) -> None: self._backup_curves: list[tuple[np.ndarray, np.ndarray]] = [] self._backup_limits = [] self._curve_limit_per_dataset = 12 + self._plot_limit = 60 self.height_max, self.length_max = 0.0, 0.0 def clear(self, figure: Figure | None = None): @@ -70,9 +70,9 @@ def slider_labels(self) -> list[str]: """Return slider labels for single plot mode.""" return ["Y offset", "X offset"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return slider limits for single plot mode.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.001]] + return [(-1.0, 1.0, 0.001)] * self._number_of_sliders def handle_slider(self, new_value: list[float]): """Save slider values and call offset_curves.""" @@ -142,26 +142,6 @@ def check_curve_lengths(self): xdata = self._backup_curves[num][0] self.curve_length_limit = max(self.curve_length_limit, len(xdata)) - @staticmethod - def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: - """Generate colours from root colour. - - Parameters - ---------- - colour : str - Root colour. - - Returns - ------- - Generator[tuple[float, float, float]] - Next colour in sequence. - """ - main_colour = np.array(to_rgb(colour)) - colour_increment = (0.5 - main_colour) / n_curves - for _ in range(n_curves): - yield tuple(main_colour) - main_colour += colour_increment - def plot( self, plotting_context: PlottingContext, @@ -198,8 +178,6 @@ def plot( self._normalisation_errors = [] x_axis_labels = [] - axes = target.add_subplot(111) - self._axes = [axes] self.apply_settings(plotting_context) self.height_max, self.length_max = 0.0, 0.0 @@ -220,17 +198,17 @@ def plot( colours = {} for databundle in plotting_context.datasets().values(): - n_curves = ilen( - databundle.dataset.curves_vs_axis( - databundle.main_axis, - self._curve_limit_per_dataset, - ) + colours[databundle.row] = self.colours( + databundle.colour, databundle.dataset.n_curves ) - colours[databundle.row] = self.colours(databundle.colour, n_curves) x_axis_labels.append(databundle.dataset.x_axis_label(databundle.main_axis)) - for databundle, _, curve in plotting_context.curves( - self._curve_limit_per_dataset + axes = target.add_subplot(111) + self._axes = [axes] + + for databundle, _, curve in islice( + flatten(plotting_context.curves(self._curve_limit_per_dataset)), + self._plot_limit, ): self._plot_single( axes, @@ -300,8 +278,8 @@ def _plot_single( with contextlib.suppress(Exception): line.set_marker(int(databundle.marker)) - self.height_max = max(self.height_max, line.get_ydata().max()) - self.length_max = max(self.length_max, line.get_xdata().max()) + self.height_max = np.nanmax([self.height_max, np.nanmax(line.get_ydata())]) + self.length_max = np.nanmax([self.length_max, np.nanmax(line.get_xdata())]) self._active_curves.extend(lines) self._backup_curves.extend( diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py index 542b55de07..13e36f470b 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING import numpy as np -from more_itertools import collapse, prepend, transpose +from more_itertools import collapse, prepend, transpose, unzip from MDANSE.Framework.Units import measure from MDANSE.MLogging import LOG @@ -303,21 +303,14 @@ def process_2D_data( dataset._data.shape[0] if flip_array else dataset._data.shape[1] ) - multi_curves = np.vstack( - list( - dataset.curves_vs_axis( - (best_unit, best_axis), max_limit=curves_limit, skip_label_text=True - ).values() - ) - ) # Add corner nil xaxis = prepend("_", new_axes[axis_numbers[flip_array]].flat) # Add axes to data - data_lines = zip( - dataset._curve_labels.values(), - multi_curves, - strict=True, + data_lines = list( + dataset.curves_vs_axis( + (best_unit, best_axis), max_limit=curves_limit, skip_label_text=True + ) ) # Put xaxis in diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py index 1085cff763..ccedf4b705 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py @@ -16,6 +16,7 @@ from __future__ import annotations from .Grid import Grid as Grid +from .Grouped import Grouped as Grouped from .Heatmap import Heatmap as Heatmap from .Plotter import Plotter as Plotter from .Single import Single as Single From 3b69564b0561d3286031e77c94295e269f23add6 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Mon, 8 Jun 2026 17:27:35 +0100 Subject: [PATCH 9/9] Fix misc issues --- .../MDANSE_GUI/Tabs/Models/PlottingContext.py | 92 ++++++++----- .../Src/MDANSE_GUI/Tabs/Plotters/Grid.py | 9 +- .../Src/MDANSE_GUI/Tabs/Plotters/Grouped.py | 17 ++- .../Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py | 121 +++++++++--------- .../Src/MDANSE_GUI/Tabs/Plotters/Plotter.py | 25 +++- .../Src/MDANSE_GUI/Tabs/Plotters/Single.py | 21 +-- .../Src/MDANSE_GUI/Tabs/Plotters/Vectors.py | 2 +- .../MDANSE_GUI/Tabs/Visualisers/PlotWidget.py | 4 +- 8 files changed, 171 insertions(+), 120 deletions(-) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index c2fc5c7aa3..11e9f667c5 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -17,12 +17,11 @@ import copy import functools -from collections.abc import Generator, Iterable, Sequence from contextlib import suppress from itertools import islice from math import prod from pathlib import Path -from typing import TYPE_CHECKING, Literal, NamedTuple, overload +from typing import TYPE_CHECKING, NamedTuple import h5py import matplotlib.pyplot as mpl @@ -32,16 +31,16 @@ from matplotlib.colors import to_hex as mpl_to_hex from matplotlib.lines import lineStyles from matplotlib.markers import MarkerStyle -from more_itertools import first, locate, nth, nth_product, sort_together, unzip +from more_itertools import first, locate, nth, nth_product from qtpy.QtCore import QModelIndex, Qt, Signal, Slot from qtpy.QtGui import QColor, QStandardItem, QStandardItemModel from MDANSE.IO.IOUtils import summarise_array from MDANSE.MLogging import LOG -from MDANSE.util_types import ComplexArray, FloatArray +from MDANSE.util_types import ComplexArray, FloatArray, IntArray if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Generator, Iterable, Iterator, Sequence import h5py @@ -484,7 +483,7 @@ def dep_axes(self) -> dict[str, FloatArray]: return {aname: axis for aname, axis in self._axes.items() if aname != la} @property - def data(self) -> npt.NDArray[np.floating]: + def data(self) -> FloatArray: """Data array, scaled if requested in the GUI table. Returns @@ -605,9 +604,9 @@ def curves_vs_axis( ------ str Plot label. - npt.NDArray[np.floating] + FloatArray x-axis. - npt.NDArray[np.floating] + FloatArray Curve to plot. """ match axis_label: @@ -622,7 +621,7 @@ def curves_vs_axis( x_axis = self.x_axis(axis_label) if self._data.ndim == 1: - yield axis_label, (x_axis, self.data) + yield "", (x_axis, self.data) return data_shape = self._data.shape @@ -642,7 +641,7 @@ def curves_vs_axis( slicer.append([slice(None)]) continue - indices: npt.NDArray[np.integer] = np.arange(data_shape[current_dim]) + indices: IntArray = np.arange(data_shape[current_dim]) slicer.append(indices) indexer.append(indices) label_lookup.append(axis_name) @@ -689,67 +688,94 @@ def curves_vs_axis( def planes_vs_axis( self, - axis_number: int, - max_limit: int = 1, - ) -> Generator[tuple[str, FloatArray]]: + main_axis: str, + max_limit: int = 9, + ) -> Generator[tuple[str, FloatArray, tuple[str, str]]]: """Prepare for plotting 2D subsets of an ND array. Parameters ---------- - axis_number : int - index of the axis perpendicular to the plotted array + main_axis : str + Label of the axis perpendicular to the plotted array. Yields ------ - str + main_label : str Grid label. - npt.NDArray[np.floating] + image_array : FloatArray 2D array. - + axis_labels : tuple[str, ...] + Labels for each axis. """ + main_axis_index = self.main_axis_index(main_axis) + other_labels = self._axis_labels(main_axis) + match self._data.ndim: case 1: pass - case 2 if axis_number == 1: - yield self._labels["medium"], self.data.T + case 2 if main_axis_index == 1: + yield self._labels["medium"], self.data.T, (main_axis, other_labels[0]) case 2: - yield self._labels["medium"], self.data + yield self._labels["medium"], self.data, (main_axis, other_labels[0]) case 3: perpendicular_axis_name, perpendicular_axis = nth( - self._axes.items(), axis_number, default=(None, None) + self._axes.items(), main_axis_index, default=(None, None) ) if perpendicular_axis is None: return - reordered_view = np.moveaxis(self.data, axis_number, 0) + reordered_view = np.moveaxis(self.data, main_axis_index, 0) for plane_number in self.curve_ind(max_limit): + if plane_number > len(reordered_view): + continue + yield ( f"{self._labels['minimal']}:{perpendicular_axis_name}={perpendicular_axis[plane_number]}", reordered_view[plane_number], + other_labels, ) case _: raise NotImplementedError( f"Cannot handle {self._data.ndim}-dimensional data." ) - def main_axis_index(self, main_axis: str | None, *, default: int) -> int: + def main_axis_index( + self, main_axis: str | None, *, default: int | None = None + ) -> int: """Find index of main axis. Parameters ---------- main_axis : str Main axis name to search for. - default : int + default : int, optional Index if ``main_axis`` not found. Returns ------- int Index of main axis. + + Raises + ------ + ValueError + Axis not found and no default. """ - return first(locate(self._axes, pred=lambda x: x == main_axis), default) + ind = first(locate(self._axes, pred=main_axis.__eq__), default) + if ind is None: + raise ValueError( + f"Cannot find axis {main_axis} in {','.join(self._axes.keys())}" + ) + return ind + + def _axis_labels(self, main_axis: str) -> tuple[str] | tuple[str, str]: + main_axis_index = self.main_axis_index(main_axis) + + return tuple( + label for i, label in enumerate(self._axes) if i != main_axis_index + ) plotting_column_labels = [ @@ -1060,18 +1086,16 @@ def delete_dataset(self, index: QModelIndex): self._datasets.pop(dkey, None) def planes( - self, default_axis: int = 0, planes_per_dataset: int | None = None - ) -> Generator[tuple[PlotArgs, str, npt.NDArray[np.floating]]]: + self, default_axis: int | None = None, planes_per_dataset: int | None = None + ) -> Generator[tuple[PlotArgs, str, FloatArray, tuple[str, str]]]: for databundle in self.datasets().values(): ds = databundle.dataset - for label, plane in islice( - ds.planes_vs_axis( - ds.main_axis_index(databundle.main_axis, default=default_axis) - ), + for label, plane, axis_labels in islice( + ds.planes_vs_axis(databundle.main_axis), planes_per_dataset, ): - yield databundle, label, plane + yield databundle, label, plane, axis_labels def curves( self, curves_per_dataset: int | None = None @@ -1080,7 +1104,7 @@ def curves( tuple[ PlotArgs, str, - tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]], + tuple[FloatArray, FloatArray], ] ] ]: diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index c348ac25b9..2490b3a7e5 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -146,27 +146,26 @@ def plot( self._normalisation_errors = [] self.apply_settings(plotting_context) - nplots = min( + self._n_curves = min( sum(db.dataset.n_curves for db in plotting_context.datasets().values()), self._plot_limit, ) - grid_size = self.grid_size(nplots) + grid_size = self.grid_size(self.n_curves) gs = self._figure.add_gridspec(*grid_size) for ind, (databundle, label, curve) in enumerate( islice(flatten(plotting_context.curves()), self._plot_limit) ): axes = target.add_subplot(gs[ind]) - self._axes_titles.append(databundle.dataset._name) + self._axes_titles.append(f"{databundle.dataset._name} {label}") self._plot_single( axes, curve, databundle, - label=label, + label="", colour=databundle.colour, ) - axes.legend() self._axes.append(axes) self.toggle_legend(plotting_context.use_legend) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py index 2eadc81b97..3a310a3305 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py @@ -33,6 +33,7 @@ from matplotlib.figure import Figure from matplotlib.lines import Line2D + from MDANSE.util_types import FloatArray from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext @@ -46,7 +47,7 @@ def __init__(self) -> None: self._backup_limits: list[tuple[float, float, float, float]] = [] self._active_curves: list[Line2D] = [] self._backup_curves: list[ - tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]] + tuple[FloatArray, FloatArray] ] = [] self._plot_limit = 9 self.height_max, self.length_max = -np.inf, -np.inf @@ -201,14 +202,17 @@ def plot( self.height_max, self.length_max = 0.0, 0.0 self._figure = target + self._axes = [] self._axes_titles = [] self._backup_curves = [] self._active_curves = [] self._normalisation_errors = [] + self.apply_settings(plotting_context) - nplots = min(ilen(plotting_context.curves()), self._plot_limit) + self._n_curves = sum(ilen(curves) for curves in plotting_context.curves()) + nplots = min(len(plotting_context.datasets()), self._plot_limit) grid_size = self.grid_size(nplots) gs = self._figure.add_gridspec(*grid_size) limits = [(0.0, 0.0, 0.0, 0.0) for _ in range(nplots)] @@ -229,16 +233,18 @@ def plot( colours = self.colours(db.colour, ds.n_curves) - for (databundle, label, curve), colour in zip( - dataclump, colours, strict=True + for curve_ind, ((databundle, label, curve), colour) in enumerate( + zip(dataclump, colours, strict=True) ): self._plot_single( axes, curve, databundle, + ind=curve_ind, label=label, colour=colour, ) + axes.legend() self._axes.append(axes) limits[ind] = (*axes.get_xlim(), *axes.get_ylim()) @@ -283,6 +289,7 @@ def _plot_single( curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], databundle: PlotArgs, *, + ind: int, label: str, colour: tuple[float, float, float] | str, ): @@ -304,7 +311,7 @@ def _plot_single( lines: list[Line2D] = axes.plot( *curve, linestyle=databundle.line_style, - label=label, + label=self.label(label, ind, n_curves=databundle.dataset.n_curves), color=colour, ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 7574f3231f..e6b4db8805 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -33,11 +33,11 @@ from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: - import numpy.typing as npt from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.image import AxesImage + from MDANSE.util_types import FloatArray from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -48,17 +48,19 @@ class Heatmap(Plotter): @dataclass class BackupInfo: ind: int - image: AxesImage | None = None - array: np.ndarray = field(default_factory=lambda: np.empty((0,), dtype=float)) + image: AxesImage = None + array: FloatArray = field(default_factory=lambda: np.empty((0,), dtype=float)) minmax: tuple[float, float] = (-np.inf, np.inf) limits: tuple[float, float, float, float] = (-np.inf, np.inf, -np.inf, np.inf) interp: interp1d = None + title: str = "" def __init__(self) -> None: """Initialise all plotting parameters to defaults.""" super().__init__() self._figure = None - self._backup: dict[int, Heatmap.BackupInfo] = {} + self._axes_titles: list[str] = [] + self._backup: dict[str, Heatmap.BackupInfo] = {} self._current_x_axes = [] self._initial_values = [0.0, 100.0] self._slider_values = [0.0, 100.0] @@ -116,9 +118,8 @@ def change_normalisation(self, new_value: dict[str, Any]): def handle_slider(self, new_value: list[float]): """Adjust colormap values based on slider values.""" super().handle_slider(new_value) - target = self._figure - if target is None or new_value[1] <= new_value[0]: + if self._figure is None or new_value[1] <= new_value[0]: return self._slider_values = [new_value[0], new_value[1]] @@ -153,7 +154,19 @@ def handle_slider(self, new_value: list[float]): else: self._figure.canvas.draw_idle() backup.minmax = (newmin, newmax) - target.canvas.draw() + self._figure.canvas.draw() + + def toggle_legend(self, enabled: bool) -> None: + if self._figure is None: + return + + for axes, title in zip(self._axes, self._axes_titles, strict=True): + axes.set_title( + title if enabled else "", + fontsize=self.title_fontsize(title), + ) + + self._figure.canvas.draw() def check_curve_lengths(self): """Find the maximum number of elements in the x axes of the plot data.""" @@ -192,63 +205,65 @@ def plot( self._figure.set_layout_engine(layout="constrained") self._current_x_axes = [] - # minmax_bak = {key: val.minmax for key, val in self._backup.items()} scale_interpolators = {val.ind: val.interp for val in self._backup.values()} - self._backup = { - databundle.row: self.BackupInfo(ind=databundle.row) - for databundle in plotting_context.datasets().values() - } - self._axes = [] + + self._backup.clear() + self._axes.clear() + self._axes_titles.clear() self.apply_settings(plotting_context) if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return - nplots = min(ilen(plotting_context.planes(self._slice_axis)), self._plot_limit) + self._n_curves = min( + ilen(plotting_context.planes(self._slice_axis)), self._plot_limit + ) - if not nplots: + if not self.n_curves: self.plot_blank() return - # Check interpolators - for databundle in plotting_context.datasets().values(): + def get_interp(ind: int, data: FloatArray): + # Check interpolators try: - scale_interpolators[databundle.row](51.2) + scale_interpolators[ind](51.2) except Exception: percentiles = np.linspace(0, 100.0, 21) results = [ - np.percentile(np.nan_to_num(databundle.dataset._data), perc) - for perc in percentiles + np.percentile(np.nan_to_num(data), perc) for perc in percentiles ] - self._backup[databundle.row].interp = interp1d( - percentiles, - results, - ) + return interp1d(percentiles, results) else: - self._backup[databundle.row].interp = scale_interpolators[ - databundle.row - ] + return scale_interpolators[ind] - grid_size = self.grid_size(nplots) + grid_size = self.grid_size(self.n_curves) gs = self._figure.add_gridspec(*grid_size) - for ind, (databundle, label, plane) in enumerate( + for ind, (databundle, label, plane, axis_labels) in enumerate( islice(plotting_context.planes(self._slice_axis), self._plot_limit), ): + self._backup[label] = Heatmap.BackupInfo(ind=ind) dataset = databundle.dataset axes = self._figure.add_subplot(gs[ind]) image = self._plot_single( - axes, plane, databundle, label=label, colour=plotting_context.colormap + axes, + plane[:, ::-1].T, + databundle, + label=label, + axis_labels=axis_labels, + colour=plotting_context.colormap, ) xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - interpolator = self._backup[databundle.row].interp + + self._backup[label].interp = get_interp(ind, plane) + last_minmax = ( - interpolator(self._slider_values[0]), - interpolator(self._slider_values[1]), + self._backup[label].interp(self._slider_values[0]), + self._backup[label].interp(self._slider_values[1]), ) try: @@ -259,30 +274,18 @@ def plot( ) if update_only: - self._backup[databundle.row].limits = ( - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ) + self._backup[label].limits = (*xlimits, *ylimits) else: - self._backup[databundle.row].minmax = ( + self._backup[label].minmax = ( np.nanmin(dataset._data), np.nanmax(dataset._data), ) - self._backup[databundle.row].limits = ( - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ) + self._backup[label].limits = (*xlimits, *ylimits) - if ind > 1: - legend = axes.legend() - legend.set_visible(plotting_context.use_legend) - axes.grid(plotting_context.use_grid) + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) - if nplots == 1: # Exploit label from loop for one plot + if self.n_curves == 1: # Exploit label from loop for one plot self._figure.suptitle(label) self.check_curve_lengths() @@ -354,19 +357,16 @@ def _get_datasets(axis: Axes) -> Iterator[AxesImage]: def _plot_single( self, axes: Axes, - plane: npt.NDArray[np.floating], + plane: FloatArray, databundle: PlotArgs, *, label: str, + axis_labels: tuple[str, str], colour: str, ) -> AxesImage: - dataset = databundle.dataset - x_label = databundle.main_axis - y_label = nth(dataset._axes, self._slice_axis) - if y_label is None: - y_label = nth(dataset._axes, 1) + x_label, y_label = axis_labels x_axis = dataset.x_axis(x_label) y_axis = dataset.x_axis(y_label) @@ -382,12 +382,13 @@ def _plot_single( interpolation=None, cmap=colour, ) - axes.set_title(label) + colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) colorbar.set_label(dataset._data_unit) self._axes.append(axes) - self._backup[databundle.row].array = plane - self._backup[databundle.row].image = image + self._axes_titles.append(label) + self._backup[label].array = plane + self._backup[label].image = image return image diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index ea797dc916..ecc44f98e2 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -19,7 +19,6 @@ import csv import enum import math -from collections.abc import Generator from itertools import count from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO @@ -34,7 +33,7 @@ from MDANSE.util_types import FloatArray if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Generator, Iterator from matplotlib.axes import Axes from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar @@ -113,6 +112,7 @@ class Plotter(RegisterFactory): 6: (2, 3), } _title_length_limit: ClassVar[int] = 30 + _max_labels = 5 def __init__(self) -> None: """Create defaults common to all plotters.""" @@ -127,6 +127,7 @@ def __init__(self) -> None: self.curve_length_limit = 10 self._normalisation_values = copy.copy(NORMALISATION_DEFAULTS) self._normalisation_errors = [] + self._n_curves = 0 def request_slider_values(self): """Manually read values from sliders, if they are present.""" @@ -485,8 +486,14 @@ def grid_size(cls, n_plots: int) -> tuple[int, int]: """ return cls.GRID_SIZES.get(n_plots, (math.ceil(n_plots**0.5),) * 2) - @staticmethod - def get_label(ind: int, n_curves: int, limit: int, label: str): + def label( + self, + label: str, + ind: int, + *, + n_curves: int | None = None, + limit: int | None = None, + ) -> str | None: """Get label for legend. For the abbreviated legend return None for those which are @@ -504,12 +511,22 @@ def get_label(ind: int, n_curves: int, limit: int, label: str): label : str Current label. """ + if n_curves is None: + n_curves = self.n_curves + if limit is None: + limit = self._max_labels + if ind == limit: return "..." if limit < ind < n_curves - 1: return None return label + @property + def n_curves(self) -> int: + """Number of expected/total curves.""" + return self._n_curves + @staticmethod def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: """Generate colours from root colour. diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 709d6fbdc6..d3ac4cfa50 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -186,15 +186,15 @@ def plot( LOG.debug("Axis check failed.") return - total_n_curves = sum( - map(ilen, plotting_context.curves(self._curve_limit_per_dataset)) + self._n_curves = sum( + ds.dataset.n_curves for ds in plotting_context.datasets().values() ) - if not total_n_curves: + if not self._n_curves: self.plot_blank() return - self.enable_slider(allow_slider=total_n_curves > 1) + self.enable_slider(allow_slider=self._n_curves > 1) colours = {} for databundle in plotting_context.datasets().values(): @@ -206,7 +206,7 @@ def plot( axes = target.add_subplot(111) self._axes = [axes] - for databundle, _, curve in islice( + for databundle, label, curve in islice( flatten(plotting_context.curves(self._curve_limit_per_dataset)), self._plot_limit, ): @@ -214,6 +214,7 @@ def plot( axes, curve, databundle, + label=f"{databundle.legend_label} {label}", colour=next(colours[databundle.row]), ) @@ -235,11 +236,11 @@ def plot( self._backup_limits = [*axes.get_xlim(), *axes.get_ylim()] axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + axes.legend() - if plotting_context.use_legend: - axes.legend() + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) - axes.grid(plotting_context.use_grid) self.check_curve_lengths() self.offset_curves() @@ -249,6 +250,7 @@ def _plot_single( curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], databundle: PlotArgs, *, + label: str, colour: tuple[float, float, float], ): """Plot a single curve to axes. @@ -264,10 +266,11 @@ def _plot_single( colour : tuple[float, float, float] Curve colour. """ + lines: list[Line2D] = axes.plot( *curve, linestyle=databundle.line_style, - label=databundle.legend_label, + label=self.label(label, len(self._active_curves)), color=colour, ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py index 307ba1f70c..79fed0ccdc 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py @@ -197,7 +197,7 @@ def plot( try: axes.bar( *curve, - label=self.get_label( + label=self.label( ind=ind, n_curves=n_curves, limit=self._legend_limit_for_histogram, diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py index fc4f264f77..1c53545c07 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py @@ -172,10 +172,10 @@ def set_values(self, new_values: list[float]): @Slot() def slider_to_box(self): """Update spin boxes if slider is moving.""" - vals = np.zeros_like(self._valarray) clicks = np.zeros_like(self._clickarray) for ns, slider in enumerate(self._sliders): clicks[ns] = slider.value() + vals = self._minarray + clicks * self._steparray for ns, box in enumerate(self._spinboxes): box.setValue(vals[ns]) @@ -185,7 +185,7 @@ def box_to_slider(self): """Update sliders if spin boxes have changed.""" with block_signals(self): vals = np.zeros_like(self._valarray) - clicks = np.zeros_like(self._clickarray) + for ns, box in enumerate(self._spinboxes): vals[ns] = box.value() clicks = np.round((vals - self._minarray) / self._steparray).astype(int)