diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index b5beb9eaf1..11e9f667c5 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -31,16 +31,16 @@ from matplotlib.colors import to_hex as mpl_to_hex from matplotlib.lines import lineStyles from matplotlib.markers import MarkerStyle -from more_itertools import nth_product +from more_itertools import first, locate, nth, nth_product from qtpy.QtCore import QModelIndex, Qt, Signal, Slot from qtpy.QtGui import QColor, QStandardItem, QStandardItemModel from MDANSE.IO.IOUtils import summarise_array from MDANSE.MLogging import LOG -from MDANSE.util_types import ComplexArray, FloatArray +from MDANSE.util_types import ComplexArray, FloatArray, IntArray if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Generator, Iterable, Iterator, Sequence import h5py @@ -52,7 +52,7 @@ class PlotArgs(NamedTuple): """Arguments for plotting data.""" - dataset: FloatArray | ComplexArray + dataset: SingleDataset colour: str line_style: str marker: str @@ -95,13 +95,9 @@ def __init__( ): self._name = name self._use_scaling = True - self._curves: dict[tuple[int, ...], FloatArray] = {} - self._curve_labels: dict[tuple[int, ...], str] = {} self._linestyle = linestyle self._marker = marker - self._planes: dict[int, FloatArray] = {} - self._plane_labels: dict[int, str] = {} - self._data_limits: list[int] | None = None + self._data_limits = None self._imaginary_data = None self._valid = True self._scaling_factor = 1.0 @@ -237,11 +233,11 @@ def _( self._axes_order.append(axis_key) self._axes_scaling[axis_key] = 1.0 self._current_units[axis_key] = self._axes_units[axis_key] - self._axes_tag = "|".join([str(x) for x in self._axes]) + self._axes_tag = "|".join(str(x) for x in self._axes) return for axis_key, axis_array in plot_axes.items(): - self._axes_tag = "|".join([str(x) for x in plot_axes]) + self._axes_tag = "|".join(str(x) for x in plot_axes) self._axes[axis_key] = axis_array self._axes_units[axis_key] = ( "N/A" if axes_units is None else axes_units[axis_key] @@ -266,9 +262,11 @@ def create_axes_tags(self, axes_tag: str, source: h5py.File): self._axes[f"index{dim_number}"] = np.arange(dim_length) self._axes_units[f"index{dim_number}"] = "N/A" return + self._current_units = {} self._axes_scaling = {} self._axes_order = [] + for ax_number, axis_name in enumerate(axes_tag.split("|")): aname = axis_name.strip() if aname == "index": @@ -326,6 +324,7 @@ def set_current_units(self, unit_lookup): """Update the unit based on the unit lookup of the PlottingContext.""" if unit_lookup is None: return + for axis_name, axis_unit in self._axes_units.items(): factor, new_unit = unit_lookup.conversion_factor(axis_unit) self._axes_scaling[axis_name] = factor @@ -484,7 +483,7 @@ def dep_axes(self) -> dict[str, FloatArray]: return {aname: axis for aname, axis in self._axes.items() if aname != la} @property - def data(self): + def data(self) -> FloatArray: """Data array, scaled if requested in the GUI table. Returns @@ -497,35 +496,45 @@ def data(self): return self._data * self._scaling_factor return self._data + @property + def n_curves(self) -> int: + """Number of curves in dataset.""" + if self._data_limits is None: + return 0 + + return len(self._data_limits) + def generate_curve_label( self, - index_tuple: list[int], - axis_lookup: list[str], + index_tuple: Sequence[int], + axis_lookup: Iterable[str], *, skip_text: bool = False, - ) -> str | float: + ) -> str: """Get a meaningful label for a subset of data. Used when plotting 1D arrays out of a multidimensional array. Parameters ---------- - index_tuple : list[int] + index_tuple : Sequence[int] indices of the 1D data array position in the ND array. - axis_lookup : list[str] + axis_lookup : Iterable[str] Names of the axes to use. skip_text : bool, optional If set to true, omits the text parts of the label. By default False. Returns ------- - str | float + str A string label for the plot legend or a number for Text plotter. """ if self._n_dim < 2: return "" + label = "at " + for axis_index, axis_name in enumerate(axis_lookup): axis_label = self.axis_true_name(axis_name) @@ -534,15 +543,15 @@ def generate_curve_label( picked_value = axis_values[index_tuple[axis_index]] if len(axis_values) > 1: - significant_digit = np.floor( - np.log10(abs(np.mean(axis_values[1:] - axis_values[:-1]))), - ).astype(int) + data = np.mean(np.diff(axis_values)) elif len(axis_values) == 1: - significant_digit = np.floor(np.log10(abs(axis_values[0]))).astype(int) + data: np.floating = axis_values[0] else: label += f"{axis_label} has no values, unit {axis_unit}" continue + significant_digit: np.integer = np.floor(np.log10(np.abs(data))).astype(int) + if significant_digit < -20: picked_value = 0 elif significant_digit < 0: @@ -552,149 +561,221 @@ def generate_curve_label( label += f"{axis_label}={picked_value} {axis_unit}, " if skip_text: - return float(picked_value) + return str(float(picked_value)) return label.rstrip(", ") + def curve_ind(self, limits: int | None = None, /): + """Get indices of valid axes. + + Parameters + ---------- + limits : int + Max number of curves to return. + """ + return islice(self._data_limits, limits) + def curves_vs_axis( self, - x_axis_details: tuple[str, str], - max_limit: int = 1, + axis_label: tuple[str, str] | str, + max_limit: int | None = None, *, + axis_unit: str | None = None, skip_label_text: bool = False, - ) -> dict[int, FloatArray]: + ) -> Generator[ + tuple[ + str | float | None, + tuple[FloatArray, FloatArray], + ] + ]: """Prepare a set of curves for plotting. Parameters ---------- - x_axis_details : tuple[str, str] - Name and original unit of the primary plotting axis + axis_label : str + Name of the primary plotting axis. + axis_unit : str, optional + Unit of the primary plotting axis. max_limit : int, optional Maximum number of curves allowed by plotter, by default 1 skip_label_text: bool, optional Whether to skip the axis name and unit in the curve label, by default False. - Returns - ------- - dict[int, FloatArray] - List of data arrays ready for plotting - + Yields + ------ + str + Plot label. + FloatArray + x-axis. + FloatArray + Curve to plot. """ - self._curves = {} - self._curve_labels = {} + match axis_label: + case (str(unit), str(label)): + axis_unit = unit + axis_label = label + case str(): + pass + case _: + raise ValueError(f"Cannot handle {axis_label} as axis label") + + x_axis = self.x_axis(axis_label) if self._data.ndim == 1: - self._curves[(0,)] = self.data - self._curve_labels[(0,)] = "" - return self.data + yield "", (x_axis, self.data) + return data_shape = self._data.shape - x_axis_unit, x_axis_name = x_axis_details slicer = [] - indexer = [] + indexer: list[Sequence[int]] = [] label_lookup = [] axis_lengths = [len(self._axes[name]) for name in self._axes_order] + match_unit = axis_unit is not None + + if np.allclose(data_shape, axis_lengths): + for current_dim, axis_name in enumerate(self._axes_order): + curr_unit = self._axes_units[axis_name] + + if axis_name == axis_label and ( + not match_unit or (axis_unit == curr_unit) + ): + slicer.append([slice(None)]) + continue + + indices: IntArray = np.arange(data_shape[current_dim]) + slicer.append(indices) + indexer.append(indices) + label_lookup.append(axis_name) + + if not indexer: + LOG.warning("Empty selection for data set %s", self._name) + return + + for index in self.curve_ind(max_limit): + try: + index_tuple = nth_product(index, *indexer) + index_slicer = nth_product(index, *slicer) + + yield ( + self.generate_curve_label( + index_tuple, label_lookup, skip_text=skip_label_text + ), + (x_axis, self.data[index_slicer].squeeze()), + ) + except IndexError: + LOG.warning( + "Skipping: in dataset %s, index %s is out of bounds", + self._name, + index, + ) + + elif ( + len(axis_lengths) == 1 + and len(data_shape) == 2 + and data_shape[0] == axis_lengths[0] + ): + # Assume multiple lines in block - if not np.allclose(data_shape, axis_lengths): - raise ValueError("Array shape does not match the order of the axes") + axis_name = first(self._axes_order) - for current_dim, axis_name in enumerate(self._axes_order): - axis_unit = self._axes_units[axis_name] - if axis_unit == x_axis_unit and axis_name == x_axis_name: - slicer.append([slice(None)]) - continue + for current_dim in range(data_shape[1]): + yield ( + axis_name, + (x_axis, self.data[:, current_dim]), + ) - indices = np.arange(data_shape[current_dim]) - slicer.append(indices) - indexer.append(indices) - label_lookup.append(axis_name) + else: + raise ValueError("Array shape does not match the order of the axes") - if not indexer: - LOG.warning("Empty selection for data set %s", self._name) - return self._curves + def planes_vs_axis( + self, + main_axis: str, + max_limit: int = 9, + ) -> Generator[tuple[str, FloatArray, tuple[str, str]]]: + """Prepare for plotting 2D subsets of an ND array. - for index in self.curve_ind(max_limit): - try: - index_tuple = nth_product(index, *indexer) - index_slicer = nth_product(index, *slicer) - self._curves[index_tuple] = self.data[index_slicer].squeeze() - except IndexError: - LOG.warning( - "Skipping: in dataset %s, index %s is out of bounds", - self._name, - index, - ) - else: - self._curve_labels[index_tuple] = self.generate_curve_label( - index_tuple, - label_lookup, - skip_text=skip_label_text, + Parameters + ---------- + main_axis : str + Label of the axis perpendicular to the plotted array. + + Yields + ------ + main_label : str + Grid label. + image_array : FloatArray + 2D array. + axis_labels : tuple[str, ...] + Labels for each axis. + """ + main_axis_index = self.main_axis_index(main_axis) + other_labels = self._axis_labels(main_axis) + + match self._data.ndim: + case 1: + pass + case 2 if main_axis_index == 1: + yield self._labels["medium"], self.data.T, (main_axis, other_labels[0]) + case 2: + yield self._labels["medium"], self.data, (main_axis, other_labels[0]) + case 3: + perpendicular_axis_name, perpendicular_axis = nth( + self._axes.items(), main_axis_index, default=(None, None) ) - return self._curves + if perpendicular_axis is None: + return - def curve_ind(self, limits: int, /) -> Iterator[int]: - """Return a generator of indices indexing only the curves within the limits.""" - return ( - islice(self._data_limits, limits) - if self._data_limits is not None - else range(limits) - ) + reordered_view = np.moveaxis(self.data, main_axis_index, 0) - def planes_vs_axis( - self, - axis_number: int, - max_limit: int = 1, - ) -> list[FloatArray] | FloatArray | None: - """Prepare for plotting 2D subsets of an ND array. + for plane_number in self.curve_ind(max_limit): + if plane_number > len(reordered_view): + continue + + yield ( + f"{self._labels['minimal']}:{perpendicular_axis_name}={perpendicular_axis[plane_number]}", + reordered_view[plane_number], + other_labels, + ) + case _: + raise NotImplementedError( + f"Cannot handle {self._data.ndim}-dimensional data." + ) + + def main_axis_index( + self, main_axis: str | None, *, default: int | None = None + ) -> int: + """Find index of main axis. Parameters ---------- - axis_number : int - index of the axis perpendicular to the plotted array - max_limit : int, optional - maximum number of planes allowed by plotter, by default 1 + main_axis : str + Main axis name to search for. + default : int, optional + Index if ``main_axis`` not found. Returns ------- - list[FloatArray] - List of 2D arrays for heatmap plots + int + Index of main axis. + Raises + ------ + ValueError + Axis not found and no default. """ - self._planes = {} - self._plane_labels = {} - _found = -1 - total_ndim = self._data.ndim - - if total_ndim == 1: - return None - if total_ndim == 2: - return self.data - - data_shape = self._data.shape - number_of_planes = data_shape[axis_number] - perpendicular_axis = None - perpendicular_axis_name = "" - slice_def = [] - - for number, (axis_name, axis_array) in enumerate(self._axes.items()): - if number == axis_number: - slice_def.append(0) - perpendicular_axis = axis_array - perpendicular_axis_name = self.axis_true_name(axis_name) - else: - slice_def.append(slice(None)) - - for plane_number in self.curve_ind(max_limit): - if plane_number >= number_of_planes: - break - fixed_argument = perpendicular_axis[plane_number] - slice_def[axis_number] = plane_number - self._planes[plane_number] = self.data[tuple(slice_def)] - self._plane_labels[plane_number] = ( - f"{perpendicular_axis_name}={fixed_argument}" + ind = first(locate(self._axes, pred=main_axis.__eq__), default) + if ind is None: + raise ValueError( + f"Cannot find axis {main_axis} in {','.join(self._axes.keys())}" ) + return ind - return None + def _axis_labels(self, main_axis: str) -> tuple[str] | tuple[str, str]: + main_axis_index = self.main_axis_index(main_axis) + + return tuple( + label for i, label in enumerate(self._axes) if i != main_axis_index + ) plotting_column_labels = [ @@ -720,7 +801,9 @@ class PlottingContext(QStandardItemModel): needs_an_update = Signal("quint64") - def __init__(self, *args, unit_lookup=None, **kwargs): + def __init__( + self, *args, unit_lookup: int | None = None, colormap: str = "viridis", **kwargs + ): super().__init__(*args, **kwargs) self._datasets = {} self._current_axis = [None, None, None] @@ -731,13 +814,14 @@ def __init__(self, *args, unit_lookup=None, **kwargs): self._best_xunits = [] self._colour_list = get_mpl_colours() self._last_colour_list = get_mpl_colours() - self._colour_map = kwargs.get("colormap", "viridis") + self._colour_map = colormap self._last_colour = 0 self._unit_lookup = unit_lookup self.plot_widget_id = -1 self.use_legend = True self.use_grid = True self.setHorizontalHeaderLabels(plotting_column_labels) + self.itemChanged.connect(self.ask_for_update) def generate_colour(self, number: int) -> str: """Get the matplotlib colour string for the nth curve. @@ -872,18 +956,17 @@ def datasets(self) -> dict[str, PlotArgs]: data_number_string = row_data["Use it?"].text() main_axis = row_data["Main axis"].text() - plot_args = { - "colour": row_data["Colour"].text(), - "line_style": row_data["Line style"].text(), - "marker": row_data["Marker"].text(), - "row": row, - "main_axis": main_axis, - "legend_label": row_data["Legend label"].text(), - } - self._datasets[key].set_data_limits(data_number_string, main_axis=main_axis) self._datasets[key].set_current_units(self._unit_lookup) - result[key] = PlotArgs(self._datasets[key], **plot_args) + result[key] = PlotArgs( + dataset=self._datasets[key], + colour=row_data["Colour"].text(), + line_style=row_data["Line style"].text(), + marker=row_data["Marker"].text(), + row=row, + main_axis=row_data["Main axis"].text(), + legend_label=row_data["Legend label"].text(), + ) return result @@ -910,7 +993,7 @@ def add_dataset( self._datasets[newkey] = new_dataset items = [ QStandardItem(str(x)) - for x in [ + for x in ( new_dataset._name, getattr(optional_values, "legend_label", new_dataset._labels["medium"]), new_dataset._data_shape, @@ -926,7 +1009,7 @@ def add_dataset( new_dataset._scaling_factor, show=1, arr_fmt=SCALE_FACTOR_FORMAT ), new_dataset._filename, - ] + ) ] fixed = {"Dataset", "Trajectory", "Size", "Unit", "Apply scaling?"} @@ -948,8 +1031,6 @@ def add_dataset( f"0:{prod(len(arr) for arr in new_dataset.dep_axes.values())}:1", ) - self.itemChanged.connect(self.ask_for_update) - temp = items[plotting_column_index["Colour"]] temp.setData(QColor(temp.text()), role=Qt.ItemDataRole.BackgroundRole) @@ -1003,3 +1084,36 @@ def delete_dataset(self, index: QModelIndex): dkey = index.data(role=Qt.ItemDataRole.UserRole) self.removeRow(index.row()) self._datasets.pop(dkey, None) + + def planes( + self, default_axis: int | None = None, planes_per_dataset: int | None = None + ) -> Generator[tuple[PlotArgs, str, FloatArray, tuple[str, str]]]: + for databundle in self.datasets().values(): + ds = databundle.dataset + + for label, plane, axis_labels in islice( + ds.planes_vs_axis(databundle.main_axis), + planes_per_dataset, + ): + yield databundle, label, plane, axis_labels + + def curves( + self, curves_per_dataset: int | None = None + ) -> Generator[ + Generator[ + tuple[ + PlotArgs, + str, + tuple[FloatArray, FloatArray], + ] + ] + ]: + for databundle in self.datasets().values(): + ds = databundle.dataset + + yield ( + (databundle, label, curve) + for label, curve in islice( + ds.curves_vs_axis(databundle.main_axis), curves_per_dataset + ) + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py index b44f3cd34e..2490b3a7e5 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grid.py @@ -20,15 +20,19 @@ from itertools import islice from typing import TYPE_CHECKING, Any -from matplotlib import rcParams +from more_itertools import flatten, ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + import numpy as np + from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure + from matplotlib.lines import Line2D - from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext @Plotter.register("Grid") @@ -41,16 +45,15 @@ def __init__(self) -> None: self._backup_limits = [] self._active_curves = [] self._backup_curves = [] - self._plot_limit = 8 - self._title_length_limit = 30 + self._plot_limit = 9 def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def check_curve_lengths(self): """Find the maximum number of elements in the x axes of the plot data.""" @@ -70,73 +73,71 @@ def change_normalisation(self, new_value: dict[str, Any]): """ super().change_normalisation(new_value) target = self._figure - if target is None: - return - if len(self._active_curves) == 0: + if target is None or not self._active_curves: return + for curve_index, curve in enumerate(self._active_curves): - xdata = self._backup_curves[curve_index][0] - ydata = self._backup_curves[curve_index][1] + xdata, ydata = self._backup_curves[curve_index] xdata, ydata = self.normalise_curve(xdata, ydata) curve.set_xdata(xdata) curve.set_ydata(ydata) + target.canvas.draw() + for axes in self._axes: axes.relim() axes.autoscale() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() - def title_fontsize(self, title_text: str) -> int: - normal_size = rcParams["font.size"] - new_size = ( - normal_size - if len(title_text) < self._title_length_limit - else normal_size - round(len(title_text) / self._title_length_limit) - ) - return new_size - def toggle_legend(self, enabled: bool) -> None: if self._figure is None: return - for plot_index, axes in enumerate(self._axes): + + for axes, title in zip(self._axes, self._axes_titles, strict=True): axes.set_title( - self._axes_titles[plot_index] if enabled else "", - fontsize=self.title_fontsize(self._axes_titles[plot_index]), + title if enabled else "", + fontsize=self.title_fontsize(title), ) + self._figure.canvas.draw() def plot( self, plotting_context: PlottingContext, - figure: Figure = None, - update_only=False, - toolbar=None, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Plot datasets in separate subplots. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) target = self.get_figure(figure) + if target is None: return + if toolbar is not None: self._toolbar = toolbar + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return + self._figure = target self._axes = [] self._axes_titles = [] @@ -144,74 +145,81 @@ def plot( self._active_curves = [] self._normalisation_errors = [] self.apply_settings(plotting_context) - nplots = 0 - for databundle in plotting_context.datasets().values(): - ds = databundle.dataset - try: - axis_info = ds._axes_units[databundle.main_axis], databundle.main_axis - except KeyError: - axis_info = ds.longest_axis() - curves = ds.curves_vs_axis(axis_info, max_limit=self._plot_limit) - nplots += len(curves) - nplots = min(nplots, self._plot_limit) - gridsize = math.ceil(nplots**0.5) - startnum = 1 - counter = 0 - for databundle in plotting_context.datasets().values(): - dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() - for key, curve in islice(dataset._curves.items(), self._plot_limit): - counter += 1 - if counter > self._plot_limit: - LOG.warning( - "Curves above the current limit of %s will be ignored", - self._plot_limit, - ) - break - axes = target.add_subplot(gridsize, gridsize, startnum) - self._axes.append(axes) - plotlabel = databundle.legend_label - if dataset._curve_labels[key]: - plotlabel += ":" + dataset._curve_labels[key] - x_axis_label = dataset.x_axis_label(best_axis) - [temp_curve] = axes.plot( - dataset.x_axis(best_axis), - curve, - linestyle=databundle.line_style, - color=databundle.colour, - label=plotlabel, - ) - try: - temp_curve.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp_curve.set_marker(int(databundle.marker)) - axes.set_xlabel(x_axis_label) - self._axes_titles.append(plotlabel) - axes.set_title( - plotlabel if plotting_context.use_legend else "", - fontsize=self.title_fontsize(plotlabel), - ) - legend = axes.legend() - legend.set_visible(False) - axes.grid(plotting_context.use_grid) - startnum += 1 - self._active_curves.append(temp_curve) - self._backup_curves.append( - [temp_curve.get_xdata(), temp_curve.get_ydata()], - ) - if counter == 0: - self.plot_blank() - return + + self._n_curves = min( + sum(db.dataset.n_curves for db in plotting_context.datasets().values()), + self._plot_limit, + ) + grid_size = self.grid_size(self.n_curves) + gs = self._figure.add_gridspec(*grid_size) + + for ind, (databundle, label, curve) in enumerate( + islice(flatten(plotting_context.curves()), self._plot_limit) + ): + axes = target.add_subplot(gs[ind]) + self._axes_titles.append(f"{databundle.dataset._name} {label}") + + self._plot_single( + axes, + curve, + databundle, + label="", + colour=databundle.colour, + ) + self._axes.append(axes) + + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) + self.apply_settings(plotting_context) self.check_curve_lengths() target.canvas.draw() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + label: str, + colour: tuple[float, float, float] | str, + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + label : str + Plot label. + colour : tuple[float, float, float] | str + Curve colour. + """ + lines: list[Line2D] = axes.plot( + *curve, + linestyle=databundle.line_style, + label=label, + color=colour, + ) + + axes.set_xlabel(databundle.dataset.x_axis_label(databundle.main_axis)) + + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py new file mode 100644 index 0000000000..3a310a3305 --- /dev/null +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Grouped.py @@ -0,0 +1,333 @@ +# This file is part of MDANSE_GUI. +# +# MDANSE_GUI is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +from __future__ import annotations + +import contextlib +import math +from itertools import islice +from typing import TYPE_CHECKING, Any + +import numpy as np +from more_itertools import ilen + +from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter + +if TYPE_CHECKING: + import numpy.typing as npt + from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar + from matplotlib.figure import Figure + from matplotlib.lines import Line2D + + from MDANSE.util_types import FloatArray + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext + + +@Plotter.register("Grouped") +class Grouped(Plotter): + """Plots each dataset in its own subplot.""" + + def __init__(self) -> None: + super().__init__() + self._figure = None + self._backup_limits: list[tuple[float, float, float, float]] = [] + self._active_curves: list[Line2D] = [] + self._backup_curves: list[ + tuple[FloatArray, FloatArray] + ] = [] + self._plot_limit = 9 + self.height_max, self.length_max = -np.inf, -np.inf + + def slider_labels(self) -> list[str]: + """Return labels to show that sliders are not used.""" + return ["Y offset", "X offset"] + + def slider_limits(self) -> list[tuple[float, float, float]]: + """Return generic slider limit values.""" + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders + + def handle_slider(self, new_value: list[float]): + """Save slider values and call offset_curves.""" + super().handle_slider(new_value) + self.offset_curves() + + def offset_curves(self): + """Offset curves against each other based on slider settings.""" + target = self._figure + if target is None: + return + + new_value = self._slider_values + + backup = iter(self._backup_curves) + for ind, axes in enumerate(self._axes): + saved_xmin, saved_xmax, saved_ymin, saved_ymax = self._backup_limits[ind] + + for num, curve in enumerate(axes.get_lines()): + xdata, ydata = next(backup) + xdata, ydata = self.normalise_curve(xdata, ydata) + new_xdata = xdata + num * self.length_max * new_value[1] + new_ydata = ydata + num * self.height_max * new_value[0] + curve.set_xdata(new_xdata) + curve.set_ydata(new_ydata) + xmin, xmax = np.nanmin(new_xdata), np.nanmax(new_xdata) + ymin, ymax = np.nanmin(new_ydata), np.nanmax(new_ydata) + saved_xmin = np.nanmin([xmin, saved_xmin]) + saved_xmax = np.nanmax([xmax, saved_xmax]) + saved_ymin = np.nanmin([ymin, saved_ymin]) + saved_ymax = np.nanmax([ymax, saved_ymax]) + + axes.relim() + axes.autoscale() + + self._backup_limits[ind] = (saved_xmin, saved_xmax, saved_ymin, saved_ymax) + + try: + axes.set_xlim(saved_xmin, saved_xmax) + except ValueError: + LOG.error( + f"Matplotlib could not set x limits to {saved_xmin}, {saved_xmax}", + ) + + try: + axes.set_ylim(saved_ymin, saved_ymax) + except ValueError: + LOG.error( + f"Matplotlib could not set y limits to {saved_ymin}, {saved_ymax}", + ) + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + target.canvas.draw() + + def check_curve_lengths(self): + """Find the maximum number of elements in the x axes of the plot data.""" + self.curve_length_limit = 0 + for num, _ in enumerate(self._active_curves): + xdata = self._backup_curves[num][0] + self.curve_length_limit = max(self.curve_length_limit, len(xdata)) + + def change_normalisation(self, new_value: dict[str, Any]): + """Normalise the data based on the new parameters. + + Parameters + ---------- + new_value : dict[str, Any] + parameters as in NORMALISATION_DEFAULTS + + """ + super().change_normalisation(new_value) + target = self._figure + if target is None or not self._active_curves: + return + + for curve_index, curve in enumerate(self._active_curves): + xdata, ydata = self._backup_curves[curve_index] + xdata, ydata = self.normalise_curve(xdata, ydata) + curve.set_xdata(xdata) + curve.set_ydata(ydata) + + target.canvas.draw() + + for axes in self._axes: + axes.relim() + axes.autoscale() + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + def toggle_legend(self, enabled: bool) -> None: + if self._figure is None: + return + + for axes, title in zip(self._axes, self._axes_titles, strict=True): + axes.set_title( + title if enabled else "", + fontsize=self.title_fontsize(title), + ) + axes.get_legend().set_visible(enabled) + + self._figure.canvas.draw() + + def plot( + self, + plotting_context: PlottingContext, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, + ): + """Plot datasets in separate subplots. + + Parameters + ---------- + plotting_context : PlottingContext + Data model storing the data to be plotted. + figure : Figure, optional + Matplotlib figure instance for plotting, by default None. + update_only : bool, optional + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. + + """ + self.enable_slider(allow_slider=False) + target = self.get_figure(figure) + + if target is None: + return + + if toolbar is not None: + self._toolbar = toolbar + + if plotting_context.set_axes() is None: + LOG.debug("Axis check failed.") + return + + self.height_max, self.length_max = 0.0, 0.0 + self._figure = target + + self._axes = [] + self._axes_titles = [] + self._backup_curves = [] + self._active_curves = [] + self._normalisation_errors = [] + + self.apply_settings(plotting_context) + + self._n_curves = sum(ilen(curves) for curves in plotting_context.curves()) + nplots = min(len(plotting_context.datasets()), self._plot_limit) + grid_size = self.grid_size(nplots) + gs = self._figure.add_gridspec(*grid_size) + limits = [(0.0, 0.0, 0.0, 0.0) for _ in range(nplots)] + + for ind, (db, dataclump) in enumerate( + islice( + zip( + plotting_context.datasets().values(), + plotting_context.curves(), + strict=True, + ), + self._plot_limit, + ) + ): + ds = db.dataset + axes = target.add_subplot(gs[ind]) + self._axes_titles.append(ds._name) + + colours = self.colours(db.colour, ds.n_curves) + + for curve_ind, ((databundle, label, curve), colour) in enumerate( + zip(dataclump, colours, strict=True) + ): + self._plot_single( + axes, + curve, + databundle, + ind=curve_ind, + label=label, + colour=colour, + ) + + axes.legend() + self._axes.append(axes) + limits[ind] = (*axes.get_xlim(), *axes.get_ylim()) + + if update_only: + try: + axes.set_xlim(self._backup_limits[ind][:2]) + except ValueError: + LOG.error( + f"Matplotlib could not set x limits to {self._backup_limits[ind][:2]}" + ) + + try: + axes.set_ylim(self._backup_limits[ind][2:]) + except ValueError: + LOG.error( + f"Matplotlib could not set y limits to {self._backup_limits[ind][2:]}" + ) + + if not update_only: + self._backup_limits = limits + + self.enable_slider( + allow_slider=any( + db.dataset.n_curves > 1 for db in plotting_context.datasets().values() + ) + ) + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) + + self.apply_settings(plotting_context) + self.check_curve_lengths() + target.canvas.draw() + + if self._toolbar is not None: + self._toolbar.update() + self._toolbar.push_current() + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + ind: int, + label: str, + colour: tuple[float, float, float] | str, + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + label : str + Plot label. + colour : tuple[float, float, float] | str + Curve colour. + """ + lines: list[Line2D] = axes.plot( + *curve, + linestyle=databundle.line_style, + label=self.label(label, ind, n_curves=databundle.dataset.n_curves), + color=colour, + ) + + axes.set_xlabel(databundle.dataset.x_axis_label(databundle.main_axis)) + + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self.height_max = np.nanmax([self.height_max, np.nanmax(line.get_ydata())]) + self.length_max = np.nanmax([self.length_max, np.nanmax(line.get_xdata())]) + + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py index 0e51b13ebc..e6b4db8805 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Heatmap.py @@ -18,20 +18,26 @@ import csv import math from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, TextIO +from dataclasses import dataclass, field +from itertools import islice +from typing import TYPE_CHECKING, Any, NamedTuple, TextIO import numpy as np from matplotlib.axes import Axes -from matplotlib.image import AxesImage from matplotlib.pyplot import colorbar as mpl_colorbar +from more_itertools import ilen, nth from scipy.interpolate import interp1d from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure + from matplotlib.image import AxesImage + from MDANSE.util_types import FloatArray from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -39,22 +45,29 @@ class Heatmap(Plotter): """Creates a 2D heatmap plot.""" + @dataclass + class BackupInfo: + ind: int + image: AxesImage = None + array: FloatArray = field(default_factory=lambda: np.empty((0,), dtype=float)) + minmax: tuple[float, float] = (-np.inf, np.inf) + limits: tuple[float, float, float, float] = (-np.inf, np.inf, -np.inf, np.inf) + interp: interp1d = None + title: str = "" + def __init__(self) -> None: """Initialise all plotting parameters to defaults.""" super().__init__() self._figure = None - self._backup_images = {} - self._backup_arrays = {} - self._backup_minmax = {} - self._backup_scale_interpolators = {} + self._axes_titles: list[str] = [] + self._backup: dict[str, Heatmap.BackupInfo] = {} self._current_x_axes = [] - self._backup_limits = {} self._initial_values = [0.0, 100.0] self._slider_values = [0.0, 100.0] self._slice_axis = 2 - self._plot_limit = 1 + self._plot_limit = 9 - def clear(self, figure: Figure = None): + def clear(self, figure: Figure | None = None): """Clear the figure.""" target = self._figure if figure is None else figure if target is None: @@ -65,15 +78,15 @@ def slider_labels(self) -> list[str]: """Return labels for the sliders in heatmap mode.""" return ["Minimum (percentile)", "Maximum (percentile)"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return slider limits for the colormap, in percent.""" - return self._number_of_sliders * [[0.0, 100.0, 0.01]] + return [(0.0, 100.0, 0.01)] * self._number_of_sliders def sliders_coupled(self) -> bool: """Confirm that sliders are coupled in heatmap mode.""" return True - def get_figure(self, figure: Figure = None): + def get_figure(self, figure: Figure | None = None) -> Figure | None: """Return current figure which will be used for plotting.""" target = self._figure if figure is None else figure if target is None: @@ -92,271 +105,192 @@ def change_normalisation(self, new_value: dict[str, Any]): """ super().change_normalisation(new_value) - for ds_num, image in self._backup_images.items(): - data = self._backup_arrays[ds_num] + for backup in self._backup.values(): + data = backup.array new_data = self.normalise_array(data) - image.set_data(new_data) + backup.image.set_data(new_data) percentiles = np.linspace(0, 100.0, 21) results = np.percentile(np.nan_to_num(new_data), percentiles) - self._backup_scale_interpolators[ds_num] = interp1d( - percentiles, - results, - ) + backup.interp = interp1d(percentiles, results) + self.request_slider_values() def handle_slider(self, new_value: list[float]): """Adjust colormap values based on slider values.""" super().handle_slider(new_value) - target = self._figure - if target is None: - return - if new_value[1] <= new_value[0]: + + if self._figure is None or new_value[1] <= new_value[0]: return + self._slider_values = [new_value[0], new_value[1]] - for ds_num, image in self._backup_images.items(): + + for backup in self._backup.values(): try: - last_minmax = self._backup_minmax[ds_num] + last_minmax = backup.minmax except KeyError: - self._backup_minmax[ds_num] = [-1, -1] + backup.minmax = (-1, -1) last_minmax = [-1, -1] - interpolator = self._backup_scale_interpolators[ds_num] + + interpolator = backup.interp newmax = interpolator(new_value[1]) newmin = interpolator(new_value[0]) + if newmax < newmin: if newmax == last_minmax[1]: newmin = float(newmax) else: newmax = float(newmin) + if newmin == last_minmax[0] and newmax == last_minmax[1]: return + if newmax >= newmin: try: - image.set_clim([newmin, newmax]) + backup.image.set_clim((newmin, newmax)) except ValueError: LOG.error( f"Matplotlib could not set colorbar limits to {newmin}, {newmax}" ) else: self._figure.canvas.draw_idle() - self._backup_minmax[ds_num] = [newmin, newmax] - target.canvas.draw() + backup.minmax = (newmin, newmax) + self._figure.canvas.draw() + + def toggle_legend(self, enabled: bool) -> None: + if self._figure is None: + return + + for axes, title in zip(self._axes, self._axes_titles, strict=True): + axes.set_title( + title if enabled else "", + fontsize=self.title_fontsize(title), + ) + + self._figure.canvas.draw() def check_curve_lengths(self): """Find the maximum number of elements in the x axes of the plot data.""" - self.curve_length_limit = 0 - for xdata in self._current_x_axes: - self.curve_length_limit = max(self.curve_length_limit, len(xdata)) + self.curve_length_limit = max(map(len, self._current_x_axes), default=0) def plot( self, plotting_context: PlottingContext, - figure: Figure = None, - update_only=False, - toolbar=None, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Plot the first dataset as a heatmap. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None - + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=True) target = self.get_figure(figure) + if target is None: return + if toolbar is not None: self._toolbar = toolbar + self._figure = target + self._figure.set_layout_engine(layout="constrained") self._current_x_axes = [] - self._normalisation_errors = [] - self._backup_images = {} - self._backup_arrays = {} - self._backup_scale_interpolators = {} - self._axes = [] + + scale_interpolators = {val.ind: val.interp for val in self._backup.values()} + + self._backup.clear() + self._axes.clear() + self._axes_titles.clear() + self.apply_settings(plotting_context) if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return - nplots = 0 - for databundle in plotting_context.datasets().values(): - if nplots >= self._plot_limit: - break - ds = databundle.dataset - if ds._n_dim == 1: - continue - elif ds._n_dim == 3: - replacement_axis_number = None - for number, axis_name in enumerate(ds._axes.keys()): - if axis_name == databundle.main_axis: - replacement_axis_number = number - if replacement_axis_number is None: - ds.planes_vs_axis(self._slice_axis, max_limit=self._plot_limit) - else: - ds.planes_vs_axis( - replacement_axis_number, max_limit=self._plot_limit - ) - nplots += len(ds._planes) - else: - nplots += 1 + + self._n_curves = min( + ilen(plotting_context.planes(self._slice_axis)), self._plot_limit + ) + + if not self.n_curves: + self.plot_blank() + return + + def get_interp(ind: int, data: FloatArray): + # Check interpolators try: - self._backup_scale_interpolators[databundle.row](51.2) + scale_interpolators[ind](51.2) except Exception: percentiles = np.linspace(0, 100.0, 21) results = [ - np.percentile(np.nan_to_num(ds._data), perc) for perc in percentiles + np.percentile(np.nan_to_num(data), perc) for perc in percentiles ] - self._backup_scale_interpolators[databundle.row] = interp1d( - percentiles, - results, - ) - nplots = min(nplots, self._plot_limit) - gridsize = math.ceil(nplots**0.5) - startnum = 1 - for ds_index, databundle in enumerate(plotting_context.datasets().values()): - if ds_index >= self._plot_limit: - break - dataset = databundle.dataset - transposed = False - primary_axis_number = 0 - limits = [] - x_axis_labels, y_axis_labels = [], [] - for number, axis_name in enumerate(ds._axes.keys()): - if axis_name == databundle.main_axis: - primary_axis_number = number - if dataset._n_dim == 1: - continue - if dataset._n_dim == 3: - all_numbers, all_datasets = ( - list(dataset._planes.keys()), - list(dataset._planes.values()), - ) - all_labels = [dataset._plane_labels[number] for number in all_numbers] - for counter, name in enumerate(dataset._axes.keys()): - if counter == primary_axis_number: - continue - axis_array = dataset.x_axis(name) - limits += [ - axis_array[0], - axis_array[-1], - ] - if not x_axis_labels: - x_axis_labels.append(dataset.x_axis_label(name)) - self._current_x_axes.append(axis_array) - else: - y_axis_labels.append(dataset.x_axis_label(name)) + return interp1d(percentiles, results) else: - all_numbers = [0] - if primary_axis_number == 0: - all_datasets = [dataset._data.T] - else: - all_datasets = [dataset._data] - transposed = True - all_labels = [dataset._name] - for counter, name in enumerate(dataset._axes.keys()): - axis_array = dataset.x_axis(name) - limits += [ - axis_array[0], - axis_array[-1], - ] - if counter == primary_axis_number: - x_axis_labels.append(dataset.x_axis_label(name)) - self._current_x_axes.append(axis_array) - else: - y_axis_labels.append(dataset.x_axis_label(name)) - if transposed: - limits = limits[2:] + limits[:2] - for xnum in range(len(all_datasets)): - if startnum > self._plot_limit: - LOG.warning( - "Datasets above the current limit of %s will be ignored", - self._plot_limit, - ) - break - axes = target.add_subplot(gridsize, gridsize, startnum) - startnum += 1 - self._axes.append(axes) - image = axes.imshow( - all_datasets[xnum][::-1, :], - extent=limits, - aspect="auto", - interpolation=None, - cmap=plotting_context.colormap, + return scale_interpolators[ind] + + grid_size = self.grid_size(self.n_curves) + gs = self._figure.add_gridspec(*grid_size) + + for ind, (databundle, label, plane, axis_labels) in enumerate( + islice(plotting_context.planes(self._slice_axis), self._plot_limit), + ): + self._backup[label] = Heatmap.BackupInfo(ind=ind) + dataset = databundle.dataset + + axes = self._figure.add_subplot(gs[ind]) + + image = self._plot_single( + axes, + plane[:, ::-1].T, + databundle, + label=label, + axis_labels=axis_labels, + colour=plotting_context.colormap, + ) + + xlimits, ylimits = axes.get_xlim(), axes.get_ylim() + + self._backup[label].interp = get_interp(ind, plane) + + last_minmax = ( + self._backup[label].interp(self._slider_values[0]), + self._backup[label].interp(self._slider_values[1]), + ) + + try: + image.set_clim(last_minmax) + except ValueError: + LOG.error( + f"Matplotlib could not set colorbar limits to {last_minmax}", ) - axes.set_title(all_labels[xnum]) - colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) - colorbar.set_label(dataset._data_unit) - xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - self._backup_arrays[databundle.row] = all_datasets[xnum][::-1, :] + if update_only: - interpolator = self._backup_scale_interpolators[databundle.row] - last_minmax = [ - interpolator(self._slider_values[0]), - interpolator(self._slider_values[1]), - ] - try: - image.set_clim(last_minmax) - except ValueError: - LOG.error( - f"Matplotlib could not set colorbar limits to {last_minmax}", - ) - self._backup_limits[databundle.row] = [ - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ] - xlim = axes.get_xlim() - self._backup_limits[databundle.row][0] = xlim[0] - self._backup_limits[databundle.row][1] = xlim[1] - ylim = axes.get_ylim() - self._backup_limits[databundle.row][2] = ylim[0] - self._backup_limits[databundle.row][3] = ylim[1] + self._backup[label].limits = (*xlimits, *ylimits) else: - self._backup_limits[databundle.row] = [ - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ] - interpolator = self._backup_scale_interpolators[databundle.row] - last_minmax = [ - interpolator(self._slider_values[0]), - interpolator(self._slider_values[1]), - ] - try: - image.set_clim(last_minmax) - except ValueError: - LOG.error( - f"Matplotlib could not set colorbar limits to {last_minmax}", - ) - self._backup_minmax[databundle.row] = [ + self._backup[label].minmax = ( np.nanmin(dataset._data), np.nanmax(dataset._data), - ] - self._backup_limits[databundle.row] = [ - xlimits[0], - xlimits[1], - ylimits[0], - ylimits[1], - ] - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_ylabel(", ".join(np.unique(y_axis_labels))) - self._backup_images[databundle.row] = image - if startnum > 1: - legend = axes.legend() - legend.set_visible(plotting_context.use_legend) - axes.grid(plotting_context.use_grid) + ) + self._backup[label].limits = (*xlimits, *ylimits) + + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) + + if self.n_curves == 1: # Exploit label from loop for one plot + self._figure.suptitle(label) + self.check_curve_lengths() self.request_slider_values() - target.canvas.draw() + self._figure.canvas.draw() @staticmethod def _write_save_data( @@ -419,3 +353,42 @@ def _get_datasets(axis: Axes) -> Iterator[AxesImage]: Each image in dataset. """ yield from axis.get_images() + + def _plot_single( + self, + axes: Axes, + plane: FloatArray, + databundle: PlotArgs, + *, + label: str, + axis_labels: tuple[str, str], + colour: str, + ) -> AxesImage: + dataset = databundle.dataset + + x_label, y_label = axis_labels + + x_axis = dataset.x_axis(x_label) + y_axis = dataset.x_axis(y_label) + + limits = (x_axis[0], x_axis[-1], y_axis[0], y_axis[-1]) + axes.set_xlabel(x_label) + axes.set_ylabel(y_label) + + image = axes.imshow( + plane, + extent=limits, + aspect="auto", + interpolation=None, + cmap=colour, + ) + + colorbar = mpl_colorbar(image, ax=image.axes, format="%.1e", pad=0.02) + colorbar.set_label(dataset._data_unit) + + self._axes.append(axes) + self._axes_titles.append(label) + self._backup[label].array = plane + self._backup[label].image = image + + return image diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py index e8a50a48c9..ecc44f98e2 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Plotter.py @@ -18,34 +18,44 @@ import copy import csv import enum +import math from itertools import count from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO import numpy as np +from matplotlib import rcParams +from matplotlib.colors import to_rgb from more_itertools import consumer from MDANSE.Core.RegisterFactory import RegisterFactory -from MDANSE.IO.IOUtils import UCDict +from MDANSE.IO.IOUtils import UCDict, UCEnum from MDANSE.MLogging import LOG from MDANSE.util_types import FloatArray if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Generator, Iterator from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from matplotlib.lines import Line2D from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext -class NormOperations(enum.Enum): +class NormOperations(UCEnum): """Enum for selecting mathematical operations when calculating norms.""" AVERAGE = enum.auto() SUM = enum.auto() NOT_IMPLEMENTED = enum.auto() + @classmethod + def _missing_(cls, value: str) -> NormOperations: + if res := super()._missing_(value): + return res + return cls.NOT_IMPLEMENTED + def str_to_enum(operation: str) -> NormOperations: """Get the right enum from the input text string. @@ -53,19 +63,15 @@ def str_to_enum(operation: str) -> NormOperations: Parameters ---------- operation : str - name of the mathematical operation as string. + Name of the mathematical operation as string. Returns ------- NormOperations - enum value of the operation. + Enum value of the operation. """ - if operation == "average": - return NormOperations.AVERAGE - if operation == "sum": - return NormOperations.SUM - return NormOperations.NOT_IMPLEMENTED + return NormOperations(operation) def enum_to_str(operation: NormOperations) -> str: @@ -74,19 +80,15 @@ def enum_to_str(operation: NormOperations) -> str: Parameters ---------- operation : NormOperations - Enum of the mathematical operation + Enum of the mathematical operation. Returns ------- str - name of the operation as string + Name of the operation as string. """ - if operation == NormOperations.AVERAGE: - return "average" - if NormOperations.SUM: - return "sum" - return "not implemented" + return operation.name.lower() NORMALISATION_DEFAULTS = { @@ -96,12 +98,22 @@ def enum_to_str(operation: NormOperations) -> str: "operation": NormOperations.AVERAGE, } +ValidPlotters = Literal["Single", "Vectors", "Text", "Heatmap", "Grid"] + class Plotter(RegisterFactory): """Parent class to all classes used for displaying data.""" registry: ClassVar[UCDict[str, type[Plotter]]] = UCDict() + GRID_SIZES = { + 2: (2, 1), + 5: (2, 3), + 6: (2, 3), + } + _title_length_limit: ClassVar[int] = 30 + _max_labels = 5 + def __init__(self) -> None: """Create defaults common to all plotters.""" self._figure = None @@ -115,6 +127,7 @@ def __init__(self) -> None: self.curve_length_limit = 10 self._normalisation_values = copy.copy(NORMALISATION_DEFAULTS) self._normalisation_errors = [] + self._n_curves = 0 def request_slider_values(self): """Manually read values from sliders, if they are present.""" @@ -214,25 +227,30 @@ def normalise_curve( """ apply = self._normalisation_values["apply"] operation = self._normalisation_values["operation"] - if not apply or operation == NormOperations.NOT_IMPLEMENTED: + if not apply or operation is NormOperations.NOT_IMPLEMENTED: return xdata, ydata + min_index = self._normalisation_values["min_index"] max_index = self._normalisation_values["max_index"] ref_values = ydata[min_index:max_index] + if len(ref_values) < 1: self._normalisation_errors.append( "No points within the specified index range" ) return xdata, ydata - if operation == NormOperations.AVERAGE: + + if operation is NormOperations.AVERAGE: scale_factor = np.nanmean(ref_values) - elif operation == NormOperations.SUM: + elif operation is NormOperations.SUM: scale_factor = np.sum(np.nan_to_num(ref_values)) + if np.isclose(scale_factor, 0.0): self._normalisation_errors.append( "Normalisation factor is 0 and will not be applied." ) return xdata, ydata + return xdata, ydata / scale_factor def normalise_array(self, data_array: FloatArray) -> FloatArray: @@ -258,15 +276,18 @@ def normalise_array(self, data_array: FloatArray) -> FloatArray: ref_column = data_array[:, min_index:max_index] if ref_column.shape[1] < 1: return data_array - if operation == NormOperations.AVERAGE: + + if operation is NormOperations.AVERAGE: scale_column = np.nanmean(ref_column, axis=1) - elif operation == NormOperations.SUM: + elif operation is NormOperations.SUM: scale_column = np.sum(np.nan_to_num(ref_column), axis=1) + if np.any(np.isclose(scale_column, 0.0)): self._normalisation_errors.append( "Normalisation factor is 0 for some rows of the 2D array." ) return data_array + return data_array / scale_column.reshape((len(scale_column), 1)) def change_normalisation(self, new_value: dict[str, Any]): @@ -279,20 +300,20 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar=None, + toolbar: Toolbar | None = None, ): """Plot the selected data in the figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ LOG.info(f"normalisation errors {self._normalisation_errors}, setting to []") @@ -448,3 +469,89 @@ def _get_datasets(axis: Axes) -> Iterator[Line2D]: Each line in dataset. """ yield from axis.get_lines() + + @classmethod + def grid_size(cls, n_plots: int) -> tuple[int, int]: + """Get a good grid layout for plotting. + + Parameters + ---------- + n_plots : int + Number of expected plots. + + Returns + ------- + tuple[int, int] + Grid size. + """ + return cls.GRID_SIZES.get(n_plots, (math.ceil(n_plots**0.5),) * 2) + + def label( + self, + label: str, + ind: int, + *, + n_curves: int | None = None, + limit: int | None = None, + ) -> str | None: + """Get label for legend. + + For the abbreviated legend return None for those which are + between ``limit`` and ``n_curves`` (skipping them), and "..." + when it's at the limit. + + Parameters + ---------- + ind : int + Current index. + n_curves : int + Total number of "curves" to plot. + limit : int + Max number of entries in legend. + label : str + Current label. + """ + if n_curves is None: + n_curves = self.n_curves + if limit is None: + limit = self._max_labels + + if ind == limit: + return "..." + if limit < ind < n_curves - 1: + return None + return label + + @property + def n_curves(self) -> int: + """Number of expected/total curves.""" + return self._n_curves + + @staticmethod + def colours(colour: str, n_curves: int) -> Generator[tuple[float, float, float]]: + """Generate colours from root colour. + + Parameters + ---------- + colour : str + Root colour. + + Returns + ------- + Generator[tuple[float, float, float]] + Next colour in sequence. + """ + main_colour = np.array(to_rgb(colour)) + colour_increment = (0.5 - main_colour) / n_curves + for _ in range(n_curves): + yield tuple(main_colour) + main_colour += colour_increment + + def title_fontsize(self, title_text: str) -> int: + normal_size = rcParams["font.size"] + new_size = ( + normal_size + if len(title_text) < self._title_length_limit + else normal_size - round(len(title_text) / self._title_length_limit) + ) + return new_size diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py index 065ca8d26b..d3ac4cfa50 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Single.py @@ -16,19 +16,23 @@ from __future__ import annotations import contextlib +from collections.abc import Generator from itertools import islice from typing import TYPE_CHECKING, Any import numpy as np -from matplotlib.colors import to_rgb +from more_itertools import flatten, ilen from MDANSE.MLogging import LOG from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure + from matplotlib.lines import Line2D - from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext + from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs, PlottingContext @Plotter.register("Single") @@ -39,20 +43,21 @@ def __init__(self) -> None: """Initialise all ploting parameters to default values.""" super().__init__() self._figure = None - self._active_curves = [] - self._backup_curves = [] + self._active_curves: list[Line2D] = [] + self._backup_curves: list[tuple[np.ndarray, np.ndarray]] = [] self._backup_limits = [] self._curve_limit_per_dataset = 12 + self._plot_limit = 60 self.height_max, self.length_max = 0.0, 0.0 - def clear(self, figure: Figure = None): + def clear(self, figure: Figure | None = None): """Clear the figure.""" target = self._figure if figure is None else figure if target is None: return target.clear() - def get_figure(self, figure: Figure = None): + def get_figure(self, figure: Figure | None = None): """Return the figure instance used for plotting.""" target = self._figure if figure is None else figure if target is None: @@ -65,9 +70,9 @@ def slider_labels(self) -> list[str]: """Return slider labels for single plot mode.""" return ["Y offset", "X offset"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return slider limits for single plot mode.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.001]] + return [(-1.0, 1.0, 0.001)] * self._number_of_sliders def handle_slider(self, new_value: list[float]): """Save slider values and call offset_curves.""" @@ -80,8 +85,7 @@ def change_normalisation(self, new_value: dict[str, Any]): Parameters ---------- new_value : dict[str, Any] - parameters as in NORMALISATION_DEFAULTS - + Parameters as in NORMALISATION_DEFAULTS. """ super().change_normalisation(new_value) self.offset_curves() @@ -89,15 +93,14 @@ def change_normalisation(self, new_value: dict[str, Any]): def offset_curves(self): """Offset curves against each other based on slider settings.""" target = self._figure - if target is None: - return - if len(self._active_curves) == 0: + if target is None or not self._active_curves: return + new_value = self._slider_values saved_xmin, saved_xmax, saved_ymin, saved_ymax = self._backup_limits + for num, curve in enumerate(self._active_curves): - xdata = self._backup_curves[num][0] - ydata = self._backup_curves[num][1] + xdata, ydata = self._backup_curves[num] xdata, ydata = self.normalise_curve(xdata, ydata) new_xdata = xdata + num * self.length_max * new_value[1] new_ydata = ydata + num * self.height_max * new_value[0] @@ -112,15 +115,18 @@ def offset_curves(self): self._backup_limits = [saved_xmin, saved_xmax, saved_ymin, saved_ymax] self._axes[0].relim() self._axes[0].autoscale() + if self._toolbar is not None: self._toolbar.update() self._toolbar.push_current() + try: self._axes[0].set_xlim(saved_xmin, saved_xmax) except ValueError: LOG.error( f"Matplotlib could not set x limits to {saved_xmin}, {saved_xmax}", ) + try: self._axes[0].set_ylim(saved_ymin, saved_ymax) except ValueError: @@ -139,121 +145,79 @@ def check_curve_lengths(self): def plot( self, plotting_context: PlottingContext, - figure: Figure = None, - update_only=False, - toolbar=None, + figure: Figure | None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Plot all datasets in the same figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) + target = self.get_figure(figure) if target is None: return + if toolbar is not None: self._toolbar = toolbar + self._figure = target - self._figure.set_layout_engine("none") self._active_curves = [] self._backup_curves = [] self._normalisation_errors = [] - axes = target.add_subplot(111) - self._axes = [axes] - self.apply_settings(plotting_context) x_axis_labels = [] + + self.apply_settings(plotting_context) + self.height_max, self.length_max = 0.0, 0.0 + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return - if len(plotting_context.datasets()) == 0: - target.clear() - target.canvas.draw() - for databundle in plotting_context.datasets().values(): - dataset = databundle.dataset - try: - best_unit, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - best_unit, best_axis = dataset.longest_axis() - plotlabel = databundle.legend_label - x_axis_labels.append(dataset.x_axis_label(best_axis)) - if dataset._n_dim == 1: - [temp] = axes.plot( - dataset.x_axis(best_axis), - dataset.data, - linestyle=databundle.line_style, - label=plotlabel, - color=databundle.colour, - ) - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._active_curves.append(temp) - self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) - self.height_max = np.nanmax( - [self.height_max, np.nanmax(temp.get_ydata())] - ) - self.length_max = np.nanmax( - [self.length_max, np.nanmax(temp.get_xdata())] - ) - else: - multi_curves = dataset.curves_vs_axis( - (best_unit, best_axis), max_limit=self._curve_limit_per_dataset - ) - main_colour = np.array(to_rgb(databundle.colour)) - colour_increment = (0.5 - main_colour) / min( - self._curve_limit_per_dataset, len(multi_curves) - ) - for key, value in islice( - multi_curves.items(), self._curve_limit_per_dataset - ): - try: - [temp] = axes.plot( - dataset.x_axis(best_axis), - value, - label=plotlabel + ":" + dataset._curve_labels[key], - linestyle=databundle.line_style, - color=tuple(main_colour), - ) - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._active_curves.append(temp) - self._backup_curves.append([temp.get_xdata(), temp.get_ydata()]) - self.height_max = np.nanmax( - [self.height_max, np.nanmax(temp.get_ydata())] - ) - self.length_max = np.nanmax( - [self.length_max, np.nanmax(temp.get_xdata())] - ) - except ValueError: - LOG.error(f"Plotting failed for {plotlabel} using {best_axis}") - LOG.error(f"x_axis={dataset._axes[best_axis]}") - LOG.error(f"values={value}") - return - main_colour += colour_increment - if len(self._backup_curves) > 1: - self.enable_slider(allow_slider=True) - elif not self._backup_curves: + + self._n_curves = sum( + ds.dataset.n_curves for ds in plotting_context.datasets().values() + ) + + if not self._n_curves: self.plot_blank() return + + self.enable_slider(allow_slider=self._n_curves > 1) + + colours = {} + for databundle in plotting_context.datasets().values(): + colours[databundle.row] = self.colours( + databundle.colour, databundle.dataset.n_curves + ) + x_axis_labels.append(databundle.dataset.x_axis_label(databundle.main_axis)) + + axes = target.add_subplot(111) + self._axes = [axes] + + for databundle, label, curve in islice( + flatten(plotting_context.curves(self._curve_limit_per_dataset)), + self._plot_limit, + ): + self._plot_single( + axes, + curve, + databundle, + label=f"{databundle.legend_label} {label}", + colour=next(colours[databundle.row]), + ) + if update_only: try: axes.set_xlim((self._backup_limits[0], self._backup_limits[1])) @@ -261,6 +225,7 @@ def plot( LOG.error( f"Matplotlib could not set x limits to {self._backup_limits[0]}, {self._backup_limits[1]}" ) + try: axes.set_ylim((self._backup_limits[2], self._backup_limits[3])) except ValueError: @@ -268,11 +233,58 @@ def plot( f"Matplotlib could not set y limits to {self._backup_limits[2]}, {self._backup_limits[3]}" ) else: - xlimits, ylimits = axes.get_xlim(), axes.get_ylim() - self._backup_limits = [xlimits[0], xlimits[1], ylimits[0], ylimits[1]] + self._backup_limits = [*axes.get_xlim(), *axes.get_ylim()] + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - legend = axes.legend() - legend.set_visible(plotting_context.use_legend) - axes.grid(plotting_context.use_grid) + axes.legend() + + self.toggle_legend(plotting_context.use_legend) + self.toggle_grid(plotting_context.use_grid) + self.check_curve_lengths() self.offset_curves() + + def _plot_single( + self, + axes: Axes, + curve: tuple[np.ndarray, np.ndarray] | tuple[np.ndarray], + databundle: PlotArgs, + *, + label: str, + colour: tuple[float, float, float], + ): + """Plot a single curve to axes. + + Parameters + ---------- + axes : Axes + Axis to plot to. + curve : tuple[np.ndarray, np.ndarray] | tuple[np.ndarray] + Curve to plot. + databundle : PlotArgs + Data to plot. + colour : tuple[float, float, float] + Curve colour. + """ + + lines: list[Line2D] = axes.plot( + *curve, + linestyle=databundle.line_style, + label=self.label(label, len(self._active_curves)), + color=colour, + ) + + for line in lines: + try: + line.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + line.set_marker(int(databundle.marker)) + + self.height_max = np.nanmax([self.height_max, np.nanmax(line.get_ydata())]) + self.length_max = np.nanmax([self.length_max, np.nanmax(line.get_xdata())]) + + self._active_curves.extend(lines) + self._backup_curves.extend( + (line.get_xdata(), line.get_ydata()) for line in lines + ) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py index 057296d2a4..13e36f470b 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING import numpy as np -from more_itertools import collapse, prepend, transpose +from more_itertools import collapse, prepend, transpose, unzip from MDANSE.Framework.Units import measure from MDANSE.MLogging import LOG @@ -29,6 +29,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from qtpy.QtWidgets import QTextBrowser from MDANSE_GUI.Tabs.Models.PlottingContext import ( @@ -302,21 +303,14 @@ def process_2D_data( dataset._data.shape[0] if flip_array else dataset._data.shape[1] ) - multi_curves = np.vstack( - list( - dataset.curves_vs_axis( - (best_unit, best_axis), max_limit=curves_limit, skip_label_text=True - ).values() - ) - ) # Add corner nil xaxis = prepend("_", new_axes[axis_numbers[flip_array]].flat) # Add axes to data - data_lines = zip( - dataset._curve_labels.values(), - multi_curves, - strict=True, + data_lines = list( + dataset.curves_vs_axis( + (best_unit, best_axis), max_limit=curves_limit, skip_label_text=True + ) ) # Put xaxis in @@ -501,9 +495,9 @@ def plot( self, plotting_context: PlottingContext, figure: QTextBrowser = None, - colours=None, - update_only=False, - toolbar=None, + colours: None = None, + update_only: bool = False, + toolbar: Toolbar | None = None, ): """Show data as text. @@ -516,12 +510,12 @@ def plot( Data model containing the data sets to be shown figure : QTextBrowser, optional Target widget, an instance of QTextBrowser - colours : _type_, optional - ignored here + colours : None, optional + Ignored here update_only : bool, optional - ignored - toolbar : _type_, optional - ignored + Ignored + toolbar : Toolbar, optional + Ignored. """ target = self.get_figure(figure) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py index a41b6beca0..79fed0ccdc 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors.py @@ -20,11 +20,16 @@ import numpy as np import numpy.typing as npt +from matplotlib.axes import Axes +from matplotlib.lines import Line2D +from more_itertools import ilen from MDANSE.MLogging import LOG +from MDANSE_GUI.Tabs.Models.PlottingContext import PlotArgs from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as Toolbar from matplotlib.figure import Figure from MDANSE.util_types import FloatArray @@ -38,7 +43,7 @@ def violin_plot_width(positions: FloatArray) -> float: @Plotter.register("Vectors") class Vectors(Plotter): - """Plots all the datasets in the same figure.""" + """Plots summarised Q-Vectors to one figure.""" def __init__(self) -> None: """Initialise all ploting parameters to default values.""" @@ -52,9 +57,9 @@ def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def clear(self, figure: Figure | None = None): """Clear the figure.""" @@ -78,7 +83,7 @@ def change_normalisation(self, new_value: dict[str, Any]): Parameters ---------- new_value : dict[str, Any] - parameters as in NORMALISATION_DEFAULTS + Parameters as in NORMALISATION_DEFAULTS. """ super().change_normalisation(new_value) @@ -88,20 +93,20 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar: type | None = None, + toolbar: Toolbar | None = None, ): """Plot all datasets in the same figure. Parameters ---------- plotting_context : PlottingContext - Data model storing the data to be plotted + Data model storing the data to be plotted. figure : Figure, optional - Matplotlib figure instance for plotting, by default None + Matplotlib figure instance for plotting, by default None. update_only : bool, optional - If true, try to re-use zoom settings, by default False - toolbar : _type_, optional - GUI instance of the matplotlib toolbar, by default None + If true, try to re-use zoom settings, by default False. + toolbar : Toolbar, optional + GUI instance of the matplotlib toolbar, by default None. """ self.enable_slider(allow_slider=False) @@ -109,14 +114,16 @@ def plot( if target is None: return + if toolbar is not None: self._toolbar = toolbar self._figure = target self._normalisation_errors = [] self._axes = [] + self._active_curves = [] + self._backup_curves = [] self.apply_settings(plotting_context) - x_axis_labels = [] if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") @@ -124,115 +131,93 @@ def plot( if not plotting_context.datasets(): target.clear() - target.canvas.draw() + self.plot_blank() + return - single_plot_stack = [222, 221] - label_stack = [ - "Used", - "Found", - ] + gs = self._figure.add_gridspec(2, 2) + labels = iter(("Used", "Found")) for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - try: - best_unit, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - best_unit, best_axis = dataset.longest_axis() plotlabel = databundle.legend_label - x_axis_labels.append(dataset.x_axis_label(best_axis)) - if dataset._name == "Available vectors": - axes = target.add_subplot(single_plot_stack.pop()) - if dataset._n_dim == 2: - temp_curves = [] - for value in dataset.data.T: - [temp] = axes.plot( - dataset.x_axis(best_axis), - value, - label=label_stack.pop(), + x_axis_labels = [dataset.x_axis_label(databundle.main_axis)] + + match dataset._name: + case "Available vectors": + axes = self._figure.add_subplot(gs[0, 0]) + + for _, curve in dataset.curves_vs_axis(databundle.main_axis): + lab = plotlabel if dataset._n_dim != 2 else next(labels) + + axes.bar( + *curve, + label=lab, ) - temp_curves.append(temp) - if not label_stack: - break - else: - temp_curves = axes.plot( - dataset.x_axis(best_axis), + + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + axes.set_title(dataset._name) + self._axes.append(axes) + + case r"<|q|> - q$_{target}$": + axes = self._figure.add_subplot(gs[0, 1]) + xvals = dataset.x_axis(databundle.main_axis) + axes.violinplot( dataset.data, - linestyle=databundle.line_style, - label=plotlabel, - color=databundle.colour, + positions=xvals, + widths=violin_plot_width(xvals), ) - for temp in temp_curves: - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) - self._axes.append(axes) - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_title(dataset._name) - elif dataset._name == r"<|q|> - q$_{target}$": - axes = target.add_subplot(single_plot_stack.pop()) - xvals = dataset.x_axis(best_axis) - axes.violinplot( - dataset.data, - positions=xvals, - widths=violin_plot_width(xvals), - ) - self._axes.append(axes) - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - axes.set_title(dataset._name) - elif dataset._name == "Shell population": - axes = target.add_subplot(212) - multi_curves = dataset.curves_vs_axis( - (best_unit, best_axis), - max_limit=self._curve_limit_per_dataset, - ) - x_axis = dataset.x_axis(best_axis) - bottom = np.zeros(len(x_axis)) - width = 0.8 * abs(np.mean(x_axis[1:] - x_axis[:-1])) - self._axes.append(axes) - add_legend_placeholder = ( - len(multi_curves) > self._legend_limit_for_histogram - ) - add_last_entry = False - axes.set_xlabel(", ".join(np.unique(x_axis_labels))) - for bar_index, (key, value) in enumerate(multi_curves.items()): - legend_label = plotlabel + ":" + dataset._curve_labels[key] - try: - axes.bar( - x_axis, - value, - label=legend_label - if bar_index < self._legend_limit_for_histogram - or add_last_entry - else None, - bottom=bottom, - width=width, - edgecolor="black", + self._axes.append(axes) + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + axes.set_title(dataset._name) + + case "Shell population": + axes = self._figure.add_subplot(gs[1, :]) + + n_curves = ilen( + dataset.curves_vs_axis( + databundle.main_axis, + max_limit=self._curve_limit_per_dataset, ) - bottom += value - except ValueError: - LOG.error(f"Plotting failed for {plotlabel} using {best_axis}") - LOG.error(f"x_axis={dataset._axes[best_axis]}") - LOG.error(f"values={value}") - return - else: - if ( - add_legend_placeholder - and bar_index == len(multi_curves) - 2 - ): - add_legend_placeholder = False - add_last_entry = True + ) + + x_axis = dataset.x_axis(databundle.main_axis) + bottom = np.zeros_like(x_axis) + width = 0.8 * abs(np.mean(np.diff(x_axis))) + + self._axes.append(axes) + + axes.set_xlabel(", ".join(np.unique(x_axis_labels))) + for ind, (label, curve) in enumerate( + dataset.curves_vs_axis( + databundle.main_axis, + max_limit=self._curve_limit_per_dataset, + ) + ): + try: axes.bar( - x_axis, - 0, - label="...", - color=target.get_facecolor(), + *curve, + label=self.label( + ind=ind, + n_curves=n_curves, + limit=self._legend_limit_for_histogram, + label=f"{plotlabel}:{label}", + ), + bottom=bottom, + width=width, + edgecolor="black", ) + bottom += curve[1] + + except ValueError: + x, y = curve + LOG.error( + f"Plotting failed for {plotlabel} using {databundle.main_axis}" + ) + LOG.error(f"x_axis={x}") + LOG.error(f"values={y}") + raise + for axindex, axes in enumerate(self._axes): if update_only: plot_limits = self._backup_limits[axindex] @@ -246,7 +231,7 @@ def plot( else: for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - best_unit, best_axis = dataset.longest_axis() + _, best_axis = dataset.longest_axis() if dataset._name == r"<|q|> - q$_{target}$": axes.clear() xvals = dataset.x_axis(best_axis) @@ -279,4 +264,4 @@ def plot( axes.autoscale() if self._toolbar is not None: self._toolbar.update() - target.canvas.draw() + self._figure.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py index a067f8381d..c0100bcc3b 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Vectors3D.py @@ -24,6 +24,7 @@ from MDANSE_GUI.Tabs.Plotters.Plotter import Plotter if TYPE_CHECKING: + from matplotlib.backends.backend_qt import NavigationToolbar2QT from matplotlib.figure import Figure from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -45,9 +46,9 @@ def slider_labels(self) -> list[str]: """Return labels to show that sliders are not used.""" return ["Inactive", "Inactive"] - def slider_limits(self) -> list[str]: + def slider_limits(self) -> list[tuple[float, float, float]]: """Return generic slider limit values.""" - return self._number_of_sliders * [[-1.0, 1.0, 0.01]] + return [(-1.0, 1.0, 0.01)] * self._number_of_sliders def clear(self, figure: Figure | None = None): """Clear the figure.""" @@ -81,7 +82,7 @@ def plot( plotting_context: PlottingContext, figure: Figure | None = None, update_only: bool = False, - toolbar: type | None = None, + toolbar: NavigationToolbar2QT | None = None, ): """Plot all datasets in the same figure. @@ -99,60 +100,88 @@ def plot( """ self.enable_slider(allow_slider=False) target = self.get_figure(figure) + if target is None: return if toolbar is not None: self._toolbar = toolbar + self._figure = target self._normalisation_errors = [] self._axes = [] self.apply_settings(plotting_context) + if plotting_context.set_axes() is None: LOG.debug("Axis check failed.") return + if not plotting_context.datasets(): target.clear() target.canvas.draw() - single_plot_stack = [221, 222, 223, 224] - first_shared_axes = target.add_subplot(single_plot_stack[0]) - for ds_index, databundle in enumerate(plotting_context.datasets().values()): + + gs = self._figure.add_gridspec(2, 2) + shared_axes = figure.add_subplot(gs[0]) + self._axes = [shared_axes] + + for databundle in plotting_context.datasets().values(): dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() plotlabel = databundle.legend_label - if ds_index < 3: - axes = first_shared_axes - temp_curves = axes.plot( - dataset.x_axis(best_axis), - dataset.data, - linestyle=databundle.line_style, + + x_axis_labels = [dataset.x_axis_label(databundle.main_axis)] + + bins = dataset.x_axis(databundle.main_axis) + + if "angle" in plotlabel: + axes = self._figure.add_subplot( + gs[2 + plotlabel.startswith("Azimuthal")] + ) + + axes.bar( + bins, + dataset.data[0, :], + linestyle="none", label=plotlabel, color=databundle.colour, + width=np.mean(np.diff(bins)), + ) + axes.bar( + bins, + dataset.data[1, :], + linestyle=databundle.line_style, + label="Unique counts", + color="none", + edgecolor="black", + width=np.mean(np.diff(bins)), ) - for temp in temp_curves: - try: - temp.set_marker(databundle.marker) - except ValueError: - with contextlib.suppress(Exception): - temp.set_marker(int(databundle.marker)) self._axes.append(axes) - axes.set_xlabel(dataset.x_axis_label(best_axis)) + axes.set_xlabel(", ".join(x_axis_labels)) axes.set_title(dataset._name) - elif ds_index < 4: - axes = target.add_subplot( - single_plot_stack[ds_index - 2], projection="3d" + elif "vs" in plotlabel: + [curve] = shared_axes.plot( + bins, + dataset.data, + linestyle=databundle.line_style, + label=plotlabel, + color=databundle.colour, ) + try: + curve.set_marker(databundle.marker) + except ValueError: + with contextlib.suppress(Exception): + curve.set_marker(int(databundle.marker)) + + shared_axes.set_xlabel(", ".join(x_axis_labels)) + + elif "in 3D" in plotlabel: + axes = self._figure.add_subplot(gs[1], projection="3d") + all_coords = np.concatenate( [dataset.x_axis("q_x"), dataset.x_axis("q_y"), dataset.data] ) self.axis3d_min = np.min(all_coords) self.axis3d_max = np.max(all_coords) - temp_curves = axes.scatter( + + axes.scatter( dataset.x_axis("q_x"), dataset.x_axis("q_y"), dataset.data, @@ -165,38 +194,7 @@ def plot( axes.set_xlabel(dataset.x_axis_label("q_x")) axes.set_ylabel(dataset.x_axis_label("q_y")) axes.set_title(dataset._name) - else: - dataset = databundle.dataset - try: - _, best_axis = ( - dataset._axes_units[databundle.main_axis], - databundle.main_axis, - ) - except KeyError: - _, best_axis = dataset.longest_axis() - plotlabel = databundle.legend_label - axes = target.add_subplot(single_plot_stack[ds_index - 2]) - bins = dataset.x_axis(best_axis) - temp_curves = axes.bar( - bins, - dataset.data[0, :], - linestyle="none", - label=plotlabel, - color=databundle.colour, - width=np.mean(np.diff(bins)), - ) - _ = axes.bar( - bins, - dataset.data[1, :], - linestyle=databundle.line_style, - label="Unique counts", - color="none", - edgecolor="black", - width=np.mean(np.diff(bins)), - ) - self._axes.append(axes) - axes.set_xlabel(dataset.x_axis_label(best_axis)) - axes.set_title(dataset._name) + for axes in self._axes: if update_only: try: @@ -214,15 +212,19 @@ def plot( else: xlimits, ylimits = axes.get_xlim(), axes.get_ylim() self._backup_limits = [xlimits[0], xlimits[1], ylimits[0], ylimits[1]] + for axes in self._axes: legend = axes.legend() legend.set_visible(plotting_context.use_legend) axes.grid(plotting_context.use_grid) axes.relim() axes.autoscale() + self.axes3d.set_xlim([self.axis3d_min, self.axis3d_max]) self.axes3d.set_ylim([self.axis3d_min, self.axis3d_max]) self.axes3d.set_zlim([self.axis3d_min, self.axis3d_max]) + if self._toolbar is not None: self._toolbar.update() + target.canvas.draw() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py index 1085cff763..ccedf4b705 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/__init__.py @@ -16,6 +16,7 @@ from __future__ import annotations from .Grid import Grid as Grid +from .Grouped import Grouped as Grouped from .Heatmap import Heatmap as Heatmap from .Plotter import Plotter as Plotter from .Single import Single as Single diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py index e8d56d0428..9ca5741840 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py @@ -436,7 +436,9 @@ def q_data(elem: str) -> Generator: common_bins = common_bins[start_index:] qmod_histograms = [np.histogram(qmods, common_bins)[0] for qmods in modq_per_shell] stacked_histograms = np.vstack(qmod_histograms) + xvals = common_bins[1:] - np.diff(common_bins) / 2 + nvec_per_q = SingleDataset( "Available vectors", None, diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py index b9183e1a5a..1c53545c07 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Visualisers/PlotWidget.py @@ -45,6 +45,10 @@ from MDANSE_GUI.Widgets.RestrictedSlider import RestrictedSlider if TYPE_CHECKING: + from matplotlib.backends.backend_qt5agg import ( + NavigationToolbar2QT as Toolbar, + ) + from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext @@ -168,10 +172,10 @@ def set_values(self, new_values: list[float]): @Slot() def slider_to_box(self): """Update spin boxes if slider is moving.""" - vals = np.zeros_like(self._valarray) clicks = np.zeros_like(self._clickarray) for ns, slider in enumerate(self._sliders): clicks[ns] = slider.value() + vals = self._minarray + clicks * self._steparray for ns, box in enumerate(self._spinboxes): box.setValue(vals[ns]) @@ -181,7 +185,7 @@ def box_to_slider(self): """Update sliders if spin boxes have changed.""" with block_signals(self): vals = np.zeros_like(self._valarray) - clicks = np.zeros_like(self._clickarray) + for ns, box in enumerate(self._spinboxes): vals[ns] = box.value() clicks = np.round((vals - self._minarray) / self._steparray).astype(int) @@ -207,9 +211,8 @@ class PlotWidget(QWidget): reset_slider_values = Signal(bool) change_slider_coupling = Signal(bool) - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, plotter_type: str = "Single", **kwargs) -> None: """Create an empty plot with the default plotter.""" - plotter_type = kwargs.pop("plotter_type", "Single") super().__init__(*args, **kwargs) self._plotter = None self._sliderpack = None @@ -245,10 +248,12 @@ def set_plotter(self, plotter_option: str): except Exception: self._plotter = Plotter() self._plotter._figure = self._figure + self.change_slider_labels.emit(self._plotter.slider_labels()) self.change_slider_limits.emit(self._plotter.slider_limits()) self.change_slider_coupling.emit(self._plotter.sliders_coupled()) self.reset_slider_values.emit(self._plotter._value_reset_needed) + self._plotter._slider_reference = self._sliderpack self._sliderpack.setEnabled(False) self.plot_data() @@ -333,13 +338,16 @@ def plot_data(self, update_only=False): return if self._plotting_context is None: return - self._figure.set_layout_engine("tight") + + # self._figure.set_layout_engine("tight") + self._plotter.plot( self._plotting_context, self._figure, update_only=update_only, toolbar=self._toolbar, ) + self._normaliser.update_spinbox_limits(self._plotter.curve_length_limit) self._normaliser.collect_values() self._sliderpack.collect_values() diff --git a/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py b/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py index 5b729618b2..383a04aff5 100644 --- a/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py +++ b/MDANSE_GUI/Tests/UnitTests/test_PlottingContext.py @@ -1,12 +1,12 @@ -import pytest -import tempfile -import os +from __future__ import annotations + from pathlib import Path import h5py import numpy as np - -from qtpy import QtGui, QtCore, QtWidgets +import pytest +from more_itertools import ilen, nth, first +from qtpy import QtCore, QtGui, QtWidgets from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext, SingleDataset @@ -62,18 +62,16 @@ def test_available_x_axes_2d(file_2d): def test_curves_vs_axis_2d_long_axis(file_2d): temp = SingleDataset("f(q,t)_total", file_2d) temp.set_data_limits("0;1;2;3") - curves = temp.curves_vs_axis(("ps", "time"), max_limit=12) - print(len(curves)) - assert len(curves) == 4 - print(curves.keys()) - assert len(curves[(0,)]) == 501 + curves = temp.curves_vs_axis("time", max_limit=12) + assert ilen(curves) == 4 + label, (x, y) = first(temp.curves_vs_axis("time", max_limit=12)) + assert len(x) == 501 def test_curves_vs_axis_2d_short_axis(file_2d): temp = SingleDataset("f(q,t)_total", file_2d) temp.set_data_limits("2;3;4;5") - curves = temp.curves_vs_axis(("1/nm", "q"), max_limit=12) - print(len(curves)) - assert len(curves) == 4 - print(curves.keys()) - assert len(curves[(2,)]) == 10 + curves = temp.curves_vs_axis("q", max_limit=12) + assert ilen(curves) == 4 + label, (x, y) = nth(temp.curves_vs_axis("q", max_limit=12), 3) + assert len(x) == 10