diff --git a/build_tools/requirements.txt b/build_tools/requirements.txt index 76210f56dc..7100bc6cc5 100644 --- a/build_tools/requirements.txt +++ b/build_tools/requirements.txt @@ -38,4 +38,4 @@ zope requests # Alphabetized list of OS-specific packages -pywin32; platform_system == "Windows" +pywin32; platform_system == "Windows" \ No newline at end of file diff --git a/conftest.py b/conftest.py index ac23acb31e..877460d1ca 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,10 @@ import os import platform +# Must be set before any Qt or matplotlib import so that matplotlib's +# qt_compat.py selects PySide6 instead of raising an ImportError. +os.environ.setdefault("QT_API", "pyside6") + import pytest # Run the tests in offscreen mode diff --git a/docs/sphinx-docs/source/user/menu_bar.rst b/docs/sphinx-docs/source/user/menu_bar.rst index 0ab0c88d84..f5f8543689 100644 --- a/docs/sphinx-docs/source/user/menu_bar.rst +++ b/docs/sphinx-docs/source/user/menu_bar.rst @@ -22,12 +22,38 @@ Edit ---- The Edit option allows you to: +- undo the most recent parameter change (``Ctrl+Z``); +- redo a previously undone parameter change (``Ctrl+Y``); - copy and paste parameters between *SasView* analysis windows; - copy parameters from a *SasView* analysis window to the Clipboard as either tab-delimited text (compatible with Microsoft Excel) or LaTex-wrapped text; - generate a summary 'Report' of the most recent analysis performed; - reset parameter values in the P(r) Inversion analysis page; - freeze/copy fit results as separate data sets. +The **Undo** and **Redo** commands are also available as buttons in the +main toolbar, and their tooltips dynamically show which action will be +reverted (e.g. "Undo Change radius"). + +Undo/Redo is supported in the following perspectives: + +- **Model Fitting** – parameter changes in individual fit pages; +- **P(r) Inversion** – parameter changes in each inversion page; +- **Invariant** – parameter changes in the invariant calculator; +- **Size Distribution** – parameter changes in the size distribution + calculator; +- **Correlation Function** – parameter changes in the correlation + function analysis. + +Each tab or page maintains its own independent undo history, so +switching between perspectives or fit pages preserves the undo/redo +state of each. The default history depth is 200 actions per tab, +configurable via the ``UNDO_STACK_MAX_DEPTH`` setting in +:file:`src/sas/system/config/config.py`. + +.. note:: Undo/Redo is automatically suppressed during programmatic + operations such as loading a project or applying fit results, + to prevent spurious entries from cluttering the undo history. + View ---- The View option allows you to: diff --git a/src/ascii_dialog/col_editor.py b/src/ascii_dialog/col_editor.py new file mode 100644 index 0000000000..6348bf805f --- /dev/null +++ b/src/ascii_dialog/col_editor.py @@ -0,0 +1,89 @@ +from typing import cast + +from PySide6.QtCore import Signal, Slot +from PySide6.QtWidgets import QHBoxLayout, QWidget + +from sasdata.ascii_reader_metadata import bidirectional_pairings +from sasdata.quantities.units import NamedUnit + +from ascii_dialog.column_unit import ColumnUnit + + +class ColEditor(QWidget): + """An editor widget which allows the user to specify the columns of the data + from a set of options based on which dataset type has been selected.""" + column_changed = Signal() + + def __init__(self, cols: int, options: list[str]): + super().__init__() + + self.cols = cols + self.options = options + self.layout = QHBoxLayout(self) + self.option_widgets: list[ColumnUnit] = [] + for _ in range(cols): + new_widget = ColumnUnit(self.options) + new_widget.column_changed.connect(self.onColumnUpdate) + self.layout.addWidget(new_widget) + self.option_widgets.append(new_widget) + + @Slot() + def onColumnUpdate(self): + column_changed = cast(ColumnUnit, self.sender()) + pairing = bidirectional_pairings.get(column_changed.currentColumn) + if pairing is not None: + for col_unit in self.option_widgets: + # Second condition is important because otherwise, this event will keep being called, and the GUI will + # go into an infinite loop. + if col_unit.currentColumn == pairing and col_unit.currentUnit != column_changed.currentUnit: + col_unit.currentUnit = column_changed.currentUnit + + def setCols(self, new_cols: int): + """Set the amount of columns for the user to edit.""" + + # Decides whether we need to extend the current set of combo boxes, or + # remove some. + if self.cols < new_cols: + for _ in range(new_cols - self.cols): + new_widget = ColumnUnit(self.options) + new_widget.column_changed.connect(self.onColumnUpdate) + self.layout.addWidget(new_widget) + self.option_widgets.append(new_widget) + + self.cols = new_cols + if self.cols > new_cols: + excess_cols = self.cols - new_cols + length = len(self.option_widgets) + excess_combo_boxes = self.option_widgets[length - excess_cols:length] + for box in excess_combo_boxes: + self.layout.removeWidget(box) + box.setParent(None) + self.option_widgets = self.option_widgets[0:length - excess_cols] + self.cols = new_cols + self.column_changed.emit() + + def setColOrder(self, cols: list[str]): + """Sets the series of currently selected columns to be cols, in that + order. If there are not enough column widgets include as many of the + columns in cols as possible. + + """ + try: + for i, col_name in enumerate(cols): + self.option_widgets[i].setCurrentColumn(col_name) + except IndexError: + pass # Can ignore because it means we've run out of widgets. + + def colNames(self) -> list[str]: + """Get a list of all of the currently selected columns.""" + return [widget.currentColumn for widget in self.option_widgets] + + @property + def columns(self) -> list[tuple[str, NamedUnit | None]]: + return [(widget.currentColumn, widget.currentUnit if widget.currentColumn != "" else None) for widget in self.option_widgets] + + def replaceOptions(self, new_options: list[str]) -> None: + """Replace options from which the user can choose for each column.""" + self.options = new_options + for widget in self.option_widgets: + widget.replaceOptions(new_options) diff --git a/src/ascii_dialog/column_unit.py b/src/ascii_dialog/column_unit.py new file mode 100644 index 0000000000..323d5ef0b5 --- /dev/null +++ b/src/ascii_dialog/column_unit.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +from PySide6.QtCore import Signal, Slot +from PySide6.QtGui import QRegularExpressionValidator +from PySide6.QtWidgets import QComboBox, QHBoxLayout, QSizePolicy, QWidget + +from sasdata.dataset_types import unit_kinds +from sasdata.default_units import defaults_or_fallback +from sasdata.quantities.units import NamedUnit + +from ascii_dialog.unit_selector import UnitSelector + + +def configure_size_policy(combo_box: QComboBox) -> None: + policy = combo_box.sizePolicy() + policy.setHorizontalPolicy(QSizePolicy.Policy.Ignored) + combo_box.setSizePolicy(policy) + +class ColumnUnit(QWidget): + """Widget with 2 combo boxes: one allowing the user to pick a column, and + another to specify the units for that column.""" + def __init__(self, options) -> None: + super().__init__() + self.col_widget = self.createColComboBox(options) + self.unit_widget = self.createUnitComboBox(self.col_widget.currentText()) + self.layout = QHBoxLayout(self) + self.layout.addWidget(self.col_widget) + self.layout.addWidget(self.unit_widget) + self.current_option: str + + column_changed = Signal() + + def createColComboBox(self, options: list[str]) -> QComboBox: + """Create the combo box for specifying the column based on the given + options.""" + new_combo_box = QComboBox() + configure_size_policy(new_combo_box) + for option in options: + new_combo_box.addItem(option) + new_combo_box.setEditable(True) + validator = QRegularExpressionValidator(r"[a-zA-Z0-9]+") + new_combo_box.setValidator(validator) + new_combo_box.currentTextChanged.connect(self.onOptionChange) + return new_combo_box + + def createUnitComboBox(self, selected_option: str) -> QComboBox: + """Create the combo box for specifying the unit for selected_option""" + new_combo_box = QComboBox() + configure_size_policy(new_combo_box) + new_combo_box.setEditable(True) + self.updateUnits(new_combo_box, selected_option) + new_combo_box.currentTextChanged.connect(self.onUnitChange) + return new_combo_box + + def updateUnits(self, unit_box: QComboBox, selected_option: str): + unit_box.clear() + self.current_option = selected_option + # Use the list of preferred units but fallback to the first 5 if there aren't any for this particular column. + if self.current_option == '': + unit_box.setDisabled(True) + else: + unit_box.setDisabled(False) + unit_options = defaults_or_fallback(self.current_option) + option_symbols = [unit.symbol for unit in unit_options] + for option in option_symbols[:5]: + unit_box.addItem(option) + unit_box.addItem('Select More') + + + def replaceOptions(self, new_options) -> None: + """Replace the old options for the column with new_options""" + self.col_widget.clear() + self.col_widget.addItems(new_options) + + def setCurrentColumn(self, new_column_value: str) -> None: + """Change the current selected column to new_column_value""" + self.col_widget.setCurrentText(new_column_value) + self.updateUnits(self.unit_widget, new_column_value) + + + @Slot() + def onOptionChange(self): + # If the new option is empty string, its probably because the current + # options have been removed. Can safely ignore this. + self.column_changed.emit() + new_option = self.col_widget.currentText() + if new_option == '': + return + try: + self.updateUnits(self.unit_widget, new_option) + except KeyError: + # Means the units for this column aren't known. This shouldn't be + # the case in the real version so for now we'll just clear the unit + # widget. + self.unit_widget.clear() + + @Slot() + def onUnitChange(self): + new_text = self.unit_widget.currentText() + if new_text == 'Select More': + selector = UnitSelector(unit_kinds[self.col_widget.currentText()].name, False) + selector.exec() + # We need the selection unit in the list of options, or else QT has some dodgy behaviour. + self.unit_widget.insertItem(-1, selector.selected_unit.symbol) + self.unit_widget.setCurrentText(selector.selected_unit.symbol) + # This event could get triggered when the units have just been cleared, and not actually updated. We don't want + # to trigger it in this case. + elif not new_text == '': + self.column_changed.emit() + + @property + def currentColumn(self): + """The currently selected column.""" + return self.col_widget.currentText() + + @property + def currentUnit(self) -> NamedUnit: + """The currently selected unit.""" + current_unit_symbol = self.unit_widget.currentText() + for unit in unit_kinds[self.current_option].units: + if current_unit_symbol == unit.symbol: + return unit + # This error shouldn't really happen so if it does, it indicates there is a bug in the code. + raise ValueError("Current unit doesn't seem to exist") + + @currentUnit.setter + def currentUnit(self, new_value: NamedUnit): + self.unit_widget.setCurrentText(new_value.symbol) diff --git a/src/ascii_dialog/constants.py b/src/ascii_dialog/constants.py new file mode 100644 index 0000000000..bf789cf38c --- /dev/null +++ b/src/ascii_dialog/constants.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python + + +TABLE_MAX_ROWS = 1000 +NOFILE_TEXT = "Click the button below to load a file." diff --git a/src/ascii_dialog/dialog.py b/src/ascii_dialog/dialog.py new file mode 100644 index 0000000000..6d22bacd35 --- /dev/null +++ b/src/ascii_dialog/dialog.py @@ -0,0 +1,503 @@ +from os import path + +from PySide6.QtCore import QModelIndex, QPoint, Slot +from PySide6.QtGui import QColor, QCursor, Qt +from PySide6.QtWidgets import ( + QAbstractScrollArea, + QApplication, + QCheckBox, + QComboBox, + QDialog, + QFileDialog, + QHBoxLayout, + QHeaderView, + QLabel, + QMessageBox, + QPushButton, + QSpacerItem, + QSpinBox, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata +from sasdata.dataset_types import DatasetType, dataset_types, one_dim, sesans, two_dim +from sasdata.guess import guess_column_count, guess_columns, guess_starting_position +from sasdata.temp_ascii_reader import AsciiReaderParams, load_data, split_line + +from ascii_dialog.col_editor import ColEditor +from ascii_dialog.constants import TABLE_MAX_ROWS +from ascii_dialog.row_status_widget import RowStatusWidget +from ascii_dialog.selection_menu import SelectionMenu +from ascii_dialog.warning_label import WarningLabel +from metadata_filename_gui.metadata_filename_dialog import MetadataFilenameDialog + +dataset_dictionary = dict([(dataset.name, dataset) for dataset in [one_dim, two_dim, sesans]]) + +class AsciiDialog(QDialog): + """A dialog window allowing the user to adjust various properties regarding + how an ASCII file should be interpreted. This widget allows the user to + visualise what the data will look like with the parameter the user has + selected. + + """ + def __init__(self): + super().__init__() + + self.files: dict[str, list[str]] = {} + self.files_full_path: dict[str, str] = {} + self.files_is_included: dict[str, list[bool]] = {} + # This is useful for whenever the user wants to reopen the metadata editor. + self.internal_metadata: AsciiReaderMetadata = AsciiReaderMetadata() + self.current_filename: str | None = None + + self.seperators: dict[str, bool] = { + 'Comma': True, + 'Whitespace': True, + 'Tab': True + } + + self.setWindowTitle('ASCII File Reader') + + # Filename, unload button, and edit metadata button. + + self.filename_unload_layout = QHBoxLayout() + self.unloadButton = QPushButton("Unload") + self.unloadButton.setDisabled(True) + self.unloadButton.clicked.connect(self.unload) + # Filename chooser + self.filename_chooser = QComboBox() + self.filename_chooser.currentTextChanged.connect(self.updateCurrentFile) + + self.filename_unload_layout.addWidget(self.filename_chooser) + self.filename_unload_layout.addWidget(self.unloadButton) + + + self.select_button = QPushButton("Select File") + self.select_button.clicked.connect(self.loadFile) + + ## Dataset type selection + self.dataset_layout = QHBoxLayout() + self.dataset_label = QLabel("Dataset Type") + self.dataset_combobox = QComboBox() + for name in dataset_types: + self.dataset_combobox.addItem(name) + self.dataset_layout.addWidget(self.dataset_label) + self.dataset_layout.addWidget(self.dataset_combobox) + + ## Seperator + self.sep_layout = QHBoxLayout() + + self.sep_widgets: list[QWidget] = [] + self.sep_label = QLabel('Seperators:') + self.sep_layout.addWidget(self.sep_label) + for seperator_name, value in self.seperators.items(): + check_box = QCheckBox(seperator_name) + check_box.setChecked(value) + check_box.clicked.connect(self.seperatorToggle) + self.sep_widgets.append(check_box) + self.sep_layout.addWidget(check_box) + + ## Starting Line + self.startline_layout = QHBoxLayout() + self.startline_label = QLabel('Starting Line') + self.startline_entry = QSpinBox() + self.startline_entry.setMinimum(1) + self.startline_entry.valueChanged.connect(self.updateStartpos) + self.startline_layout.addWidget(self.startline_label) + self.startline_layout.addWidget(self.startline_entry) + + ## Column Count + self.colcount_layout = QHBoxLayout() + self.colcount_label = QLabel('Number of Columns') + self.colcount_entry = QSpinBox() + self.colcount_entry.setMinimum(1) + self.colcount_entry.valueChanged.connect(self.updateColcount) + self.colcount_layout.addWidget(self.colcount_label) + self.colcount_layout.addWidget(self.colcount_entry) + + ## Column Editor + options = self.datasetOptions() + self.col_editor: ColEditor = ColEditor(self.colcount_entry.value(), options) + self.dataset_combobox.currentTextChanged.connect(self.changeDatasetType) + self.col_editor.column_changed.connect(self.updateColumn) + + ## Data Table + + self.table = QTableWidget() + self.table.show() + # Make the table readonly + self.table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers) + # The table's width will always resize to fit the amount of space it has. + self.table.setSizeAdjustPolicy(QAbstractScrollArea.SizeAdjustPolicy.AdjustToContents) + # Add the context menu + self.table.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.table.customContextMenuRequested.connect(self.showContextMenu) + + # Warning Label + self.warning_label: WarningLabel = WarningLabel(self.requiredMissing(), self.duplicateColumns()) + + # Done button + # TODO: Not entirely sure what to call/label this. Just going with 'done' for now. + + self.done_line = QHBoxLayout() + self.cancel_button = QPushButton('Cancel') + self.cancel_button.clicked.connect(self.onCancel) + self.done_line_spacer = QSpacerItem(70, 0) + self.editMetadataButton = QPushButton("Edit Metadata") + self.editMetadataButton.setDisabled(True) + self.editMetadataButton.clicked.connect(self.editMetadata) + self.done_button = QPushButton('Done') + self.done_button.clicked.connect(self.onDoneButton) + self.done_line.addWidget(self.cancel_button) + self.done_line.addItem(self.done_line_spacer) + self.done_line.addWidget(self.editMetadataButton) + self.done_line.addWidget(self.done_button) + + self.layout = QVBoxLayout(self) + + self.layout.addLayout(self.filename_unload_layout) + self.layout.addWidget(self.select_button) + self.layout.addLayout(self.dataset_layout) + self.layout.addLayout(self.sep_layout) + self.layout.addLayout(self.startline_layout) + self.layout.addLayout(self.colcount_layout) + self.layout.addWidget(self.col_editor) + self.layout.addWidget(self.table) + self.layout.addWidget(self.warning_label) + self.layout.addLayout(self.done_line) + + @property + def startingPos(self) -> int: + return self.startline_entry.value() - 1 + + @startingPos.setter + def startingPos(self, value: int): + self.startline_entry.setValue(value + 1) + + @property + def rawCsv(self) -> list[str] | None: + if self.current_filename is None: + return None + return self.files[self.current_filename] + + @property + def rowsIsIncluded(self) -> list[bool] | None: + if self.current_filename is None: + return None + return self.files_is_included[self.current_filename] + + @property + def excludedLines(self) -> set[int]: + return set([i for i, included in enumerate(self.rowsIsIncluded) if not included]) + + def splitLine(self, line: str) -> list[str]: + """Split a line in a CSV file based on which seperators the user has + selected on the widget. + + """ + return split_line(self.seperators, line) + + def attemptGuesses(self) -> None: + """Attempt to guess various parameters of the data to provide some + default values. Uses the guess.py module + + """ + split_csv = [self.splitLine(line.strip()) for line in self.rawCsv] + + # TODO: I'm not sure if there is any point in holding this initial value. Can possibly be refactored. + self.initial_starting_pos = guess_starting_position(split_csv) + + guessed_colcount = guess_column_count(split_csv, self.initial_starting_pos) + self.col_editor.setCols(guessed_colcount) + + columns = guess_columns(guessed_colcount, self.currentDatasetType()) + self.col_editor.setColOrder(columns) + self.colcount_entry.setValue(guessed_colcount) + self.startingPos = self.initial_starting_pos + + def fillTable(self) -> None: + """Write the data to the table based on the parameters the user has + selected. + + """ + + # Don't try to fill the table if there's no data. + if self.rawCsv is None: + return + + self.table.clear() + + col_count = self.colcount_entry.value() + + self.table.setRowCount(min(len(self.rawCsv), TABLE_MAX_ROWS + 1)) + self.table.setColumnCount(col_count + 1) + self.table.setHorizontalHeaderLabels(["Included"] + self.col_editor.colNames()) + self.table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch) + + # Now fill the table with data + for i, row in enumerate(self.rawCsv): + if i == TABLE_MAX_ROWS: + # Fill with elipsis to indicate there is more data. + for j in range(len(row_split)): + elipsis_item = QTableWidgetItem("...") + elipsis_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + self.table.setItem(i, j, elipsis_item) + break + + if i < len(self.rowsIsIncluded): + initial_state = self.rowsIsIncluded[i] + else: + initial_state = True + self.rowsIsIncluded.append(initial_state) + if i >= self.startingPos: + row_status = RowStatusWidget(initial_state, i) + row_status.status_changed.connect(self.updateRowStatus) + self.table.setCellWidget(i, 0, row_status) + row_split = self.splitLine(row) + for j, col_value in enumerate(row_split): + if j >= col_count: + continue # Ignore rows that have extra columns. + item = QTableWidgetItem(col_value) + self.table.setItem(i, j + 1, item) + self.setRowTypesetting(i, self.rowsIsIncluded[i]) + + self.table.show() + + def currentDatasetType(self) -> DatasetType: + """Get the dataset type that the user has currently selected.""" + return dataset_dictionary[self.dataset_combobox.currentText()] + + def setRowTypesetting(self, row: int, item_checked: bool) -> None: + """Set the typesetting for the given role depending on whether it is to + be included in the data being loaded, or not. + + """ + for column in range(1, self.table.columnCount() + 1): + item = self.table.item(row, column) + if item is None: + continue + item_font = item.font() + if not item_checked or row < self.startingPos: + item.setForeground(QColor.fromString('grey')) + item_font.setStrikeOut(True) + else: + item.setForeground(QApplication.palette().text()) + item_font.setStrikeOut(False) + item.setFont(item_font) + + def updateWarningLabel(self): + required_missing = self.requiredMissing() + duplicates = self.duplicateColumns() + if self.rawCsv is None: + # We don't have any actual data yet so we're just updating the warning based on the column. + self.warning_label.updateWarning(required_missing, duplicates) + else: + self.warning_label.updateWarning(required_missing, duplicates, [self.splitLine(line) for line in self.rawCsv], self.rowsIsIncluded, self.startingPos) + + @Slot() + def loadFile(self) -> None: + """Open the file loading dialog, and load the file the user selects.""" + filenames, result = QFileDialog.getOpenFileNames(self) + # Happens when the user cancels without selecting a file. There isn't a + # file to load in this case. + if result == '': + return + for filename in filenames: + + basename = path.basename(filename) + + try: + with open(filename) as file: + file_csv = file.readlines() + file_csv = [line.strip() for line in file_csv] + # TODO: This assumes that no two files will be loaded with the same + # name. This might not be a reasonable assumption. + self.files[basename] = file_csv + self.files_full_path[basename] = filename + # Reset checkboxes + self.files_is_included[basename] = [] + if len(self.files) == 1: + # Default behaviour is going to be to set this to the first file we load. This seems sensible but + # may provoke further discussion. + self.current_filename = basename + # This will trigger the update current file event which will cause + # the table to be drawn. + self.internal_metadata.init_separator(basename) + self.filename_chooser.addItem(basename) + self.filename_chooser.setCurrentText(basename) + self.internal_metadata.add_file(basename) + + except OSError: + QMessageBox.critical(self, 'File Read Error', f'There was an error reading {basename}') + except UnicodeDecodeError: + QMessageBox.critical(self, 'File Read Error', f"""There was an error decoding {basename}. +This could potentially be because the file {basename} an ASCII format.""") + # Attempt guesses on the first file that was loaded. + self.attemptGuesses() + + @Slot() + def unload(self) -> None: + del self.files[self.current_filename] + self.filename_chooser.removeItem(self.filename_chooser.currentIndex()) + # Filename chooser should now revert back to a different file. + self.updateCurrentFile() + + @Slot() + def updateColcount(self) -> None: + """Triggered when the amount of columns the user has selected has + changed. + + """ + self.col_editor.setCols(self.colcount_entry.value()) + self.fillTable() + self.updateWarningLabel() + + @Slot() + def updateStartpos(self) -> None: + """Triggered when the starting position of the data has changed.""" + self.fillTable() + self.updateWarningLabel() + + @Slot() + def updateSeperator(self) -> None: + """Changed when the user modifies the set of seperators being used.""" + self.fillTable() + self.updateWarningLabel() + + @Slot() + def updateColumn(self) -> None: + """Triggered when any of the columns has been changed.""" + self.fillTable() + self.updateWarningLabel() + + @Slot() + def updateCurrentFile(self) -> None: + """Triggered when the current file (choosen from the file chooser + ComboBox) changes. + + """ + self.current_filename = self.filename_chooser.currentText() + if self.current_filename == '': + self.table.clear() + self.table.setDisabled(True) + self.unloadButton.setDisabled(True) + self.editMetadataButton.setDisabled(True) + # Set this to None because other methods are expecting this. + self.current_filename = None + else: + self.table.setDisabled(False) + self.unloadButton.setDisabled(False) + self.editMetadataButton.setDisabled(False) + self.fillTable() + self.updateWarningLabel() + + @Slot() + def seperatorToggle(self) -> None: + """Triggered when one of the seperator check boxes has been toggled.""" + check_box = self.sender() + self.seperators[check_box.text()] = check_box.isChecked() + self.fillTable() + self.updateWarningLabel() + + @Slot() + def changeDatasetType(self) -> None: + """Triggered when the selected dataset type has changed.""" + options = self.datasetOptions() + self.col_editor.replaceOptions(options) + + # Update columns as they'll be different now. + columns = guess_columns(self.colcount_entry.value(), self.currentDatasetType()) + self.col_editor.setColOrder(columns) + + @Slot() + def updateRowStatus(self, row: int) -> None: + """Triggered when the status of row has changed.""" + new_status = self.table.cellWidget(row, 0).isChecked() + self.rowsIsIncluded[row] = new_status + self.setRowTypesetting(row, new_status) + + @Slot() + def showContextMenu(self, point: QPoint) -> None: + """Show the context menu for the table.""" + context_menu = SelectionMenu(self) + context_menu.select_all_event.connect(self.selectItems) + context_menu.deselect_all_event.connect(self.deselectItems) + context_menu.exec(QCursor.pos()) + + def changeInclusion(self, indexes: list[QModelIndex], new_value: bool): + for index in indexes: + # This will happen if the user has selected a point which exists before the starting line. To prevent an + # error, this code will skip that position. + row = index.row() + if row < self.startingPos: + continue + self.table.cellWidget(row, 0).setChecked(new_value) + self.updateRowStatus(row) + + @Slot() + def selectItems(self) -> None: + """Include all of the items that have been selected in the table.""" + self.changeInclusion(self.table.selectedIndexes(), True) + self.updateWarningLabel() + + @Slot() + def deselectItems(self) -> None: + """Don't include all of the items that have been selected in the table.""" + self.changeInclusion(self.table.selectedIndexes(), False) + self.updateWarningLabel() + + def requiredMissing(self) -> list[str]: + """Returns all the columns that are required by the dataset type but + have not currently been selected. + + """ + dataset = self.currentDatasetType() + missing_columns = [col for col in dataset.required if col not in self.col_editor.colNames()] + return missing_columns + + def duplicateColumns(self) -> set[str]: + """Returns all of the columns which have been selected multiple times.""" + col_names = self.col_editor.colNames() + return set([col for col in col_names if not col == '' and col_names.count(col) > 1]) + + def datasetOptions(self) -> list[str]: + current_dataset_type = self.currentDatasetType() + return current_dataset_type.required + current_dataset_type.optional + [''] + + def onDoneButton(self): + params = AsciiReaderParams( + list(self.files_full_path.values()), + self.col_editor.columns, + self.internal_metadata, + self.startingPos, + self.excludedLines, + self.seperators, + ) + self.params = params + self.accept() + + def onCancel(self): + self.reject() + + def editMetadata(self): + dialog = MetadataFilenameDialog(self.current_filename, self.internal_metadata) + status = dialog.exec() + if status == 1: + self.internal_metadata = dialog.internal_metadata + + +if __name__ == "__main__": + app = QApplication([]) + + dialog = AsciiDialog() + status = dialog.exec() + # 1 means the dialog was accepted. + if status == 1: + loaded = load_data(dialog.params) + for datum in loaded: + print(datum.summary()) + + exit() diff --git a/src/ascii_dialog/row_status_widget.py b/src/ascii_dialog/row_status_widget.py new file mode 100644 index 0000000000..edddbed4de --- /dev/null +++ b/src/ascii_dialog/row_status_widget.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +from PySide6.QtCore import Qt, Signal, Slot +from PySide6.QtWidgets import QCheckBox, QHBoxLayout, QWidget + + +class RowStatusWidget(QWidget): + """Widget to toggle whether the row is to be included as part of the data.""" + def __init__(self, initial_value: bool, row: int): + super().__init__() + self.row = row + self.checkbox = QCheckBox() + self.checkbox.setChecked(initial_value) + self.updateLabel() + self.checkbox.stateChanged.connect(self.onStateChange) + self.layout = QHBoxLayout(self) + self.layout.addWidget(self.checkbox, alignment=Qt.AlignmentFlag.AlignCenter) + + status_changed = Signal(int) + def updateLabel(self): + """Update the label of the check box depending on whether it is checked, + or not.""" + pass + + + @Slot() + def onStateChange(self): + self.updateLabel() + self.status_changed.emit(self.row) + + def isChecked(self) -> bool: + return self.checkbox.isChecked() + + def setChecked(self, new_value: bool): + self.checkbox.setChecked(new_value) diff --git a/src/ascii_dialog/selection_menu.py b/src/ascii_dialog/selection_menu.py new file mode 100644 index 0000000000..e275686646 --- /dev/null +++ b/src/ascii_dialog/selection_menu.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + + +from PySide6.QtCore import Signal +from PySide6.QtGui import QAction +from PySide6.QtWidgets import QMenu, QWidget + + +class SelectionMenu(QMenu): + select_all_event = Signal() + deselect_all_event = Signal() + + def __init__(self, parent: QWidget): + super().__init__(parent) + + select_all = QAction("Select All", parent) + select_all.triggered.connect(self.select_all_event) + + deselect_all = QAction("Deselect All", parent) + deselect_all.triggered.connect(self.deselect_all_event) + + self.addAction(select_all) + self.addAction(deselect_all) diff --git a/src/ascii_dialog/unit_list_widget.py b/src/ascii_dialog/unit_list_widget.py new file mode 100644 index 0000000000..20a7c34ab0 --- /dev/null +++ b/src/ascii_dialog/unit_list_widget.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +from PySide6.QtWidgets import QListWidget, QListWidgetItem + +from sasdata.quantities.units import NamedUnit + + +class UnitListWidget(QListWidget): + def reprUnit(self, unit: NamedUnit) -> str: + return f"{unit.symbol} ({unit.name})" + + def populateList(self, units: list[NamedUnit]) -> None: + self.clear() + self.units = units + for unit in units: + item = QListWidgetItem(self.reprUnit(unit)) + self.addItem(item) + + @property + def selectedUnit(self) -> NamedUnit | None: + return self.units[self.currentRow()] + + def __init__(self): + super().__init__() + self.units: list[NamedUnit] = [] diff --git a/src/ascii_dialog/unit_preference_line.py b/src/ascii_dialog/unit_preference_line.py new file mode 100644 index 0000000000..be6a15da56 --- /dev/null +++ b/src/ascii_dialog/unit_preference_line.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from PySide6.QtCore import Slot +from PySide6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QWidget + +from sasdata.quantities.units import NamedUnit, UnitGroup + +from ascii_dialog.unit_selector import UnitSelector + + +class UnitPreferenceLine(QWidget): + def __init__(self, column_name: str, initial_unit: NamedUnit, group: UnitGroup): + super().__init__() + + self.group = group + self.current_unit = initial_unit + + self.column_label = QLabel(column_name) + self.unit_button = QPushButton(initial_unit.symbol) + self.unit_button.clicked.connect(self.onUnitPress) + + self.layout = QHBoxLayout(self) + self.layout.addWidget(self.column_label) + self.layout.addWidget(self.unit_button) + + @Slot() + def onUnitPress(self): + picker = UnitSelector(self.group.name, False) + picker.exec() + self.current_unit = picker.selected_unit + self.unit_button.setText(self.current_unit.symbol) diff --git a/src/ascii_dialog/unit_preferences.py b/src/ascii_dialog/unit_preferences.py new file mode 100644 index 0000000000..9152675f41 --- /dev/null +++ b/src/ascii_dialog/unit_preferences.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import random + +from PySide6.QtGui import Qt +from PySide6.QtWidgets import QApplication, QScrollArea, QVBoxLayout, QWidget + +from sasdata.dataset_types import unit_kinds +from sasdata.quantities.units import NamedUnit + +from ascii_dialog.unit_preference_line import UnitPreferenceLine + + +class UnitPreferences(QWidget): + def __init__(self): + super().__init__() + + # TODO: Presumably this will be loaded from some config from somewhere. + # For now just fill it with some placeholder values. + column_names = unit_kinds.keys() + self.columns: dict[str, NamedUnit] = {} + for name in column_names: + self.columns[name] = random.choice(unit_kinds[name].units) + + self.layout = QVBoxLayout(self) + preference_lines = QWidget() + scroll_area = QScrollArea() + scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + scroll_layout = QVBoxLayout(preference_lines) + for column_name, unit in self.columns.items(): + line = UnitPreferenceLine(column_name, unit, unit_kinds[column_name]) + scroll_layout.addWidget(line) + + scroll_area.setWidget(preference_lines) + self.layout.addWidget(scroll_area) + + +if __name__ == "__main__": + app = QApplication([]) + + widget = UnitPreferences() + widget.show() + + exit(app.exec()) diff --git a/src/ascii_dialog/unit_selector.py b/src/ascii_dialog/unit_selector.py new file mode 100644 index 0000000000..f039f4b6df --- /dev/null +++ b/src/ascii_dialog/unit_selector.py @@ -0,0 +1,80 @@ +from PySide6.QtCore import Slot +from PySide6.QtWidgets import QApplication, QComboBox, QDialog, QLineEdit, QPushButton, QVBoxLayout + +from sasdata.quantities.units import NamedUnit, UnitGroup, unit_group_names, unit_groups + +from ascii_dialog.unit_list_widget import UnitListWidget + +all_unit_groups = list(unit_groups.values()) + +class UnitSelector(QDialog): + def currentUnitGroup(self) -> UnitGroup: + index = self.unit_type_selector.currentIndex() + return all_unit_groups[index] + + @property + def selected_unit(self) -> NamedUnit | None: + return self.unit_list_widget.selectedUnit + + @Slot() + def onSearchChanged(self): + search_input = self.search_box.text() + current_group = self.currentUnitGroup() + units = current_group.units + if search_input != '': + units = [unit for unit in units if search_input.lower() in unit.name] + self.unit_list_widget.populateList(units) + + + @Slot() + def unitGroupChanged(self): + new_group = self.currentUnitGroup() + self.search_box.setText('') + self.unit_list_widget.populateList(new_group.units) + + @Slot() + def selectUnit(self): + self.accept() + + @Slot() + def selectionChanged(self): + self.select_button.setDisabled(False) + + def __init__(self, default_group='length', allow_group_edit=True): + super().__init__() + + self.unit_type_selector = QComboBox() + self.unit_type_selector.addItems(unit_group_names) + self.unit_type_selector.setCurrentText(default_group) + if not allow_group_edit: + self.unit_type_selector.setDisabled(True) + self.unit_type_selector.currentTextChanged.connect(self.unitGroupChanged) + + self.search_box = QLineEdit() + self.search_box.textChanged.connect(self.onSearchChanged) + self.search_box.setPlaceholderText('Search for a unit...') + + self.unit_list_widget = UnitListWidget() + # TODO: Are they all named units? + self.unit_list_widget.populateList(self.currentUnitGroup().units) + self.unit_list_widget.itemSelectionChanged.connect(self.selectionChanged) + self.unit_list_widget.itemDoubleClicked.connect(self.selectUnit) + + self.select_button = QPushButton('Select Unit') + self.select_button.pressed.connect(self.selectUnit) + self.select_button.setDisabled(True) + + self.layout = QVBoxLayout(self) + self.layout.addWidget(self.unit_type_selector) + self.layout.addWidget(self.search_box) + self.layout.addWidget(self.unit_list_widget) + self.layout.addWidget(self.select_button) + +if __name__ == "__main__": + app = QApplication([]) + + widget = UnitSelector() + widget.exec() + print(widget.selected_unit) + + exit() diff --git a/src/ascii_dialog/warning_label.py b/src/ascii_dialog/warning_label.py new file mode 100644 index 0000000000..ccf2607797 --- /dev/null +++ b/src/ascii_dialog/warning_label.py @@ -0,0 +1,52 @@ +from PySide6.QtWidgets import QLabel + +from ascii_dialog.constants import TABLE_MAX_ROWS + + +class WarningLabel(QLabel): + """Widget to display an appropriate warning message based on whether there + exists columns that are missing, or there are columns that are duplicated. + + """ + def setFontRed(self): + self.setStyleSheet("QLabel { color: red}") + + def setFontOrange(self): + self.setStyleSheet("QLabel { color: orange}") + + def setFontNormal(self): + self.setStyleSheet('') + + def updateWarning(self, missing_columns: list[str], duplicate_columns: list[str], lines: list[list[str]] | None = None, rows_is_included: list[bool] | None = None, starting_pos: int = 0): + """Determine, and set the appropriate warning messages given how many + columns are missing, and how many columns are duplicated.""" + unparsable = 0 + if lines is not None and rows_is_included is not None: + for i, line in enumerate(lines): + # Right now, rows_is_included only includes a limited number of rows as there is a maximum that can be + # shown in the table without it being really laggy. We're just going to assume the lines after it should + # be included. + if (i >= TABLE_MAX_ROWS or rows_is_included[i]) and i >= starting_pos: + # TODO: Is there really no builtin function for this? I don't like using try/except like this. + try: + for item in line: + _ = float(item) + except: + unparsable += 1 + + if len(missing_columns) != 0: + self.setText(f'The following columns are missing: {missing_columns}') + self.setFontRed() + elif len(duplicate_columns) > 0: + self.setText('There are columns which are repeated.') + self.setFontRed() + elif unparsable > 0: + # FIXME: This error message could perhaps be a bit clearer. + self.setText(f'{unparsable} lines failed to be read. They will be ignored.') + self.setFontOrange() + else: + self.setText('') + + def __init__(self, initial_missing_columns, initial_duplicate_classes): + super().__init__() + self.updateWarning(initial_missing_columns, initial_duplicate_classes) diff --git a/src/metadata_filename_gui/metadata_component_selector.py b/src/metadata_filename_gui/metadata_component_selector.py new file mode 100644 index 0000000000..1800bc71c7 --- /dev/null +++ b/src/metadata_filename_gui/metadata_component_selector.py @@ -0,0 +1,58 @@ +from PySide6.QtCore import Qt, Signal +from PySide6.QtWidgets import QHBoxLayout, QPushButton, QWidget + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata + + +class MetadataComponentSelector(QWidget): + # Creating a separate signal for this because the custom button may be destroyed/recreated whenever the options are + # redrawn. + + custom_button_pressed = Signal(Qt.MouseButton()) + + def __init__(self, category: str, metadatum: str, filename: str, internal_metadata: AsciiReaderMetadata): + super().__init__() + self.options: list[str] + self.option_buttons: list[QPushButton] + self.layout = QHBoxLayout(self) + self.internal_metadata = internal_metadata + self.metadatum = metadatum + self.category = category + self.filename = filename + + def clear_options(self): + for i in reversed(range(self.layout.count() - 1)): + self.layout.takeAt(i).widget().deleteLater() + + def draw_options(self, new_options: list[str], selected_option: str | None): + self.clear_options() + self.options = new_options + self.option_buttons = [] + for option in self.options: + option_button = QPushButton(option) + option_button.setCheckable(True) + option_button.clicked.connect(self.selection_changed) + option_button.setChecked(option == selected_option) + self.layout.addWidget(option_button) + self.option_buttons.append(option_button) + # This final button is to convert to use custom entry instead of this. + self.custom_entry_button = QPushButton('Custom') + # self.custom_entry_button.clicked.connect(self.custom_button_pressed) + self.custom_entry_button.clicked.connect(self.handle_custom_button) + self.layout.addWidget(self.custom_entry_button) + + def handle_custom_button(self): + self.custom_button_pressed.emit() + + def selection_changed(self): + selected_button: QPushButton = self.sender() + button_index = -1 + for i, button in enumerate(self.option_buttons): + if button != selected_button: + button.setChecked(False) + else: + button_index = i + if selected_button.isChecked(): + self.internal_metadata.update_metadata(self.category, self.metadatum, self.filename, button_index) + else: + self.internal_metadata.clear_metadata(self.category, self.metadatum, self.filename) diff --git a/src/metadata_filename_gui/metadata_custom_selector.py b/src/metadata_filename_gui/metadata_custom_selector.py new file mode 100644 index 0000000000..7a42b1d817 --- /dev/null +++ b/src/metadata_filename_gui/metadata_custom_selector.py @@ -0,0 +1,29 @@ +from PySide6.QtWidgets import QHBoxLayout, QLineEdit, QPushButton, QWidget + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata + + +class MetadataCustomSelector(QWidget): + def __init__(self, category:str, metadatum: str, internal_metadata: AsciiReaderMetadata, filename: str): + super().__init__() + self.internal_metadata = internal_metadata + self.metadatum = metadatum + self.category = category + self.filename = filename + + prexisting_value = self.internal_metadata.get_metadata(category, metadatum, filename) + initial_value = prexisting_value if prexisting_value is not None else '' + self.entry_box = QLineEdit(initial_value) + self.entry_box.textChanged.connect(self.selection_changed) + self.from_filename_button = QPushButton('From Filename') + + self.layout = QHBoxLayout(self) + self.layout.addWidget(self.entry_box) + self.layout.addWidget(self.from_filename_button) + + def selection_changed(self): + new_value = self.entry_box.text() + if new_value != '': + self.internal_metadata.update_metadata(self.category, self.metadatum, self.filename, new_value) + else: + self.internal_metadata.clear_metadata(self.category, self.metadatum, self.filename) diff --git a/src/metadata_filename_gui/metadata_filename_dialog.py b/src/metadata_filename_gui/metadata_filename_dialog.py new file mode 100644 index 0000000000..48f4ae511f --- /dev/null +++ b/src/metadata_filename_gui/metadata_filename_dialog.py @@ -0,0 +1,125 @@ +from sys import argv + +from PySide6.QtWidgets import ( + QApplication, + QButtonGroup, + QDialog, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QRadioButton, + QVBoxLayout, +) + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata + +from metadata_filename_gui.metadata_tree_widget import MetadataTreeWidget + + +def build_font(text: str, classname: str = '') -> str: + match classname: + case 'token': + return f"{text}" + case 'separator': + return f"{text}" + case _: + return text + return f'{text}' + +class MetadataFilenameDialog(QDialog): + def __init__(self, filename: str, initial_metadata: AsciiReaderMetadata): + super().__init__() + + # TODO: Will probably change this default later (or a more sophisticated way of getting this default from the + # filename.) + initial_separator_text = initial_metadata.filename_separator[filename] + + self.setWindowTitle('Metadata') + + self.filename = filename + # Key is the metadatum, value is the component selected for it. + self.internal_metadata = initial_metadata + + self.filename_line_label = QLabel() + self.separate_on_group = QButtonGroup() + self.character_radio = QRadioButton("Character") + self.separate_on_group.addButton(self.character_radio) + self.casing_radio = QRadioButton("Casing") + self.separate_on_group.addButton(self.casing_radio) + if isinstance(initial_separator_text, str): + self.character_radio.setChecked(True) + else: # if bool + self.casing_radio.setChecked(True) + self.separate_on_layout = QHBoxLayout() + self.separate_on_group.buttonToggled.connect(self.update_filename_separation) + self.separate_on_layout.addWidget(self.filename_line_label) + self.separate_on_layout.addWidget(self.character_radio) + self.separate_on_layout.addWidget(self.casing_radio) + + if not any([char.isupper() for char in self.filename]): + self.casing_radio.setDisabled(True) + + self.seperator_chars_label = QLabel('Seperators') + if isinstance(initial_separator_text, str): + self.separator_chars = QLineEdit(initial_separator_text) + else: + self.separator_chars = QLineEdit() + self.separator_chars.textChanged.connect(self.update_filename_separation) + + self.filename_separator_layout = QHBoxLayout() + self.filename_separator_layout.addWidget(self.seperator_chars_label) + self.filename_separator_layout.addWidget(self.separator_chars) + + self.metadata_tree = MetadataTreeWidget(self.internal_metadata) + + # Have to update this now because it relies on the value of the separator, and tree. + self.update_filename_separation() + + self.save_button = QPushButton('Save') + self.save_button.clicked.connect(self.on_save) + + self.layout = QVBoxLayout(self) + self.layout.addLayout(self.separate_on_layout) + self.layout.addLayout(self.filename_separator_layout) + self.layout.addWidget(self.metadata_tree) + self.layout.addWidget(self.save_button) + + def formatted_filename(self) -> str: + sep_str = self.separator_chars.text() + if sep_str == '' or self.casing_radio.isChecked(): + return f'{self.filename}' + # TODO: Won't escape characters; I'll handle that later. + separated = self.internal_metadata.filename_components(self.filename, False, True) + font_elements = '' + for i, token in enumerate(separated): + classname = 'token' if i % 2 == 0 else 'separator' + font_elements += build_font(token, classname) + return font_elements + + def update_filename_separation(self): + if self.casing_radio.isChecked(): + self.separator_chars.setDisabled(True) + else: + self.separator_chars.setDisabled(False) + self.internal_metadata.filename_separator[self.filename] = self.separator_chars.text() if self.character_radio.isChecked() else True + self.internal_metadata.purge_unreachable(self.filename) + self.filename_line_label.setText(f'Filename: {self.formatted_filename()}') + self.metadata_tree.draw_tree(self.filename) + + def on_save(self): + self.accept() + # Don't really need to do anything else. Anyone using this dialog can access the component_metadata dict. + + + +if __name__ == "__main__": + app = QApplication([]) + if len(argv) < 2: + filename = input('Input filename to test: ') + else: + filename = argv[1] + dialog = MetadataFilenameDialog(filename) + status = dialog.exec() + if status == 1: + print(dialog.component_metadata) diff --git a/src/metadata_filename_gui/metadata_selector.py b/src/metadata_filename_gui/metadata_selector.py new file mode 100644 index 0000000000..b7b583d6ca --- /dev/null +++ b/src/metadata_filename_gui/metadata_selector.py @@ -0,0 +1,50 @@ +from PySide6.QtWidgets import QHBoxLayout, QWidget + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata + +from metadata_filename_gui.metadata_component_selector import MetadataComponentSelector +from metadata_filename_gui.metadata_custom_selector import MetadataCustomSelector + + +class MetadataSelector(QWidget): + def __init__(self, category: str, metadatum: str, metadata: AsciiReaderMetadata, filename: str): + super().__init__() + self.category = category + self.metadatum = metadatum + self.metadata: AsciiReaderMetadata = metadata + self.filename = filename + self.options = self.metadata.filename_components(filename) + current_option = self.metadata.get_metadata(self.category, metadatum, filename) + if current_option is None or current_option in self.options: + self.selector_widget = self.new_component_selector() + else: + self.selector_widget = self.new_custom_selector() + + # I can't seem to find any layout that just has one widget in so this will do for now. + self.layout = QHBoxLayout(self) + self.layout.addWidget(self.selector_widget) + + def new_component_selector(self) -> MetadataComponentSelector: + new_selector = MetadataComponentSelector(self.category, self.metadatum, self.filename, self.metadata) + new_selector.custom_button_pressed.connect(self.handle_selector_change) + new_selector.draw_options(self.options, self.metadata.get_metadata(self.category, self.metadatum, self.filename)) + return new_selector + + def new_custom_selector(self) -> MetadataCustomSelector: + new_selector = MetadataCustomSelector(self.category, self.metadatum, self.metadata, self.filename) + new_selector.from_filename_button.clicked.connect(self.handle_selector_change) + return new_selector + + def handle_selector_change(self): + # Need to keep this for when we delete it. + if isinstance(self.selector_widget, MetadataComponentSelector): + # TODO: Will eventually have args + new_widget = self.new_custom_selector() + elif isinstance(self.selector_widget, MetadataCustomSelector): + new_widget = self.new_component_selector() + else: + # Shouldn't happen as selector widget should be either of the above. + return + self.layout.replaceWidget(self.selector_widget, new_widget) + self.selector_widget.deleteLater() + self.selector_widget = new_widget diff --git a/src/metadata_filename_gui/metadata_tree_data.py b/src/metadata_filename_gui/metadata_tree_data.py new file mode 100644 index 0000000000..20e7d66755 --- /dev/null +++ b/src/metadata_filename_gui/metadata_tree_data.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +# TODO: This file can probably be deleted. Just want to make sure nothing else +# depends on it. + +metadata = { + 'source': ['name', 'radiation', 'type', 'probe_particle', 'beam_size_name', 'beam_size', 'beam_shape', 'wavelength', 'wavelength_min', 'wavelength_max', 'wavelength_spread'], + 'detector': ['name', 'distance', 'offset', 'orientation', 'beam_center', 'pixel_size', 'slit_length'], + 'aperture': ['name', 'type', 'size_name', 'size', 'distance'], + 'collimation': ['name', 'lengths'], + 'process': ['name', 'date', 'description', 'term', 'notes'], + 'sample': ['name', 'sample_id', 'thickness', 'transmission', 'temperature', 'position', 'orientation', 'details'], + 'transmission_spectrum': ['name', 'timestamp', 'transmission', 'transmission_deviation'], + 'other': ['title', 'run', 'definition'] +} + +initial_metadata_dict = {key: {} for key, _ in metadata.items()} diff --git a/src/metadata_filename_gui/metadata_tree_widget.py b/src/metadata_filename_gui/metadata_tree_widget.py new file mode 100644 index 0000000000..f9bbf57398 --- /dev/null +++ b/src/metadata_filename_gui/metadata_tree_widget.py @@ -0,0 +1,27 @@ +from PySide6.QtWidgets import QTreeWidget, QTreeWidgetItem + +from sasdata.ascii_reader_metadata import AsciiReaderMetadata, initial_metadata + +from metadata_filename_gui.metadata_selector import MetadataSelector + + +class MetadataTreeWidget(QTreeWidget): + def __init__(self, metadata: AsciiReaderMetadata): + super().__init__() + self.setColumnCount(2) + self.setHeaderLabels(['Name', 'Filename Components']) + self.metadata: AsciiReaderMetadata = metadata + + def draw_tree(self, full_filename: str): + self.clear() + for top_level, items in initial_metadata.items(): + top_level_item = QTreeWidgetItem([top_level]) + for metadatum in items: + # selector = MetadataComponentSelector(metadatum, self.metadata_dict) + selector = MetadataSelector(top_level, metadatum, self.metadata, full_filename) + metadatum_item = QTreeWidgetItem([metadatum]) + # selector.draw_options(options, metadata_dict.get(metadatum)) + top_level_item.addChild(metadatum_item) + self.setItemWidget(metadatum_item, 1, selector) + self.insertTopLevelItem(0, top_level_item) + self.expandAll() diff --git a/src/sas/qtgui/MainWindow/GuiManager.py b/src/sas/qtgui/MainWindow/GuiManager.py index e3d787116e..d0dfdc3485 100644 --- a/src/sas/qtgui/MainWindow/GuiManager.py +++ b/src/sas/qtgui/MainWindow/GuiManager.py @@ -7,9 +7,11 @@ from packaging.version import Version from PySide6.QtCore import QLocale, Qt from PySide6.QtGui import QStandardItem -from PySide6.QtWidgets import QDockWidget, QLabel, QMessageBox, QProgressBar, QTextBrowser +from PySide6.QtWidgets import QDockWidget, QLabel, QProgressBar, QTextBrowser from twisted.internet import reactor +from sasdata.temp_ascii_reader import load_data + import sas # Perspectives @@ -102,6 +104,8 @@ def __init__(self, parent=None): # Currently displayed perspective self._current_perspective: Perspective | None = None self.loadedPerspectives: dict[str, Perspective] = {} + self._connected_undo_stack = None + self._connected_tabbed_perspective = None # Populate the main window with stuff self.addWidgets() @@ -369,6 +373,8 @@ def perspectiveChanged(self, new_perspective_name: str): Respond to change of the perspective signal """ + self._disconnect_undo_redo_hooks() + if new_perspective_name not in self.loadedPerspectives: keylist = ', '.join(self.loadedPerspectives.keys()) raise KeyError( @@ -459,6 +465,7 @@ def perspectiveChanged(self, new_perspective_name: str): # Set the current perspective to new one and show self._current_perspective = new_perspective self._current_perspective.show() + self._connect_undo_redo_hooks() def updatePerspective(self, data): """ @@ -645,6 +652,7 @@ def addCallbacks(self): self.communicator.plotFromNameSignal.connect(self.showPlotFromName) self.communicator.updateModelFromDataOperationPanelSignal.connect(self.updateModelFromDataOperationPanel) self.communicator.activeGraphsSignal.connect(self.updatePlotItems) + self.communicator.undoRedoUpdateSignal.connect(self._update_undo_redo_actions) def addTriggers(self): @@ -652,8 +660,8 @@ def addTriggers(self): Trigger definitions for all menu/toolbar actions. """ # disable not yet fully implemented actions - self._workspace.actionUndo.setVisible(False) - self._workspace.actionRedo.setVisible(False) + self._workspace.actionUndo.setEnabled(False) + self._workspace.actionRedo.setEnabled(False) self._workspace.actionReset.setVisible(False) self._workspace.actionStartup_Settings.setVisible(False) #self._workspace.actionImage_Viewer.setVisible(False) @@ -747,6 +755,11 @@ def addTriggers(self): self._workspace.actionWelcomeWidget.triggered.connect(self.actionWelcome) self._workspace.actionCheck_for_update.triggered.connect(self.actionCheck_for_update) self._workspace.actionWhat_s_New.triggered.connect(self.actionWhatsNew) + # Dev + self._workspace.menuDev.menuAction().setVisible(config.DEV_MENU) + self._workspace.actionParticle_Editor.triggered.connect(self.particleEditor) + self._workspace.actionAscii_Loader.triggered.connect(self.asciiLoader) + self.communicator.sendDataToGridSignal.connect(self.showBatchOutput) self.communicator.resultPlotUpdateSignal.connect(self.showFitResults) @@ -846,12 +859,87 @@ def actionQuit(self): def actionUndo(self): """ """ - print("actionUndo TRIGGERED") + stack = self._active_undo_stack() + if stack is not None: + stack.undo() def actionRedo(self): """ """ - print("actionRedo TRIGGERED") + stack = self._active_undo_stack() + if stack is not None: + stack.redo() + + def _active_undo_stack(self): + """Return the undo stack for the active perspective, if available.""" + if self._current_perspective is None: + return None + return getattr(self._current_perspective, "undo_stack", None) + + def _disconnect_undo_redo_hooks(self): + """Disconnect temporary undo/redo signal hooks.""" + if self._connected_undo_stack is not None: + try: + self._connected_undo_stack.stackChanged.disconnect( + self._update_undo_redo_actions + ) + except (RuntimeError, TypeError): + pass + self._connected_undo_stack = None + + if self._connected_tabbed_perspective is not None: + try: + self._connected_tabbed_perspective.currentChanged.disconnect( + self._on_perspective_tab_changed + ) + except (RuntimeError, TypeError): + pass + self._connected_tabbed_perspective = None + + def _connect_undo_redo_hooks(self): + """Connect action refresh hooks for active perspective and stack.""" + perspective = self._current_perspective + if perspective is None: + self._update_undo_redo_actions() + return + + if hasattr(perspective, "currentChanged"): + try: + perspective.currentChanged.connect(self._on_perspective_tab_changed) + self._connected_tabbed_perspective = perspective + except (RuntimeError, TypeError): + self._connected_tabbed_perspective = None + + stack = self._active_undo_stack() + if stack is not None and hasattr(stack, "stackChanged"): + try: + stack.stackChanged.connect(self._update_undo_redo_actions) + self._connected_undo_stack = stack + except (RuntimeError, TypeError): + self._connected_undo_stack = None + + self._update_undo_redo_actions() + + def _on_perspective_tab_changed(self, *_): + """Rewire undo hooks when active tab changes (e.g., fitting tabs).""" + self._disconnect_undo_redo_hooks() + self._connect_undo_redo_hooks() + + def _update_undo_redo_actions(self): + """Refresh undo/redo enabled state and action tooltips.""" + stack = self._active_undo_stack() + + if stack is None: + self._workspace.actionUndo.setEnabled(False) + self._workspace.actionRedo.setEnabled(False) + self._workspace.actionUndo.setToolTip("Undo") + self._workspace.actionRedo.setToolTip("Redo") + return + + self._workspace.actionUndo.setEnabled(stack.can_undo()) + self._workspace.actionRedo.setEnabled(stack.can_redo()) + self._workspace.actionUndo.setToolTip(stack.undo_text()) + self._workspace.actionRedo.setToolTip(stack.redo_text()) def actionCopy(self): """ @@ -1404,3 +1492,21 @@ def resetProject(self): # file manager self.filesWidget.reset() + + # ============= DEV ================= + + def particleEditor(self): + from sas.qtgui.Perspectives.ParticleEditor.DesignWindow import show_particle_editor + show_particle_editor() + + + def asciiLoader(self): + from ascii_dialog.dialog import AsciiDialog + dialog = AsciiDialog() + status = dialog.exec() + if status == 1: + loaded = load_data(dialog.params) + for datum in loaded: + logger.info(datum.summary()) + else: + logger.error('ASCII Reader Closed') diff --git a/src/sas/qtgui/MainWindow/UI/MainWindowUI.ui b/src/sas/qtgui/MainWindow/UI/MainWindowUI.ui index 65a8656e34..1137303dbf 100755 --- a/src/sas/qtgui/MainWindow/UI/MainWindowUI.ui +++ b/src/sas/qtgui/MainWindow/UI/MainWindowUI.ui @@ -167,6 +167,15 @@ + + + Dev + + + + + + @@ -175,6 +184,7 @@ + @@ -250,6 +260,12 @@ Undo + + Undo + + + Ctrl+Z + @@ -262,6 +278,9 @@ Redo + + Ctrl+Y + @@ -652,6 +671,21 @@ Close Project + + + Ascii Loader + + + + + Particle Editor + + + + + Dev Tools + + diff --git a/src/sas/qtgui/MainWindow/UnitTesting/GuiManagerUndoRedoTest.py b/src/sas/qtgui/MainWindow/UnitTesting/GuiManagerUndoRedoTest.py new file mode 100644 index 0000000000..207a53e6a2 --- /dev/null +++ b/src/sas/qtgui/MainWindow/UnitTesting/GuiManagerUndoRedoTest.py @@ -0,0 +1,252 @@ +from unittest.mock import MagicMock, patch + +from sas.qtgui.MainWindow.GuiManager import GuiManager + + +def _make_manager(): + """Construct a lightweight GuiManager instance for unit testing internals.""" + manager = object.__new__(GuiManager) + + workspace = MagicMock() + workspace.actionUndo = MagicMock() + workspace.actionRedo = MagicMock() + + manager._workspace = workspace + manager._current_perspective = None + manager._connected_undo_stack = None + manager._connected_tabbed_perspective = None + + return manager + + +class TestGuiManagerUndoRedoPhase3: + + def test_action_undo_dispatches_to_active_stack(self): + manager = _make_manager() + stack = MagicMock() + perspective = MagicMock() + perspective.undo_stack = stack + manager._current_perspective = perspective + + manager.actionUndo() + + stack.undo.assert_called_once_with() + + def test_action_redo_dispatches_to_active_stack(self): + manager = _make_manager() + stack = MagicMock() + perspective = MagicMock() + perspective.undo_stack = stack + manager._current_perspective = perspective + + manager.actionRedo() + + stack.redo.assert_called_once_with() + + def test_action_undo_no_stack_is_noop(self): + manager = _make_manager() + manager._current_perspective = MagicMock(undo_stack=None) + + # Should not raise and should not call undo on a missing stack. + manager.actionUndo() + + def test_update_actions_disabled_when_no_active_stack(self): + manager = _make_manager() + manager._current_perspective = MagicMock(undo_stack=None) + + manager._update_undo_redo_actions() + + manager._workspace.actionUndo.setEnabled.assert_called_with(False) + manager._workspace.actionRedo.setEnabled.assert_called_with(False) + manager._workspace.actionUndo.setToolTip.assert_called_with("Undo") + manager._workspace.actionRedo.setToolTip.assert_called_with("Redo") + + def test_update_actions_uses_stack_state_and_labels(self): + manager = _make_manager() + stack = MagicMock() + stack.can_undo.return_value = True + stack.can_redo.return_value = False + stack.undo_text.return_value = "Undo Change radius" + stack.redo_text.return_value = "Redo Change radius" + manager._current_perspective = MagicMock(undo_stack=stack) + + manager._update_undo_redo_actions() + + manager._workspace.actionUndo.setEnabled.assert_called_with(True) + manager._workspace.actionRedo.setEnabled.assert_called_with(False) + manager._workspace.actionUndo.setToolTip.assert_called_with("Undo Change radius") + manager._workspace.actionRedo.setToolTip.assert_called_with("Redo Change radius") + + def test_connect_hooks_binds_current_tab_and_stack_signals(self): + manager = _make_manager() + + stack_signal = MagicMock() + stack = MagicMock() + stack.stackChanged = stack_signal + + tab_signal = MagicMock() + perspective = MagicMock() + perspective.currentChanged = tab_signal + perspective.undo_stack = stack + manager._current_perspective = perspective + + manager._connect_undo_redo_hooks() + + tab_signal.connect.assert_called_once() + stack_signal.connect.assert_called_once() + assert manager._connected_tabbed_perspective is perspective + assert manager._connected_undo_stack is stack + + def test_disconnect_hooks_unbinds_connected_signals(self): + manager = _make_manager() + + stack_signal = MagicMock() + stack = MagicMock() + stack.stackChanged = stack_signal + + tab_signal = MagicMock() + perspective = MagicMock() + perspective.currentChanged = tab_signal + + manager._connected_undo_stack = stack + manager._connected_tabbed_perspective = perspective + + manager._disconnect_undo_redo_hooks() + + stack_signal.disconnect.assert_called_once() + tab_signal.disconnect.assert_called_once() + assert manager._connected_undo_stack is None + assert manager._connected_tabbed_perspective is None + + def test_tab_changed_rewires_hooks(self): + manager = _make_manager() + with patch.object(manager, "_disconnect_undo_redo_hooks") as disconnect_spy, \ + patch.object(manager, "_connect_undo_redo_hooks") as connect_spy: + + manager._on_perspective_tab_changed(1) + + disconnect_spy.assert_called_once_with() + connect_spy.assert_called_once_with() + + # -- Additional Phase 3 coverage -- + + def test_action_redo_no_stack_is_noop(self): + """actionRedo should not raise when no stack is available.""" + manager = _make_manager() + manager._current_perspective = MagicMock(undo_stack=None) + manager.actionRedo() # must not raise + + def test_action_undo_no_perspective_is_noop(self): + """actionUndo should be safe when no perspective is active.""" + manager = _make_manager() + manager._current_perspective = None + manager.actionUndo() # must not raise + + def test_action_redo_no_perspective_is_noop(self): + """actionRedo should be safe when no perspective is active.""" + manager = _make_manager() + manager._current_perspective = None + manager.actionRedo() # must not raise + + def test_update_actions_disabled_when_no_perspective(self): + """With no perspective active, both actions must be disabled.""" + manager = _make_manager() + manager._current_perspective = None + + manager._update_undo_redo_actions() + + manager._workspace.actionUndo.setEnabled.assert_called_with(False) + manager._workspace.actionRedo.setEnabled.assert_called_with(False) + + def test_connect_hooks_skips_tab_signal_for_non_tabbed_perspective(self): + """Perspectives without currentChanged should only wire the stack signal.""" + manager = _make_manager() + + stack_signal = MagicMock() + stack = MagicMock() + stack.stackChanged = stack_signal + + perspective = MagicMock(spec=[]) # no attributes by default + perspective.undo_stack = stack + manager._current_perspective = perspective + + manager._connect_undo_redo_hooks() + + stack_signal.connect.assert_called_once() + assert manager._connected_undo_stack is stack + assert manager._connected_tabbed_perspective is None + + def test_connect_hooks_no_stack_still_updates_actions(self): + """Even without a stack, connect should call _update_undo_redo_actions.""" + manager = _make_manager() + manager._current_perspective = MagicMock(undo_stack=None) + + manager._connect_undo_redo_hooks() + + # Actions should be disabled since there's no stack + manager._workspace.actionUndo.setEnabled.assert_called_with(False) + manager._workspace.actionRedo.setEnabled.assert_called_with(False) + + def test_disconnect_hooks_noop_when_nothing_connected(self): + """Disconnecting with no prior connection should not raise.""" + manager = _make_manager() + manager._connected_undo_stack = None + manager._connected_tabbed_perspective = None + manager._disconnect_undo_redo_hooks() # must not raise + + def test_disconnect_hooks_tolerates_runtime_error(self): + """Disconnecting a signal that was already disconnected is handled gracefully.""" + manager = _make_manager() + stack = MagicMock() + stack.stackChanged.disconnect.side_effect = RuntimeError("already disconnected") + manager._connected_undo_stack = stack + manager._connected_tabbed_perspective = None + + manager._disconnect_undo_redo_hooks() # must not raise + + assert manager._connected_undo_stack is None + + def test_active_undo_stack_returns_none_without_perspective(self): + manager = _make_manager() + manager._current_perspective = None + assert manager._active_undo_stack() is None + + def test_active_undo_stack_returns_perspective_stack(self): + manager = _make_manager() + stack = MagicMock() + manager._current_perspective = MagicMock(undo_stack=stack) + assert manager._active_undo_stack() is stack + + def test_active_undo_stack_handles_missing_undo_stack_attr(self): + """A perspective without undo_stack attribute should return None.""" + manager = _make_manager() + perspective = MagicMock(spec=[]) # no attributes + manager._current_perspective = perspective + assert manager._active_undo_stack() is None + + def test_tab_switch_updates_stack_connection(self): + """Switching tabs should connect the new tab's stack for signal updates.""" + manager = _make_manager() + + old_stack = MagicMock() + old_stack.stackChanged = MagicMock() + manager._connected_undo_stack = old_stack + manager._connected_tabbed_perspective = None + + new_stack = MagicMock() + new_stack.stackChanged = MagicMock() + new_stack.can_undo.return_value = True + new_stack.can_redo.return_value = False + new_stack.undo_text.return_value = "Undo Change radius" + new_stack.redo_text.return_value = "Redo" + + manager._current_perspective = MagicMock(undo_stack=new_stack) + + # Simulate what _on_perspective_tab_changed does + manager._disconnect_undo_redo_hooks() + manager._connect_undo_redo_hooks() + + old_stack.stackChanged.disconnect.assert_called_once() + new_stack.stackChanged.connect.assert_called_once() + assert manager._connected_undo_stack is new_stack + manager._workspace.actionUndo.setEnabled.assert_called_with(True) diff --git a/src/sas/qtgui/Perspectives/Corfunc/CorfuncPerspective.py b/src/sas/qtgui/Perspectives/Corfunc/CorfuncPerspective.py index 854a03a05a..b518c5aaa8 100644 --- a/src/sas/qtgui/Perspectives/Corfunc/CorfuncPerspective.py +++ b/src/sas/qtgui/Perspectives/Corfunc/CorfuncPerspective.py @@ -30,6 +30,7 @@ from sas.sascalc.util import ExtrapolationInteractionState, ExtrapolationParameters from ..perspective import Perspective +from ..UndoRedo import DictSnapshotCommand, UndoStack from .SaveExtrapolatedPopup import SaveExtrapolatedPopup from .UI.CorfuncPanel import Ui_CorfuncDialog from .util import WIDGETS, safe_float @@ -70,6 +71,10 @@ def __init__(self, parent=None): self._allow_close = False self._model_item: QStandardItem | None = None + # Undo/redo infrastructure + self._undo_stack_obj = UndoStack(self) + self._undo_baseline: dict | None = None + self.data: Data1D | None = None self.extrap: Data1D | None = None self.has_data = False @@ -122,6 +127,11 @@ def __init__(self, parent=None): # Allow Go button only when data is loaded self.allow_go() + # Undo/redo: baseline after full initialization + self._setupUndoConnections() + self._rebaseline_undo_state() + self._undo_stack_obj.clear() + def set_background_warning(self): if (self._calculator is None or self._calculator.background is None or @@ -152,6 +162,140 @@ def isSerializable(self): """ return True + # ------------------------------------------------------------------ + # Undo/redo contract methods + # ------------------------------------------------------------------ + + def _get_parameter_dict(self) -> dict: + """Capture current input-only state (excludes computed outputs). + + Called by ``UndoStack._auto_snapshot()`` for recovery snapshots. + """ + return { + "background": self.txtBackground.text(), + "guinier_a": self.txtGuinierA.text(), + "guinier_b": self.txtGuinierB.text(), + "porod_k": self.txtPorodK.text(), + "porod_sigma": self.txtPorodSigma.text(), + "lower_q_max": self.txtLowerQMax.text(), + "upper_q_min": self.txtUpperQMin.text(), + "upper_q_max": self.txtUpperQMax.text(), + "fit_background": self.fitBackground.isChecked(), + "fit_guinier": self.fitGuinier.isChecked(), + "fit_porod": self.fitPorod.isChecked(), + "tangent_auto": self.radTangentAuto.isChecked(), + "tangent_inflection": self.radTangentInflection.isChecked(), + "tangent_midpoint": self.radTangentMidpoint.isChecked(), + "long_period_auto": self.radLongPeriodAuto.isChecked(), + "long_period_max": self.radLongPeriodMax.isChecked(), + "long_period_double": self.radLongPeriodDouble.isChecked(), + } + + def _restore_parameter_values(self, state: dict) -> None: + """Apply a state dict to all input widgets AND the underlying model. + + Called by ``DictSnapshotCommand.undo/redo`` and + ``UndoStack.reset_to_last_good()``. + + Must also update ``self.model`` items because ``setText()`` only + fires ``textEdited`` which in Corfunc only validates — the model + is only updated on ``editingFinished`` (which doesn't fire here). + """ + # Disconnect model to prevent cascading model_changed signals + self.model.itemChanged.disconnect(self.model_changed) + try: + with self._undo_stack_obj.suppressed(): + # Text widgets + self.txtBackground.setText(str(state.get("background", ""))) + self.txtGuinierA.setText(str(state.get("guinier_a", ""))) + self.txtGuinierB.setText(str(state.get("guinier_b", ""))) + self.txtPorodK.setText(str(state.get("porod_k", ""))) + self.txtPorodSigma.setText(str(state.get("porod_sigma", ""))) + self.txtLowerQMax.setText(str(state.get("lower_q_max", ""))) + self.txtUpperQMin.setText(str(state.get("upper_q_min", ""))) + self.txtUpperQMax.setText(str(state.get("upper_q_max", ""))) + + # Checkboxes + self.fitBackground.setChecked(bool(state.get("fit_background", False))) + self.fitGuinier.setChecked(bool(state.get("fit_guinier", False))) + self.fitPorod.setChecked(bool(state.get("fit_porod", False))) + + # Radio buttons + self.radTangentAuto.setChecked(bool(state.get("tangent_auto", True))) + self.radTangentInflection.setChecked(bool(state.get("tangent_inflection", False))) + self.radTangentMidpoint.setChecked(bool(state.get("tangent_midpoint", False))) + self.radLongPeriodAuto.setChecked(bool(state.get("long_period_auto", True))) + self.radLongPeriodMax.setChecked(bool(state.get("long_period_max", False))) + self.radLongPeriodDouble.setChecked(bool(state.get("long_period_double", False))) + + # Update model items so calculation uses restored values + # (setText fires textEdited but Corfunc's handler only validates, + # it does NOT update the model — that only happens on editingFinished) + self.model.setItem(WIDGETS.W_BACKGROUND, + QtGui.QStandardItem(str(state.get("background", "")))) + self.model.setItem(WIDGETS.W_QMIN, + QtGui.QStandardItem(str(state.get("lower_q_max", "")))) + self.model.setItem(WIDGETS.W_QMAX, + QtGui.QStandardItem(str(state.get("upper_q_min", "")))) + self.model.setItem(WIDGETS.W_QCUTOFF, + QtGui.QStandardItem(str(state.get("upper_q_max", "")))) + + self.update_readonly() + finally: + self.model.itemChanged.connect(self.model_changed) + + def _captureUndoState(self, description: str = "Change") -> None: + """Push a DictSnapshotCommand if current state differs from baseline.""" + if self._undo_baseline is None: + return + new_state = self._get_parameter_dict() + if new_state != self._undo_baseline: + self._undo_stack_obj.push( + DictSnapshotCommand(self._undo_baseline, new_state, description) + ) + self._undo_baseline = new_state + + def _rebaseline_undo_state(self) -> None: + """Update undo baseline without pushing a command.""" + self._undo_baseline = self._get_parameter_dict() + + def _setupUndoConnections(self) -> None: + """Connect undo-capture signals for all user-editable input widgets.""" + # Text edits — editingFinished as commit boundary + text_edits = [ + self.txtBackground, + self.txtGuinierA, self.txtGuinierB, + self.txtPorodK, self.txtPorodSigma, + self.txtLowerQMax, self.txtUpperQMin, self.txtUpperQMax, + ] + for te in text_edits: + te.editingFinished.connect( + lambda desc="Edit value": self._captureUndoState(desc)) + + # Fit checkboxes + self.fitBackground.toggled.connect( + lambda _: self._captureUndoState("Toggle fit background")) + self.fitGuinier.toggled.connect( + lambda _: self._captureUndoState("Toggle fit Guinier")) + self.fitPorod.toggled.connect( + lambda _: self._captureUndoState("Toggle fit Porod")) + + # Tangent method radio buttons + self.radTangentAuto.toggled.connect( + lambda _: self._captureUndoState("Change tangent method")) + self.radTangentInflection.toggled.connect( + lambda _: self._captureUndoState("Change tangent method")) + self.radTangentMidpoint.toggled.connect( + lambda _: self._captureUndoState("Change tangent method")) + + # Long period method radio buttons + self.radLongPeriodAuto.toggled.connect( + lambda _: self._captureUndoState("Change long period method")) + self.radLongPeriodMax.toggled.connect( + lambda _: self._captureUndoState("Change long period method")) + self.radLongPeriodDouble.toggled.connect( + lambda _: self._captureUndoState("Change long period method")) + def setup_slots(self): """Connect the buttons to their appropriate slots.""" @@ -290,6 +434,10 @@ def removeData(self, data_list=None): self.set_background_warning() + # Clear undo stack + re-baseline after data removal + self._undo_stack_obj.clear() + self._rebaseline_undo_state() + def model_changed(self, _): """Actions to perform when the data is updated""" @@ -312,6 +460,9 @@ def _run(self): self.cmdExtract.setText("Calculating...") self.cmdExtract.repaint() + # Disable undo during calculation + self._undo_stack_obj.set_enabled(False) + # Set up calculator calculator = CorfuncCalculator( @@ -409,10 +560,14 @@ def _run(self): self._running = False + # Re-enable undo after calculation completes + self._undo_stack_obj.set_enabled(True) + self._rebaseline_undo_state() + self.update_readonly() def allow_go(self, reason: str | None = None): - """ + """ Disable Go button if reason is provided or if no data is loaded :param reason: Reason why Go button should be disabled """ @@ -545,6 +700,14 @@ def extrapolation_parameters(self) -> ExtrapolationParameters | None: else: return None + @property + def undo_stack(self): + """Return the undo stack for this perspective. + + Overrides ``Perspective.undo_stack`` (which returns ``None``). + """ + return self._undo_stack_obj + def setData(self, data_item: list[QStandardItem], is_batch=False): """ Obtain a QStandardItem object and dissect it to get Data1D/2D @@ -644,6 +807,10 @@ def fractional_position(f): self.tabWidget.setCurrentIndex(0) self.set_background_warning() + # Clear undo stack + re-baseline for fresh data + self._undo_stack_obj.clear() + self._rebaseline_undo_state() + def setClosable(self, value=True): """ @@ -757,7 +924,12 @@ def correct_extrapolation_values(self): messages = [] # block signals to avoid recursive calls - with QtCore.QSignalBlocker(self.txtLowerQMax), QtCore.QSignalBlocker(self.txtUpperQMin), QtCore.QSignalBlocker(self.txtUpperQMax): + with ( + QtCore.QSignalBlocker(self.txtLowerQMax), + QtCore.QSignalBlocker(self.txtUpperQMin), + QtCore.QSignalBlocker(self.txtUpperQMax), + self._undo_stack_obj.suppressed(), + ): # start by updating p2 as it is used in multiple checks if self.validity_flags[2]: # p2 <= data_q_min @@ -839,6 +1011,7 @@ def on_extrapolation_slider_changed(self, state: ExtrapolationParameters): QtGui.QStandardItem(format_string%state.point_2)) self.model.setItem(WIDGETS.W_QCUTOFF, QtGui.QStandardItem(format_string%state.point_3)) + self._captureUndoState("Change extrapolation slider") def on_extrapolation_slider_changing(self, state: ExtrapolationInteractionState): """ Slider is being moved about""" @@ -976,47 +1149,52 @@ def updateFromParameters(self, params): c_name = params.__class__.__name__ msg = "Corfunc.updateFromParameters expects a dictionary" raise TypeError(f"{msg}: {c_name} received") - # Assign values to 'Invariant' tab inputs - use defaults if not found - # don't raise model_changed signal for a while - self.model.itemChanged.disconnect(self.model_changed) - self.model.setItem( - WIDGETS.W_GUINIERA, QtGui.QStandardItem(params.get('guinier_a', '0.0'))) - self.model.setItem( - WIDGETS.W_GUINIERB, QtGui.QStandardItem(params.get('guinier_b', '0.0'))) - self.model.setItem( - WIDGETS.W_PORODK, QtGui.QStandardItem(params.get('porod_k', '0.0'))) - self.model.setItem(WIDGETS.W_PORODSIGMA, QtGui.QStandardItem( - params.get('porod_sigma', '0.0'))) - self.model.setItem(WIDGETS.W_CORETHICK, QtGui.QStandardItem( - params.get('avg_core_thick', '0'))) - self.model.setItem(WIDGETS.W_INTTHICK, QtGui.QStandardItem( - params.get('avg_inter_thick', '0'))) - self.model.setItem(WIDGETS.W_HARDBLOCK, QtGui.QStandardItem( - params.get('avg_hard_block_thick', '0'))) - self.model.setItem(WIDGETS.W_SOFTBLOCK, QtGui.QStandardItem( - params.get('avg_soft_block_thick', '0'))) - self.model.setItem(WIDGETS.W_CRYSTAL, QtGui.QStandardItem( - params.get('local_crystalinity', '0'))) - self.model.setItem( - WIDGETS.W_POLY_RYAN, QtGui.QStandardItem(params.get('polydispersity', '0'))) - self.model.setItem( - WIDGETS.W_POLY_STRIBECK, QtGui.QStandardItem(params.get('polydispersity_stribeck', '0'))) - self.model.setItem( - WIDGETS.W_PERIOD, QtGui.QStandardItem(params.get('long_period', '0'))) - self.model.setItem( - WIDGETS.W_FILENAME, QtGui.QStandardItem(params.get('data_name', ''))) - self.model.setItem( - WIDGETS.W_QMIN, QtGui.QStandardItem(params.get('lower_q_max', '0.01'))) - self.model.setItem( - WIDGETS.W_QMAX, QtGui.QStandardItem(params.get('upper_q_min', '0.20'))) - self.model.setItem( - WIDGETS.W_QCUTOFF, QtGui.QStandardItem(params.get('upper_q_max', '0.22'))) - self.model.setItem(WIDGETS.W_BACKGROUND, QtGui.QStandardItem( - params.get('background', '0'))) - # reconnect model - self.model.itemChanged.connect(self.model_changed) - self.cmdSave.setEnabled(params.get('guinier_a', '0.0') != '0.0') - self.cmdExtract.setEnabled(params.get('long_period', '0') != '0') + # Assign values to inputs - use defaults if not found + # Suppress undo during programmatic load + with self._undo_stack_obj.suppressed(): + # don't raise model_changed signal for a while + self.model.itemChanged.disconnect(self.model_changed) + self.model.setItem( + WIDGETS.W_GUINIERA, QtGui.QStandardItem(params.get('guinier_a', '0.0'))) + self.model.setItem( + WIDGETS.W_GUINIERB, QtGui.QStandardItem(params.get('guinier_b', '0.0'))) + self.model.setItem( + WIDGETS.W_PORODK, QtGui.QStandardItem(params.get('porod_k', '0.0'))) + self.model.setItem(WIDGETS.W_PORODSIGMA, QtGui.QStandardItem( + params.get('porod_sigma', '0.0'))) + self.model.setItem(WIDGETS.W_CORETHICK, QtGui.QStandardItem( + params.get('avg_core_thick', '0'))) + self.model.setItem(WIDGETS.W_INTTHICK, QtGui.QStandardItem( + params.get('avg_inter_thick', '0'))) + self.model.setItem(WIDGETS.W_HARDBLOCK, QtGui.QStandardItem( + params.get('avg_hard_block_thick', '0'))) + self.model.setItem(WIDGETS.W_SOFTBLOCK, QtGui.QStandardItem( + params.get('avg_soft_block_thick', '0'))) + self.model.setItem(WIDGETS.W_CRYSTAL, QtGui.QStandardItem( + params.get('local_crystalinity', '0'))) + self.model.setItem( + WIDGETS.W_POLY_RYAN, QtGui.QStandardItem(params.get('polydispersity', '0'))) + self.model.setItem( + WIDGETS.W_POLY_STRIBECK, QtGui.QStandardItem(params.get('polydispersity_stribeck', '0'))) + self.model.setItem( + WIDGETS.W_PERIOD, QtGui.QStandardItem(params.get('long_period', '0'))) + self.model.setItem( + WIDGETS.W_FILENAME, QtGui.QStandardItem(params.get('data_name', ''))) + self.model.setItem( + WIDGETS.W_QMIN, QtGui.QStandardItem(params.get('lower_q_max', '0.01'))) + self.model.setItem( + WIDGETS.W_QMAX, QtGui.QStandardItem(params.get('upper_q_min', '0.20'))) + self.model.setItem( + WIDGETS.W_QCUTOFF, QtGui.QStandardItem(params.get('upper_q_max', '0.22'))) + self.model.setItem(WIDGETS.W_BACKGROUND, QtGui.QStandardItem( + params.get('background', '0'))) + # reconnect model + self.model.itemChanged.connect(self.model_changed) + self.cmdSave.setEnabled(params.get('guinier_a', '0.0') != '0.0') + self.cmdExtract.setEnabled(params.get('long_period', '0') != '0') + + # Re-baseline after programmatic restore + self._rebaseline_undo_state() @property def real_space_figure(self): diff --git a/src/sas/qtgui/Perspectives/Corfunc/UnitTesting/CorfuncTest.py b/src/sas/qtgui/Perspectives/Corfunc/UnitTesting/CorfuncTest.py index 2c1f5ecf3c..874a5d30f4 100755 --- a/src/sas/qtgui/Perspectives/Corfunc/UnitTesting/CorfuncTest.py +++ b/src/sas/qtgui/Perspectives/Corfunc/UnitTesting/CorfuncTest.py @@ -7,7 +7,7 @@ from sasdata.dataloader.loader import Loader import sas.qtgui.Utilities.GuiUtils as GuiUtils -from sas.qtgui.Perspectives.Corfunc.CorfuncPerspective import CorfuncWindow +from sas.qtgui.Perspectives.Corfunc.CorfuncPerspective import WIDGETS, CorfuncWindow from sas.qtgui.Plotting.PlotterData import Data1D from sas.qtgui.UnitTesting import base_path @@ -78,7 +78,7 @@ def testProcess(self, widget, mocker): os.stat(filename) except OSError: assert False, "ISIS_98929.TXT does not exist" - f = Loader().load(filename) + Loader().load(filename) mocker.patch.object(QtWidgets.QFileDialog, 'getOpenFileName', return_value=(filename, '')) assert widget.txtBackground.text() == '' @@ -159,6 +159,21 @@ def testLoadParams(self, widget): widget.removeData([self.fakeData]) self.checkDefaults(widget) + def testUndoRestoreReconnectsModelChangedOnError(self, widget, mocker): + widget.model.itemChanged.disconnect(widget.model_changed) + model_changed = mocker.Mock() + widget.model_changed = model_changed + widget.model.itemChanged.connect(widget.model_changed) + set_item = mocker.patch.object(widget.model, 'setItem', side_effect=RuntimeError("restore failed")) + + with pytest.raises(RuntimeError): + widget._restore_parameter_values({}) + + mocker.stop(set_item) + widget.model.setItem(WIDGETS.W_BACKGROUND, QtGui.QStandardItem("1.0")) + + model_changed.assert_called() + def checkFakeDataState(self, widget): assert widget.txtFilename.text() == 'data' assert widget.txtLowerQMax.text() == '0.137973' diff --git a/src/sas/qtgui/Perspectives/Fitting/DataManager.py b/src/sas/qtgui/Perspectives/Fitting/DataManager.py new file mode 100644 index 0000000000..de286a4cab --- /dev/null +++ b/src/sas/qtgui/Perspectives/Fitting/DataManager.py @@ -0,0 +1,291 @@ +""" +DataManager - Handle dataset loading, weighting, and Q-range updates. + +This class encapsulates data management logic for the fitting perspective, +separating data handling concerns from the main FittingWidget. +""" + +import copy +from collections.abc import Callable +from typing import Any + +import numpy as np +from PySide6 import QtGui + +from sas.qtgui.Perspectives.Fitting import FittingUtilities +from sas.qtgui.Perspectives.Fitting.FittingLogic import FittingLogic +from sas.qtgui.Plotting.PlotterData import Data1D, Data2D +from sas.qtgui.Utilities import GuiUtils + + +class DataManager: + """ + Manages data loading, processing, and state for fitting operations. + + Responsibilities: + - Load datasets from GUI items (single or batch) + - Apply weighting to data for fitting + - Calculate and update Q-range + - Manage data state (is2D, batch fitting, data loaded) + - Create default datasets for theory calculations + """ + + def __init__(self, parent: Any | None = None): + """ + Initialize the DataManager. + + Args: + parent: Parent widget (optional, for callbacks) + """ + self.parent = parent + + # Data state + self.data_is_loaded = False + self.is_batch_fitting = False + self.is2D = False + + # Data holders + self.all_data: list[QtGui.QStandardItem] = [] + self.data_index = 0 + self._logic: list[FittingLogic] = [FittingLogic()] + + # Q-range parameters + self.q_range_min = 0.005 + self.q_range_max = 0.1 + self.npts = 50 + self.log_points = True + self.weighting = 0 + + # Callbacks (set by parent) + self.on_data_loaded: Callable[[], None] | None = None + self.on_q_range_updated: Callable[[float, float, int], None] | None = None + + @property + def logic(self) -> FittingLogic: + """Get the current FittingLogic instance for the active data.""" + assert self._logic + return self._logic[self.data_index] + + @property + def data(self) -> Data1D | Data2D: + """Get the current data object.""" + return self.logic.data + + def loadDataFromItems(self, value: QtGui.QStandardItem | list[QtGui.QStandardItem]) -> None: + """ + Load data from GUI items (single or batch). + + Args: + value: Single QStandardItem or list of items containing data + """ + # Convert to list format + if isinstance(value, list): + self.is_batch_fitting = True + else: + value = [value] + + assert isinstance(value[0], QtGui.QStandardItem) + + # Keep reference to all datasets for batch + self.all_data = value + + # Create logics with data items + if len(value) == 1: + # Single data - update existing logic + self._logic[0].data = GuiUtils.dataFromItem(value[0]) + else: + # Batch datasets - create multiple logics + self._logic = [] + for data_item in value: + logic = FittingLogic(data=GuiUtils.dataFromItem(data_item)) + self._logic.append(logic) + + # Determine data dimensionality + self.is2D = isinstance(self.logic.data, Data2D) + + # Mark data as loaded + self.data_is_loaded = True + + # Update Q-range from data + self.updateQRange() + + # Notify parent if callback is set + if self.on_data_loaded: + self.on_data_loaded() + + def addWeightingToData(self, data: Data1D | Data2D) -> Data1D | Data2D: + """ + Add weighting contribution to fitting data. + + Args: + data: Data object to apply weighting to + + Returns: + New data object with weighting applied + """ + if not self.data_is_loaded: + # No weighting for theories (dy = 0) + return data + + new_data = copy.deepcopy(data) + + # Calculate weights + weight = FittingUtilities.getWeight( + data=data, + is2d=self.is2D, + flag=self.weighting + ) + + # Apply weights based on dimensionality + if self.is2D: + new_data.err_data = weight + else: + new_data.dy = weight + + return new_data + + def updateQRange(self) -> None: + """ + Update Q-range from current data. + + Calculates Q-range from data if loaded, otherwise keeps current values. + """ + if self.data_is_loaded: + self.q_range_min, self.q_range_max, self.npts = \ + self.logic.computeDataRange() + + # Notify parent if callback is set + if self.on_q_range_updated: + self.on_q_range_updated( + self.q_range_min, + self.q_range_max, + self.npts + ) + + def setQRangeParameters( + self, + q_min: float, + q_max: float, + npts: int, + log_points: bool, + weighting: int + ) -> None: + """ + Set Q-range parameters for calculations. + + Args: + q_min: Minimum Q value + q_max: Maximum Q value + npts: Number of points + log_points: Use logarithmic spacing + weighting: Weighting type (0=none, 1=statistical, etc.) + """ + self.q_range_min = q_min + self.q_range_max = q_max + self.npts = npts + self.log_points = log_points + self.weighting = weighting + + def createDefault1DDataset(self, tab_id: int) -> None: + """ + Create a default 1D dataset for theory calculations. + + Args: + tab_id: Tab identifier for the dataset + """ + if self.log_points: + qmin = -10.0 if self.q_range_min < 1.e-10 else np.log10(self.q_range_min) + qmax = 10.0 if self.q_range_max > 1.e10 else np.log10(self.q_range_max) + interval = np.logspace( + start=qmin, + stop=qmax, + num=self.npts, + endpoint=True, + base=10.0 + ) + else: + interval = np.linspace( + start=self.q_range_min, + stop=self.q_range_max, + num=int(self.npts), + endpoint=True + ) + + self.logic.createDefault1dData(interval, tab_id) + + def createDefault2DDataset(self, tab_id: int) -> None: + """ + Create a default 2D dataset for theory calculations. + + Args: + tab_id: Tab identifier for the dataset + """ + qmax = self.q_range_max / np.sqrt(2) + qstep = self.npts + self.logic.createDefault2dData(qmax, qstep, tab_id) + + def createDefaultDataset(self, tab_id: int) -> None: + """ + Create a default dataset (1D or 2D) based on current state. + + Args: + tab_id: Tab identifier for the dataset + """ + if self.is2D: + self.createDefault2DDataset(tab_id) + else: + self.createDefault1DDataset(tab_id) + + def selectBatchData(self, index: int) -> None: + """ + Select a specific dataset in batch fitting mode. + + Args: + index: Index of the dataset to select + """ + if 0 <= index < len(self.all_data): + self.data_index = index + self.updateQRange() + + def getWeight(self, data: Data1D | Data2D | None = None) -> Any: + """ + Get weight array for the specified data. + + Args: + data: Data object (uses current data if None) + + Returns: + Weight array + """ + if data is None: + data = self.data + + return FittingUtilities.getWeight( + data=data, + is2d=self.is2D, + flag=self.weighting + ) + + def reset(self) -> None: + """Reset the data manager to initial state.""" + self.data_is_loaded = False + self.is_batch_fitting = False + self.is2D = False + self.all_data = [] + self.data_index = 0 + self._logic = [FittingLogic()] + + def setCallbacks( + self, + on_data_loaded: Callable[[], None] | None = None, + on_q_range_updated: Callable[[float, float, int], None] | None = None + ) -> None: + """ + Set callback functions for data manager events. + + Args: + on_data_loaded: Called when data is loaded + on_q_range_updated: Called when Q-range is updated + """ + self.on_data_loaded = on_data_loaded + self.on_q_range_updated = on_q_range_updated diff --git a/src/sas/qtgui/Perspectives/Fitting/FitPage.py b/src/sas/qtgui/Perspectives/Fitting/FitPage.py index 1f5fe98cee..f1774cdbac 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FitPage.py +++ b/src/sas/qtgui/Perspectives/Fitting/FitPage.py @@ -59,23 +59,3 @@ def __init__(self): self.algorithm = None self.algorithm_options = {} - def save(self): - """ - Serialize the current state - """ - pass - - def load(self, location): - """ - Retrieve serialized state from specified location - """ - pass - - def saveAsXML(self): - """ - Serialize the current state - """ - # Connect to PageState.to_xml(), which serializes - # to the existing XML with file I(Q) - pass - diff --git a/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py b/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py index 30e1446d8e..afe88e5c02 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py +++ b/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py @@ -575,6 +575,12 @@ def currentFittingWidget(self) -> FittingWidget | None: else: return None + @property + def undo_stack(self): + """Return undo stack for the currently selected fitting tab.""" + fitting_widget = self.currentFittingWidget + return None if fitting_widget is None else fitting_widget.undo_stack + def getFitTabs(self): """ Returns the list of fitting tabs diff --git a/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py b/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py index a09f7c0a96..427f6f94ae 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @@ -36,7 +36,17 @@ from sas.qtgui.Perspectives.Fitting.ReportPageLogic import ReportPageLogic from sas.qtgui.Perspectives.Fitting.SmearingWidget import SmearingWidget from sas.qtgui.Perspectives.Fitting.UI.FittingWidgetUI import Ui_FittingWidgetUI +from sas.qtgui.Perspectives.Fitting.UndoRedo import ( + CheckboxToggleCommand, + FitOptionsCommand, + FitResultCommand, + ModelSelectionCommand, + ParameterMinMaxCommand, + ParameterValueCommand, + SmearingOptionsCommand, +) from sas.qtgui.Perspectives.Fitting.ViewDelegate import ModelViewDelegate +from sas.qtgui.Perspectives.UndoRedo import UndoStack from sas.qtgui.Plotting.Plotter import PlotterWidget from sas.qtgui.Plotting.PlotterData import Data1D, Data2D, DataRole from sas.qtgui.Utilities.BackgroundColor import BG_DEFAULT, BG_ERROR @@ -214,6 +224,9 @@ def dataFromItems(self, value: QtGui.QStandardItem | list[QtGui.QStandardItem]) assert isinstance(value[0], QtGui.QStandardItem) + if hasattr(self, "undo_stack"): + self.undo_stack.clear() + # Keep reference to all datasets for batch self.all_data = value @@ -300,10 +313,10 @@ def initializeGlobals(self) -> None: self.weighting = 0 self.chi2 = None - # Does the control support UNDO/REDO - # temporarily off - self.undo_supported = False - self.page_stack = [] + # Undo/redo stack (per-tab, incremental command pattern) + self.undo_stack = UndoStack(self) + self._last_smearing_state = None + self._last_model_triple = None # (category, model, structure) for undo capture self.all_data = [] # custom plugin models # {model.name:model} @@ -489,36 +502,40 @@ def setEnablementOnDataLoad(self) -> None: """ Enable/disable various UI elements based on data loaded """ - # Tag along functionality - self.label.setText("Data loaded from: ") - if self.logic.data.name: - self.lblFilename.setText(self.logic.data.name) - else: - self.lblFilename.setText(self.logic.data.filename) - self.updateQRange() - # Switch off Data2D control - self.chk2DView.setEnabled(False) - self.chk2DView.setVisible(False) - self.chkMagnetism.setEnabled(self.canHaveMagnetism()) - self.tabFitting.setTabEnabled(TAB_MAGNETISM, self.chkMagnetism.isChecked()) - # Combo box or label for file name" - if self.is_batch_fitting: - self.lblFilename.setVisible(False) - for dataitem in self.all_data: - name = GuiUtils.dataFromItem(dataitem).name - self.cbFileNames.addItem(name) - self.cbFileNames.setVisible(True) - self.chkChainFit.setEnabled(True) - self.chkChainFit.setVisible(True) - # This panel is not designed to view individual fits, so disable plotting - self.cmdPlot.setVisible(False) - # Similarly on other tabs - self.options_widget.setEnablementOnDataLoad() - self.onSelectModel() - # Smearing tab - self.smearing_widget.updateData(self.data) - # Check if a model was already loaded when data is sent to the tab - self.cmdFit.setEnabled(self.haveParamsToFit()) + # Suppress undo capture: data loading is a programmatic operation, + # not a user-initiated parameter change. Without this, updateQRange + # and onWeightingChoice push spurious FitOptionsCommand entries. + with self.undo_stack.suppressed(): + # Tag along functionality + self.label.setText("Data loaded from: ") + if self.logic.data.name: + self.lblFilename.setText(self.logic.data.name) + else: + self.lblFilename.setText(self.logic.data.filename) + self.updateQRange() + # Switch off Data2D control + self.chk2DView.setEnabled(False) + self.chk2DView.setVisible(False) + self.chkMagnetism.setEnabled(self.canHaveMagnetism()) + self.tabFitting.setTabEnabled(TAB_MAGNETISM, self.chkMagnetism.isChecked()) + # Combo box or label for file name" + if self.is_batch_fitting: + self.lblFilename.setVisible(False) + for dataitem in self.all_data: + name = GuiUtils.dataFromItem(dataitem).name + self.cbFileNames.addItem(name) + self.cbFileNames.setVisible(True) + self.chkChainFit.setEnabled(True) + self.chkChainFit.setVisible(True) + # This panel is not designed to view individual fits, so disable plotting + self.cmdPlot.setVisible(False) + # Similarly on other tabs + self.options_widget.setEnablementOnDataLoad() + self.onSelectModel() + # Smearing tab + self.smearing_widget.updateData(self.data) + # Check if a model was already loaded when data is sent to the tab + self.cmdFit.setEnabled(self.haveParamsToFit()) def acceptsData(self) -> bool: """ Tells the caller this widget can accept new dataset """ @@ -570,6 +587,9 @@ def togglePoly(self, isChecked: bool) -> None: # Check if any parameters are ready for fitting self.cmdFit.setEnabled(self.haveParamsToFit()) self.polydispersity_widget.togglePoly(isChecked) + self.undo_stack.push( + CheckboxToggleCommand('chkPolydispersity', not isChecked, isChecked) + ) def onPolyToggled(self, isChecked: bool) -> None: """ @@ -587,6 +607,9 @@ def toggleMagnetism(self, isChecked: bool) -> None: # Check if any parameters are ready for fitting self.cmdFit.setEnabled(self.haveParamsToFit()) self.magnetism_widget.isActive = isChecked + self.undo_stack.push( + CheckboxToggleCommand('chkMagnetism', not isChecked, isChecked) + ) def onMagnetismToggled(self, isChecked: bool) -> None: """ @@ -611,9 +634,14 @@ def toggle2D(self, isChecked: bool) -> None: """ Enable/disable the controls dependent on 1D/2D data instance """ self.chkMagnetism.setEnabled(isChecked) self.is2D = isChecked - # Reload the current model - if self.logic.kernel_module: - self.onSelectModel() + # Reload the current model — suppress so the inner onSelectModel + # doesn't push its own command; toggle2D is one atomic undo step. + with self.undo_stack.suppressed(): + if self.logic.kernel_module: + self.onSelectModel() + self.undo_stack.push( + CheckboxToggleCommand('chk2DView', not isChecked, isChecked) + ) @classmethod def customModels(cls) -> dict[str, Any]: @@ -1167,6 +1195,11 @@ def onSelectModel(self) -> None: self.cbModel.setCurrentIndex(self._previous_model_index) self.cbModel.blockSignals(False) return + + # Capture old state for undo before anything changes + old_triple = self._last_model_triple + old_params = self._get_parameter_dict() + if self.model_data is not None: # Store any old parameters before switching to a new model self.page_parameters = self.getParameterDict() @@ -1178,30 +1211,46 @@ def onSelectModel(self) -> None: if not model: return - self.chkMagnetism.setEnabled(self.canHaveMagnetism()) - self.tabFitting.setTabEnabled(TAB_MAGNETISM, self.chkMagnetism.isChecked() and self.canHaveMagnetism()) - self._previous_model_index = self.cbModel.currentIndex() + # Suppress undo capture during the model rebuild + with self.undo_stack.suppressed(): + self.chkMagnetism.setEnabled(self.canHaveMagnetism()) + self.tabFitting.setTabEnabled(TAB_MAGNETISM, self.chkMagnetism.isChecked() and self.canHaveMagnetism()) + self._previous_model_index = self.cbModel.currentIndex() - # Reset parameters to fit - self.resetParametersToFit() - self.has_error_column = False - self.polydispersity_widget.is2D = self.is2D - self.polydispersity_widget.has_poly_error_column = False - self.magnetism_widget.has_magnet_error_column = False - - structure = None - if self.cbStructureFactor.isEnabled(): - structure = str(self.cbStructureFactor.currentText()) - self.respondToModelStructure(model=model, structure_factor=structure) + # Reset parameters to fit + self.resetParametersToFit() + self.has_error_column = False + self.polydispersity_widget.is2D = self.is2D + self.polydispersity_widget.has_poly_error_column = False + self.magnetism_widget.has_magnet_error_column = False - # paste parameters from previous state - if self.page_parameters: - self.updatePageWithParameters(self.page_parameters, warn_user=False) + structure = None + if self.cbStructureFactor.isEnabled(): + structure = str(self.cbStructureFactor.currentText()) + self.respondToModelStructure(model=model, structure_factor=structure) + + # paste parameters from previous state + if self.page_parameters: + self.updatePageWithParameters(self.page_parameters, warn_user=False) + + # disable polydispersity if the model does not support it + has_poly = self.polydispersity_widget.poly_model.rowCount() != 0 + self.chkPolydispersity.setEnabled(has_poly) + # self.tabFitting.setTabEnabled(TAB_POLY, has_poly) + + # Capture new state after model change + new_triple = ( + str(self.cbCategory.currentText()), + str(self.cbModel.currentText()), + str(self.cbStructureFactor.currentText()) if self.cbStructureFactor.isEnabled() else '', + ) + new_params = self._get_parameter_dict() + self._last_model_triple = new_triple - # disable polydispersity if the model does not support it - has_poly = self.polydispersity_widget.poly_model.rowCount() != 0 - self.chkPolydispersity.setEnabled(has_poly) - # self.tabFitting.setTabEnabled(TAB_POLY, has_poly) + if old_triple is not None and (old_triple != new_triple or old_params != new_params): + self.undo_stack.push( + ModelSelectionCommand(old_triple, new_triple, old_params, new_params) + ) # set focus so it doesn't move up self.cbModel.setFocus() @@ -1210,6 +1259,8 @@ def onSelectBatchFilename(self, data_index: int) -> None: """ Update the logic based on the selected file in batch fitting """ + if data_index != self.data_index: + self.undo_stack.clear() self.data_index = data_index self.updateQRange() @@ -1384,9 +1435,6 @@ def respondToModelStructure(self, model: str | None = None, structure_factor: st # Update plot self.updateData() - # Update state stack - self.updateUndo() - # Let others know self.newModelSignal.emit() @@ -1536,6 +1584,8 @@ def onFit(self) -> None: # Disable some elements self.disableInteractiveElements() + # The undo stack stays enabled during the fit; intermediate parameter + # updates are blocked via the suppressed() context manager instead. def stopFit(self) -> None: """ @@ -1549,6 +1599,7 @@ def stopFit(self) -> None: msg = "Fitting cancelled." self.communicator.statusBarUpdateSignal.emit(msg) + self.communicator.undoRedoUpdateSignal.emit() def updateFit(self) -> None: """ @@ -1648,6 +1699,12 @@ def fitComplete(self, result: tuple) -> None: self.kernel_module = copy.deepcopy(self.kernel_module_copy) return + # Capture parameter snapshot BEFORE fit results are applied + # (main + polydispersity + magnetism, since a fit can modify all three). + # The fit mutates the live kernel module in place, so capture the + # pre-fit values from kernel_module_copy (the deepcopy made in onFit). + old_snapshot = self._get_fit_result_snapshot(self.kernel_module_copy) + # Don't recalculate chi2 - it's in res.fitness already self.fitResults = True if result is None or len(result) == 0 or len(result[0]) == 0: @@ -1675,11 +1732,10 @@ def fitComplete(self, result: tuple) -> None: # Dictionary of fitted parameter: value, error # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)} - self.fitting_controller.updateModelFromList(param_dict) - - self.polydispersity_widget.updatePolyModelFromList(param_dict) - - self.magnetism_widget.updateMagnetModelFromList(param_dict) + with self.undo_stack.suppressed(): + self.fitting_controller.updateModelFromList(param_dict) + self.polydispersity_widget.updatePolyModelFromList(param_dict) + self.magnetism_widget.updateMagnetModelFromList(param_dict) # update charts self.onPlot() @@ -1688,6 +1744,17 @@ def fitComplete(self, result: tuple) -> None: chi2_repr = GuiUtils.formatNumber(self.chi2, high=True) self.lblChi2Value.setText(chi2_repr) + # Push a single FitResultCommand for the entire fit + new_snapshot = self._get_fit_result_snapshot() + if old_snapshot != new_snapshot: + self.undo_stack.push(FitResultCommand(old_snapshot, new_snapshot)) + + # Ensure undo/redo action state is refreshed in the GUI manager. + # The push above emits stackChanged, but if the per-stack signal + # connection was disrupted (e.g. tab change during processEvents + # in onPlot), this communicator signal provides a reliable fallback. + self.communicator.undoRedoUpdateSignal.emit() + def prepareFitters(self, fitter: Fit | None = None, fit_id: int = 0, weight_increase: int = 1) -> tuple[list[Fit], int]: """ @@ -1756,11 +1823,18 @@ def onSmearingOptionsUpdate(self) -> None: """ React to changes in the smearing widget """ + old_state = self._last_smearing_state + # update display smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state() self.lblCurrentSmearing.setText(smearing) self.calculateQGridForModel() + new_state = self._get_smearing_state_dict() + if old_state is not None and old_state != new_state: + self.undo_stack.push(SmearingOptionsCommand(old_state, new_state)) + self._last_smearing_state = new_state + def onKey(self, event: QtGui.QKeyEvent) -> None: if event.key() in [QtCore.Qt.Key_Enter, QtCore.Qt.Key_Return] and self.cmdPlot.isEnabled(): self.onPlot() @@ -1817,6 +1891,8 @@ def onOptionsUpdate(self) -> None: """ Update local option values and replot """ + old_opts = self._get_fit_options_dict() + self.q_range_min, self.q_range_max, self.npts, self.log_points, self.weighting = \ self.options_widget.state() # set Q range labels on the main tab @@ -1824,12 +1900,21 @@ def onOptionsUpdate(self) -> None: self.lblMaxRangeDef.setText(GuiUtils.formatNumber(self.q_range_max, high=True)) self.recalculatePlotData() + new_opts = self._get_fit_options_dict() + if old_opts != new_opts: + self.undo_stack.push(FitOptionsCommand(old_opts, new_opts)) + def setDefaultStructureCombo(self) -> None: """ Fill in the structure factors combo box with defaults """ structure_factor_list = self.master_category_dict.pop(CATEGORY_STRUCTURE) - factors = [factor[0] for factor in structure_factor_list] + # Only offer structure factors that are actually loadable. The + # categorization file (which may be a stale user file) can list models + # that are not present in the installed sasmodels, and selecting one of + # those would raise a KeyError in self.models[...] and leave the page + # in a broken state. + factors = [factor[0] for factor in structure_factor_list if factor[0] in self.models] factors.insert(0, STRUCTURE_DEFAULT) self.cbStructureFactor.clear() self.cbStructureFactor.addItems(sorted(factors)) @@ -2062,11 +2147,9 @@ def fromModelToQModel(self, model_name: str) -> None: self.logic.kernel_module.name = self.modelName() # Explicitly add scale and background with default values - temp_undo_state = self.undo_supported - self.undo_supported = False - self.addScaleToModel(self._model_model) - self.addBackgroundToModel(self._model_model) - self.undo_supported = temp_undo_state + with self.undo_stack.suppressed(): + self.addScaleToModel(self._model_model) + self.addBackgroundToModel(self._model_model) self.logic.shell_names = self.shellNamesList() @@ -2196,8 +2279,7 @@ def onMainParamsChange(self, top: QtCore.QModelIndex, bottom: QtCore.QModelIndex if model_column == 0: self.checkboxSelected(item, model_key="standard") self.cmdFit.setEnabled(self.haveParamsToFit()) - # Update state stack - self.updateUndo() + # Fit-checkbox toggles intentionally excluded from undo stack return model_row = item.row() @@ -2222,16 +2304,18 @@ def onMainParamsChange(self, top: QtCore.QModelIndex, bottom: QtCore.QModelIndex min_column = self.lstParams.itemDelegate().param_min max_column = self.lstParams.itemDelegate().param_max if model_column == param_column: - # don't try to update multiplicity counters if they aren't there. - # Note that this will fail for proper bad update where the model - # doesn't contain multiplicity parameter + old_val = self.logic.kernel_module.getParam(parameter_name) self.logic.kernel_module.setParam(parameter_name, value) + self.undo_stack.push(ParameterValueCommand(parameter_name, old_val, value)) elif model_column == min_column: - # min/max to be changed in self.logic.kernel_module.details[parameter_name] = ['Ang', 0.0, inf] + old_val = self.logic.kernel_module.details[parameter_name][1] self.logic.kernel_module.details[parameter_name][1] = value + self.undo_stack.push(ParameterMinMaxCommand(parameter_name, "min", old_val, value)) elif model_column == max_column: + old_val = self.logic.kernel_module.details[parameter_name][2] self.logic.kernel_module.details[parameter_name][2] = value + self.undo_stack.push(ParameterMinMaxCommand(parameter_name, "max", old_val, value)) else: # don't update the chart return @@ -2240,8 +2324,8 @@ def onMainParamsChange(self, top: QtCore.QModelIndex, bottom: QtCore.QModelIndex # necessary self.processEffectiveRadius() - # Update state stack - self.updateUndo() + # Keep page_parameters up-to-date so parameter retention across + # model switches works even when no plot has been computed yet. self.page_parameters = self.getParameterDict() def processEffectiveRadius(self) -> None: @@ -3094,34 +3178,280 @@ def saveToFitPage(self, fp: FitPage) -> None: # TODO: add polydispersity and magnetism - def updateUndo(self) -> None: - """ - Create a new state page and add it to the stack - """ - if self.undo_supported: - self.pushFitPage(self.currentState()) + # ------------------------------------------------------------------ + # Undo/redo helper methods (called by UndoCommand subclasses) + # ------------------------------------------------------------------ + + def _update_model_param_value(self, param_name: str, value: float) -> None: + """Update the UI model item for a parameter value (called by undo/redo).""" + for row in range(self._model_model.rowCount()): + name_item = self._model_model.item(row, 0) + if name_item is None: + continue + row_name = str(name_item.data(QtCore.Qt.UserRole) or name_item.text()) + if row_name == param_name: + col = self.lstParams.itemDelegate().param_value + value_item = self._model_model.item(row, col) + if value_item: + value_item.setText(GuiUtils.formatNumber(value, high=True)) + return - def currentState(self) -> FitPage: + def _update_model_param_limit(self, param_name: str, bound: str, value: float) -> None: + """Update min or max bound in the UI model (called by undo/redo).""" + col = self.lstParams.itemDelegate().param_min if bound == "min" else self.lstParams.itemDelegate().param_max + for row in range(self._model_model.rowCount()): + name_item = self._model_model.item(row, 0) + if name_item is None: + continue + row_name = str(name_item.data(QtCore.Qt.UserRole) or name_item.text()) + if row_name == param_name: + item = self._model_model.item(row, col) + if item: + item.setText(GuiUtils.formatNumber(value, high=True)) + return + + def _restore_model_selection(self, triple: tuple, params: dict) -> None: + """Restore model selection and parameter values (called by undo/redo). + + ``triple`` is ``(category, model, structure_factor)``. """ - Return fit page with current state + with self.undo_stack.suppressed(): + category, model, structure = triple + # Block cbModel signals while changing category to prevent + # onSelectModel firing with the wrong model during repopulation. + self.cbModel.blockSignals(True) + self.cbCategory.setCurrentIndex(self.cbCategory.findText(category)) + self.cbModel.blockSignals(False) + self.cbModel.setCurrentIndex(self.cbModel.findText(model)) + if structure: + self.cbStructureFactor.setCurrentIndex( + self.cbStructureFactor.findText(structure) + ) + else: + self.cbStructureFactor.setCurrentIndex(0) + # Apply saved parameter values + self._restore_parameter_values(params) + + def _apply_fit_options(self, options: dict) -> None: + """Apply fit options dict (called by undo/redo).""" + with self.undo_stack.suppressed(): + self.q_range_min = options.get('q_range_min', self.q_range_min) + self.q_range_max = options.get('q_range_max', self.q_range_max) + self.npts = options.get('npts', self.npts) + self.log_points = options.get('log_points', self.log_points) + self.weighting = options.get('weighting', self.weighting) + self.options_widget.setState( + self.q_range_min, self.q_range_max, + self.npts, self.log_points, self.weighting, + ) + self.lblMinRangeDef.setText(GuiUtils.formatNumber(self.q_range_min, high=True)) + self.lblMaxRangeDef.setText(GuiUtils.formatNumber(self.q_range_max, high=True)) + self.recalculatePlotData() + + def _apply_smearing_state(self, state: dict) -> None: + """Apply smearing options from a dict (called by undo/redo).""" + with self.undo_stack.suppressed(): + self.smearing_widget.setState( + state.get('smearing'), state.get('accuracy'), + state.get('d_down'), state.get('d_up'), + ) + self.calculateQGridForModel() + + def _restore_parameter_values(self, params: dict) -> None: + """Restore all kernel parameter values from a ``{name: value}`` dict.""" + with self.undo_stack.suppressed(): + for name, value in params.items(): + self.logic.kernel_module.setParam(name, value) + self._update_model_param_value(name, value) + self.calculateQGridForModel() + + def _update_poly_param_value(self, param_name: str, value: float) -> None: + """Update the polydispersity UI model item for *param_name* (``.width``).""" + poly_model = self.polydispersity_widget.poly_model + for row in range(poly_model.rowCount()): + name_item = poly_model.item(row, 0) + if name_item is None: + continue + row_name = str(name_item.text()).rsplit()[-1] + '.width' + if row_name == param_name: + value_item = poly_model.item(row, 1) + if value_item: + value_item.setText(GuiUtils.formatNumber(value, high=True)) + return + + def _update_magnet_param_value(self, param_name: str, value: float) -> None: + """Update the magnetism UI model item for *param_name*.""" + magnet_model = self.magnetism_widget._magnet_model + for row in range(magnet_model.rowCount()): + name_item = magnet_model.item(row, 0) + if name_item is None: + continue + if str(name_item.text()) == param_name: + value_item = magnet_model.item(row, 1) + if value_item: + value_item.setText(GuiUtils.formatNumber(value, high=True)) + return + + def _get_poly_param_dict(self, kernel_module=None) -> dict: + """Return ``{.width: value}`` for all polydisperse parameters. + + Parameter *names* are taken from the polydispersity UI model (stable + across a fit); *values* are read from ``kernel_module`` so a pre-fit + snapshot can be captured from ``kernel_module_copy`` (defaults to the + live kernel module). """ - new_page = FitPage() - self.saveToFitPage(new_page) + if kernel_module is None: + kernel_module = self.logic.kernel_module + if kernel_module is None: + return {} + poly_model = self.polydispersity_widget.poly_model + result = {} + for row in range(poly_model.rowCount()): + name_item = poly_model.item(row, 0) + if name_item is None: + continue + param_name = str(name_item.text()).rsplit()[-1] + '.width' + try: + result[param_name] = kernel_module.getParam(param_name) + except (KeyError, ValueError): + continue + return result - return new_page + def _get_magnet_param_dict(self, kernel_module=None) -> dict: + """Return ``{param_name: value}`` for all magnetism parameters. - def pushFitPage(self, new_page: FitPage) -> None: + As with :meth:`_get_poly_param_dict`, *values* are read from + ``kernel_module`` (defaults to the live kernel module) so a pre-fit + snapshot can be captured from ``kernel_module_copy``. """ - Add a new fit page object with current state + if kernel_module is None: + kernel_module = self.logic.kernel_module + if kernel_module is None: + return {} + magnet_model = self.magnetism_widget._magnet_model + result = {} + for row in range(magnet_model.rowCount()): + name_item = magnet_model.item(row, 0) + if name_item is None: + continue + param_name = str(name_item.text()) + try: + result[param_name] = kernel_module.getParam(param_name) + except (KeyError, ValueError): + continue + return result + + def _get_fit_result_snapshot(self, kernel_module=None) -> dict: + """Capture main, polydispersity and magnetism parameter values. + + Reads values from ``kernel_module`` (defaults to the live kernel + module). The pre-fit snapshot in ``fitComplete`` MUST be captured + from ``self.kernel_module_copy``, since the fit mutates the live + kernel module in place. + + Returns a structured dict consumed by ``FitResultCommand``:: + + {"main": {...}, "poly": {...}, "magnet": {...}} """ - self.page_stack.append(new_page) + return { + "main": self._get_parameter_dict(kernel_module), + "poly": self._get_poly_param_dict(kernel_module), + "magnet": self._get_magnet_param_dict(kernel_module), + } - def popFitPage(self) -> None: + def _restore_fit_result_snapshot(self, snapshot: dict) -> None: + """Restore main, polydispersity and magnetism values from a snapshot. + + A FitResultCommand snapshot only carries parameter *values*, not the + fitted uncertainties, so any error column shown in the tables is no + longer valid once the values are restored. Remove it. + """ + with self.undo_stack.suppressed(): + self._remove_error_columns() + for name, value in snapshot.get("main", {}).items(): + self.logic.kernel_module.setParam(name, value) + self._update_model_param_value(name, value) + for name, value in snapshot.get("poly", {}).items(): + self.logic.kernel_module.setParam(name, value) + self._update_poly_param_value(name, value) + for name, value in snapshot.get("magnet", {}).items(): + self.logic.kernel_module.setParam(name, value) + self._update_magnet_param_value(name, value) + self.calculateQGridForModel() + + def _remove_error_columns(self) -> None: + """Remove the fitted-uncertainty (error) column from the main, + polydispersity and magnetism tables, restoring the no-error layout. + + Safe to call when no error column is present (each section is guarded + by its ``has_*_error_column`` flag). """ - Remove top fit page from stack + # --- main model (and its polydispersity sub-rows) --- + if self.has_error_column: + def deletePolyErrorColumn(row): + item = self._model_model.item(row, 0) + if item is None or not item.hasChildren(): + return + poly_item = item.child(0) + if poly_item is None or not poly_item.hasChildren(): + return + poly_item.removeColumn(2) + + self._model_model.removeColumn(2) + self.iterateOverModel(deletePolyErrorColumn) + self.lstParams.itemDelegate().removeErrorColumn() + FittingUtilities.addHeadersToModel(self._model_model) + self.has_error_column = False + + # --- polydispersity model --- + if self.polydispersity_widget.has_poly_error_column: + self.polydispersity_widget.poly_model.removeColumn(2) + self.polydispersity_widget.lstPoly.itemDelegate().removeErrorColumn() + FittingUtilities.addPolyHeadersToModel(self.polydispersity_widget.poly_model) + self.polydispersity_widget.has_poly_error_column = False + + # --- magnetism model --- + if self.magnetism_widget.has_magnet_error_column: + self.magnetism_widget._magnet_model.removeColumn(2) + self.magnetism_widget.lstMagnetic.itemDelegate().removeErrorColumn() + FittingUtilities.addHeadersToModel(self.magnetism_widget._magnet_model) + self.magnetism_widget.has_magnet_error_column = False + + def _get_parameter_dict(self, kernel_module=None) -> dict: + """Return ``{param_name: float_value}`` for all current kernel params. + + Values are read from ``kernel_module`` (defaults to the live kernel + module) so a pre-fit snapshot can be captured from + ``kernel_module_copy``. """ - if self.page_stack: - self.page_stack.pop() + if kernel_module is None: + kernel_module = self.logic.kernel_module + if kernel_module is None: + return {} + return { + p.name: kernel_module.getParam(p.name) + for p in kernel_module._model_info.parameters.kernel_parameters + } + + def _get_fit_options_dict(self) -> dict: + """Return current fit options as a dict (for undo command capture).""" + return { + 'q_range_min': self.q_range_min, + 'q_range_max': self.q_range_max, + 'npts': self.npts, + 'log_points': self.log_points, + 'weighting': self.weighting, + } + + def _get_smearing_state_dict(self) -> dict: + """Return current smearing state as a dict (for undo command capture).""" + smearing, accuracy, d_down, d_up = self.smearing_widget.state() + return { + 'smearing': smearing, + 'accuracy': accuracy, + 'd_down': d_down, + 'd_up': d_up, + } def getReport(self) -> list[str]: """ diff --git a/src/sas/qtgui/Perspectives/Fitting/MagnetismWidget.py b/src/sas/qtgui/Perspectives/Fitting/MagnetismWidget.py index 985726e2a0..fee4c9d242 100644 --- a/src/sas/qtgui/Perspectives/Fitting/MagnetismWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/MagnetismWidget.py @@ -12,6 +12,10 @@ # Local UI from sas.qtgui.Perspectives.Fitting.UI.MagnetismWidget import Ui_MagnetismWidgetUI +from sas.qtgui.Perspectives.Fitting.UndoRedo import ( + ParameterMinMaxCommand, + ParameterValueCommand, +) from sas.qtgui.Perspectives.Fitting.ViewDelegate import MagnetismViewDelegate logger = logging.getLogger(__name__) @@ -31,6 +35,7 @@ def __init__(self, parent: QtWidgets.QWidget | None = None, logic: Any | None = self._magnet_model = FittingUtilities.ToolTippedItemModel() self.is2D = False self.isActive = False + self._fitting_widget = parent self.logic = parent.logic self.magnet_params = {} self.has_magnet_error_column = False @@ -138,18 +143,30 @@ def onMagnetModelChange(self, top: QtCore.QModelIndex, bottom: QtCore.QModelInde if model_column > 1: if model_column == delegate.mag_min: pos = 1 + bound = "min" elif model_column == delegate.mag_max: pos = 2 + bound = "max" elif model_column == delegate.mag_unit: pos = 0 + bound = None else: # For all other values sent here (e.g. the error column, do nothing) return # min/max to be changed in self.logic.kernel_module.details[parameter_name] = ['Ang', 0.0, inf] + old_val = self.logic.kernel_module.details[parameter_name][pos] self.logic.kernel_module.details[parameter_name][pos] = value + if bound is not None: + self._fitting_widget.undo_stack.push( + ParameterMinMaxCommand(parameter_name, bound, old_val, value) + ) else: + old_val = self.logic.kernel_module.getParam(parameter_name) self.magnet_params[parameter_name] = value self.logic.kernel_module.setParam(parameter_name, value) + self._fitting_widget.undo_stack.push( + ParameterValueCommand(parameter_name, old_val, value) + ) # Update plot self.updateDataSignal.emit() diff --git a/src/sas/qtgui/Perspectives/Fitting/OptionsWidget.py b/src/sas/qtgui/Perspectives/Fitting/OptionsWidget.py index 5577ce19ea..92cd9b9c69 100644 --- a/src/sas/qtgui/Perspectives/Fitting/OptionsWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/OptionsWidget.py @@ -239,14 +239,24 @@ def updateQRange(self, q_range_min, q_range_max, npts): """ Update the local model based on calculated values """ - qmax = str(q_range_max) - qmin = str(q_range_min) - self.model.item(self.MODEL.index('MIN_RANGE')).setText(qmin) - self.model.item(self.MODEL.index('MAX_RANGE')).setText(qmax) - self.model.item(self.MODEL.index('NPTS')).setText(str(npts)) - self.qmin, self.qmax, self.npts = q_range_min, q_range_max, npts - npts_fit = self.npts2fit(self.logic.data, self.qmin, self.qmax, self.npts) - self.model.item(self.MODEL.index('NPTS_FIT')).setText(str(npts_fit)) + # Block signals to prevent intermediate dataChanged→onModelChange→ + # plot_signal firing for each individual setText. Without this, + # onOptionsUpdate receives partially-updated state and can push + # multiple spurious FitOptionsCommand entries onto the undo stack. + self.model.blockSignals(True) + try: + qmax = str(q_range_max) + qmin = str(q_range_min) + self.model.item(self.MODEL.index('MIN_RANGE')).setText(qmin) + self.model.item(self.MODEL.index('MAX_RANGE')).setText(qmax) + self.model.item(self.MODEL.index('NPTS')).setText(str(npts)) + self.qmin, self.qmax, self.npts = q_range_min, q_range_max, npts + npts_fit = self.npts2fit(self.logic.data, self.qmin, self.qmax, self.npts) + self.model.item(self.MODEL.index('NPTS_FIT')).setText(str(npts_fit)) + finally: + self.model.blockSignals(False) + # Single signal after all values are consistent + self.plot_signal.emit() def state(self): """ @@ -259,6 +269,26 @@ def state(self): log_points = self.chkLogData.isChecked() return (q_range_min, q_range_max, npts, log_points, self.weighting) + def setState(self, q_range_min, q_range_max, npts, log_points, weighting): + """ + Set the state of controls from provided values. + Used by undo/redo to restore fit options. + """ + self.model.blockSignals(True) + self.updateQRange(q_range_min, q_range_max, npts) + self.chkLogData.setChecked(log_points) + self.weighting = weighting + # Update the weighting radio buttons to match + buttons = self.weightingGroup.buttons() + for btn in buttons: + btn_id = abs(self.weightingGroup.id(btn) + 2) + if btn_id == weighting: + btn.setChecked(True) + break + self.model.blockSignals(False) + # Refresh the QDataWidgetMapper so text fields reflect the model + self.mapper.toFirst() + def npts2fit(self, data=None, qmin=None, qmax=None, npts=None): """ return numbers of data points within qrange diff --git a/src/sas/qtgui/Perspectives/Fitting/PolydispersityWidget.py b/src/sas/qtgui/Perspectives/Fitting/PolydispersityWidget.py index d6cf00291d..d2f26e2037 100644 --- a/src/sas/qtgui/Perspectives/Fitting/PolydispersityWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/PolydispersityWidget.py @@ -15,6 +15,10 @@ # Local UI from sas.qtgui.Perspectives.Fitting.UI.PolydispersityWidget import Ui_PolydispersityWidgetUI +from sas.qtgui.Perspectives.Fitting.UndoRedo import ( + ParameterMinMaxCommand, + ParameterValueCommand, +) from sas.qtgui.Perspectives.Fitting.ViewDelegate import PolyViewDelegate DEFAULT_POLYDISP_FUNCTION = 'gaussian' @@ -34,6 +38,7 @@ def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: self.poly_model = FittingUtilities.ToolTippedItemModel() self.is2D = False self.isActive = False + self._fitting_widget = parent self.logic = parent.logic self.poly_params = {} self.has_poly_error_column = False @@ -139,7 +144,6 @@ def onPolyModelChange(self, top: QtCore.QModelIndex) -> None: if parameter_name_w in self.poly_params_to_fit: self.poly_params_to_fit.remove(parameter_name_w) self.cmdFitSignal.emit() - # self.updateUndo() elif model_column in [delegate.poly_min, delegate.poly_max]: try: @@ -151,9 +155,15 @@ def onPolyModelChange(self, top: QtCore.QModelIndex) -> None: current_details = self.logic.kernel_module.details[parameter_name_w] if self.has_poly_error_column: # err column changes the indexing - current_details[model_column-2] = value + pos = model_column - 2 else: - current_details[model_column-1] = value + pos = model_column - 1 + old_val = current_details[pos] + current_details[pos] = value + bound = "min" if pos == 1 else "max" + self._fitting_widget.undo_stack.push( + ParameterMinMaxCommand(parameter_name_w, bound, old_val, value) + ) elif model_column == delegate.poly_function: # name of the function - just pass @@ -172,8 +182,12 @@ def onPolyModelChange(self, top: QtCore.QModelIndex) -> None: # Map the column to the poly param that was changed associations = {1: "width", delegate.poly_npts: "npts", delegate.poly_nsigs: "nsigmas"} p_name = f"{parameter_name}.{associations.get(model_column, 'width')}" + old_val = self.logic.kernel_module.getParam(p_name) self.poly_params[p_name] = value self.logic.kernel_module.setParam(p_name, value) + self._fitting_widget.undo_stack.push( + ParameterValueCommand(p_name, old_val, value) + ) # Update plot self.updateDataSignal.emit() diff --git a/src/sas/qtgui/Perspectives/Fitting/UndoRedo.py b/src/sas/qtgui/Perspectives/Fitting/UndoRedo.py new file mode 100644 index 0000000000..9bb4f7b7b4 --- /dev/null +++ b/src/sas/qtgui/Perspectives/Fitting/UndoRedo.py @@ -0,0 +1,269 @@ +"""Undo/Redo commands specific to the SasView Fitting perspective. + +The shared base classes (UndoCommand, CompoundCommand, UndoStack, +DictSnapshotCommand) live in ``sas.qtgui.Perspectives.UndoRedo``. + +This module defines Fitting-specific subclasses: +- ParameterValueCommand +- ParameterMinMaxCommand +- ModelSelectionCommand +- FitOptionsCommand +- SmearingOptionsCommand +- CheckboxToggleCommand +- FitResultCommand + +Design notes: +- Each command stores only old_value + new_value (delta, not snapshot). +- ParameterValueCommand supports coalescing: consecutive edits to the same + parameter are merged into one entry. +- Parameter 'fit' checkbox toggles are intentionally NOT tracked + (see UNDO_PLAN_CLAUDE.md, Decisions). +""" +from __future__ import annotations + +import logging +from typing import Any + +# Base class for Fitting-specific commands — imported from shared module +from sas.qtgui.Perspectives.UndoRedo import UndoCommand # noqa: E402 + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Concrete commands — Fitting-specific +# --------------------------------------------------------------------------- + +class ParameterValueCommand(UndoCommand): + """Single parameter value change. + + Applies via ``widget.logic.kernel_module.setParam()`` and + ``widget._update_model_param_value()`` (added in Phase 2). + + Supports coalescing: two consecutive edits to the same parameter are + merged into one entry whose undo reverts all the way to the first + captured value. + """ + + def __init__(self, param_name: str, old_val: float, new_val: float) -> None: + super().__init__(f"Change {param_name}") + self._param_name = param_name + self._old_val = old_val + self._new_val = new_val + + @property + def param_name(self) -> str: + return self._param_name + + def _apply(self, widget, value: float) -> None: + widget.logic.kernel_module.setParam(self._param_name, value) + widget._update_model_param_value(self._param_name, value) + # Recompute the theory so the plot reflects the restored input value. + # The snapshot-based restore paths (_restore_parameter_values, + # _apply_fit_options, _restore_fit_result_snapshot, ...) all end with a + # recompute; without this an undo/redo updates the parameter table but + # leaves the plotted curve stale, desyncing the view from the params. + widget.calculateQGridForModel() + + def undo(self, widget) -> None: + self._apply(widget, self._old_val) + + def redo(self, widget) -> None: + self._apply(widget, self._new_val) + + #: Maximum age difference (seconds) between two edits to the same parameter + #: that may still be coalesced into a single undo entry. Edits farther + #: apart than this are treated as independent actions. + _COALESCE_WINDOW_SECONDS: float = 5.0 + + def can_merge(self, other: UndoCommand) -> bool: + return ( + isinstance(other, ParameterValueCommand) + and other._param_name == self._param_name + and (other.timestamp - self.timestamp) <= self._COALESCE_WINDOW_SECONDS + ) + + def merge(self, other: ParameterValueCommand) -> ParameterValueCommand: + """Merge *self* (earlier) with *other* (later). + + The merged command undoes all the way to *self*'s old value and + redoes all the way to *other*'s new value. The *other* timestamp + (latest edit) is carried forward so that the coalescing window is + measured from the most recent edit, not the first one in the group. + """ + merged = ParameterValueCommand(self._param_name, self._old_val, other._new_val) + merged.timestamp = other.timestamp + return merged + + +class ParameterMinMaxCommand(UndoCommand): + """Parameter min or max bound change. + + ``bound`` must be ``"min"`` or ``"max"``. + Writes directly to ``kernel_module.details[param_name][1 or 2]`` and + delegates UI item update to ``widget._update_model_param_limit()`` + (added in Phase 2). + """ + + _BOUND_INDEX: dict[str, int] = {"min": 1, "max": 2} + + def __init__( + self, param_name: str, bound: str, old_val: float, new_val: float + ) -> None: + assert bound in ("min", "max"), ( + f"bound must be 'min' or 'max', got {bound!r}" + ) + super().__init__(f"Change {param_name} {bound}") + self._param_name = param_name + self._bound = bound + self._old_val = old_val + self._new_val = new_val + + def _apply(self, widget, value: float) -> None: + idx = self._BOUND_INDEX[self._bound] + widget.logic.kernel_module.details[self._param_name][idx] = value + widget._update_model_param_limit(self._param_name, self._bound, value) + # Keep the plotted theory in sync with the restored bound, matching the + # snapshot-based restore paths (see ParameterValueCommand._apply). + widget.calculateQGridForModel() + + def undo(self, widget) -> None: + self._apply(widget, self._old_val) + + def redo(self, widget) -> None: + self._apply(widget, self._new_val) + + +class ModelSelectionCommand(UndoCommand): + """Category / model / structure-factor triple change. + + ``old_triple`` / ``new_triple`` are ``(category, model, structure_factor)``. + ``old_params`` / ``new_params`` are ``{param_name: value}`` dicts. + + On undo the old model triple is re-selected (triggering a param table + rebuild) and old parameter *values* are re-applied. UI micro-state + (expanded rows, active editor) is NOT restored — values only. + + The entire replay must run inside ``undo_stack.suppressed()`` (handled + in Phase 2) to prevent the internal rebuild from creating spurious + stack entries. + + Delegates to ``widget._restore_model_selection(triple, params)`` + (added in Phase 2). + """ + + def __init__( + self, + old_triple: tuple[str, str, str], + new_triple: tuple[str, str, str], + old_params: dict[str, float], + new_params: dict[str, float], + ) -> None: + super().__init__(f"Select model {new_triple[1]!r}") + self._old_triple = old_triple + self._new_triple = new_triple + self._old_params = dict(old_params) + self._new_params = dict(new_params) + + def undo(self, widget) -> None: + widget._restore_model_selection(self._old_triple, self._old_params) + + def redo(self, widget) -> None: + widget._restore_model_selection(self._new_triple, self._new_params) + + +class FitOptionsCommand(UndoCommand): + """Q range, npts, log_points, weighting changes. + + Delegates to ``widget._apply_fit_options(options)`` (added in Phase 2). + """ + + def __init__( + self, old_options: dict[str, Any], new_options: dict[str, Any] + ) -> None: + super().__init__("Change fit options") + self._old_options = dict(old_options) + self._new_options = dict(new_options) + + def undo(self, widget) -> None: + widget._apply_fit_options(self._old_options) + + def redo(self, widget) -> None: + widget._apply_fit_options(self._new_options) + + +class SmearingOptionsCommand(UndoCommand): + """Smearing state change. + + Delegates to ``widget._apply_smearing_state(state)`` (added in Phase 2). + """ + + def __init__( + self, old_state: dict[str, Any], new_state: dict[str, Any] + ) -> None: + super().__init__("Change smearing options") + self._old_state = dict(old_state) + self._new_state = dict(new_state) + + def undo(self, widget) -> None: + widget._apply_smearing_state(self._old_state) + + def redo(self, widget) -> None: + widget._apply_smearing_state(self._new_state) + + +class CheckboxToggleCommand(UndoCommand): + """Polydispersity / magnetism / 2D-view toggle. + + ``checkbox_id`` is the attribute name of the QCheckBox on *widget* + (e.g. ``"chkPolydispersity"``). + + Note: parameter 'fit' checkbox toggles are intentionally NOT tracked. + See UNDO_PLAN_CLAUDE.md, Decisions for rationale. + """ + + def __init__(self, checkbox_id: str, old_bool: bool, new_bool: bool) -> None: + super().__init__(f"Toggle {checkbox_id}") + self._checkbox_id = checkbox_id + self._old_bool = old_bool + self._new_bool = new_bool + + def _apply(self, widget, value: bool) -> None: + getattr(widget, self._checkbox_id).setChecked(value) + + def undo(self, widget) -> None: + self._apply(widget, self._old_bool) + + def redo(self, widget) -> None: + self._apply(widget, self._new_bool) + + +class FitResultCommand(UndoCommand): + """Full parameter snapshot before and after a fit. + + The snapshots are structured dicts of the form:: + + {"main": {name: value}, "poly": {name.width: value}, "magnet": {name: value}} + + covering the main kernel parameters as well as the polydispersity-width + and magnetism parameters that a fit can also modify. + + ``old_snapshot`` MUST be captured at the very start of ``fitComplete()``, + before ``updateModelFromList()`` is called (see UNDO_PLAN_CLAUDE.md, + Step 2.6 — Critical ordering). + + Delegates to ``widget._restore_fit_result_snapshot(snapshot)`` + (added in Phase 2). + """ + + def __init__( + self, old_snapshot: dict[str, dict], new_snapshot: dict[str, dict] + ) -> None: + super().__init__("Fit result") + self._old_snapshot = {key: dict(val) for key, val in old_snapshot.items()} + self._new_snapshot = {key: dict(val) for key, val in new_snapshot.items()} + + def undo(self, widget) -> None: + widget._restore_fit_result_snapshot(self._old_snapshot) + + def redo(self, widget) -> None: + widget._restore_fit_result_snapshot(self._new_snapshot) diff --git a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingPerspectiveTest.py b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingPerspectiveTest.py index 33616d8808..21870c0758 100644 --- a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingPerspectiveTest.py +++ b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingPerspectiveTest.py @@ -215,6 +215,19 @@ def testGetFitTabs(self, widget): assert isinstance(tabs, list) assert len(tabs) == 2 + def testUndoStackDelegatesToCurrentFittingTab(self, widget): + '''undo_stack should delegate to the active fitting tab''' + current_fitting_widget = widget.currentFittingWidget + assert current_fitting_widget is not None + assert widget.undo_stack is current_fitting_widget.undo_stack + + def testUndoStackIsNoneForNonFittingTab(self, widget): + '''undo_stack should be None when a non-fitting tab is active''' + widget.addConstraintTab() + widget.setCurrentIndex(widget.count() - 1) + assert widget.currentFittingWidget is None + assert widget.undo_stack is None + @pytest.mark.xfail(reason="2026-02: Mocker patching using old constraint API") def testGetActiveConstraintList(self, widget, mocker): '''test the active constraint getter''' diff --git a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingWidgetTest.py b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingWidgetTest.py index c9b7eea8f9..5214d8c8f8 100644 --- a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingWidgetTest.py +++ b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/FittingWidgetTest.py @@ -15,6 +15,7 @@ from sas.qtgui.Perspectives.Fitting import FittingUtilities, FittingWidget from sas.qtgui.Perspectives.Fitting.Constraint import Constraint from sas.qtgui.Perspectives.Fitting.ModelThread import Calc2D +from sas.qtgui.Perspectives.UndoRedo import UndoStack from sas.qtgui.Plotting.PlotterData import Data1D, Data2D from sas.qtgui.UnitTesting.TestUtils import QtSignalSpy from sas.qtgui.Utilities import GuiUtils @@ -274,8 +275,10 @@ def testSelectCategory(self, widget): # Try to change back to default widget.cbCategory.setCurrentIndex(0) - # Observe no such luck - assert widget.cbCategory.currentIndex() == 8 + # Observe no such luck - it bounces back to the previously chosen + # category. (Use the looked-up index rather than a hard-coded one; + # the category positions depend on the installed model set.) + assert widget.cbCategory.currentIndex() == category_index assert widget.cbModel.count() == 29 # Set the structure factor @@ -327,11 +330,16 @@ def testSelectModel(self, widget, mocker): # Observe calculateQGridForModel called assert widget.calculateQGridForModel.called - def testSelectFactor(self, widget): + def testSelectFactor(self, widget, mocker): """ Assure proper behaviour on changing structure factor """ widget.show() + # Mock the background calculation. Selecting a model/structure factor + # kicks off an asynchronous Q-grid calculation that briefly disables the + # interactive controls (incl. cmdPlot) until it completes. Mocking it + # keeps the control-enablement assertions below from racing that thread. + mocker.patch.object(widget, 'calculateQGridForModel') # Change the category index so we have some models category_index = widget.cbCategory.findText("Cylinder") widget.cbCategory.setCurrentIndex(category_index) @@ -368,9 +376,12 @@ def testSelectFactor(self, widget): assert not widget.cbModel.isEnabled() assert widget._model_model.rowCount() == 0 - # Choose the last factor - last_index = widget.cbStructureFactor.count() - widget.cbStructureFactor.setCurrentIndex(last_index-1) + # Choose a structure factor with a known parameter set. hayter_msa has + # 6 parameters which, together with the heading row, gives 7 rows. + # (Don't rely on "the last factor" - the available structure factors + # depend on the installed sasmodels version.) + sf_index = widget.cbStructureFactor.findText('hayter_msa') + widget.cbStructureFactor.setCurrentIndex(sf_index) # Do we have all the rows (incl. radius_effective & heading row)? assert widget._model_model.rowCount() == 7 @@ -527,7 +538,7 @@ def testPolyModelChange(self, widget): widget.cbModel.setCurrentIndex(model_index) # click on a poly parameter checkbox - index = widget.polydispersity_widget.poly_model.index(0,0) + widget.polydispersity_widget.poly_model.index(0,0) # Set the checbox widget.polydispersity_widget.poly_model.item(0,0).setCheckState(QtCore.Qt.CheckState.Checked) @@ -685,7 +696,7 @@ def testSetMagneticModel(self, widget): # Test rows for row in range(widget.magnetism_widget._magnet_model.rowCount()): - func_index = widget.magnetism_widget._magnet_model.index(row, 0) + widget.magnetism_widget._magnet_model.index(row, 0) assert '_' in widget.magnetism_widget._magnet_model.item(row, 0).text() @@ -922,7 +933,7 @@ def notestOnFit2D(self, widget): # Spying on status update signal update_spy = QtSignalSpy(widget, widget.communicator.statusBarUpdateSignal) - with threads.deferToThread as MagicMock: + with threads.deferToThread: widget.onFit() # thread called assert threads.deferToThread.called @@ -1101,44 +1112,6 @@ def testCurrentState(self, widget): assert fp.current_model == "adsorbed_layer" assert fp.main_params_to_fit == ['scale'] - def notestPushFitPage(self, widget): - """ - Push current state of fitpage onto stack - """ - # Set data - test_data = Data1D(x=[1,2], y=[1,2]) - item = QtGui.QStandardItem() - GuiUtils.updateModelItem(item, test_data, "test") - # Force same data into logic - widget.data = item - category_index = widget.cbCategory.findText("Sphere") - model_index = widget.cbModel.findText("adsorbed_layer") - widget.cbModel.setCurrentIndex(model_index) - - # Asses the initial state of stack - assert widget.page_stack == [] - - # Set the undo flag - widget.undo_supported = True - widget.cbCategory.setCurrentIndex(category_index) - widget.main_params_to_fit = ['scale'] - - # Check that the stack is updated - assert len(widget.page_stack) == 1 - - # Change another parameter - widget._model_model.item(3, 1).setText("3.0") - - # Check that the stack is updated - assert len(widget.page_stack) == 2 - - def testPopFitPage(self, widget): - """ - Pop current state of fitpage from stack - """ - # TODO: to be added when implementing UNDO/REDO - pass - def testOnMainPageChange(self, widget): """ Test update values of modified parameters in models @@ -1799,3 +1772,71 @@ def testQRangeReset(self, widget, mocker): assert widget_with_data.options_widget.qmin == min(data.x) assert widget_with_data.options_widget.qmax == max(data.x) assert widget_with_data.options_widget.npts == len(data.x) + + +class FittingWidgetUndoRedoTest: + """Basic undo/redo integration tests (Phase 5, Step 5.3).""" + + @pytest.fixture(autouse=True) + def widget(self, qapp, monkeypatch): + """Create/Destroy the GUI (monkeypatch variant).""" + w = FittingWidgetMod(dummy_manager()) + monkeypatch.setattr(FittingUtilities, 'checkConstraints', lambda *a, **kw: None) + yield w + w.close() + + def testWidgetHasUndoStack(self, widget): + """FittingWidget should own an UndoStack instance.""" + assert hasattr(widget, 'undo_stack') + assert isinstance(widget.undo_stack, UndoStack) + assert not widget.undo_stack.can_undo() + assert not widget.undo_stack.can_redo() + + def testParamEditUndoRedo(self, widget): + """Edit a parameter → undo restores old value → redo restores new value.""" + # Load a model + category_index = widget.cbCategory.findText("Cylinder") + widget.cbCategory.setCurrentIndex(category_index) + model_index = widget.cbModel.findText("cylinder") + widget.cbModel.setCurrentIndex(model_index) + widget.undo_stack.clear() + + # Find the 'radius' row and edit it + param_col = widget.lstParams.itemDelegate().param_value + for row in range(widget._model_model.rowCount()): + if widget._model_model.item(row, 0).text() == "radius": + item = widget._model_model.item(row, param_col) + old_val = float(item.text()) + item.setText("99.0") + break + else: + pytest.fail("'radius' parameter not found") + + # Verify undo restores original + assert widget.undo_stack.can_undo() + widget.undo_stack.undo() + assert widget.logic.kernel_module.getParam("radius") == old_val + + # Verify redo re-applies the edit + assert widget.undo_stack.can_redo() + widget.undo_stack.redo() + assert widget.logic.kernel_module.getParam("radius") == 99.0 + + def testUndoStackClearedOnModelChange(self, widget): + """Selecting a new model should not carry stale undo entries.""" + category_index = widget.cbCategory.findText("Cylinder") + widget.cbCategory.setCurrentIndex(category_index) + model_index = widget.cbModel.findText("cylinder") + widget.cbModel.setCurrentIndex(model_index) + + # Edit a parameter to create an undo entry + param_col = widget.lstParams.itemDelegate().param_value + for row in range(widget._model_model.rowCount()): + if widget._model_model.item(row, 0).text() == "radius": + widget._model_model.item(row, param_col).setText("99.0") + break + assert widget.undo_stack.can_undo() + + # Stack should still function (undo works) + widget.undo_stack.undo() + assert not widget.undo_stack.can_undo() diff --git a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoIntegrationTest.py b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoIntegrationTest.py new file mode 100644 index 0000000000..af5e590412 --- /dev/null +++ b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoIntegrationTest.py @@ -0,0 +1,930 @@ +"""Integration tests for undo/redo wiring in FittingWidget (Phase 2). + +These tests verify that FittingWidget handlers push the correct undo +commands to the UndoStack in response to user-initiated state changes. + +Unlike UndoRedoTest.py (which tests Phase 1 command/stack internals in +isolation with a mocked widget), these tests spin up a real FittingWidget +and exercise its actual signal/slot plumbing. + +Test organisation: + TestUndoStackInitialization — stack exists, initial state correct + TestParameterValueUndo — main params, poly, magnetism + TestParameterMinMaxUndo — min/max edits via main model + TestModelSelectionUndo — switching models via cbModel + TestFitOptionsUndo — Q range / npts / weighting changes + TestSmearingOptionsUndo — smearing state changes + TestCheckboxToggleUndo — poly / magnetism / 2D toggles + TestFitResultUndo — undo after a fit completes + TestUndoStackDuringFit — stack stays enabled while fitting + TestGuiManagerUndoHookSignalChain — GUI action refresh signal chain + TestOptionsWidgetSetState — OptionsWidget.setState() round-trip +""" +import glob +import os +from unittest.mock import MagicMock + +import pytest +from PySide6 import QtCore, QtGui, QtWidgets + +from sasmodels.sasview_model import load_custom_model + +from sas.qtgui.Perspectives.Fitting import FittingUtilities, FittingWidget +from sas.qtgui.Perspectives.Fitting.UndoRedo import ( + CheckboxToggleCommand, + FitOptionsCommand, + FitResultCommand, + ModelSelectionCommand, + ParameterMinMaxCommand, + ParameterValueCommand, + SmearingOptionsCommand, +) +from sas.qtgui.Perspectives.UndoRedo import UndoStack +from sas.qtgui.Plotting.PlotterData import Data1D +from sas.qtgui.Utilities import GuiUtils +from sas.sascalc.fit.models import ModelManager, ModelManagerBase + +# --------------------------------------------------------------------------- +# Helpers (same pattern as FittingWidgetTest.py) +# --------------------------------------------------------------------------- + +class dummy_manager: + HELP_DIRECTORY_LOCATION = "html" + communicator = GuiUtils.communicator + + def __init__(self): + self._perspective = dummy_perspective() + + def perspective(self): + return self._perspective + + +class dummy_perspective: + + def __init__(self): + self.symbol_dict = {} + self.constraint_list = [] + self.constraint_tab = None + + def getActiveConstraintList(self): + return self.constraint_list + + def getSymbolDictForConstraints(self): + return self.symbol_dict + + def getConstraintTab(self): + return self.constraint_tab + + +def find_plugin_models_mod(): + plugins_dir = [ + os.path.abspath(path) for path in glob.glob("**/plugin_models", recursive=True) + if os.path.normpath("qtgui/Perspectives/Fitting/plugin_models") in os.path.abspath(path) + ][0] + plugins = {} + for filename in os.listdir(plugins_dir): + name, ext = os.path.splitext(filename) + if ext == '.py' and not name == '__init__': + path = os.path.abspath(os.path.join(plugins_dir, filename)) + model = load_custom_model(path) + plugins[model.name] = model + return plugins + + +class ModelManagerBaseMod(ModelManagerBase): + def _is_plugin_dir_changed(self): + return False + + def plugins_reset(self): + self.plugin_models = find_plugin_models_mod() + self.model_dictionary.clear() + self.model_dictionary.update(self.standard_models) + self.model_dictionary.update(self.plugin_models) + return self.get_model_list() + + +class ModelManagerMod(ModelManager): + base = None + + def __init__(self): + if ModelManagerMod.base is None: + ModelManagerMod.base = ModelManagerBaseMod() + + +class FittingWidgetMod(FittingWidget.FittingWidget): + def customModels(cls): + manager = ModelManagerMod() + manager.update() + return manager.base.plugin_models + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _suppress_message_boxes(monkeypatch): + """Suppress QMessageBox dialogs globally.""" + monkeypatch.setattr( + "sas.qtgui.Perspectives.UndoRedo.QtWidgets.QMessageBox", + MagicMock(), + ) + + +@pytest.fixture +def widget(qapp, monkeypatch): + """Create a real FittingWidget for integration testing.""" + w = FittingWidgetMod(dummy_manager()) + monkeypatch.setattr(FittingUtilities, 'checkConstraints', lambda *a, **kw: None) + yield w + w.close() + del w + + +@pytest.fixture +def widget_with_model(widget): + """FittingWidget with 'cylinder' model loaded and processEvents run.""" + category_index = widget.cbCategory.findText("Cylinder") + widget.cbCategory.setCurrentIndex(category_index) + model_index = widget.cbModel.findText("cylinder") + widget.cbModel.setCurrentIndex(model_index) + QtWidgets.QApplication.processEvents() + # Clear the undo stack so model-load commands don't interfere with tests + widget.undo_stack.clear() + return widget + + +# --------------------------------------------------------------------------- +# TestUndoStackInitialization +# --------------------------------------------------------------------------- + +class TestUndoStackInitialization: + + def test_widget_has_undo_stack(self, widget): + assert hasattr(widget, 'undo_stack') + assert isinstance(widget.undo_stack, UndoStack) + + def test_initial_stack_is_empty(self, widget): + assert not widget.undo_stack.can_undo() + assert not widget.undo_stack.can_redo() + + def test_stack_widget_reference(self, widget): + """The stack must reference the correct widget for command replay.""" + assert widget.undo_stack._widget is widget + + +# --------------------------------------------------------------------------- +# TestParameterValueUndo +# --------------------------------------------------------------------------- + +class TestParameterValueUndo: + + def test_main_param_change_pushes_value_command(self, widget_with_model): + """Editing a model parameter value should push a ParameterValueCommand.""" + w = widget_with_model + # Find the 'radius' row + param_col = w.lstParams.itemDelegate().param_value + for row in range(w._model_model.rowCount()): + name = w._model_model.item(row, 0).text() + if name == "radius": + item = w._model_model.item(row, param_col) + old_val = float(item.text()) + item.setText("99.0") + break + else: + pytest.fail("Could not find 'radius' parameter in the model") + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, ParameterValueCommand) + assert top_cmd.param_name == "radius" + assert top_cmd._new_val == 99.0 + assert top_cmd._old_val == old_val + + def test_main_param_undo_restores_value(self, widget_with_model): + """Undo should restore the original kernel parameter value.""" + w = widget_with_model + param_col = w.lstParams.itemDelegate().param_value + for row in range(w._model_model.rowCount()): + name = w._model_model.item(row, 0).text() + if name == "radius": + old_val = float(w._model_model.item(row, param_col).text()) + w._model_model.item(row, param_col).setText("99.0") + break + + assert w.logic.kernel_module.getParam("radius") == 99.0 + w.undo_stack.undo() + assert w.logic.kernel_module.getParam("radius") == old_val + + def test_value_undo_redo_recomputes_plot(self, widget_with_model, mocker): + """Undo/redo of a value change must recompute the theory so the plot + stays in sync with the parameter table. + + Regression for the desync krzywon reported on PR #3995: undoing a + parameter value updated the table cell but left the plotted curve + stale, because ParameterValueCommand did not recompute (unlike every + snapshot-based restore path).""" + w = widget_with_model + w.undo_stack.push(ParameterValueCommand("radius", 10.0, 99.0)) + spy = mocker.patch.object(w, "calculateQGridForModel") + + w.undo_stack.undo() + assert spy.call_count >= 1, "undo did not recompute the theory plot" + + spy.reset_mock() + w.undo_stack.redo() + assert spy.call_count >= 1, "redo did not recompute the theory plot" + + def test_checkbox_column_does_not_push_command(self, widget_with_model): + """Toggling the 'fit' checkbox (column 0) must NOT push an undo command.""" + w = widget_with_model + initial_count = len(w.undo_stack._undo_stack) + item = w._model_model.item(0, 0) + item.setCheckState( + QtCore.Qt.Unchecked if item.checkState() == QtCore.Qt.Checked + else QtCore.Qt.Checked + ) + assert len(w.undo_stack._undo_stack) == initial_count + + +# --------------------------------------------------------------------------- +# TestDatasetUndoHistory +# --------------------------------------------------------------------------- + +class TestDatasetUndoHistory: + + def test_data_load_clears_existing_undo_history(self, widget_with_model, monkeypatch): + """Loading replacement data must drop undo commands from the previous data.""" + w = widget_with_model + w.undo_stack.push(ParameterValueCommand("radius", 1.0, 2.0)) + assert w.undo_stack.can_undo() + + data = Data1D(x=[0.01, 0.02], y=[1.0, 2.0]) + monkeypatch.setattr(GuiUtils, 'dataFromItem', lambda *a, **kw: data) + + w.dataFromItems(QtGui.QStandardItem("replacement data")) + + assert not w.undo_stack.can_undo() + assert not w.undo_stack.can_redo() + + def test_batch_file_switch_clears_existing_undo_history(self, widget_with_model, monkeypatch): + """Batch file selection changes the active logic, so undo history must reset.""" + w = widget_with_model + w.undo_stack.push(ParameterValueCommand("radius", 1.0, 2.0)) + assert w.undo_stack.can_undo() + + monkeypatch.setattr(w, 'updateQRange', lambda *a, **kw: None) + + w.onSelectBatchFilename(1) + + assert w.data_index == 1 + assert not w.undo_stack.can_undo() + assert not w.undo_stack.can_redo() + + +# --------------------------------------------------------------------------- +# TestParameterMinMaxUndo +# --------------------------------------------------------------------------- + +class TestParameterMinMaxUndo: + + def test_min_change_pushes_minmax_command(self, widget_with_model): + """Editing a min bound should push a ParameterMinMaxCommand.""" + w = widget_with_model + min_col = w.lstParams.itemDelegate().param_min + for row in range(w._model_model.rowCount()): + name = w._model_model.item(row, 0).text() + if name == "radius": + w._model_model.item(row, min_col).setText("1.0") + break + + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, ParameterMinMaxCommand) + assert top_cmd._param_name == "radius" + assert top_cmd._bound == "min" + assert top_cmd._new_val == 1.0 + + def test_max_change_pushes_minmax_command(self, widget_with_model): + """Editing a max bound should push a ParameterMinMaxCommand.""" + w = widget_with_model + max_col = w.lstParams.itemDelegate().param_max + for row in range(w._model_model.rowCount()): + name = w._model_model.item(row, 0).text() + if name == "radius": + w._model_model.item(row, max_col).setText("500.0") + break + + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, ParameterMinMaxCommand) + assert top_cmd._bound == "max" + assert top_cmd._new_val == 500.0 + + def test_minmax_undo_redo_recomputes_plot(self, widget_with_model, mocker): + """Undo/redo of a bound change must recompute the theory, keeping the + plot in sync with the parameter table (see + TestParameterValueUndo.test_value_undo_redo_recomputes_plot).""" + w = widget_with_model + w.undo_stack.push(ParameterMinMaxCommand("radius", "min", 1.0, 5.0)) + spy = mocker.patch.object(w, "calculateQGridForModel") + + w.undo_stack.undo() + assert spy.call_count >= 1, "undo did not recompute the theory plot" + + spy.reset_mock() + w.undo_stack.redo() + assert spy.call_count >= 1, "redo did not recompute the theory plot" + + +# --------------------------------------------------------------------------- +# TestModelSelectionUndo +# --------------------------------------------------------------------------- + +class TestModelSelectionUndo: + + def test_model_switch_pushes_model_selection_command(self, widget_with_model): + """Changing cbModel should push a ModelSelectionCommand.""" + w = widget_with_model + # Switch to a different model in the same category + model_index = w.cbModel.findText("barbell") + if model_index < 0: + pytest.skip("barbell model not available") + w.cbModel.setCurrentIndex(model_index) + QtWidgets.QApplication.processEvents() + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, ModelSelectionCommand) + assert "barbell" in top_cmd._new_triple[1] + + def test_model_selection_undo_restores_previous_model(self, widget_with_model): + """Undoing a model switch should restore the previous model type.""" + w = widget_with_model + original_model_type = type(w.logic.kernel_module) + + # Switch to a different model + model_index = w.cbModel.findText("barbell") + if model_index < 0: + pytest.skip("barbell model not available") + w.cbModel.setCurrentIndex(model_index) + QtWidgets.QApplication.processEvents() + + assert type(w.logic.kernel_module) is not original_model_type + w.undo_stack.undo() + QtWidgets.QApplication.processEvents() + # Verify the kernel module type was restored + assert type(w.logic.kernel_module) is original_model_type + + def test_select_default_model_no_command(self, widget_with_model): + """Selecting MODEL_DEFAULT should not push a command.""" + w = widget_with_model + initial_count = len(w.undo_stack._undo_stack) + + default_index = w.cbModel.findText(FittingWidget.MODEL_DEFAULT) + if default_index >= 0: + w.cbModel.setCurrentIndex(default_index) + # Should not push + assert len(w.undo_stack._undo_stack) == initial_count + + +# --------------------------------------------------------------------------- +# TestFitOptionsUndo +# --------------------------------------------------------------------------- + +class TestFitOptionsUndo: + + def test_options_update_pushes_fit_options_command(self, widget_with_model): + """Changing Q-range or npts should push a FitOptionsCommand.""" + w = widget_with_model + + # Modify the Q range via the options widget's model + from sas.qtgui.Perspectives.Fitting.OptionsWidget import OptionsWidget + w._get_fit_options_dict() + # Directly change the widget's model value to simulate user edit + w.options_widget.model.item( + OptionsWidget.MODEL.index('MIN_RANGE') + ).setText("0.01") + QtWidgets.QApplication.processEvents() + + if w.undo_stack.can_undo(): + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, FitOptionsCommand) + + def test_options_widget_setState_round_trip(self, widget_with_model): + """OptionsWidget.setState must restore values consistently.""" + w = widget_with_model + ow = w.options_widget + + # Get current state + original = ow.state() + # Set to different values + ow.setState(0.001, 1.0, 200, True, 2) + new_state = ow.state() + assert new_state[0] == pytest.approx(0.001) + assert new_state[1] == pytest.approx(1.0) + assert new_state[2] == 200 + assert new_state[3] is True + assert new_state[4] == 2 + + # Restore original + ow.setState(*original) + restored = ow.state() + assert restored[0] == pytest.approx(original[0]) + assert restored[1] == pytest.approx(original[1]) + assert restored[2] == original[2] + + +# --------------------------------------------------------------------------- +# TestSmearingOptionsUndo +# --------------------------------------------------------------------------- + +class TestSmearingOptionsUndo: + + def test_initial_smearing_state_is_none(self, widget): + """_last_smearing_state should be None until first update.""" + assert widget._last_smearing_state is None + + def test_first_smearing_update_no_command(self, widget_with_model): + """First smearing update should NOT push a command (no prior state).""" + w = widget_with_model + w._last_smearing_state = None + initial_count = len(w.undo_stack._undo_stack) + w.onSmearingOptionsUpdate() + # No command pushed because old_state was None + assert len(w.undo_stack._undo_stack) == initial_count + # But _last_smearing_state should now be populated + assert w._last_smearing_state is not None + + def test_second_smearing_update_pushes_command(self, widget_with_model): + """A smearing change after the first should push SmearingOptionsCommand.""" + w = widget_with_model + # Prime the initial state + w.onSmearingOptionsUpdate() + w.undo_stack.clear() + + # Change smearing — simulate by altering the stored state + old_state = dict(w._last_smearing_state) + # Trigger another update (even if values are the same, verify logic) + w.onSmearingOptionsUpdate() + new_state = w._last_smearing_state + + if old_state != new_state: + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, SmearingOptionsCommand) + + +# --------------------------------------------------------------------------- +# TestCheckboxToggleUndo +# --------------------------------------------------------------------------- + +class TestCheckboxToggleUndo: + + def test_toggle_poly_pushes_checkbox_command(self, widget_with_model): + """Toggling polydispersity should push a CheckboxToggleCommand.""" + w = widget_with_model + w.undo_stack.clear() + + w.chkPolydispersity.setEnabled(True) + w.chkPolydispersity.setChecked(True) + QtWidgets.QApplication.processEvents() + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, CheckboxToggleCommand) + assert top_cmd._checkbox_id == "chkPolydispersity" + assert top_cmd._new_bool is True + + def test_toggle_poly_undo_unchecks(self, widget_with_model): + """Undoing poly toggle should uncheck the checkbox.""" + w = widget_with_model + w.undo_stack.clear() + + w.chkPolydispersity.setEnabled(True) + w.chkPolydispersity.setChecked(True) + QtWidgets.QApplication.processEvents() + + w.undo_stack.undo() + assert not w.chkPolydispersity.isChecked() + + def test_toggle_magnetism_pushes_checkbox_command(self, widget_with_model): + """Toggling magnetism should push a CheckboxToggleCommand.""" + w = widget_with_model + w.undo_stack.clear() + + w.chkMagnetism.setEnabled(True) + w.chkMagnetism.setChecked(True) + QtWidgets.QApplication.processEvents() + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, CheckboxToggleCommand) + assert top_cmd._checkbox_id == "chkMagnetism" + + def test_toggle_2d_pushes_checkbox_command(self, widget_with_model): + """Toggling 2D view should push a CheckboxToggleCommand.""" + w = widget_with_model + w.undo_stack.clear() + + w.chk2DView.setEnabled(True) + w.chk2DView.setChecked(True) + QtWidgets.QApplication.processEvents() + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, CheckboxToggleCommand) + assert top_cmd._checkbox_id == "chk2DView" + + def test_toggle_2d_does_not_push_model_selection(self, widget_with_model): + """toggle2D calls onSelectModel internally — it must be suppressed.""" + w = widget_with_model + w.undo_stack.clear() + + w.chk2DView.setEnabled(True) + w.chk2DView.setChecked(True) + QtWidgets.QApplication.processEvents() + + # Should have exactly one command (CheckboxToggle), not two + # (i.e. the inner onSelectModel must not push its own ModelSelection) + checkbox_cmds = [ + c for c in w.undo_stack._undo_stack + if isinstance(c, CheckboxToggleCommand) + ] + model_cmds = [ + c for c in w.undo_stack._undo_stack + if isinstance(c, ModelSelectionCommand) + ] + assert len(checkbox_cmds) == 1 + assert len(model_cmds) == 0 + + +# --------------------------------------------------------------------------- +# TestFitResultUndo +# --------------------------------------------------------------------------- + +class TestFitResultUndo: + + def _make_fit_result(self, widget): + """Create a minimal fake fit result tuple.""" + res = MagicMock() + res.fitness = 1.5 + res.pvec = [1.0, 2.0, 3.0] + res.stderr = [0.1, 0.2, 0.3] + # Build param_dict that paramDictFromResults would return + param_names = [ + p.name for p in widget.logic.kernel_module._model_info.parameters.kernel_parameters + ] + res.pname = param_names[:len(res.pvec)] + return ([[res]], 0.5) + + def test_fit_complete_pushes_fit_result_command(self, widget_with_model, monkeypatch): + """fitComplete should push a FitResultCommand.""" + import copy + w = widget_with_model + w.undo_stack.clear() + # The pre-fit snapshot is captured from kernel_module_copy (set in onFit). + w.kernel_module_copy = copy.deepcopy(w.logic.kernel_module) + + # Mock methods that would crash without real fit data + monkeypatch.setattr(w.polydispersity_widget, 'updatePolyModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w.magnetism_widget, 'updateMagnetModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w, 'onPlot', lambda *a, **kw: None) + + # Create a result that paramDictFromResults can handle + old_params = w._get_parameter_dict() + param_dict = {n: (v + 1.0, 0.01) for n, v in old_params.items()} + monkeypatch.setattr( + w.fitting_controller, 'paramDictFromResults', lambda *a, **kw: param_dict + ) + + # Make updateModelFromList actually modify the kernel so new_params != old_params + def fake_update(pd): + for name, (val, _err) in pd.items(): + w.logic.kernel_module.setParam(name, val) + + monkeypatch.setattr( + w.fitting_controller, 'updateModelFromList', fake_update + ) + + result = self._make_fit_result(w) + w.fitComplete(result) + + assert w.undo_stack.can_undo() + top_cmd = w.undo_stack._undo_stack[-1] + assert isinstance(top_cmd, FitResultCommand) + + def test_fit_complete_keeps_undo_stack_enabled(self, widget_with_model, monkeypatch): + """fitComplete should leave the undo stack enabled.""" + import copy + w = widget_with_model + w.kernel_module_copy = copy.deepcopy(w.logic.kernel_module) + + monkeypatch.setattr(w.fitting_controller, 'updateModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w.polydispersity_widget, 'updatePolyModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w.magnetism_widget, 'updateMagnetModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w, 'onPlot', lambda *a, **kw: None) + monkeypatch.setattr( + w.fitting_controller, 'paramDictFromResults', + lambda *a, **kw: {n: (v, 0.01) for n, v in w._get_parameter_dict().items()} + ) + + result = self._make_fit_result(w) + w.fitComplete(result) + assert w.undo_stack._enabled + + def test_fit_complete_failed_no_command(self, widget_with_model, monkeypatch): + """A failed fit should not push an undo command.""" + w = widget_with_model + w.undo_stack.clear() + + monkeypatch.setattr(w, 'enableInteractiveElements', lambda *a, **kw: None) + w.kernel_module_copy = MagicMock() + + # Simulate failed fit + w.fitComplete(None) + assert not w.undo_stack.can_undo() + + def test_undo_fit_removes_error_column(self, widget_with_model, monkeypatch): + """Undoing a fit must remove the now-stale error column from the table.""" + import copy + w = widget_with_model + w.undo_stack.clear() + w.kernel_module_copy = copy.deepcopy(w.logic.kernel_module) + + # Let the real updateModelFromList run so the error column is added. + monkeypatch.setattr(w.polydispersity_widget, 'updatePolyModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w.magnetism_widget, 'updateMagnetModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w, 'onPlot', lambda *a, **kw: None) + + old_params = w._get_parameter_dict() + param_dict = {n: (v + 1.0, 0.01) for n, v in old_params.items()} + monkeypatch.setattr(w.fitting_controller, 'paramDictFromResults', lambda *a, **kw: param_dict) + + assert not w.has_error_column + cols_before = w._model_model.columnCount() + + result = self._make_fit_result(w) + w.fitComplete(result) + + # The fit added an error column. + assert w.has_error_column + assert w._model_model.columnCount() == cols_before + 1 + + # Undo removes it again. + w.undo_stack.undo() + assert not w.has_error_column + assert w._model_model.columnCount() == cols_before + + +# --------------------------------------------------------------------------- +# TestUndoStackDuringFit +# --------------------------------------------------------------------------- + +class TestUndoStackDuringFit: + + def test_onfit_leaves_undo_stack_enabled(self, widget_with_model, monkeypatch): + """onFit should leave the undo stack enabled (suppressed() handles blocking).""" + w = widget_with_model + + monkeypatch.setattr(w.fitting_controller, 'prepareFitters', lambda *a, **kw: ([MagicMock()], 0)) + monkeypatch.setattr(w, 'disableInteractiveElements', lambda *a, **kw: None) + monkeypatch.setattr('twisted.internet.threads.deferToThread', lambda *a, **kw: MagicMock()) + + w.onFit() + assert w.undo_stack._enabled + + def test_stopfit_keeps_undo_stack_enabled(self, widget_with_model, monkeypatch): + """stopFit should leave the undo stack enabled.""" + w = widget_with_model + w.calc_fit = MagicMock() + w.calc_fit.isrunning.return_value = True + + monkeypatch.setattr(w, 'enableInteractiveElements', lambda *a, **kw: None) + w.stopFit() + assert w.undo_stack._enabled + + +# --------------------------------------------------------------------------- +# TestGuiManagerUndoHookSignalChain +# --------------------------------------------------------------------------- + +class TestGuiManagerUndoHookSignalChain: + """Verify the signal chain that keeps the GUI undo/redo actions in sync + after a fit completes, reproducing the wiring done in + GuiManager._connect_undo_redo_hooks and the communicator fallback.""" + + def _make_fit_result(self, widget): + """Create a minimal fake fit result tuple.""" + res = MagicMock() + res.fitness = 1.5 + res.pvec = [1.0, 2.0, 3.0] + res.stderr = [0.1, 0.2, 0.3] + param_names = [ + p.name for p in widget.logic.kernel_module._model_info.parameters.kernel_parameters + ] + res.pname = param_names[:len(res.pvec)] + return ([[res]], 0.5) + + def _prepare_fit_complete(self, w, monkeypatch): + """Common fitComplete boilerplate: stub side effects, mutate the live + kernel so new_snapshot differs from the pre-fit snapshot.""" + import copy + w.undo_stack.clear() + w.kernel_module_copy = copy.deepcopy(w.logic.kernel_module) + + monkeypatch.setattr(w.polydispersity_widget, 'updatePolyModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w.magnetism_widget, 'updateMagnetModelFromList', lambda *a, **kw: None) + monkeypatch.setattr(w, 'onPlot', lambda *a, **kw: None) + + old_params = w._get_parameter_dict() + param_dict = {n: (v + 1.0, 0.01) for n, v in old_params.items()} + monkeypatch.setattr(w.fitting_controller, 'paramDictFromResults', lambda *a, **kw: param_dict) + + def fake_update(pd): + for name, (val, _err) in pd.items(): + w.logic.kernel_module.setParam(name, val) + + monkeypatch.setattr(w.fitting_controller, 'updateModelFromList', fake_update) + + def test_stack_changed_calls_handler_on_push(self, widget_with_model): + """stackChanged should fire when push() is called.""" + w = widget_with_model + handler_calls = [] + w.undo_stack.stackChanged.connect(lambda: handler_calls.append(True)) + w.undo_stack.push(ParameterValueCommand("scale", 1.0, 2.0)) + assert len(handler_calls) >= 1 + + def test_fit_complete_triggers_action_enable_via_signal_chain( + self, widget_with_model, monkeypatch + ): + """Connect stackChanged to a GuiManager-style handler that reads the + active stack and calls actionUndo.setEnabled. After fitComplete the + action must end up enabled.""" + w = widget_with_model + self._prepare_fit_complete(w, monkeypatch) + + action_undo = MagicMock() + action_redo = MagicMock() + + def update_undo_redo_actions(): + action_undo.setEnabled(w.undo_stack.can_undo()) + action_redo.setEnabled(w.undo_stack.can_redo()) + + w.undo_stack.stackChanged.connect(update_undo_redo_actions) + + action_undo.reset_mock() + action_redo.reset_mock() + + result = self._make_fit_result(w) + w.fitComplete(result) + + assert action_undo.setEnabled.call_args_list[-1] == ((True,),) + assert w.undo_stack.can_undo() + + def test_fit_complete_emits_communicator_fallback(self, widget_with_model, monkeypatch): + """fitComplete must emit undoRedoUpdateSignal as a reliable fallback + for refreshing the GUI action state.""" + w = widget_with_model + self._prepare_fit_complete(w, monkeypatch) + + signal_calls = [] + w.communicator.undoRedoUpdateSignal.connect(lambda: signal_calls.append(True)) + + result = self._make_fit_result(w) + w.fitComplete(result) + + assert len(signal_calls) >= 1 + + +# --------------------------------------------------------------------------- +# TestOptionsWidgetSetState +# --------------------------------------------------------------------------- + +class TestOptionsWidgetSetState: + + def test_setState_updates_q_range(self, widget_with_model): + """setState should update Q range text fields.""" + ow = widget_with_model.options_widget + ow.setState(0.002, 2.0, 300, False, 0) + q_min, q_max, npts, log_pts, weighting = ow.state() + assert q_min == pytest.approx(0.002) + assert q_max == pytest.approx(2.0) + assert npts == 300 + + def test_setState_updates_log_checkbox(self, widget_with_model): + """setState should toggle the log checkbox.""" + ow = widget_with_model.options_widget + ow.setState(0.001, 0.5, 150, True, 0) + assert ow.chkLogData.isChecked() + ow.setState(0.001, 0.5, 150, False, 0) + assert not ow.chkLogData.isChecked() + + def test_setState_updates_weighting(self, widget_with_model): + """setState should update the weighting radio buttons.""" + ow = widget_with_model.options_widget + for w_val in range(4): + ow.setState(0.001, 0.5, 150, False, w_val) + assert ow.weighting == w_val + + def test_setState_does_not_emit_signals(self, widget_with_model): + """setState blocks model signals to avoid feedback loops.""" + ow = widget_with_model.options_widget + signal_received = [] + ow.model.dataChanged.connect(lambda *a: signal_received.append(1)) + ow.setState(0.003, 3.0, 500, True, 1) + assert len(signal_received) == 0 + + +# --------------------------------------------------------------------------- +# TestUndoRedoRoundTrip +# --------------------------------------------------------------------------- + +class TestUndoRedoRoundTrip: + """End-to-end: make a change, undo it, redo it, verify state at each step.""" + + def test_param_value_round_trip(self, widget_with_model): + """Change radius → undo → redo should return to changed value.""" + w = widget_with_model + param_col = w.lstParams.itemDelegate().param_value + for row in range(w._model_model.rowCount()): + if w._model_model.item(row, 0).text() == "radius": + original = w.logic.kernel_module.getParam("radius") + w._model_model.item(row, param_col).setText("42.0") + break + + assert w.logic.kernel_module.getParam("radius") == 42.0 + w.undo_stack.undo() + assert w.logic.kernel_module.getParam("radius") == original + w.undo_stack.redo() + assert w.logic.kernel_module.getParam("radius") == 42.0 + + def test_multiple_undo_redo(self, widget_with_model): + """Multiple parameter edits: undo all, then redo all.""" + w = widget_with_model + param_col = w.lstParams.itemDelegate().param_value + + # Find radius and length rows + radius_row = length_row = None + for row in range(w._model_model.rowCount()): + name = w._model_model.item(row, 0).text() + if name == "radius": + radius_row = row + elif name == "length": + length_row = row + + if radius_row is None or length_row is None: + pytest.skip("Could not find both radius and length parameters") + + original_radius = w.logic.kernel_module.getParam("radius") + original_length = w.logic.kernel_module.getParam("length") + + # Edit radius then length + w._model_model.item(radius_row, param_col).setText("11.0") + w._model_model.item(length_row, param_col).setText("22.0") + + assert w.logic.kernel_module.getParam("radius") == 11.0 + assert w.logic.kernel_module.getParam("length") == 22.0 + + # Undo length + w.undo_stack.undo() + assert w.logic.kernel_module.getParam("length") == original_length + assert w.logic.kernel_module.getParam("radius") == 11.0 + + # Undo radius + w.undo_stack.undo() + assert w.logic.kernel_module.getParam("radius") == original_radius + + # Redo radius + w.undo_stack.redo() + assert w.logic.kernel_module.getParam("radius") == 11.0 + + # Redo length + w.undo_stack.redo() + assert w.logic.kernel_module.getParam("length") == 22.0 + + def test_suppressed_context_prevents_push(self, widget_with_model): + """Changes inside suppressed() must not appear on the stack.""" + w = widget_with_model + param_col = w.lstParams.itemDelegate().param_value + w.undo_stack.clear() + + with w.undo_stack.suppressed(): + for row in range(w._model_model.rowCount()): + if w._model_model.item(row, 0).text() == "radius": + w._model_model.item(row, param_col).setText("77.0") + break + + assert not w.undo_stack.can_undo() + + def test_stackChanged_signal_emitted_on_param_edit(self, widget_with_model): + """stackChanged should fire when a parameter edit pushes a command.""" + w = widget_with_model + received = [] + w.undo_stack.stackChanged.connect(lambda: received.append(1)) + + param_col = w.lstParams.itemDelegate().param_value + for row in range(w._model_model.rowCount()): + if w._model_model.item(row, 0).text() == "radius": + w._model_model.item(row, param_col).setText("55.0") + break + + assert len(received) >= 1 diff --git a/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoTest.py b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoTest.py new file mode 100644 index 0000000000..c8b6ca0117 --- /dev/null +++ b/src/sas/qtgui/Perspectives/Fitting/UnitTesting/UndoRedoTest.py @@ -0,0 +1,794 @@ +"""Unit tests for UndoRedo.py — UndoCommand subclasses and UndoStack. + +Tests focus on single-tab fitting scenarios. The FittingWidget dependency +is fully mocked; no real fitting window is opened. + +Test organisation: + TestUndoCommand — abstract base behaviour + TestParameterValueCommand — value change, coalescing + TestParameterMinMaxCommand — bound change + TestModelSelectionCommand — model triple + param restore + TestFitOptionsCommand — options dict round-trip + TestSmearingOptionsCommand — smearing state round-trip + TestCheckboxToggleCommand — checkbox flip + TestFitResultCommand — pre/post fit snapshot + TestCompoundCommand — atomic group, ordering + TestUndoStack — push / undo / redo / depth / suppression + TestUndoStackFailure — failure dialog, reset_to_last_good +""" +import logging +import time +from unittest.mock import MagicMock + +import pytest + +from sas.qtgui.Perspectives.Fitting.UndoRedo import ( + CheckboxToggleCommand, + FitOptionsCommand, + FitResultCommand, + ModelSelectionCommand, + ParameterMinMaxCommand, + ParameterValueCommand, + SmearingOptionsCommand, +) +from sas.qtgui.Perspectives.UndoRedo import ( + CompoundCommand, + UndoCommand, + UndoStack, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def mock_message_box(monkeypatch): + """Suppress all QMessageBox dialogs for the entire test module.""" + mock = MagicMock() + monkeypatch.setattr( + "sas.qtgui.Perspectives.UndoRedo.QtWidgets.QMessageBox", mock + ) + return mock + + +@pytest.fixture +def widget(): + """Minimal mock that satisfies the widget protocol used by commands.""" + w = MagicMock() + # kernel_module.details must support real dict operations for MinMax tests + w.logic.kernel_module.details = {} + return w + + +@pytest.fixture +def stack(widget, qapp): + """An UndoStack wired to a mock widget.""" + return UndoStack(widget) + + +# --------------------------------------------------------------------------- +# UndoCommand — abstract base +# --------------------------------------------------------------------------- + +class TestUndoCommand: + + def test_undo_raises_not_implemented(self): + cmd = UndoCommand("test") + with pytest.raises(NotImplementedError): + cmd.undo(None) + + def test_redo_raises_not_implemented(self): + cmd = UndoCommand("test") + with pytest.raises(NotImplementedError): + cmd.redo(None) + + def test_can_merge_false_by_default(self): + cmd = UndoCommand("test") + assert cmd.can_merge(UndoCommand("other")) is False + + def test_merge_raises_not_implemented(self): + cmd = UndoCommand("test") + with pytest.raises(NotImplementedError): + cmd.merge(UndoCommand("other")) + + def test_description_stored(self): + cmd = UndoCommand("my action") + assert cmd.description == "my action" + + def test_timestamp_is_recent(self): + t_before = time.monotonic() + cmd = UndoCommand("t") + t_after = time.monotonic() + assert t_before <= cmd.timestamp <= t_after + + def test_repr_contains_class_and_description(self): + cmd = UndoCommand("hello") + assert "UndoCommand" in repr(cmd) + assert "hello" in repr(cmd) + + +# --------------------------------------------------------------------------- +# ParameterValueCommand +# --------------------------------------------------------------------------- + +class TestParameterValueCommand: + + def test_undo_sets_old_value(self, widget): + cmd = ParameterValueCommand("radius", 5.0, 10.0) + cmd.undo(widget) + widget.logic.kernel_module.setParam.assert_called_once_with("radius", 5.0) + widget._update_model_param_value.assert_called_once_with("radius", 5.0) + + def test_redo_sets_new_value(self, widget): + cmd = ParameterValueCommand("radius", 5.0, 10.0) + cmd.redo(widget) + widget.logic.kernel_module.setParam.assert_called_once_with("radius", 10.0) + widget._update_model_param_value.assert_called_once_with("radius", 10.0) + + def test_can_merge_same_param(self): + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + cmd2 = ParameterValueCommand("radius", 2.0, 3.0) + assert cmd1.can_merge(cmd2) is True + + def test_cannot_merge_different_param(self): + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + cmd2 = ParameterValueCommand("length", 1.0, 2.0) + assert cmd1.can_merge(cmd2) is False + + def test_cannot_merge_different_type(self): + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + assert cmd1.can_merge(UndoCommand("other")) is False + + def test_merge_spans_full_range(self): + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + cmd2 = ParameterValueCommand("radius", 2.0, 3.0) + merged = cmd1.merge(cmd2) + assert merged._old_val == 1.0 + assert merged._new_val == 3.0 + assert merged.param_name == "radius" + + def test_merge_carries_latest_timestamp(self): + # The coalescing window is measured from the most recent edit, so the + # merged command must carry the later command's timestamp. + cmd1 = ParameterValueCommand("r", 1.0, 2.0) + cmd2 = ParameterValueCommand("r", 2.0, 3.0) + cmd2.timestamp = cmd1.timestamp + 1.0 + merged = cmd1.merge(cmd2) + assert merged.timestamp == cmd2.timestamp + + def test_param_name_property(self): + cmd = ParameterValueCommand("scale", 1.0, 2.0) + assert cmd.param_name == "scale" + + def test_cannot_merge_stale_same_param_command(self): + """Edits to the same parameter outside the coalescing window must not merge.""" + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + cmd2 = ParameterValueCommand("radius", 2.0, 3.0) + # Backdate cmd1 so the gap exceeds the coalescing window + cmd1.timestamp = ( + cmd2.timestamp - ParameterValueCommand._COALESCE_WINDOW_SECONDS - 1.0 + ) + assert cmd1.can_merge(cmd2) is False + + +# --------------------------------------------------------------------------- +# ParameterMinMaxCommand +# --------------------------------------------------------------------------- + +class TestParameterMinMaxCommand: + + def test_undo_restores_min(self, widget): + widget.logic.kernel_module.details = {"radius": [None, 0.0, 100.0]} + cmd = ParameterMinMaxCommand("radius", "min", 0.0, 5.0) + cmd.undo(widget) + assert widget.logic.kernel_module.details["radius"][1] == 0.0 + widget._update_model_param_limit.assert_called_once_with( + "radius", "min", 0.0 + ) + + def test_redo_applies_new_min(self, widget): + widget.logic.kernel_module.details = {"radius": [None, 0.0, 100.0]} + cmd = ParameterMinMaxCommand("radius", "min", 0.0, 5.0) + cmd.redo(widget) + assert widget.logic.kernel_module.details["radius"][1] == 5.0 + widget._update_model_param_limit.assert_called_once_with( + "radius", "min", 5.0 + ) + + def test_undo_restores_max(self, widget): + widget.logic.kernel_module.details = {"length": [None, 0.0, 100.0]} + cmd = ParameterMinMaxCommand("length", "max", 100.0, 200.0) + cmd.undo(widget) + assert widget.logic.kernel_module.details["length"][2] == 100.0 + + def test_redo_applies_new_max(self, widget): + widget.logic.kernel_module.details = {"length": [None, 0.0, 100.0]} + cmd = ParameterMinMaxCommand("length", "max", 100.0, 200.0) + cmd.redo(widget) + assert widget.logic.kernel_module.details["length"][2] == 200.0 + + def test_invalid_bound_raises(self): + with pytest.raises(AssertionError): + ParameterMinMaxCommand("r", "middle", 1.0, 2.0) + + +# --------------------------------------------------------------------------- +# ModelSelectionCommand +# --------------------------------------------------------------------------- + +class TestModelSelectionCommand: + + def test_undo_restores_old_triple_and_params(self, widget): + old_triple = ("Shape", "sphere", "None") + new_triple = ("Shape", "cylinder", "None") + cmd = ModelSelectionCommand( + old_triple, new_triple, {"radius": 1.0}, {"length": 5.0} + ) + cmd.undo(widget) + widget._restore_model_selection.assert_called_once_with( + old_triple, {"radius": 1.0} + ) + + def test_redo_restores_new_triple_and_params(self, widget): + old_triple = ("Shape", "sphere", "None") + new_triple = ("Shape", "cylinder", "None") + cmd = ModelSelectionCommand( + old_triple, new_triple, {"radius": 1.0}, {"length": 5.0} + ) + cmd.redo(widget) + widget._restore_model_selection.assert_called_once_with( + new_triple, {"length": 5.0} + ) + + def test_params_are_deep_copied(self): + params = {"radius": 5.0} + cmd = ModelSelectionCommand( + ("A", "B", "C"), ("D", "E", "F"), params, {} + ) + params["radius"] = 999.0 # mutate original + assert cmd._old_params["radius"] == 5.0 # snapshot unchanged + + def test_description_includes_new_model_name(self): + cmd = ModelSelectionCommand( + ("A", "sphere", "C"), ("A", "cylinder", "C"), {}, {} + ) + assert "cylinder" in cmd.description + + +# --------------------------------------------------------------------------- +# FitOptionsCommand +# --------------------------------------------------------------------------- + +class TestFitOptionsCommand: + + def test_undo_applies_old_options(self, widget): + old = {"q_min": 0.01, "q_max": 0.5} + new = {"q_min": 0.05, "q_max": 1.0} + cmd = FitOptionsCommand(old, new) + cmd.undo(widget) + widget._apply_fit_options.assert_called_once_with(old) + + def test_redo_applies_new_options(self, widget): + old = {"q_min": 0.01, "q_max": 0.5} + new = {"q_min": 0.05, "q_max": 1.0} + cmd = FitOptionsCommand(old, new) + cmd.redo(widget) + widget._apply_fit_options.assert_called_once_with(new) + + def test_options_are_deep_copied(self): + opts = {"q_min": 0.01} + cmd = FitOptionsCommand(opts, {}) + opts["q_min"] = 999.0 + assert cmd._old_options["q_min"] == 0.01 + + +# --------------------------------------------------------------------------- +# SmearingOptionsCommand +# --------------------------------------------------------------------------- + +class TestSmearingOptionsCommand: + + def test_undo_applies_old_state(self, widget): + cmd = SmearingOptionsCommand({"type": "none"}, {"type": "pinhole"}) + cmd.undo(widget) + widget._apply_smearing_state.assert_called_once_with({"type": "none"}) + + def test_redo_applies_new_state(self, widget): + cmd = SmearingOptionsCommand({"type": "none"}, {"type": "pinhole"}) + cmd.redo(widget) + widget._apply_smearing_state.assert_called_once_with({"type": "pinhole"}) + + def test_state_is_deep_copied(self): + state = {"type": "none"} + cmd = SmearingOptionsCommand(state, {}) + state["type"] = "changed" + assert cmd._old_state["type"] == "none" + + +# --------------------------------------------------------------------------- +# CheckboxToggleCommand +# --------------------------------------------------------------------------- + +class TestCheckboxToggleCommand: + + def test_undo_sets_old_bool(self, widget): + widget.chkPolydispersity = MagicMock() + cmd = CheckboxToggleCommand("chkPolydispersity", False, True) + cmd.undo(widget) + widget.chkPolydispersity.setChecked.assert_called_once_with(False) + + def test_redo_sets_new_bool(self, widget): + widget.chkPolydispersity = MagicMock() + cmd = CheckboxToggleCommand("chkPolydispersity", False, True) + cmd.redo(widget) + widget.chkPolydispersity.setChecked.assert_called_once_with(True) + + def test_description_includes_checkbox_id(self): + cmd = CheckboxToggleCommand("chkMagnetism", False, True) + assert "chkMagnetism" in cmd.description + + +# --------------------------------------------------------------------------- +# FitResultCommand +# --------------------------------------------------------------------------- + +class TestFitResultCommand: + + @staticmethod + def _snapshot(main, poly=None, magnet=None): + return {"main": main, "poly": poly or {}, "magnet": magnet or {}} + + def test_undo_restores_pre_fit_params(self, widget): + old = self._snapshot({"radius": 1.0}) + cmd = FitResultCommand(old, self._snapshot({"radius": 2.5})) + cmd.undo(widget) + widget._restore_fit_result_snapshot.assert_called_once_with(old) + + def test_redo_restores_post_fit_params(self, widget): + new = self._snapshot({"radius": 2.5}) + cmd = FitResultCommand(self._snapshot({"radius": 1.0}), new) + cmd.redo(widget) + widget._restore_fit_result_snapshot.assert_called_once_with(new) + + def test_undo_restores_poly_and_magnet(self, widget): + old = self._snapshot( + {"radius": 1.0}, poly={"radius.width": 0.1}, magnet={"up_frac_i": 0.5} + ) + cmd = FitResultCommand(old, self._snapshot({"radius": 2.5})) + cmd.undo(widget) + widget._restore_fit_result_snapshot.assert_called_once_with(old) + + def test_params_are_deep_copied(self): + old = self._snapshot({"radius": 1.0}) + cmd = FitResultCommand(old, self._snapshot({})) + old["main"]["radius"] = 999.0 + assert cmd._old_snapshot["main"]["radius"] == 1.0 + + def test_description_is_fit_result(self): + cmd = FitResultCommand(self._snapshot({}), self._snapshot({})) + assert cmd.description == "Fit result" + + +# --------------------------------------------------------------------------- +# CompoundCommand +# --------------------------------------------------------------------------- + +class TestCompoundCommand: + + def test_undo_executes_in_reverse_order(self, widget): + order = [] + cmd1 = MagicMock(spec=UndoCommand) + cmd1.undo.side_effect = lambda w: order.append("cmd1") + cmd2 = MagicMock(spec=UndoCommand) + cmd2.undo.side_effect = lambda w: order.append("cmd2") + CompoundCommand([cmd1, cmd2], "c").undo(widget) + assert order == ["cmd2", "cmd1"] + + def test_redo_executes_in_forward_order(self, widget): + order = [] + cmd1 = MagicMock(spec=UndoCommand) + cmd1.redo.side_effect = lambda w: order.append("cmd1") + cmd2 = MagicMock(spec=UndoCommand) + cmd2.redo.side_effect = lambda w: order.append("cmd2") + CompoundCommand([cmd1, cmd2], "c").redo(widget) + assert order == ["cmd1", "cmd2"] + + def test_commands_property_returns_copy(self): + cmd1 = MagicMock(spec=UndoCommand) + compound = CompoundCommand([cmd1], "c") + copy = compound.commands + copy.append(MagicMock()) + assert len(compound.commands) == 1 # original unaffected + + def test_description_falls_back_to_first_command(self): + cmd1 = MagicMock(spec=UndoCommand) + cmd1.description = "Select model" + compound = CompoundCommand([cmd1]) + assert compound.description == "Select model" + + def test_explicit_description_takes_precedence(self): + compound = CompoundCommand([], "Override") + assert compound.description == "Override" + + +# --------------------------------------------------------------------------- +# UndoStack — normal operation +# --------------------------------------------------------------------------- + +def _make_cmd(description: str = "cmd") -> MagicMock: + """Return a MagicMock UndoCommand that disallows coalescing.""" + cmd = MagicMock(spec=UndoCommand) + cmd.can_merge.return_value = False + cmd.description = description + return cmd + + +class TestUndoStack: + + def test_initial_state_empty(self, stack): + assert not stack.can_undo() + assert not stack.can_redo() + assert stack.undo_text() == "Undo" + assert stack.redo_text() == "Redo" + + def test_push_enables_undo(self, stack): + stack.push(_make_cmd()) + assert stack.can_undo() + assert not stack.can_redo() + + def test_undo_moves_to_redo(self, stack, widget): + stack.push(_make_cmd()) + stack.undo() + assert not stack.can_undo() + assert stack.can_redo() + + def test_redo_moves_back_to_undo(self, stack, widget): + stack.push(_make_cmd()) + stack.undo() + stack.redo() + assert stack.can_undo() + assert not stack.can_redo() + + def test_push_after_undo_clears_redo(self, stack, widget): + stack.push(_make_cmd("a")) + stack.undo() + assert stack.can_redo() + stack.push(_make_cmd("b")) + assert not stack.can_redo() + + def test_undo_calls_cmd_undo_with_widget(self, stack, widget): + cmd = _make_cmd() + stack.push(cmd) + stack.undo() + cmd.undo.assert_called_once_with(widget) + + def test_redo_calls_cmd_redo_with_widget(self, stack, widget): + cmd = _make_cmd() + stack.push(cmd) + stack.undo() + stack.redo() + cmd.redo.assert_called_once_with(widget) + + def test_stackChanged_emitted_on_push(self, stack): + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.push(_make_cmd()) + assert len(received) == 1 + + def test_stackChanged_emitted_on_undo(self, stack): + stack.push(_make_cmd()) + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.undo() + assert len(received) == 1 + + def test_stackChanged_emitted_on_redo(self, stack): + stack.push(_make_cmd()) + stack.undo() + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.redo() + assert len(received) == 1 + + def test_stackChanged_emitted_on_clear(self, stack): + stack.push(_make_cmd()) + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.clear() + assert len(received) == 1 + + def test_stackChanged_emitted_on_set_enabled(self, stack): + """set_enabled must emit stackChanged so UI actions refresh.""" + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.set_enabled(False) + assert len(received) == 1 + stack.set_enabled(True) + assert len(received) == 2 + + def test_can_undo_false_when_disabled(self, stack): + """can_undo must return False when the stack is disabled.""" + stack.push(_make_cmd()) + assert stack.can_undo() + stack.set_enabled(False) + assert not stack.can_undo() + # Re-enable: can_undo should return True again + stack.set_enabled(True) + assert stack.can_undo() + + def test_can_redo_false_when_disabled(self, stack, widget): + """can_redo must return False when the stack is disabled.""" + stack.push(_make_cmd()) + stack.undo() + assert stack.can_redo() + stack.set_enabled(False) + assert not stack.can_redo() + stack.set_enabled(True) + assert stack.can_redo() + + def test_max_depth_drops_oldest_entries(self, stack): + stack._max_depth = 3 + cmds = [_make_cmd(f"c{i}") for i in range(5)] + for cmd in cmds: + stack.push(cmd) + assert len(stack._undo_stack) == 3 + # Three newest survive + assert cmds[2] in stack._undo_stack + assert cmds[3] in stack._undo_stack + assert cmds[4] in stack._undo_stack + # Two oldest are gone + assert cmds[0] not in stack._undo_stack + assert cmds[1] not in stack._undo_stack + + def test_clear_empties_both_stacks(self, stack, widget): + stack.push(_make_cmd()) + stack.undo() + stack.clear() + assert not stack.can_undo() + assert not stack.can_redo() + + def test_suppressed_prevents_push(self, stack): + with stack.suppressed(): + stack.push(_make_cmd()) + assert not stack.can_undo() + + def test_suppressed_restores_previous_enabled_state(self, stack): + stack.set_enabled(True) + with stack.suppressed(): + assert not stack._enabled + assert stack._enabled + + def test_suppressed_restores_false_if_was_false(self, stack): + stack.set_enabled(False) + with stack.suppressed(): + pass + assert not stack._enabled # restored to False + + def test_set_enabled_false_prevents_push(self, stack): + stack.set_enabled(False) + stack.push(_make_cmd()) + assert not stack.can_undo() + + def test_disabled_prevents_undo(self, stack, widget): + """A disabled stack must not execute undo.""" + stack.push(_make_cmd()) + stack.set_enabled(False) + stack.undo() + # Command still on internal stack, but can_undo() reports False while disabled + assert len(stack._undo_stack) == 1 + assert not stack.can_undo() + assert not stack.can_redo() + + def test_disabled_prevents_redo(self, stack, widget): + """A disabled stack must not execute redo.""" + stack.push(_make_cmd()) + stack.undo() + stack.set_enabled(False) + stack.redo() + # Command still on internal redo stack, but can_redo() reports False while disabled + assert not stack.can_undo() + assert not stack.can_redo() + assert len(stack._redo_stack) == 1 + + def test_replaying_prevents_recursive_push(self, stack): + """Commands pushed during undo replay must be silently dropped.""" + inner = _make_cmd("inner") + + def undo_side_effect(w): + stack.push(inner) # simulate handler firing during replay + + outer = _make_cmd("outer") + outer.undo.side_effect = undo_side_effect + stack.push(outer) + stack.undo() + # outer moved to redo; inner was blocked → undo stack empty + assert not stack.can_undo() + inner.undo.assert_not_called() + + def test_coalescing_merges_same_param_commands(self, stack): + cmd1 = ParameterValueCommand("radius", 1.0, 2.0) + cmd2 = ParameterValueCommand("radius", 2.0, 3.0) + stack.push(cmd1) + stack.push(cmd2) + assert len(stack._undo_stack) == 1 + merged = stack._undo_stack[0] + assert merged._old_val == 1.0 + assert merged._new_val == 3.0 + + def test_no_coalescing_for_different_params(self, stack): + stack.push(ParameterValueCommand("radius", 1.0, 2.0)) + stack.push(ParameterValueCommand("length", 1.0, 2.0)) + assert len(stack._undo_stack) == 2 + + def test_undo_text_includes_description(self, stack): + stack.push(_make_cmd("Change radius")) + assert stack.undo_text() == "Undo Change radius" + + def test_redo_text_includes_description(self, stack, widget): + stack.push(_make_cmd("Change radius")) + stack.undo() + assert stack.redo_text() == "Redo Change radius" + + def test_undo_noop_when_empty(self, stack): + stack.undo() # must not raise + + def test_redo_noop_when_empty(self, stack): + stack.redo() # must not raise + + def test_max_depth_read_from_config(self, qapp): + """UndoStack reads UNDO_STACK_MAX_DEPTH from the SasView config.""" + from sas import config as sas_config + assert hasattr(sas_config, "UNDO_STACK_MAX_DEPTH") + widget = MagicMock() + s = UndoStack(widget) + assert s._max_depth == sas_config.UNDO_STACK_MAX_DEPTH + + def test_successful_undo_auto_snapshots_widget_state(self, stack, widget): + """A successful undo must save _last_good_state from the widget.""" + widget._get_parameter_dict.return_value = {"scale": 1.0} + stack.push(_make_cmd()) + assert stack._last_good_state is None + stack.undo() + assert stack._last_good_state == {"scale": 1.0} + + def test_successful_redo_auto_snapshots_widget_state(self, stack, widget): + """A successful redo must save _last_good_state from the widget.""" + stack.push(_make_cmd()) + stack.undo() + stack._last_good_state = None # reset so we can detect the auto-save + widget._get_parameter_dict.return_value = {"scale": 2.0} + stack.redo() + assert stack._last_good_state == {"scale": 2.0} + + def test_auto_snapshot_silent_if_no_get_parameter_dict(self, stack, widget): + """Missing _get_parameter_dict on the widget must not raise.""" + widget._get_parameter_dict.side_effect = AttributeError("no such method") + stack.push(_make_cmd()) + stack.undo() # must not raise + assert stack._last_good_state is None + + +# --------------------------------------------------------------------------- +# UndoStack — failure resilience +# --------------------------------------------------------------------------- + +class TestUndoStackFailure: + + @pytest.fixture + def failing_undo_cmd(self): + cmd = _make_cmd("failing command") + cmd.undo.side_effect = RuntimeError("simulated undo failure") + return cmd + + @pytest.fixture + def failing_redo_cmd(self): + cmd = _make_cmd("failing command") + cmd.redo.side_effect = RuntimeError("simulated redo failure") + return cmd + + def test_undo_failure_logs_warning( + self, stack, failing_undo_cmd, caplog + ): + stack.push(failing_undo_cmd) + with caplog.at_level( + logging.WARNING, + logger="sas.qtgui.Perspectives.UndoRedo", + ): + stack.undo() + assert any("undo failed" in m.lower() for m in caplog.messages) + + def test_redo_failure_logs_warning( + self, stack, widget, failing_redo_cmd, caplog + ): + good = _make_cmd("good") + stack.push(good) + stack.undo() + # Replace redo stack entry with a failing one + stack._redo_stack[-1] = failing_redo_cmd + with caplog.at_level( + logging.WARNING, + logger="sas.qtgui.Perspectives.UndoRedo", + ): + stack.redo() + assert any("redo failed" in m.lower() for m in caplog.messages) + + def test_failure_increments_consecutive_counter( + self, stack, failing_undo_cmd + ): + stack.push(failing_undo_cmd) + stack.undo() + assert stack._consecutive_failures == 1 + + def test_undo_failure_preserves_undo_stack(self, stack, failing_undo_cmd): + """A failing undo must leave the command on the undo stack.""" + stack.push(failing_undo_cmd) + stack.undo() + assert stack.can_undo() # command still in undo history + assert not stack.can_redo() # nothing moved to redo + + def test_redo_failure_preserves_redo_stack(self, stack, widget, failing_redo_cmd): + """A failing redo must leave the command on the redo stack.""" + stack.push(_make_cmd("good")) + stack.undo() + stack._redo_stack[-1] = failing_redo_cmd + stack.redo() + assert not stack.can_undo() # nothing moved to undo + assert stack.can_redo() # command still in redo history + + def test_success_resets_consecutive_counter(self, stack, widget): + stack._consecutive_failures = 5 + stack.push(_make_cmd()) + stack.undo() + assert stack._consecutive_failures == 0 + + def test_failure_emits_stackChanged(self, stack, failing_undo_cmd): + stack.push(failing_undo_cmd) + received = [] + stack.stackChanged.connect(lambda: received.append(1)) + stack.undo() + assert len(received) == 1 # stackChanged fired even on failure + + def test_save_and_reset_to_last_good_state(self, stack, widget): + stack.save_last_good_state({"radius": 5.0}) + stack.reset_to_last_good() + widget._restore_parameter_values.assert_called_once_with({"radius": 5.0}) + + def test_reset_without_snapshot_logs_warning(self, stack, caplog): + with caplog.at_level( + logging.WARNING, + logger="sas.qtgui.Perspectives.UndoRedo", + ): + stack.reset_to_last_good() + assert any("no snapshot" in m.lower() for m in caplog.messages) + + def test_save_last_good_state_copies_dict(self, stack): + state = {"radius": 1.0} + stack.save_last_good_state(state) + state["radius"] = 999.0 + assert stack._last_good_state["radius"] == 1.0 + + def test_failure_dialog_shown(self, stack, failing_undo_cmd, mock_message_box): + stack.push(failing_undo_cmd) + stack.undo() + mock_message_box.assert_called_once() + + def test_repeated_failures_offer_reset_button( + self, stack, widget, mock_message_box + ): + """After 2 consecutive failures with a snapshot, reset button appears.""" + stack.save_last_good_state({"radius": 1.0}) + stack._consecutive_failures = 1 # pre-seed counter + + failing = _make_cmd("fail again") + failing.undo.side_effect = RuntimeError("boom") + stack.push(failing) + stack.undo() + + # addButton called with "Reset to Last Good State" + instance = mock_message_box.return_value + button_labels = [ + call.args[0] + for call in instance.addButton.call_args_list + if call.args + ] + assert any("Reset" in label for label in button_labels) diff --git a/src/sas/qtgui/Perspectives/Fitting/ViewDelegate.py b/src/sas/qtgui/Perspectives/Fitting/ViewDelegate.py index 4c888c7abe..b298a7b4d6 100644 --- a/src/sas/qtgui/Perspectives/Fitting/ViewDelegate.py +++ b/src/sas/qtgui/Perspectives/Fitting/ViewDelegate.py @@ -27,7 +27,6 @@ def fancyColumns(self): def addErrorColumn(self): """ Modify local column pointers - Note: the reverse is never required! """ self.param_property=0 self.param_value=1 @@ -36,6 +35,18 @@ def addErrorColumn(self): self.param_max=4 self.param_unit=5 + def removeErrorColumn(self): + """ + Restore column pointers to the no-error layout (reverse of addErrorColumn). + Used when a fit is undone and the error column is removed. + """ + self.param_error=-1 + self.param_property=0 + self.param_value=1 + self.param_min=2 + self.param_max=3 + self.param_unit=4 + def paint(self, painter, option, index): """ Overwrite generic painter for certain columns @@ -152,7 +163,6 @@ def columnDict(self): def addErrorColumn(self): """ Modify local column pointers - Note: the reverse is never required! """ self.poly_parameter = 0 self.poly_pd = 1 @@ -164,6 +174,21 @@ def addErrorColumn(self): self.poly_function = 7 self.poly_filename = 8 + def removeErrorColumn(self): + """ + Restore column pointers to the no-error layout (reverse of addErrorColumn). + Used when a fit is undone and the error column is removed. + """ + self.poly_parameter = 0 + self.poly_pd = 1 + self.poly_error = None + self.poly_min = 2 + self.poly_max = 3 + self.poly_npts = 4 + self.poly_nsigs = 5 + self.poly_function = 6 + self.poly_filename = 7 + def createEditor(self, widget, option, index): # Remember the current choice if not index.isValid(): @@ -243,7 +268,6 @@ def editableParameters(self): def addErrorColumn(self): """ Modify local column pointers - Note: the reverse is never required! """ self.mag_parameter = 0 self.mag_value = 1 @@ -252,6 +276,18 @@ def addErrorColumn(self): self.mag_max = 4 self.mag_unit = 5 + def removeErrorColumn(self): + """ + Restore column pointers to the no-error layout (reverse of addErrorColumn). + Used when a fit is undone and the error column is removed. + """ + self.mag_parameter = 0 + self.mag_value = 1 + self.mag_min = 2 + self.mag_max = 3 + self.mag_unit = 4 + self.mag_error = -1 + def createEditor(self, widget, option, index): # Remember the current choice current_text = index.data() diff --git a/src/sas/qtgui/Perspectives/Invariant/InvariantPerspective.py b/src/sas/qtgui/Perspectives/Invariant/InvariantPerspective.py index 4734d6561f..f4df6c42ce 100644 --- a/src/sas/qtgui/Perspectives/Invariant/InvariantPerspective.py +++ b/src/sas/qtgui/Perspectives/Invariant/InvariantPerspective.py @@ -21,6 +21,7 @@ # local from ..perspective import Perspective +from ..UndoRedo import DictSnapshotCommand, UndoStack from .InvariantDetails import DetailsDialog from .InvariantUtils import WIDGETS, safe_float from .UI.TabbedInvariantUI import Ui_tabbedInvariantUI @@ -67,6 +68,10 @@ def __init__(self, parent=None): self.detailsDialog = DetailsDialog(self) self.detailsDialog.cmdOK.clicked.connect(self.enableStatus) + # Undo/redo infrastructure + self._undo_stack_obj = UndoStack(self) + self._undo_baseline: dict | None = None + # Data self._data: Data1D | None = None self._path: str = "" @@ -90,6 +95,11 @@ def __init__(self, parent=None): # Go to the first data item self.mapper.toFirst() + # Undo/redo: connect capture signals after all widgets are initialized + self._setupUndoConnections() + self._rebaseline_undo_state() + self._undo_stack_obj.clear() + @property def extrapolation_parameters(self) -> ExtrapolationParameters | None: if self._data is not None: @@ -105,6 +115,108 @@ def extrapolation_parameters(self) -> ExtrapolationParameters | None: else: return None + @property + def undo_stack(self): + """Return the undo stack for this perspective. + + Overrides ``Perspective.undo_stack`` (which returns ``None``). + """ + return self._undo_stack_obj + + # ------------------------------------------------------------------ + # Undo/redo contract methods + # ------------------------------------------------------------------ + + def _get_parameter_dict(self) -> dict: + """Capture current input-only state (excludes computed outputs). + + Called by ``UndoStack._auto_snapshot()`` for recovery snapshots. + """ + return { + "background": self.txtBackgd.text(), + "scale": self.txtScale.text(), + "contrast": self.txtContrast.text(), + "contrast_err": self.txtContrastErr.text(), + "porod": self.txtPorodCst.text(), + "porod_err": self.txtPorodCstErr.text(), + "volfrac1": self.txtVolFrac1.text(), + "volfrac1_err": self.txtVolFrac1Err.text(), + "enable_contrast": self.rbContrast.isChecked(), + "enable_volfrac": self.rbVolFrac.isChecked(), + "guinier_end_low_q_ex": self.txtGuinierEnd_ex.text(), + "porod_start_high_q_ex": self.txtPorodStart_ex.text(), + "porod_end_high_q_ex": self.txtPorodEnd_ex.text(), + "power_low_q_ex": self.txtLowQPower_ex.text(), + "power_high_q_ex": self.txtHighQPower_ex.text(), + "enable_low_q_ex": self.chkLowQ_ex.isChecked(), + "enable_high_q_ex": self.chkHighQ_ex.isChecked(), + "low_q_guinier_ex": self.rbLowQGuinier_ex.isChecked(), + "low_q_power_ex": self.rbLowQPower_ex.isChecked(), + "low_q_fit_ex": self.rbLowQFit_ex.isChecked(), + "low_q_fix_ex": self.rbLowQFix_ex.isChecked(), + "high_q_fit_ex": self.rbHighQFit_ex.isChecked(), + "high_q_fix_ex": self.rbHighQFix_ex.isChecked(), + } + + def _restore_parameter_values(self, state: dict) -> None: + """Apply a state dict to all input widgets. + + Called by ``DictSnapshotCommand.undo/redo`` and + ``UndoStack.reset_to_last_good()``. + Must run inside ``undo_stack.suppressed()`` to avoid creating + spurious undo entries from widget signal handlers. + """ + with self._undo_stack_obj.suppressed(): + self.txtBackgd.setText(str(state.get("background", "0.0"))) + self.txtScale.setText(str(state.get("scale", "1.0"))) + self.txtContrast.setText(str(state.get("contrast", ""))) + self.txtContrastErr.setText(str(state.get("contrast_err", ""))) + self.txtPorodCst.setText(str(state.get("porod", ""))) + self.txtPorodCstErr.setText(str(state.get("porod_err", ""))) + self.txtVolFrac1.setText(str(state.get("volfrac1", ""))) + self.txtVolFrac1Err.setText(str(state.get("volfrac1_err", ""))) + self.rbContrast.setChecked(bool(state.get("enable_contrast", True))) + self.rbVolFrac.setChecked(bool(state.get("enable_volfrac", False))) + self.txtGuinierEnd_ex.setText(str(state.get("guinier_end_low_q_ex", ""))) + self.txtPorodStart_ex.setText(str(state.get("porod_start_high_q_ex", ""))) + self.txtPorodEnd_ex.setText(str(state.get("porod_end_high_q_ex", ""))) + self.txtLowQPower_ex.setText(str(state.get("power_low_q_ex", str(DEFAULT_POWER_VALUE)))) + self.txtHighQPower_ex.setText(str(state.get("power_high_q_ex", str(DEFAULT_POWER_VALUE)))) + self.chkLowQ_ex.setChecked(bool(state.get("enable_low_q_ex", False))) + self.chkHighQ_ex.setChecked(bool(state.get("enable_high_q_ex", False))) + self.rbLowQGuinier_ex.setChecked(bool(state.get("low_q_guinier_ex", False))) + self.rbLowQPower_ex.setChecked(bool(state.get("low_q_power_ex", False))) + self.rbLowQFit_ex.setChecked(bool(state.get("low_q_fit_ex", False))) + self.rbLowQFix_ex.setChecked(bool(state.get("low_q_fix_ex", False))) + self.rbHighQFit_ex.setChecked(bool(state.get("high_q_fit_ex", False))) + self.rbHighQFix_ex.setChecked(bool(state.get("high_q_fix_ex", False))) + self.update_from_model() + + def _captureUndoState(self, description: str = "Change") -> None: + """Push a DictSnapshotCommand if the current state differs from the baseline. + + Call on each committed user edit (``editingFinished``, ``toggled``, + slider release, etc.). The baseline-diff approach ensures cascade + signals (radio button groups, checkbox chains) produce at most one + undo entry per user action. + """ + # Guard: no baseline means we are still initializing + if self._undo_baseline is None: + return + new_state = self._get_parameter_dict() + if new_state != self._undo_baseline: + self._undo_stack_obj.push( + DictSnapshotCommand(self._undo_baseline, new_state, description) + ) + self._undo_baseline = new_state + + def _rebaseline_undo_state(self) -> None: + """Update the undo baseline to the current state without pushing a command. + + Call after programmatic state changes (project load, data swap, etc.) + """ + self._undo_baseline = self._get_parameter_dict() + def initialize_variables(self) -> None: """Initialize class variables.""" @@ -165,6 +277,49 @@ def setup_validators(self) -> None: self.txtLowQPower_ex.setValidator(GuiUtils.DoubleValidator()) self.txtHighQPower_ex.setValidator(GuiUtils.DoubleValidator()) + def _setupUndoConnections(self) -> None: + """Connect undo-capture signals for all user-editable input widgets. + + Uses ``editingFinished`` for text edits (commit boundary) and + ``toggled`` for checkboxes / radio buttons. The baseline-diff in + ``_captureUndoState`` ensures cascade signals produce one entry. + """ + # Text edits — commit boundary (not per-keystroke) + text_edits = [ + self.txtBackgd, self.txtScale, + self.txtContrast, self.txtContrastErr, + self.txtPorodCst, self.txtPorodCstErr, + self.txtVolFrac1, self.txtVolFrac1Err, + self.txtGuinierEnd_ex, self.txtPorodStart_ex, self.txtPorodEnd_ex, + self.txtLowQPower_ex, self.txtHighQPower_ex, + ] + for te in text_edits: + te.editingFinished.connect(lambda desc="Edit value": self._captureUndoState(desc)) + + # Checkboxes + self.chkLowQ_ex.toggled.connect( + lambda _: self._captureUndoState("Toggle low-Q extrapolation")) + self.chkHighQ_ex.toggled.connect( + lambda _: self._captureUndoState("Toggle high-Q extrapolation")) + + # Radio button groups + self.rbContrast.toggled.connect( + lambda _: self._captureUndoState("Toggle contrast/volfrac")) + self.rbVolFrac.toggled.connect( + lambda _: self._captureUndoState("Toggle contrast/volfrac")) + self.rbLowQGuinier_ex.toggled.connect( + lambda _: self._captureUndoState("Change low-Q extrapolation type")) + self.rbLowQPower_ex.toggled.connect( + lambda _: self._captureUndoState("Change low-Q extrapolation type")) + self.rbLowQFit_ex.toggled.connect( + lambda _: self._captureUndoState("Change low-Q power mode")) + self.rbLowQFix_ex.toggled.connect( + lambda _: self._captureUndoState("Change low-Q power mode")) + self.rbHighQFit_ex.toggled.connect( + lambda _: self._captureUndoState("Change high-Q power mode")) + self.rbHighQFix_ex.toggled.connect( + lambda _: self._captureUndoState("Change high-Q power mode")) + def setup_default_enablement(self) -> None: """Setup the default enablement of the widgets.""" self.tabWidget.setCurrentIndex(0) @@ -319,6 +474,9 @@ def calculate_invariant(self) -> None: # Modify the Calculate button to indicate background process self.enable_calculation(enabled=False, display="Calculating...") + # Disable undo during calculation + self._undo_stack_obj.set_enabled(False) + # Send the calculations to separate thread. d = threads.deferToThread(self.calculate_thread, extrapolation) @@ -329,6 +487,7 @@ def calculate_invariant(self) -> None: def on_calculation_failed(self, reason: Failure) -> None: """Handle calculation failure.""" logger.error(f"calculation failed: {reason}") + self._undo_stack_obj.set_enabled(True) self.check_status() def deferredPlot(self, model: QtGui.QStandardItemModel, extrapolation: str | None = None) -> None: @@ -340,6 +499,10 @@ def deferredPlot(self, model: QtGui.QStandardItemModel, extrapolation: str | Non reactor.callFromThread(lambda: self._manager.filesWidget.newPlot()) self.extrapolation_made = False + # Re-enable undo after calculation completes + self._undo_stack_obj.set_enabled(True) + self._rebaseline_undo_state() + self.check_status() def check_status(self) -> None: @@ -905,6 +1068,7 @@ def on_extrapolation_slider_changed(self, state: ExtrapolationParameters) -> Non self.model.setItem(WIDGETS.W_POROD_START_EX, QtGui.QStandardItem(format_string % state.point_2)) self.model.setItem(WIDGETS.W_POROD_END_EX, QtGui.QStandardItem(format_string % state.point_3)) self.correct_extrapolation_values() + self._captureUndoState("Change extrapolation slider") def on_extrapolation_text_editing(self) -> None: """Handle when user edits any of the extrapolation text boxes.""" @@ -991,11 +1155,12 @@ def correct_extrapolation_values(self) -> None: data_q_max = float(self.format_sig_fig(self._data.x.max())) # Actual data max messages = [] - # block signals to avoid recursive calls + # block signals and suppress undo to avoid spurious entries with ( QtCore.QSignalBlocker(self.txtGuinierEnd_ex), QtCore.QSignalBlocker(self.txtPorodStart_ex), QtCore.QSignalBlocker(self.txtPorodEnd_ex), + self._undo_stack_obj.suppressed(), ): # start by updating p2 as it is used in multiple checks if self.validity_flags[0]: # point_2 <= data_q_min @@ -1447,6 +1612,12 @@ def fractional_position(f): self.mapper.toFirst() + # Clear undo stack + re-baseline for fresh data + with self._undo_stack_obj.suppressed(): + pass # no programmatic changes to undo-suppress here + self._undo_stack_obj.clear() + self._rebaseline_undo_state() + def removeData(self, data_list: list | None = None) -> None: """Remove the existing data reference from the Invariant Perspective.""" if not data_list or self._model_item not in data_list: @@ -1501,6 +1672,9 @@ def updateGuiFromFile(self, data: Data1D = None) -> None: data=self._data, background=self._background, scale=self._scale ) + # Re-baseline undo state after loading data from file + self._rebaseline_undo_state() + def serializeAll(self) -> dict: """ Serialize the invariant state so data can be saved. @@ -1589,42 +1763,47 @@ def updateFromParameters(self, params: dict) -> None: msg = "Invariant.updateFromParameters expects a dictionary" raise TypeError(f"{msg}: {c_name} received") - # Assign values to 'Invariant' tab inputs - use defaults if not found - self.txtTotalQMin.setText(str(params.get("total_q_min", "0.0"))) - self.txtTotalQMax.setText(str(params.get("total_q_max", "0.0"))) - self.txtVolFract.setText(str(params.get("vol_fraction", ""))) - self.txtVolFractErr.setText(str(params.get("vol_fraction_err", ""))) - self.txtContrastOut.setText(str(params.get("contrast_out", ""))) - self.txtContrastOutErr.setText(str(params.get("contrast_out_err", ""))) - self.txtSpecSurf.setText(str(params.get("specific_surface", ""))) - self.txtSpecSurfErr.setText(str(params.get("specific_surface_err", ""))) - self.txtInvariantTot.setText(str(params.get("invariant_total", ""))) - self.txtInvariantTotErr.setText(str(params.get("invariant_total_err", ""))) - self.txtBackgd.setText(str(params.get("background", "0.0"))) - self.txtScale.setText(str(params.get("scale", "1.0"))) - self.txtContrast.setText(str(params.get("contrast", ""))) - self.txtContrastErr.setText(str(params.get("contrast_err", "0.0"))) - self.txtPorodCst.setText(str(params.get("porod", "0.0"))) - self.txtVolFrac1.setText(str(params.get("volfrac1", "0.0"))) - self.txtVolFrac1Err.setText(str(params.get("volfrac1_err", "0.0"))) - - # Extrapolation tab - use new _ex suffix variables - self.txtGuinierEnd_ex.setText(str(params.get("guinier_end_low_q_ex", ""))) - self.txtPorodStart_ex.setText(str(params.get("porod_start_high_q_ex", ""))) - self.txtPorodEnd_ex.setText(str(params.get("porod_end_high_q_ex", ""))) - self.txtLowQPower_ex.setText(str(params.get("power_low_q_ex", DEFAULT_POWER_VALUE))) - self.txtHighQPower_ex.setText(str(params.get("power_high_q_ex", DEFAULT_POWER_VALUE))) - self.chkLowQ_ex.setChecked(params.get("enable_low_q_ex", False)) - self.chkHighQ_ex.setChecked(params.get("enable_high_q_ex", False)) - self.rbLowQGuinier_ex.setChecked(params.get("low_q_guinier_ex", False)) - self.rbLowQPower_ex.setChecked(params.get("low_q_power_ex", False)) - self.rbLowQFit_ex.setChecked(params.get("low_q_fit_ex", False)) - self.rbLowQFix_ex.setChecked(params.get("low_q_fix_ex", False)) - self.rbHighQFit_ex.setChecked(params.get("high_q_fit_ex", False)) - self.rbHighQFix_ex.setChecked(params.get("high_q_fix_ex", False)) - - # Update once all inputs are changed - self.update_from_model() + # Suppress undo during programmatic load — handled after restore + with self._undo_stack_obj.suppressed(): + # Assign values to 'Invariant' tab inputs - use defaults if not found + self.txtTotalQMin.setText(str(params.get("total_q_min", "0.0"))) + self.txtTotalQMax.setText(str(params.get("total_q_max", "0.0"))) + self.txtVolFract.setText(str(params.get("vol_fraction", ""))) + self.txtVolFractErr.setText(str(params.get("vol_fraction_err", ""))) + self.txtContrastOut.setText(str(params.get("contrast_out", ""))) + self.txtContrastOutErr.setText(str(params.get("contrast_out_err", ""))) + self.txtSpecSurf.setText(str(params.get("specific_surface", ""))) + self.txtSpecSurfErr.setText(str(params.get("specific_surface_err", ""))) + self.txtInvariantTot.setText(str(params.get("invariant_total", ""))) + self.txtInvariantTotErr.setText(str(params.get("invariant_total_err", ""))) + self.txtBackgd.setText(str(params.get("background", "0.0"))) + self.txtScale.setText(str(params.get("scale", "1.0"))) + self.txtContrast.setText(str(params.get("contrast", ""))) + self.txtContrastErr.setText(str(params.get("contrast_err", "0.0"))) + self.txtPorodCst.setText(str(params.get("porod", "0.0"))) + self.txtVolFrac1.setText(str(params.get("volfrac1", "0.0"))) + self.txtVolFrac1Err.setText(str(params.get("volfrac1_err", "0.0"))) + + # Extrapolation tab - use new _ex suffix variables + self.txtGuinierEnd_ex.setText(str(params.get("guinier_end_low_q_ex", ""))) + self.txtPorodStart_ex.setText(str(params.get("porod_start_high_q_ex", ""))) + self.txtPorodEnd_ex.setText(str(params.get("porod_end_high_q_ex", ""))) + self.txtLowQPower_ex.setText(str(params.get("power_low_q_ex", DEFAULT_POWER_VALUE))) + self.txtHighQPower_ex.setText(str(params.get("power_high_q_ex", DEFAULT_POWER_VALUE))) + self.chkLowQ_ex.setChecked(params.get("enable_low_q_ex", False)) + self.chkHighQ_ex.setChecked(params.get("enable_high_q_ex", False)) + self.rbLowQGuinier_ex.setChecked(params.get("low_q_guinier_ex", False)) + self.rbLowQPower_ex.setChecked(params.get("low_q_power_ex", False)) + self.rbLowQFit_ex.setChecked(params.get("low_q_fit_ex", False)) + self.rbLowQFix_ex.setChecked(params.get("low_q_fix_ex", False)) + self.rbHighQFit_ex.setChecked(params.get("high_q_fit_ex", False)) + self.rbHighQFix_ex.setChecked(params.get("high_q_fix_ex", False)) + + # Update once all inputs are changed + self.update_from_model() + + # Re-baseline after programmatic restore + self._rebaseline_undo_state() def allowBatch(self) -> bool: """Tell the caller that we don't accept multiple data instances.""" diff --git a/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py b/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py index e72204c29d..b3b66244dd 100644 --- a/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py +++ b/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py @@ -316,6 +316,17 @@ def currentTab(self) -> InversionWidget: """Returns the tab widget currently shown.""" return self.currentWidget() + @property + def undo_stack(self): + """Return undo stack for the currently selected Inversion tab. + + Delegates to the active ``InversionWidget``'s stack, mirroring + ``FittingPerspective.undo_stack``. Returns ``None`` if the current + widget is not an ``InversionWidget``. + """ + w = self.currentWidget() + return getattr(w, "undo_stack", None) if isinstance(w, InversionWidget) else None + def currentTabDataId(self) -> list: """Returns the data ID of the current tab.""" tab_id = [item.logic.data.id for item in self.currentTab.results] diff --git a/src/sas/qtgui/Perspectives/Inversion/InversionWidget.py b/src/sas/qtgui/Perspectives/Inversion/InversionWidget.py index a8ac5d3c6f..74d767aaea 100644 --- a/src/sas/qtgui/Perspectives/Inversion/InversionWidget.py +++ b/src/sas/qtgui/Perspectives/Inversion/InversionWidget.py @@ -11,6 +11,7 @@ from sas.qtgui.Perspectives.Inversion.InversionLogic import InversionLogic from sas.qtgui.Perspectives.Inversion.Thread import CalcBatchPr, CalcPr, EstimateNT from sas.qtgui.Perspectives.Inversion.UI.TabbedInversionUI import Ui_PrInversion +from sas.qtgui.Perspectives.UndoRedo import DictSnapshotCommand, UndoStack from sas.qtgui.Plotting.PlotterData import Data1D, DataRole from sas.qtgui.Utilities.GridPanel import BatchInversionOutputPanel from sas.qtgui.Utilities.GuiUtils import ( @@ -111,10 +112,17 @@ def __init__(self, window, parent=None, data=None, tab_id=1, tab_name=''): self.batch_dict: dict[str, Any] | None = None self.input_boxes = [self.noOfTermsInput, self.regularizationConstantInput, self.maxDistanceInput, - self.minQInput, self.maxQInput, self.slitHeightInput, self.slitHeightInput] + self.minQInput, self.maxQInput, self.slitHeightInput, self.slitWidthInput] + + # Undo/redo infrastructure + self.undo_stack = UndoStack(self) + self._undo_baseline: dict | None = None self.updateGuiValues() self.events() + self._setupUndoConnections() + self._rebaseline_undo_state() + self.undo_stack.clear() def initResult(self) -> InversionResult: logic = InversionLogic() @@ -136,8 +144,8 @@ def events(self): self.noOfTermsSuggestionButton.clicked.connect(self.applyNumTermsEstimate) self.regConstantSuggestionButton.clicked.connect(self.applyRegConstantEstimate) self.explorerButton.clicked.connect(self.openExplorerWindow) - self.estimateBgd.pressed.connect(self.handleBackgroundModeChange) - self.manualBgd.pressed.connect(self.handleBackgroundModeChange) + self.estimateBgd.toggled.connect(self.handleBackgroundModeChange) + self.manualBgd.toggled.connect(self.handleBackgroundModeChange) self.dataList.currentIndexChanged.connect(self.handleCurrentDataChanged) self.helpButton.clicked.connect(self.onHelp) self.removeButton.clicked.connect(self.handleRemove) @@ -146,6 +154,109 @@ def events(self): for input_box in self.input_boxes: input_box.editingFinished.connect(self.startEstimateParameters) + # ------------------------------------------------------------------ + # Undo/redo contract methods + # ------------------------------------------------------------------ + + def _get_parameter_dict(self) -> dict: + """Capture current input state from the UI widgets. + + Reads directly from widgets so undo works even when no data is + loaded (``updateParams()`` only runs when ``data_is_loaded``). + + Called by ``UndoStack._auto_snapshot()`` for recovery snapshots. + """ + try: + return { + "noOfTerms": int(self.noOfTermsInput.text() or NUMBER_OF_TERMS), + "alpha": float(self.regularizationConstantInput.text() or REGULARIZATION), + "dmax": float(self.maxDistanceInput.text() or MAX_DIST), + "est_bck": self.estimateBgd.isChecked(), + "background": float(self.backgroundInput.text() or BACKGROUND_INPUT), + "q_min": float(self.minQInput.text() or Q_MIN_INPUT), + "q_max": float(self.maxQInput.text() or Q_MAX_INPUT), + "slit_height": float(self.slitHeightInput.text() or 0.0), + "slit_width": float(self.slitWidthInput.text() or 0.0), + } + except (ValueError, TypeError): + # Widgets have incomplete/invalid text — fall back to calculator + calc = self.currentResult.calculator + return { + "noOfTerms": calc.noOfTerms, + "alpha": calc.alpha, + "dmax": calc.dmax, + "est_bck": calc.est_bck, + "background": calc.background, + "q_min": calc.q_min, + "q_max": calc.q_max, + "slit_height": calc.slit_height, + "slit_width": calc.slit_width, + } + + def _restore_parameter_values(self, state: dict) -> None: + """Apply a state dict to the calculator, then refresh the GUI. + + Called by ``DictSnapshotCommand.undo/redo`` and + ``UndoStack.reset_to_last_good()``. + Must run inside ``undo_stack.suppressed()`` to prevent + ``updateGuiValues`` signal handlers from pushing entries. + """ + calc = self.currentResult.calculator + with self.undo_stack.suppressed(): + calc.noOfTerms = int(state.get("noOfTerms", NUMBER_OF_TERMS)) + calc.alpha = float(state.get("alpha", REGULARIZATION)) + calc.dmax = float(state.get("dmax", MAX_DIST)) + calc.est_bck = bool(state.get("est_bck", False)) + calc.background = float(state.get("background", BACKGROUND_INPUT)) + calc.q_min = float(state.get("q_min", Q_MIN_INPUT)) + calc.q_max = float(state.get("q_max", Q_MAX_INPUT)) + calc.slit_height = float(state.get("slit_height", 0.0)) + calc.slit_width = float(state.get("slit_width", 0.0)) + self.updateGuiValues() + self.enableButtons() + + def _captureUndoState(self, description: str = "Change") -> None: + """Push a DictSnapshotCommand if the current state differs from the baseline. + + Call after each committed user edit. The baseline-diff approach + ensures cascading signals produce at most one undo entry. + """ + if self._undo_baseline is None: + return + new_state = self._get_parameter_dict() + if new_state != self._undo_baseline: + self.undo_stack.push( + DictSnapshotCommand(self._undo_baseline, new_state, description) + ) + self._undo_baseline = new_state + + def _rebaseline_undo_state(self) -> None: + """Update the undo baseline without pushing a command. + + Call after programmatic state changes. + """ + self._undo_baseline = self._get_parameter_dict() + + def _setupUndoConnections(self) -> None: + """Connect undo-capture signals for user-editable input widgets. + + The ``editingFinished`` connections are added *in addition to* + the existing ``editingFinished → startEstimateParameters`` + connections made in ``events()``. + """ + # Input boxes — editingFinished is already connected to + # startEstimateParameters. We connect _captureUndoState as a + # second slot so it fires after updateParams() writes to calculator. + for input_box in self.input_boxes: + input_box.editingFinished.connect( + lambda desc="Edit parameter": self._captureUndoState(desc)) + + # Background mode toggle + self.estimateBgd.toggled.connect( + lambda _: self._captureUndoState("Toggle background mode")) + self.manualBgd.toggled.connect( + lambda _: self._captureUndoState("Toggle background mode")) + def handleRemove(self): if self.currentResult.data_plot: self.currentResult.data_plot.slider_low_q_input = [] @@ -157,39 +268,51 @@ def handleRemove(self): self.dataList.removeItem(to_remove) _ = self.results.pop(to_remove) # If there's no results left, we need an empty one. - if len(self.results) == 0: - self.results.append(self.initResult()) - self.clearGuiValues() - self.enableButtons() - self.updateGuiValues() + with self.undo_stack.suppressed(): + if len(self.results) == 0: + self.results.append(self.initResult()) + self.clearGuiValues() + self.enableButtons() + self.updateGuiValues() + # Fresh data context → clear undo history + self.undo_stack.clear() + self._rebaseline_undo_state() def handleCurrentDataChanged(self): # This event might get called before there is anything in the results list. But we can't update the GUI without # errors. if len(self.results) != 0: - self.updateGuiValues() - self.startEstimateParameters() + with self.undo_stack.suppressed(): + self.updateGuiValues() + self.startEstimateParameters() + # Data-combo switch: clear history + re-baseline (v1 behavior) + self.undo_stack.clear() + self._rebaseline_undo_state() # TODO: Need to verify type hint for data. def updateTab(self, data: HashableStandardItem | list[HashableStandardItem], tab_id: int): self.tab_id = tab_id - if isinstance(data, list): - self.results = [] - self.dataList.clear() - for datum_item in data: - new_result = self.initResult() - new_result.logic.data = datum_item - datum = dataFromItem(datum_item) - self.dataList.addItem(datum.name) - self.results.append(new_result) - else: - self.currentData = data - self.dataList.clear() - self.dataList.addItem(self.currentData.name) - self.dataList.setCurrentIndex(0) - self.updateGuiValues() - self.enableButtons() - self.startEstimateParameters() + with self.undo_stack.suppressed(): + if isinstance(data, list): + self.results = [] + self.dataList.clear() + for datum_item in data: + new_result = self.initResult() + new_result.logic.data = datum_item + datum = dataFromItem(datum_item) + self.dataList.addItem(datum.name) + self.results.append(new_result) + else: + self.currentData = data + self.dataList.clear() + self.dataList.addItem(self.currentData.name) + self.dataList.setCurrentIndex(0) + self.updateGuiValues() + self.enableButtons() + self.startEstimateParameters() + # Fresh data → clear undo history + re-baseline + self.undo_stack.clear() + self._rebaseline_undo_state() @property def is_batch(self) -> bool: @@ -229,10 +352,11 @@ def q_min(self) -> float: def onNewData(self): # FIXME: This mutates the data even for other perspectives. - self.currentResult.logic.add_errors() - qmin, qmax = self.currentResult.logic.computeDataRange() - self.currentResult.calculator.q_min = qmin - self.currentResult.calculator.q_max = qmax + with self.undo_stack.suppressed(): + self.currentResult.logic.add_errors() + qmin, qmax = self.currentResult.logic.computeDataRange() + self.currentResult.calculator.q_min = qmin + self.currentResult.calculator.q_max = qmax # TODO: Probably change this name. def enableButtons(self): @@ -256,6 +380,11 @@ def enableButtons(self): self.noOfTermsSuggestionButton.setEnabled(self.currentResult.logic.data_is_loaded and not self.isCalculating) def updateGuiValues(self): + """Refresh GUI from calculator state. Suppresses undo entries.""" + with self.undo_stack.suppressed(): + self._updateGuiValuesImpl() + + def _updateGuiValuesImpl(self): # TODO: This won't work for batch at the moment. current_calculator = self.currentResult.calculator @@ -295,13 +424,13 @@ def updateGuiValues(self): self.sigmaPosFractionValue.setText(format_float(out.pos_err)) def clearGuiValues(self): + with self.undo_stack.suppressed(): + value_text_boxes = [*self.input_boxes, self.rgValue, self.iQ0Value, self.backgroundValue, self.backgroundInput, + self.computationTimeValue, self.chiDofValue, self.oscillationValue, self.posFractionValue, + self.sigmaPosFractionValue] - value_text_boxes = [*self.input_boxes, self.rgValue, self.iQ0Value, self.backgroundValue, self.backgroundInput, - self.computationTimeValue, self.chiDofValue, self.oscillationValue, self.posFractionValue, - self.sigmaPosFractionValue] - - for text_box in value_text_boxes: - text_box.setText("") + for text_box in value_text_boxes: + text_box.setText("") def setupValidators(self): """Apply validators to editable line edits""" @@ -322,13 +451,13 @@ def acceptsData(self) -> bool: def threadError(self, error: str): logger.error(error) + self.undo_stack.set_enabled(True) # TODO: No function to stop calculation yet. # TODO: These parameters should really be type hinted (or rolled into a dataclass?) def calculationCompleted(self, out, cov, pr, elapsed): - # TODO: Placeholder. Just output the numbers for now. Later the result - # should be plotted. self.isCalculating = False + self.undo_stack.set_enabled(True) self.enableButtons() calculator = self.currentResult.calculator # TODO: Some of these probably don't need to be here. @@ -380,12 +509,14 @@ def updateMinQ(self, new_q_min: float): new_q_min = max([min(calculator.x), new_q_min]) calculator.q_min = new_q_min self.updateGuiValues() + self._captureUndoState("Change q min") def updateMaxQ(self, new_q_max: float): calculator = self.currentResult.calculator new_q_max = min([max(calculator.x), new_q_max]) calculator.q_max = new_q_max self.updateGuiValues() + self._captureUndoState("Change q max") def updateParams(self): # TODO: No validators so this will break if they can't be converted to @@ -410,6 +541,7 @@ def updateParams(self): def startThread(self): self.updateParams() self.isCalculating = True + self.undo_stack.set_enabled(False) self.enableButtons() # TODO: Calc thread should be declared beforehand. @@ -430,6 +562,7 @@ def startThread(self): def startThreadAll(self): self.updateParams() self.isCalculating = True + self.undo_stack.set_enabled(False) self.enableButtons() self.dataList.setCurrentIndex(0) @@ -444,6 +577,7 @@ def startThreadAll(self): def batchCalculationComplete(self, totalElapsed): self.isCalculating = False + self.undo_stack.set_enabled(True) self.enableButtons() self.calculationComplete.emit() @@ -476,20 +610,30 @@ def estimateAvailable(self): def endEstimateParameters(self, nterms, alpha, message, elapsed): self.currentResult.estimated_parameters = EstimatedParameters(alpha, nterms) + self.undo_stack.set_enabled(True) + self._rebaseline_undo_state() self.estimationComplete.emit() def applyRegConstantEstimate(self): - self.currentResult.calculator.alpha = self.currentResult.estimated_parameters.reg_constant - self.updateGuiValues() + calculator = self.currentResult.calculator + calculator.alpha = self.currentResult.estimated_parameters.reg_constant + # Suppress undo during updateGuiValues (which fires editingFinished) + with self.undo_stack.suppressed(): + self.updateGuiValues() + self._captureUndoState("Apply reg constant estimate") def applyNumTermsEstimate(self): - self.currentResult.calculator.noOfTerms = self.currentResult.estimated_parameters.nterms - self.updateGuiValues() + calculator = self.currentResult.calculator + calculator.noOfTerms = self.currentResult.estimated_parameters.nterms + with self.undo_stack.suppressed(): + self.updateGuiValues() + self._captureUndoState("Apply num terms estimate") def startEstimateParameters(self): if not self.currentResult.logic.data_is_loaded: return self.updateParams() + self.undo_stack.set_enabled(False) estimation_thread = EstimateNT( self.currentResult.calculator, self.currentResult.calculator.nfunc, @@ -512,6 +656,8 @@ def handleBackgroundModeChange(self): self.backgroundInput.setEnabled(True) elif self.manualBgd.isChecked(): self.backgroundInput.setEnabled(False) + # Capture undo after toggle completes (undo capture is connected to + # toggled signal on both radio buttons via _setupUndoConnections) def serialiseResult(self, result: InversionResult) -> dict[str, Any]: return { @@ -540,21 +686,24 @@ def serialiseResult(self, result: InversionResult) -> dict[str, Any]: } def updateFromParameters(self, params: dict[str, Any]): - result = self.currentResult - result.calculator.alpha = params['alpha'] - result.calculator.background = params['background'] - result.calculator.chi2 = params['chi2'] - result.calculator.cov = params['cov'] - result.calculator.dmax = params['d_max'] - result.calculator.elapsed = params['elapsed'] - result.calculator.est_bck = params['est_bck'] - result.calculator.out = params['out'] - result.calculator.q_max = params['q_max'] - result.calculator.q_min = params['q_min'] - result.calculator.slit_height = params['slit_height'] - result.calculator.slit_width = params['slit_width'] - result.calculator.suggested_alpha = params['suggested_alpha'] - result.outputs = get_outputs(result.calculator, params['elapsed']) + with self.undo_stack.suppressed(): + result = self.currentResult + result.calculator.alpha = params['alpha'] + result.calculator.background = params['background'] + result.calculator.chi2 = params['chi2'] + result.calculator.cov = params['cov'] + result.calculator.dmax = params['d_max'] + result.calculator.elapsed = params['elapsed'] + result.calculator.est_bck = params['est_bck'] + result.calculator.out = params['out'] + result.calculator.q_max = params['q_max'] + result.calculator.q_min = params['q_min'] + result.calculator.slit_height = params['slit_height'] + result.calculator.slit_width = params['slit_width'] + result.calculator.suggested_alpha = params['suggested_alpha'] + result.outputs = get_outputs(result.calculator, params['elapsed']) + self.updateGuiValues() + self._rebaseline_undo_state() def getPage(self) -> dict[str, Any]: # FIXME: Doesn't work on batch. diff --git a/src/sas/qtgui/Perspectives/ParticleEditor/DesignWindow.py b/src/sas/qtgui/Perspectives/ParticleEditor/DesignWindow.py index 21d05636f0..2716ff4a20 100644 --- a/src/sas/qtgui/Perspectives/ParticleEditor/DesignWindow.py +++ b/src/sas/qtgui/Perspectives/ParticleEditor/DesignWindow.py @@ -397,6 +397,13 @@ def qSampling(self) -> QSample: return QSample(min_q, max_q, n_samples, is_log) +particle_editor_window = None +def show_particle_editor(): + global particle_editor_window + + particle_editor_window = DesignWindow() + particle_editor_window.show() + def main(): """ Demo/testing window""" diff --git a/src/sas/qtgui/Perspectives/SizeDistribution/SizeDistributionPerspective.py b/src/sas/qtgui/Perspectives/SizeDistribution/SizeDistributionPerspective.py index 2efdbc1923..3cde365836 100644 --- a/src/sas/qtgui/Perspectives/SizeDistribution/SizeDistributionPerspective.py +++ b/src/sas/qtgui/Perspectives/SizeDistribution/SizeDistributionPerspective.py @@ -22,6 +22,7 @@ from sas.qtgui.Perspectives.SizeDistribution.UI.SizeDistributionUI import ( Ui_SizeDistribution, ) +from sas.qtgui.Perspectives.UndoRedo import DictSnapshotCommand, UndoStack from sas.qtgui.Plotting.PlotterData import Data1D from sas.qtgui.Utilities import GuiUtils @@ -75,6 +76,10 @@ def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: # The window should not close self._allowClose: bool = False + # Undo/redo infrastructure + self._undo_stack_obj = UndoStack(self) + self._undo_baseline: dict | None = None + self._data: LoadData1D | None = None self._path: str = "" self.fit_thread: SizeDistributionThread | None = None @@ -100,6 +105,11 @@ def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: # Set base window state self.setupWindow() + # Undo/redo: baseline after full initialization + self._setupUndoConnections() + self._rebaseline_undo_state() + self._undo_stack_obj.clear() + ###################################################################### # Base Perspective Class Definitions @@ -124,6 +134,14 @@ def isSerializable(self) -> bool: """Tell the caller that this perspective writes its state.""" return True + @property + def undo_stack(self): + """Return the undo stack for this perspective. + + Overrides ``Perspective.undo_stack`` (which returns ``None``). + """ + return self._undo_stack_obj + def closeEvent(self, event: QtGui.QCloseEvent) -> None: """Overwrite QDialog close method to allow for custom widget close.""" # Close report widgets before closing/minimizing main widget @@ -139,6 +157,176 @@ def closeEvent(self, event: QtGui.QCloseEvent) -> None: # Maybe we should just minimize self.setWindowState(QtCore.Qt.WindowMinimized) + # ------------------------------------------------------------------ + # Undo/redo contract methods + # ------------------------------------------------------------------ + + def _get_parameter_dict(self) -> dict: + """Capture current input-only state from widgets. + + Mirrors ``getState()`` but also includes weighting and power-fit + radio buttons not captured by the existing serializer. + """ + return { + "range_q_min": self.txtMinRange.text(), + "range_q_max": self.txtMaxRange.text(), + "aspect_ratio": self.txtAspectRatio.text(), + "d_min": self.txtMinDiameter.text(), + "d_max": self.txtMaxDiameter.text(), + "num_d_bins": self.txtBinsDiameter.text(), + "log_binning": self.chkLogBinning.isChecked(), + "contrast": self.txtContrast.text(), + "sky_background": self.txtSkyBackgd.text(), + "num_iterations": self.txtIterations.text(), + "background": self.txtBackgd.text(), + "subtract_low_q": self.chkLowQ.isChecked(), + "power_low_q": self.txtPowerLowQ.text(), + "scale_low_q": self.txtScaleLowQ.text(), + "wgt_factor": self.txtWgtFactor.text(), + "wgt_percent": self.txtWgtPercent.text(), + "weight_1": self.rbWeighting1.isChecked(), + "weight_2": self.rbWeighting2.isChecked(), + "weight_3": self.rbWeighting3.isChecked(), + "weight_4": self.rbWeighting4.isChecked(), + "fit_power": self.rbFitPower.isChecked(), + "fix_power": self.rbFixPower.isChecked(), + } + + def _restore_parameter_values(self, state: dict) -> None: + """Apply a state dict to all input widgets AND the underlying model. + + Must update model items because ``setText`` alone does not sync + through a ``QDataWidgetMapper`` — the model must be explicitly + updated for data-load paths to see the correct values. + """ + with self._undo_stack_obj.suppressed(): + # Text widgets + self.txtMinRange.setText(str(state.get("range_q_min", "0.0"))) + self.txtMaxRange.setText(str(state.get("range_q_max", "0.0"))) + self.txtAspectRatio.setText(str(state.get("aspect_ratio", str(ASPECT_RATIO)))) + self.txtMinDiameter.setText(str(state.get("d_min", str(DIAMETER_MIN)))) + self.txtMaxDiameter.setText(str(state.get("d_max", str(DIAMETER_MAX)))) + self.txtBinsDiameter.setText(str(state.get("num_d_bins", str(NUM_DIAMETER_BINS)))) + self.txtContrast.setText(str(state.get("contrast", str(CONTRAST)))) + self.txtSkyBackgd.setText(str(state.get("sky_background", str(SKY_BACKGROUND)))) + self.txtIterations.setText(str(state.get("num_iterations", str(NUM_ITERATIONS)))) + self.txtBackgd.setText(str(state.get("background", str(BACKGROUND)))) + self.txtPowerLowQ.setText(str(state.get("power_low_q", str(POWER_LOW_Q)))) + self.txtScaleLowQ.setText(str(state.get("scale_low_q", str(SCALE_LOW_Q)))) + self.txtWgtFactor.setText(str(state.get("wgt_factor", str(WEIGHT_FACTOR)))) + self.txtWgtPercent.setText(str(state.get("wgt_percent", str(WEIGHT_PERCENT)))) + + # Checkboxes + self.chkLogBinning.setChecked(bool(state.get("log_binning", True))) + self.chkLowQ.setChecked(bool(state.get("subtract_low_q", False))) + + # Radio buttons (mutually exclusive groups) + if bool(state.get("weight_1", True)): + self.rbWeighting1.setChecked(True) + elif bool(state.get("weight_2", False)): + self.rbWeighting2.setChecked(True) + elif bool(state.get("weight_3", False)): + self.rbWeighting3.setChecked(True) + elif bool(state.get("weight_4", False)): + self.rbWeighting4.setChecked(True) + + if bool(state.get("fix_power", True)): + self.rbFixPower.setChecked(True) + elif bool(state.get("fit_power", False)): + self.rbFitPower.setChecked(True) + + # Sync model items (QDataWidgetMapper is one-way, model not + # updated by setText alone) + self.model.item(WIDGETS.W_QMIN).setText( + str(state.get("range_q_min", "0.0"))) + self.model.item(WIDGETS.W_QMAX).setText( + str(state.get("range_q_max", "0.0"))) + self.model.item(WIDGETS.W_ASPECT_RATIO).setText( + str(state.get("aspect_ratio", str(ASPECT_RATIO)))) + self.model.item(WIDGETS.W_DMIN).setText( + str(state.get("d_min", str(DIAMETER_MIN)))) + self.model.item(WIDGETS.W_DMAX).setText( + str(state.get("d_max", str(DIAMETER_MAX)))) + self.model.item(WIDGETS.W_DBINS).setText( + str(state.get("num_d_bins", str(NUM_DIAMETER_BINS)))) + self.model.item(WIDGETS.W_CONTRAST).setText( + str(state.get("contrast", str(CONTRAST)))) + self.model.item(WIDGETS.W_SKY_BACKGROUND).setText( + str(state.get("sky_background", str(SKY_BACKGROUND)))) + self.model.item(WIDGETS.W_NUM_ITERATIONS).setText( + str(state.get("num_iterations", str(NUM_ITERATIONS)))) + self.model.item(WIDGETS.W_BACKGROUND).setText( + str(state.get("background", str(BACKGROUND)))) + self.model.item(WIDGETS.W_POWER_LOW_Q).setText( + str(state.get("power_low_q", str(POWER_LOW_Q)))) + self.model.item(WIDGETS.W_SCALE_LOW_Q).setText( + str(state.get("scale_low_q", str(SCALE_LOW_Q)))) + self.model.item(WIDGETS.W_WEIGHT_FACTOR).setText( + str(state.get("wgt_factor", str(WEIGHT_FACTOR)))) + self.model.item(WIDGETS.W_WEIGHT_PERCENT).setText( + str(state.get("wgt_percent", str(WEIGHT_PERCENT)))) + self.model.item(WIDGETS.W_LOG_BINNING).setText( + str(state.get("log_binning", True)).lower()) + self.model.item(WIDGETS.W_SUBTRACT_LOW_Q).setText( + str(state.get("subtract_low_q", False)).lower()) + + # Update derived UI state (low Q enablement) + self.onLowQStateChanged( + QtCore.Qt.CheckState.Checked.value if bool(state.get("subtract_low_q", False)) + else QtCore.Qt.CheckState.Unchecked.value) + + def _captureUndoState(self, description: str = "Change") -> None: + """Push a DictSnapshotCommand if current state differs from baseline.""" + if self._undo_baseline is None: + return + new_state = self._get_parameter_dict() + if new_state != self._undo_baseline: + self._undo_stack_obj.push( + DictSnapshotCommand(self._undo_baseline, new_state, description) + ) + self._undo_baseline = new_state + + def _rebaseline_undo_state(self) -> None: + """Update undo baseline without pushing a command.""" + self._undo_baseline = self._get_parameter_dict() + + def _setupUndoConnections(self) -> None: + """Connect undo-capture signals for user-editable input widgets.""" + # Text edits — editingFinished as commit boundary + text_edits = [ + self.txtMinRange, self.txtMaxRange, + self.txtAspectRatio, self.txtMinDiameter, self.txtMaxDiameter, + self.txtBinsDiameter, self.txtContrast, + self.txtSkyBackgd, self.txtIterations, + self.txtBackgd, self.txtPowerLowQ, self.txtScaleLowQ, + self.txtWgtFactor, self.txtWgtPercent, + ] + for te in text_edits: + te.editingFinished.connect( + lambda desc="Edit value": self._captureUndoState(desc)) + + # Checkboxes + self.chkLogBinning.toggled.connect( + lambda _: self._captureUndoState("Toggle log binning")) + self.chkLowQ.toggled.connect( + lambda _: self._captureUndoState("Toggle subtract low-Q")) + + # Weighting radio buttons + self.rbWeighting1.toggled.connect( + lambda _: self._captureUndoState("Change weighting")) + self.rbWeighting2.toggled.connect( + lambda _: self._captureUndoState("Change weighting")) + self.rbWeighting3.toggled.connect( + lambda _: self._captureUndoState("Change weighting")) + self.rbWeighting4.toggled.connect( + lambda _: self._captureUndoState("Change weighting")) + + # Power fit/fix radio buttons + self.rbFitPower.toggled.connect( + lambda _: self._captureUndoState("Change power mode")) + self.rbFixPower.toggled.connect( + lambda _: self._captureUndoState("Change power mode")) + ###################################################################### # Initialization routines @@ -308,6 +496,7 @@ def help(self) -> None: def onQuickFit(self) -> None: """Perform a quick fit of the size distribution.""" self.is_calculating = True + self._undo_stack_obj.set_enabled(False) self.enableButtons() params = self.getMaxEntParams() @@ -324,6 +513,7 @@ def onQuickFit(self) -> None: def onFullFit(self) -> None: """Perform a full fit of the size distribution.""" self.is_calculating = True + self._undo_stack_obj.set_enabled(False) self.enableButtons() params = self.getMaxEntParams() @@ -343,7 +533,9 @@ def onRangeReset(self) -> None: qmax: float = 0.0 if self.logic.data_is_loaded: qmin, qmax = self.logic.computeDataRange() - self.updateQRange(qmin, qmax) + with self._undo_stack_obj.suppressed(): + self.updateQRange(qmin, qmax) + self._captureUndoState("Reset Q range") def onLowQStateChanged(self, state: int) -> None: """Slot for state change of the subtract power law checkbox.""" @@ -362,9 +554,11 @@ def onFitFlatBackground(self) -> None: if fit_result is None: return constant = fit_result[0] - self.txtBackgd.setText(f"{constant:5g}") + with self._undo_stack_obj.suppressed(): + self.txtBackgd.setText(f"{constant:5g}") self.updateBackground() self.plotData() + self._captureUndoState("Fit flat background") def onFitPowerLaw(self) -> None: """Fit background power law and update plot.""" @@ -377,7 +571,8 @@ def onFitPowerLaw(self) -> None: scale, power_fit = fit_result # by convention, the power is shown without a minus sign power = -1.0 * power_fit - self.txtPowerLowQ.setText(f"{power:5g}") + with self._undo_stack_obj.suppressed(): + self.txtPowerLowQ.setText(f"{power:5g}") else: # if the power should be fixed, pass the value from the input box _, _, power_fixed = self.getBackgroundParams() @@ -386,9 +581,11 @@ def onFitPowerLaw(self) -> None: return scale = fit_result[0] # update the scale - self.txtScaleLowQ.setText(f"{scale:5g}") + with self._undo_stack_obj.suppressed(): + self.txtScaleLowQ.setText(f"{scale:5g}") self.updateBackground() self.plotData() + self._captureUndoState("Fit power law background") def eventFilter(self, widget: QtCore.QObject, event: QtCore.QEvent) -> bool: """Catch enter key presses and update data plot.""" @@ -458,6 +655,10 @@ def setData(self, data_item: list | None = None, is_batch: bool = False) -> None self.plotData() + # Clear undo stack + re-baseline for fresh data + self._undo_stack_obj.clear() + self._rebaseline_undo_state() + def plotData(self) -> None: """Plot data, background and background subtracted data.""" plots: list = [self._model_item] @@ -531,6 +732,10 @@ def resetWindow(self) -> None: self.enableButtons() self.clearStatistics() + # Clear undo stack + re-baseline after data removal + self._undo_stack_obj.clear() + self._rebaseline_undo_state() + def serializeAll(self) -> dict: """ Serialize the size distribution state so data can be saved. @@ -579,21 +784,23 @@ def updateFromParameters(self, params: dict) -> None: c_name: str = params.__class__.__name__ msg: str = "SizeDistribution.updateFromParameters expects a dictionary" raise TypeError(f"{msg}: {c_name} received") - # Assign values to 'Parameters' tab inputs - use defaults if not found - self.txtMinRange.setText(str(params.get("range_q_min", "0.0"))) - self.txtMaxRange.setText(str(params.get("range_q_max", "0.0"))) - self.txtAspectRatio.setText(str(params.get("aspect_ratio", str(ASPECT_RATIO)))) - self.txtMinDiameter.setText(str(params.get("d_min", str(DIAMETER_MIN)))) - self.txtMaxDiameter.setText(str(params.get("d_max", str(DIAMETER_MAX)))) - self.txtBinsDiameter.setText(str(params.get("num_d_bins", str(NUM_DIAMETER_BINS)))) - self.chkLogBinning.setChecked(params.get("log_binning", True)) - self.txtContrast.setText(str(params.get("contrast", str(CONTRAST)))) - self.txtSkyBackgd.setText(str(params.get("sky_background", str(SKY_BACKGROUND)))) - self.txtIterations.setText(str(params.get("num_iterations", str(NUM_ITERATIONS)))) - self.txtBackgd.setText(str(params.get("background", str(BACKGROUND)))) - self.chkLowQ.setChecked(params.get("subtract_low_q", False)) - self.txtPowerLowQ.setText(str(params.get("power_low_q", str(POWER_LOW_Q)))) - self.txtScaleLowQ.setText(str(params.get("scale_low_q", str(SCALE_LOW_Q)))) + with self._undo_stack_obj.suppressed(): + # Assign values to 'Parameters' tab inputs - use defaults if not found + self.txtMinRange.setText(str(params.get("range_q_min", "0.0"))) + self.txtMaxRange.setText(str(params.get("range_q_max", "0.0"))) + self.txtAspectRatio.setText(str(params.get("aspect_ratio", str(ASPECT_RATIO)))) + self.txtMinDiameter.setText(str(params.get("d_min", str(DIAMETER_MIN)))) + self.txtMaxDiameter.setText(str(params.get("d_max", str(DIAMETER_MAX)))) + self.txtBinsDiameter.setText(str(params.get("num_d_bins", str(NUM_DIAMETER_BINS)))) + self.chkLogBinning.setChecked(params.get("log_binning", True)) + self.txtContrast.setText(str(params.get("contrast", str(CONTRAST)))) + self.txtSkyBackgd.setText(str(params.get("sky_background", str(SKY_BACKGROUND)))) + self.txtIterations.setText(str(params.get("num_iterations", str(NUM_ITERATIONS)))) + self.txtBackgd.setText(str(params.get("background", str(BACKGROUND)))) + self.chkLowQ.setChecked(params.get("subtract_low_q", False)) + self.txtPowerLowQ.setText(str(params.get("power_low_q", str(POWER_LOW_Q)))) + self.txtScaleLowQ.setText(str(params.get("scale_low_q", str(SCALE_LOW_Q)))) + self._rebaseline_undo_state() def updateQRange(self, q_range_min: float, q_range_max: float) -> None: """Update the local model based on calculated values.""" @@ -610,6 +817,8 @@ def fittingError(self, etype: type[BaseException], value: BaseException, traceba """Handle error in the calculation thread.""" # re-enable the fit buttons self.is_calculating = False + self._undo_stack_obj.set_enabled(True) + self._rebaseline_undo_state() self.enableButtons() logger.exception("Fitting failed", exc_info=(etype, value, traceback)) @@ -622,6 +831,8 @@ def fitComplete(self, result: MaxEntResult) -> None: """ # re-enable the fit buttons self.is_calculating = False + self._undo_stack_obj.set_enabled(True) + self._rebaseline_undo_state() self.enableButtons() if result is None: msg = "Fitting failed." diff --git a/src/sas/qtgui/Perspectives/UndoRedo.py b/src/sas/qtgui/Perspectives/UndoRedo.py new file mode 100644 index 0000000000..34b32eea2a --- /dev/null +++ b/src/sas/qtgui/Perspectives/UndoRedo.py @@ -0,0 +1,399 @@ +"""Shared undo/redo infrastructure for SasView perspectives. + +Provides UndoCommand (abstract base), CompoundCommand (atomic grouping), +UndoStack (history management), and DictSnapshotCommand (bulk state +snapshot for perspectives with updateFromParameters()-style restore). + +Originally extracted from Fitting/UndoRedo.py per the UNDO_OTHER.md plan. + +Design notes: +- UndoStack is a QObject so it can emit stackChanged for UI wiring. +- Each command stores old + new state and applies via widget callbacks. +- Command capture is suppressed during programmatic updates via suppressed(). +- The shared contract for any "undoable widget" is: + 1. ``self.undo_stack = UndoStack(self)`` + 2. ``_get_parameter_dict() -> dict`` — capture current input state + 3. ``_restore_parameter_values(state: dict)`` — apply a state dict +""" +from __future__ import annotations + +import contextlib +import logging +import time +import traceback + +from PySide6 import QtCore, QtWidgets + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Base command +# --------------------------------------------------------------------------- + +class UndoCommand: + """Abstract base for all undoable actions. + + Subclasses must implement ``undo(widget)`` and ``redo(widget)``. + ``description`` is shown in UI tooltips and the failure dialog. + """ + + def __init__(self, description: str) -> None: + self.description: str = description + self.timestamp: float = time.monotonic() + + def undo(self, widget) -> None: + """Apply the reverse change to *widget*.""" + raise NotImplementedError(f"{type(self).__name__}.undo() not implemented") + + def redo(self, widget) -> None: + """Re-apply the forward change to *widget*.""" + raise NotImplementedError(f"{type(self).__name__}.redo() not implemented") + + def can_merge(self, other: UndoCommand) -> bool: + """Return True if *other* may be merged into this command. + + Merging collapses consecutive edits into a single undo entry. + Default: no merging. + """ + return False + + def merge(self, other: UndoCommand) -> UndoCommand: + """Return a single command combining *self* (earlier) with *other* (later). + + Only called when ``can_merge(other)`` returns True. + """ + raise NotImplementedError(f"{type(self).__name__}.merge() not implemented") + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {self.description!r}>" + + +# --------------------------------------------------------------------------- +# Compound command (groups multiple commands atomically) +# --------------------------------------------------------------------------- + +class CompoundCommand(UndoCommand): + """Groups multiple commands into a single atomic undo/redo entry. + + ``undo()`` executes sub-commands in reverse order. + ``redo()`` executes them in forward order. + """ + + def __init__( + self, commands: list[UndoCommand], description: str = "" + ) -> None: + desc = description or ( + commands[0].description if commands else "Compound action" + ) + super().__init__(desc) + self._commands: list[UndoCommand] = list(commands) + + @property + def commands(self) -> list[UndoCommand]: + """A copy of the contained command list.""" + return list(self._commands) + + def undo(self, widget) -> None: + for cmd in reversed(self._commands): + cmd.undo(widget) + + def redo(self, widget) -> None: + for cmd in self._commands: + cmd.redo(widget) + + +# --------------------------------------------------------------------------- +# Dict snapshot command — generic bulk-state undo/redo +# --------------------------------------------------------------------------- + +class DictSnapshotCommand(UndoCommand): + """Re-applies a full state-dict snapshot for undo or redo. + + For perspectives that already have ``updateFromParameters()``-style + bulk apply (Invariant, Inversion, later SizeDistribution/Corfunc). + + Calls ``widget._restore_parameter_values(state)`` — the same hook name + ``UndoStack.reset_to_last_good()`` already uses, so the recovery path + works without any stack changes. + """ + + def __init__(self, old_state: dict, new_state: dict, description: str = "Change") -> None: + super().__init__(description) + self._old: dict = dict(old_state) + self._new: dict = dict(new_state) + + def undo(self, widget) -> None: + widget._restore_parameter_values(self._old) + + def redo(self, widget) -> None: + widget._restore_parameter_values(self._new) + + +# --------------------------------------------------------------------------- +# UndoStack +# --------------------------------------------------------------------------- + +class UndoStack(QtCore.QObject): + """Per-tab / per-page undo/redo history. + + Responsibilities: + - Maintain undo and redo stacks of UndoCommand objects. + - Coalesce consecutive commands when supported by the command type. + - Emit ``stackChanged`` whenever state changes so that actionUndo / + actionRedo enabled state and tooltip text can be refreshed. + - Suppress command capture during programmatic updates via the + ``suppressed()`` context manager or ``set_enabled(False)``. + - Handle command execution failures: log at WARNING, show a dialog, + and offer ``reset_to_last_good()`` when failures repeat. + + The stack depth defaults to ``config.UNDO_STACK_MAX_DEPTH`` (200). This + depth is applied per page/tab/widget (each widget owns its own + ``UndoStack``); it is not a global limit shared across all widgets. + + Usage:: + + # In widget.__init__: + self.undo_stack = UndoStack(self) + + # Pushing a command: + self.undo_stack.push(DictSnapshotCommand(old, new, "Description")) + + # Suppressing during programmatic updates: + with self.undo_stack.suppressed(): + self.loadFromProject(...) + """ + + stackChanged = QtCore.Signal() + + def __init__( + self, widget, parent: QtCore.QObject | None = None + ) -> None: + super().__init__(parent) + self._widget = widget + from sas import config as _sas_config + self._max_depth: int = getattr(_sas_config, "UNDO_STACK_MAX_DEPTH", 200) + self._undo_stack: list[UndoCommand] = [] + self._redo_stack: list[UndoCommand] = [] + self._enabled: bool = True + self._replaying: bool = False + self._last_good_state: dict | None = None + self._consecutive_failures: int = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def push(self, cmd: UndoCommand) -> None: + """Push *cmd* onto the undo stack. + + - Dropped silently if the stack is disabled or a replay is active. + - Coalesced with the stack top when ``top.can_merge(cmd)`` is True. + - Clears the redo stack (new action invalidates forward history). + - Trims the oldest entry when depth exceeds ``_max_depth``. + """ + if not self._enabled or self._replaying: + return + + if self._undo_stack and self._undo_stack[-1].can_merge(cmd): + self._undo_stack[-1] = self._undo_stack[-1].merge(cmd) + else: + self._undo_stack.append(cmd) + if len(self._undo_stack) > self._max_depth: + dropped = self._undo_stack.pop(0) + logger.debug( + "UndoStack: depth limit reached, dropped %r", dropped + ) + + self._redo_stack.clear() + self.stackChanged.emit() + + def undo(self) -> None: + """Undo the most recent command.""" + if not self._enabled or not self._undo_stack: + return + cmd = self._undo_stack[-1] # peek — only removed from source on success + self._replaying = True + try: + cmd.undo(self._widget) + self._undo_stack.pop() + self._redo_stack.append(cmd) + self._consecutive_failures = 0 + self._auto_snapshot() + self._refresh_view() + + except Exception: + tb = traceback.format_exc() + logger.warning("UndoStack: undo failed for %r:\n%s", cmd, tb) + self._consecutive_failures += 1 + self._handle_failure(cmd, tb) + finally: + self._replaying = False + self.stackChanged.emit() + + def redo(self) -> None: + """Redo the most recently undone command.""" + if not self._enabled or not self._redo_stack: + return + cmd = self._redo_stack[-1] # peek — only removed from source on success + self._replaying = True + try: + cmd.redo(self._widget) + self._redo_stack.pop() + self._undo_stack.append(cmd) + self._consecutive_failures = 0 + self._auto_snapshot() + self._refresh_view() + + except Exception: + tb = traceback.format_exc() + logger.warning("UndoStack: redo failed for %r:\n%s", cmd, tb) + self._consecutive_failures += 1 + self._handle_failure(cmd, tb) + finally: + self._replaying = False + self.stackChanged.emit() + + def can_undo(self) -> bool: + """Return True if undo is possible (enabled and stack non-empty).""" + return self._enabled and bool(self._undo_stack) + + def can_redo(self) -> bool: + """Return True if redo is possible (enabled and stack non-empty).""" + return self._enabled and bool(self._redo_stack) + + def clear(self) -> None: + """Clear both stacks and reset failure state.""" + self._undo_stack.clear() + self._redo_stack.clear() + self._last_good_state = None + self._consecutive_failures = 0 + self.stackChanged.emit() + + def set_enabled(self, enabled: bool) -> None: + """Enable or disable command capture and undo/redo execution. + + Emits ``stackChanged`` so that action enabled state is refreshed + immediately (e.g. buttons grey out when a calculation starts). + """ + self._enabled = enabled + self.stackChanged.emit() + + @contextlib.contextmanager + def suppressed(self): + """Context manager: temporarily disable command capture. + + Use around programmatic updates to prevent spurious entries:: + + with self.undo_stack.suppressed(): + self.loadState(...) + """ + was_enabled = self._enabled + self._enabled = False + try: + yield + finally: + self._enabled = was_enabled + + def undo_text(self) -> str: + """Human-readable label for Undo (suitable for tooltip).""" + if self._undo_stack: + return f"Undo {self._undo_stack[-1].description}" + return "Undo" + + def redo_text(self) -> str: + """Human-readable label for Redo (suitable for tooltip).""" + if self._redo_stack: + return f"Redo {self._redo_stack[-1].description}" + return "Redo" + + def save_last_good_state(self, state: dict) -> None: + """Store *state* as the recovery snapshot.""" + self._last_good_state = dict(state) + + def reset_to_last_good(self) -> None: + """Restore widget parameters from the most recent good snapshot. + + Invoked when the user clicks "Reset to Last Good State" in the + failure dialog. If no snapshot exists, logs a warning and returns. + """ + if self._last_good_state is None: + logger.warning( + "UndoStack: reset_to_last_good called but no snapshot available" + ) + return + try: + self._widget._restore_parameter_values(self._last_good_state) + logger.info( + "UndoStack: reset to last good state (%d params)", + len(self._last_good_state), + ) + except Exception: + logger.warning( + "UndoStack: reset_to_last_good failed:\n%s", + traceback.format_exc(), + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _handle_failure(self, cmd: UndoCommand, tb: str) -> None: + """Show an error dialog; offer reset when failures repeat.""" + offer_reset = ( + self._consecutive_failures >= 2 + and self._last_good_state is not None + ) + parent = ( + self._widget + if isinstance(self._widget, QtWidgets.QWidget) + else None + ) + msg_box = QtWidgets.QMessageBox(parent) + msg_box.setIcon(QtWidgets.QMessageBox.Icon.Warning) + msg_box.setWindowTitle("Undo/Redo Error") + msg_box.setText( + f"An error occurred while replaying:\n\n" + f" {cmd.description}\n\n" + f"The command has not been removed from history. The widget\n" + f"state may be inconsistent — you may try again, or use\n" + f"'Reset to Last Good State' to recover a known-good state." + ) + msg_box.setDetailedText(tb) + if offer_reset: + reset_btn = msg_box.addButton( + "Reset to Last Good State", + QtWidgets.QMessageBox.ButtonRole.ResetRole, + ) + msg_box.addButton(QtWidgets.QMessageBox.StandardButton.Close) + msg_box.exec() + if msg_box.clickedButton() is reset_btn: + self.reset_to_last_good() + else: + msg_box.setStandardButtons( + QtWidgets.QMessageBox.StandardButton.Close + ) + msg_box.exec() + + def _auto_snapshot(self) -> None: + """Record widget state as the recovery snapshot after a successful replay. + + Calls ``widget._get_parameter_dict()`` if available; silently skips + when the method is absent or returns non-dict data. + """ + try: + state = self._widget._get_parameter_dict() + except AttributeError: + return + if isinstance(state, dict): + self._last_good_state = dict(state) + + def _refresh_view(self) -> None: + """Force viewport repaint after undo/redo (for QTreeView-based widgets). + + Silently skips when the widget doesn't have a ``lstParams``. + """ + try: + self._widget.lstParams.viewport().update() + except AttributeError: + pass diff --git a/src/sas/qtgui/Perspectives/perspective.py b/src/sas/qtgui/Perspectives/perspective.py index a3d0800a67..977807af57 100644 --- a/src/sas/qtgui/Perspectives/perspective.py +++ b/src/sas/qtgui/Perspectives/perspective.py @@ -196,5 +196,13 @@ def save_parameters(self): """ Save parameters to a file""" pass + @property + def undo_stack(self): + """Optional undo stack for this perspective. + + Perspectives without undo support should leave this as ``None``. + """ + return None + diff --git a/src/sas/qtgui/Utilities/GuiUtils.py b/src/sas/qtgui/Utilities/GuiUtils.py index 67bace0f06..de28b74a26 100644 --- a/src/sas/qtgui/Utilities/GuiUtils.py +++ b/src/sas/qtgui/Utilities/GuiUtils.py @@ -182,6 +182,9 @@ class Communicate(QtCore.QObject): # Notify about a data name to be frozen and send to fitting perspective freezeDataNameSignal = QtCore.Signal(str) + # Request refresh of undo/redo action enabled state + undoRedoUpdateSignal = QtCore.Signal() + communicator = Communicate() diff --git a/src/sas/qtgui/Utilities/Preferences/GeneralPreferencesWidget.py b/src/sas/qtgui/Utilities/Preferences/GeneralPreferencesWidget.py new file mode 100644 index 0000000000..debb6126e2 --- /dev/null +++ b/src/sas/qtgui/Utilities/Preferences/GeneralPreferencesWidget.py @@ -0,0 +1,21 @@ +from sas.system import config + +from .PreferencesWidget import PreferencesWidget + + +class GeneralPreferencesWidget(PreferencesWidget): + def __init__(self): + super().__init__("General Settings") + self.config_params = ['UNDO_STACK_MAX_DEPTH'] + + def _addAllWidgets(self): + self.undoDepthSpinner = self.addSpinBox( + title="Undo History Depth", minimum=10, maximum=1000, default=config.UNDO_STACK_MAX_DEPTH) + self.undoDepthSpinner.valueChanged.connect( + lambda val: self._stageChange('UNDO_STACK_MAX_DEPTH', val)) + + def _toggleBlockAllSignaling(self, toggle): + self.undoDepthSpinner.blockSignals(toggle) + + def _restoreFromConfig(self): + self.undoDepthSpinner.setValue(config.UNDO_STACK_MAX_DEPTH) diff --git a/src/sas/qtgui/Utilities/Preferences/PreferencesPanel.py b/src/sas/qtgui/Utilities/Preferences/PreferencesPanel.py index 29b5445163..b1bf29faca 100644 --- a/src/sas/qtgui/Utilities/Preferences/PreferencesPanel.py +++ b/src/sas/qtgui/Utilities/Preferences/PreferencesPanel.py @@ -18,11 +18,13 @@ # `BASE_PANELS = {"Bar Widget Options": BarWidget}` # PreferenceWidget Imports go here and then are added to the BASE_PANELS, but not instantiated. from .DisplayPreferencesWidget import DisplayPreferencesWidget +from .GeneralPreferencesWidget import GeneralPreferencesWidget from .PlottingPreferencesWidget import PlottingPreferencesWidget # Pre-made option widgets -BASE_PANELS = {"Plotting Settings": PlottingPreferencesWidget, +BASE_PANELS = {"General Settings": GeneralPreferencesWidget, + "Plotting Settings": PlottingPreferencesWidget, "Display Settings": DisplayPreferencesWidget, } # Type: Dict[str, Union[Type[PreferencesWidget], Callable[[],QWidget]] ConfigType = str | bool | float | int | list[str | float | int] diff --git a/src/sas/qtgui/Utilities/Preferences/PreferencesWidget.py b/src/sas/qtgui/Utilities/Preferences/PreferencesWidget.py index cf50057e50..dc1dd3f3ce 100644 --- a/src/sas/qtgui/Utilities/Preferences/PreferencesWidget.py +++ b/src/sas/qtgui/Utilities/Preferences/PreferencesWidget.py @@ -1,7 +1,17 @@ import logging from PySide6.QtGui import QDoubleValidator, QIntValidator, QValidator -from PySide6.QtWidgets import QCheckBox, QComboBox, QFrame, QHBoxLayout, QLabel, QLineEdit, QVBoxLayout, QWidget +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QFrame, + QHBoxLayout, + QLabel, + QLineEdit, + QSpinBox, + QVBoxLayout, + QWidget, +) from sas.system import config @@ -249,6 +259,24 @@ def addCheckBox(self, title: str, checked: bool | None = False) -> QCheckBox: self.verticalLayout.addLayout(layout) return check_box + def addSpinBox(self, title: str, minimum: int = 0, maximum: int = 1000, default: int | None = None) -> QSpinBox: + """Add a title and integer spin box within the widget. + :param title: The title of the spin box to be added to the preferences panel. + :param minimum: The smallest value the spin box will accept. + :param maximum: The largest value the spin box will accept. + :param default: An optional value to initialise the spin box with. Defaults to the minimum if None. + :return: QSpinBox instance to allow subclasses to assign instance name + """ + layout = self._createLayoutAndTitle(title) + spinner = QSpinBox(self) + spinner.setMinimum(minimum) + spinner.setMaximum(maximum) + if default is not None: + spinner.setValue(default) + layout.addWidget(spinner) + self.verticalLayout.addLayout(layout) + return spinner + def addHorizontalLine(self): """Add a horizontal line as a divider.""" self.verticalLayout.addWidget(QHLine()) diff --git a/src/sas/system/config/config.py b/src/sas/system/config/config.py index d264f05b54..912ae11e37 100644 --- a/src/sas/system/config/config.py +++ b/src/sas/system/config/config.py @@ -162,8 +162,6 @@ def __init__(self): self.SHOW_WELCOME_PANEL = False - - # OpenCL option - should be a string, either, "none", a number, or pair of form "A:B" self.SAS_OPENCL = "none" @@ -207,6 +205,9 @@ def __init__(self): # Default fitting optimizer self.FITTING_DEFAULT_OPTIMIZER = 'lm' + # Undo/Redo stack depth per fitting tab + self.UNDO_STACK_MAX_DEPTH = 200 + # What's New variables self.LAST_WHATS_NEW_HIDDEN_VERSION = "6.0.1" @@ -217,6 +218,9 @@ def __init__(self): # If true, plots generated when using slicers will be on the same canvas self.STACK_PLOTS = True + # Developer menu + self.DEV_MENU = False + # # Lock the class down, this is necessary both for # securing the class, and for setting up reading/writing files