diff --git a/src/shiver/views/refine_ub.py b/src/shiver/views/refine_ub.py index 4f630dd..ff6fb99 100644 --- a/src/shiver/views/refine_ub.py +++ b/src/shiver/views/refine_ub.py @@ -11,7 +11,7 @@ QGroupBox, QComboBox, ) -from qtpy.QtCore import Qt, QAbstractTableModel +from qtpy.QtCore import Qt, QAbstractTableModel, QModelIndex from matplotlib.backends.backend_qtagg import FigureCanvas # pylint: disable=no-name-in-module import matplotlib.pyplot as plt @@ -101,6 +101,26 @@ def refine_rows(self): """Return a list of row number where "Refine" box is checked""" return sorted(k for k, v in self._refine.items() if v) + def select_all(self): + """Select all refine/recenter checkboxes""" + self._refine = {n: True for n in range(self.rowCount(0))} + self._recenter = {n: True for n in range(self.rowCount(0))} + self.dataChanged.emit(QModelIndex(), QModelIndex()) # force view to redraw + + def deselect_all(self): + """Deselect all refine/recenter checkboxes""" + self._refine = {} + self._recenter = {} + self.dataChanged.emit(QModelIndex(), QModelIndex()) # force view to redraw + + def round_hkl(self): + """Round all HKL values to integer""" + for row in range(self.rowCount(0)): + for index in range(1, 4): + self._data_model.set_cell_data(row, index, round(self._data_model.get_cell(row, index)), False) + + self.dataChanged.emit(QModelIndex(), QModelIndex()) # force view to redraw + class RefineUBView(QWidget): """The view for the Refine UB widget""" @@ -143,18 +163,29 @@ def _setup_ui(self): peaks_layout = QVBoxLayout() - btn_layout = QHBoxLayout() + btn_layout = QGridLayout() self.populate_peaks = QPushButton("Populate Peaks") self.populate_peaks.setCheckable(True) - self.populate_peaks.clicked.connect(self.populate_peaks_call) + self.populate_peaks.clicked.connect(self._populate_peaks_call) self.predict_peaks = QPushButton("Predict peaks") - self.predict_peaks.clicked.connect(self.predict_peaks_call) + self.predict_peaks.clicked.connect(self._predict_peaks_call) self.recenter_peaks = QPushButton("Recenter") - self.recenter_peaks.clicked.connect(self.recenter_peaks_call) + self.recenter_peaks.clicked.connect(self._recenter_peaks_call) + + self.select_all = QPushButton("Select All") + self.select_all.clicked.connect(self._select_all_call) + self.deselect_all = QPushButton("Deselect All") + self.deselect_all.clicked.connect(self._deselect_all_call) + self.round_hkl = QPushButton("Round HKL") + self.round_hkl.clicked.connect(self._round_hkl_call) + + btn_layout.addWidget(self.populate_peaks, 0, 0) + btn_layout.addWidget(self.predict_peaks, 0, 1) + btn_layout.addWidget(self.recenter_peaks, 0, 2) + btn_layout.addWidget(self.select_all, 1, 0) + btn_layout.addWidget(self.deselect_all, 1, 1) + btn_layout.addWidget(self.round_hkl, 1, 2) - btn_layout.addWidget(self.populate_peaks) - btn_layout.addWidget(self.predict_peaks) - btn_layout.addWidget(self.recenter_peaks) peaks_layout.addLayout(btn_layout) peaks_layout.addWidget(self.peaks_table.view) self.peaks_table.view.selectionModel().currentRowChanged.connect(self._on_row_selected) @@ -240,7 +271,7 @@ def connect_populate_peaks(self, callback): """connect the "populate peaks" button callback""" self.populate_peaks_callback = callback - def populate_peaks_call(self, checked): + def _populate_peaks_call(self, checked): """call when "populate peaks" button pressed""" if self.populate_peaks_callback: self.populate_peaks_callback(checked) @@ -249,7 +280,7 @@ def connect_predict_peaks(self, callback): """connect the "predict peaks" button callback""" self.predict_peaks_callback = callback - def predict_peaks_call(self): + def _predict_peaks_call(self): """call when "predict peaks" button pressed""" if self.predict_peaks_callback: self.predict_peaks_callback() @@ -258,11 +289,23 @@ def connect_recenter_peaks(self, callback): """connect the recenter button callback""" self.recenter_peaks_callback = callback - def recenter_peaks_call(self): + def _recenter_peaks_call(self): """call when recenter button pressed""" if self.recenter_peaks_callback: self.recenter_peaks_callback() + def _select_all_call(self): + """call when select all button pressed""" + self.peaks_table.view.model().select_all() + + def _deselect_all_call(self): + """call when deselect all button pressed""" + self.peaks_table.view.model().deselect_all() + + def _round_hkl_call(self): + """call when round HKL button pressed""" + self.peaks_table.view.model().round_hkl() + def connect_refine(self, callback): """connect the refine button callback""" self.refine_callback = callback diff --git a/tests/views/test_refine_ub_ui.py b/tests/views/test_refine_ub_ui.py index a784c5f..0d09ffc 100644 --- a/tests/views/test_refine_ub_ui.py +++ b/tests/views/test_refine_ub_ui.py @@ -108,6 +108,17 @@ def test_refine_ub_ui(qtbot): [5.00797398e-07, 6.28529930, -2.60952093e-05] ) + # test round HKL button + assert refine_ub.model.peaks.getPeak(3).getHKL() == pytest.approx([-4.15318e-06, 7.97044e-08, 1.000336]) + qtbot.mouseClick(refine_ub.view.round_hkl, QtCore.Qt.LeftButton) + assert refine_ub.model.peaks.getPeak(3).getHKL() == pytest.approx([0, 0, 1]) + + # test select all/deselect all buttons + qtbot.mouseClick(refine_ub.view.select_all, QtCore.Qt.LeftButton) + assert refine_ub.view.peaks_table.view.model().refine_rows() == [0, 1, 2, 3, 4, 5] + qtbot.mouseClick(refine_ub.view.deselect_all, QtCore.Qt.LeftButton) + assert refine_ub.view.peaks_table.view.model().refine_rows() == [] + # select 3 peaks then press "Refine" and check that UB is updated assert refine_ub.model.peaks.sample().getOrientedLattice().a() == 1 assert refine_ub.model.peaks.sample().getOrientedLattice().b() == 1