diff --git a/src/shiver/models/polarized.py b/src/shiver/models/polarized.py index 5173800..af882b0 100644 --- a/src/shiver/models/polarized.py +++ b/src/shiver/models/polarized.py @@ -3,7 +3,7 @@ # pylint: disable=no-name-in-module from mantid.simpleapi import mtd, AddSampleLog from mantid.kernel import Logger - +from shiver.models.generate import gather_mde_config_dict, save_mde_config_dict logger = Logger("SHIVER") @@ -24,6 +24,8 @@ def save_experiment_sample_log(self, log_name, log_value): if self.workspace_name and mtd.doesExist(self.workspace_name): workspace = mtd[self.workspace_name] AddSampleLog(workspace, LogName=log_name, LogText=log_value, LogType="String") + # update the MDEConfig with the polarized options + self.update_polarized_mde_config() def save_polarization_logs(self, polarization_logs): """Save polarization logs in workspace""" @@ -112,3 +114,20 @@ def get_flipping_ratio(self): err = f"{flipping_formula} is invalid!" logger.error(err) return flipping_ratio + + def update_polarized_mde_config(self): + """Update the MDE Config Polarized Parameters, if the MDE Config exists""" + + # update mde config if it exists + saved_mde_config = {} + saved_mde_config.update(gather_mde_config_dict(self.workspace_name)) + + # if MDEConfig exists + if len(saved_mde_config.keys()) != 0: + # update the MDEConfig with the current values + sample_data = self.get_polarization_logs_for_workspace() + # format + sample_data["PSDA"] = sample_data["psda"] + del sample_data["psda"] + saved_mde_config["PolarizedOptions"] = sample_data + save_mde_config_dict(self.workspace_name, saved_mde_config) diff --git a/src/shiver/models/refine_ub.py b/src/shiver/models/refine_ub.py index 9c8149c..65941d0 100644 --- a/src/shiver/models/refine_ub.py +++ b/src/shiver/models/refine_ub.py @@ -18,6 +18,7 @@ IndexPeaks, ) from mantid.kernel import Logger +from shiver.models.sample import update_sample_mde_config logger = Logger("SHIVER") @@ -194,6 +195,7 @@ def predict_peaks(self): def update_mde_with_new_ub(self): """Update the UB in the MDE from the one in the peaks workspace""" CopySample(self.peaks, self.mde, CopyName=False, CopyMaterial=False, CopyEnvironment=False, CopyShape=False) + update_sample_mde_config(self.mde.name(), self.mde.getExperimentInfo(0).sample().getOrientedLattice()) def get_perpendicular_slices(self, peak_row): """Create 3 perpendicular slices center on the peaks corresponding to the given row""" diff --git a/src/shiver/models/sample.py b/src/shiver/models/sample.py index 7840db4..4baf528 100644 --- a/src/shiver/models/sample.py +++ b/src/shiver/models/sample.py @@ -18,6 +18,7 @@ from mantid.kernel import Logger from mantidqtinterfaces.DGSPlanner.LoadNexusUB import LoadNexusUB from mantidqtinterfaces.DGSPlanner.ValidateOL import ValidateUB +from shiver.models.generate import gather_mde_config_dict, save_mde_config_dict logger = Logger("SHIVER") @@ -135,6 +136,10 @@ def set_ub(self, params): v=vvec, ) logger.information(f"SetUB completed for {self.name}") + # get the saved oriented lattice + self.oriented_lattice = workspace.getExperimentInfo(0).sample().getOrientedLattice() + # update the mdeconfig + update_sample_mde_config(self.name, self.oriented_lattice) return True except ValueError as value_error: err_msg = f"Invalid lattices: {value_error}\n" @@ -227,3 +232,26 @@ def save_isaw(self, filename): logger.error(err_msg) if self.error_callback: self.error_callback(err_msg) + + +def update_sample_mde_config(name, oriented_lattice): + """Update the MDE Config Sample Parameters, if the MDE Config exists""" + + # updated mde config if it exists + saved_mde_config = {} + saved_mde_config.update(gather_mde_config_dict(name)) + + # if MDEConfig exists + if len(saved_mde_config.keys()) != 0: + # update the MDEConfig with the current value + sample_data = {} + sample_data["a"] = oriented_lattice.a() + sample_data["b"] = oriented_lattice.b() + sample_data["c"] = oriented_lattice.c() + sample_data["alpha"] = oriented_lattice.alpha() + sample_data["beta"] = oriented_lattice.beta() + sample_data["gamma"] = oriented_lattice.gamma() + sample_data["u"] = ",".join(str(item) for item in oriented_lattice.getuVector()) + sample_data["v"] = ",".join(str(item) for item in oriented_lattice.getvVector()) + saved_mde_config["SampleParameters"] = sample_data + save_mde_config_dict(name, saved_mde_config) diff --git a/src/shiver/presenters/polarized.py b/src/shiver/presenters/polarized.py index 8dc0a24..e2e8796 100644 --- a/src/shiver/presenters/polarized.py +++ b/src/shiver/presenters/polarized.py @@ -35,13 +35,15 @@ def get_polarization_logs(self): def handle_apply_button(self, polarization_logs): """Save the values for the sample logs""" + saved_logs = {} + saved_logs.update(polarization_logs) # do not update psda value if readonly field if self.view.dialog.disable_psda: - del polarization_logs["PSDA"] + del saved_logs["PSDA"] else: - polarization_logs["psda"] = polarization_logs["PSDA"] - del polarization_logs["PSDA"] - self.model.save_polarization_logs(polarization_logs) + saved_logs["psda"] = saved_logs["PSDA"] + del saved_logs["PSDA"] + self.model.save_polarization_logs(saved_logs) def create_dictionary_polarized_options(sample_log_data): diff --git a/tests/models/test_histogram_saving.py b/tests/models/test_histogram_saving.py index a5f51d6..9c54e0b 100644 --- a/tests/models/test_histogram_saving.py +++ b/tests/models/test_histogram_saving.py @@ -22,6 +22,7 @@ from shiver.models.polarized import PolarizedModel from shiver.views.polarized_options import PolarizedView from shiver.presenters.polarized import PolarizedPresenter +from shiver.models.generate import gather_mde_config_dict def test_saving(tmp_path): @@ -664,6 +665,86 @@ def test_polarization_parameters(tmp_path, shiver_app, qtbot): assert saved_pol_logs["PSDA"] == "1.3" +def test_polarization_mdeconfig_parameters(tmp_path, qtbot): + """Test the polarization parameters are saved in the MDEConfig""" + + # clear mantid workspace + mtd.clear() + + name = "px_mini_NSF" + filepath = f"{tmp_path}/{name}.nxs" + + # load mde workspace + LoadMD( + Filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/mde/px_mini_NSF.nxs"), + OutputWorkspace="data", + ) + + MakeSlice( + InputWorkspace="data", + BackgroundWorkspace=None, + NormalizationWorkspace=None, + QDimension0="0,0,1", + QDimension1="1,1,0", + QDimension2="-1,1,0", + Dimension0Name="QDimension1", + Dimension0Binning="0.35,0.025,0.65", + Dimension1Name="QDimension0", + Dimension1Binning="0.45,0.55", + Dimension2Name="QDimension2", + Dimension2Binning="-0.2,0.2", + Dimension3Name="DeltaE", + Dimension3Binning="-0.5,0.5", + SymmetryOperations=None, + Smoothing=1, + OutputWorkspace=name, + ) + model = HistogramModel() + workspace = mtd[name] + model.save(name, filepath) + + # check the mde config values + mde_config = gather_mde_config_dict(name) + assert len(mde_config) != 0 + + pol_sample_logs = { + "PolarizationState": "NSF", + "PolarizationDirection": "Px", + "FlippingRatio": "3Ei+1/4", + "FlippingRatioSampleLog": "Ei", + "PSDA": "1.8", + } + + qtbot.wait(100) + # save polarization parameters + polarized_view = PolarizedView() + polarized_model = PolarizedModel(name) + polarized_presenter = PolarizedPresenter(polarized_view, polarized_model) + polarized_view.start_dialog(False) + polarized_presenter.handle_apply_button(pol_sample_logs) + + # check polarization parameters in sample logs + run = workspace.getExperimentInfo(0).run() + assert run.getLogData("PolarizationState").value == pol_sample_logs["PolarizationState"] + assert run.getLogData("PolarizationDirection").value == pol_sample_logs["PolarizationDirection"] + assert run.getLogData("FlippingRatio").value == pol_sample_logs["FlippingRatio"] + assert run.getLogData("FlippingRatioSampleLog").value == pol_sample_logs["FlippingRatioSampleLog"] + assert run.getPropertyAsSingleValueWithTimeAveragedMean("psda") == 1.8 + + # check the MDEConfig dictionary + config = {} + config_data = run.getProperty("MDEConfig").value + config.update(ast.literal_eval(config_data)) + + assert len(config.keys()) != 0 + assert config["mde_name"] == name + assert config["PolarizedOptions"]["PolarizationState"] == pol_sample_logs["PolarizationState"] + assert config["PolarizedOptions"]["PolarizationDirection"] == pol_sample_logs["PolarizationDirection"] + assert config["PolarizedOptions"]["FlippingRatio"] == pol_sample_logs["FlippingRatio"] + assert config["PolarizedOptions"]["FlippingRatioSampleLog"] == pol_sample_logs["FlippingRatioSampleLog"] + assert config["PolarizedOptions"]["PSDA"] == pol_sample_logs["PSDA"] + + def test_polarization_state_invalid(tmp_path): """Test the polarization state for invalid state""" diff --git a/tests/models/test_refine_ub.py b/tests/models/test_refine_ub.py index ae17f8a..33193fe 100644 --- a/tests/models/test_refine_ub.py +++ b/tests/models/test_refine_ub.py @@ -1,7 +1,9 @@ """Tests for the RefineUBModel""" import pytest +import numpy as np from shiver.models.refine_ub import RefineUBModel +from shiver.models.generate import gather_mde_config_dict, save_mde_config_dict from mantid.simpleapi import ( # pylint: disable=no-name-in-module,wrong-import-order CreateMDWorkspace, FakeMDEventData, @@ -172,3 +174,98 @@ def test_refine_ub_model(): assert peak_table_model.ws.sample().getOrientedLattice().gamma() == pytest.approx(90) assert peak_table_model.ws.sample().getOrientedLattice().getuVector() == pytest.approx([1, 0, 0]) assert peak_table_model.ws.sample().getOrientedLattice().getvVector() == pytest.approx([0, 1, 0]) + + +def test_mdeconfig_refine_ub(): + """test the mdeconfig in RefineUBModel""" + + expt_info = CreateSampleWorkspace() + SetUB(expt_info) + + mde = CreateMDWorkspace( + Dimensions=4, + Extents="-10,10,-10,10,-10,10,-10,10", + Names="x,y,z,DeltaE", + Units="r.l.u.,r.l.u.,r.l.u.,DeltaE", + Frames="QSample,QSample,QSample,General Frame", + ) + mde.addExperimentInfo(expt_info) + FakeMDEventData(mde, PeakParams="1e+05,6.283,0,0,0,0.02", RandomSeed="3873875") + FakeMDEventData(mde, PeakParams="1e+05,0,6.283,0,0,0.02", RandomSeed="3873875") + FakeMDEventData(mde, PeakParams="1e+05,0,0,6.283,0,0.02", RandomSeed="3873875") + + # add new MDEConfig + new_mde_config = {} + new_mde_config["mde_name"] = mde.name() + new_mde_config["output_dir"] = "/test/file/path" + new_mde_config["mde_type"] = "Data" + save_mde_config_dict(mde.name(), new_mde_config) + # check the mde config values + mde_config = gather_mde_config_dict(mde.name()) + + assert len(mde_config) == 3 + + mdh = CreateMDWorkspace( + Dimensions=4, + Extents="-5,5,-5,5,-5,5,-10,10", + Names="[H,0,0],[0,K,0],[0,0,L],DeltaE", + Units="r.l.u.,r.l.u.,r.l.u.,DeltaE", + Frames="HKL,HKL,HKL,General Frame", + ) + mdh.addExperimentInfo(expt_info) + SetUB(mdh) + FakeMDEventData(mdh, PeakParams="1e+05,1,0,0,0,0.02", RandomSeed="3873875") + FakeMDEventData(mdh, PeakParams="1e+05,0,1,0,0,0.02", RandomSeed="3873875") + FakeMDEventData(mdh, PeakParams="1e+05,0,0,1,0,0.02", RandomSeed="3873875") + mdh.getExperimentInfo(0).run().addProperty("W_MATRIX", [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], True) + + mdh = BinMD( + mdh, + AlignedDim0="[H,0,0],-2,2,50", + AlignedDim1="[0,K,0],-2,2,50", + AlignedDim2="[0,0,L],-2,2,50", + AlignedDim3="DeltaE,-1.25,1.25,1", + ) + + model = RefineUBModel("mdh", "mde") + model.predict_peaks() + + # peak table model + peak_table_model = model.get_peaks_table_model() + peak_table_model.set_peak_number_to_rows() + + # recenter peaks + peak_table_model.recenter_rows([0, 4]) + + # refine, should change the lattice parameters and u/v vectors + peak_table_model.refine([3, 4, 5], "") + model.update_mde_with_new_ub() + # check the oriented lattice + mde_oriented_lattice = mde.getExperimentInfo(0).sample().getOrientedLattice() + peak_oriented_lattice = peak_table_model.ws.sample().getOrientedLattice() + + assert peak_oriented_lattice.a() == mde_oriented_lattice.a() + assert peak_oriented_lattice.b() == mde_oriented_lattice.b() + assert peak_oriented_lattice.c() == mde_oriented_lattice.c() + assert peak_oriented_lattice.alpha() == mde_oriented_lattice.alpha() + assert peak_oriented_lattice.beta() == mde_oriented_lattice.beta() + assert peak_oriented_lattice.gamma() == mde_oriented_lattice.gamma() + assert peak_oriented_lattice.getuVector() == pytest.approx(mde_oriented_lattice.getuVector()) + assert peak_oriented_lattice.getvVector() == pytest.approx(mde_oriented_lattice.getvVector()) + + # check the mde config values + mde_config = gather_mde_config_dict(mde.name()) + + assert len(mde_config) == 4 + assert "SampleParameters" in mde_config + assert mde_config["SampleParameters"]["a"] == mde_oriented_lattice.a() + assert mde_config["SampleParameters"]["b"] == mde_oriented_lattice.b() + assert mde_config["SampleParameters"]["c"] == mde_oriented_lattice.c() + assert mde_config["SampleParameters"]["alpha"] == mde_oriented_lattice.alpha() + assert mde_config["SampleParameters"]["beta"] == mde_oriented_lattice.beta() + assert mde_config["SampleParameters"]["gamma"] == mde_oriented_lattice.gamma() + + u_array = np.array(mde_config["SampleParameters"]["u"].split(","), dtype=float) + assert u_array == pytest.approx(mde_oriented_lattice.getuVector()) + v_array = np.array(mde_config["SampleParameters"]["v"].split(","), dtype=float) + assert v_array == pytest.approx(mde_oriented_lattice.getvVector()) diff --git a/tests/models/test_sample_parameters_buttons_actions.py b/tests/models/test_sample_parameters_buttons_actions.py index bc88362..6ecadee 100644 --- a/tests/models/test_sample_parameters_buttons_actions.py +++ b/tests/models/test_sample_parameters_buttons_actions.py @@ -1,15 +1,17 @@ """tests for Sample Parameters dialog: button actions""" import os +import numpy as np from pytest import approx # pylint: disable=no-name-in-module -from mantid.simpleapi import LoadMD +from mantid.simpleapi import LoadMD, mtd from shiver.models.sample import SampleModel +from shiver.models.generate import gather_mde_config_dict -def test_apply_button_valid(): - """Test for pressing Apply button with valid input""" +def test_apply_button_valid_no_mde(): + """Test for pressing Apply button with valid input and mde workspace without MDEConfig""" name = "data" LoadMD( @@ -45,6 +47,77 @@ def error_callback(msg): sample_model.set_ub(params) assert len(errors) == 0 + # check the oriented lattice saved in samplemodel + assert sample_model.oriented_lattice.a() == params["a"] + assert sample_model.oriented_lattice.b() == params["b"] + assert sample_model.oriented_lattice.c() == params["c"] + assert sample_model.oriented_lattice.alpha() == params["alpha"] + assert sample_model.oriented_lattice.beta() == params["beta"] + assert sample_model.oriented_lattice.gamma() == params["gamma"] + + # check the oriented lattice saved in sthe workspace + workspace_lattice = mtd[name].getExperimentInfo(0).sample().getOrientedLattice() + assert workspace_lattice.a() == sample_model.oriented_lattice.a() + assert workspace_lattice.b() == sample_model.oriented_lattice.b() + assert workspace_lattice.c() == sample_model.oriented_lattice.c() + assert workspace_lattice.alpha() == sample_model.oriented_lattice.alpha() + assert workspace_lattice.beta() == sample_model.oriented_lattice.beta() + assert workspace_lattice.gamma() == sample_model.oriented_lattice.gamma() + + mde_config = gather_mde_config_dict(name) + assert len(mde_config) == 0 + + +def test_apply_button_valid_mde(): + """Test for pressing Apply button with valid input and mde workspace with MDEConfig""" + + name = "data" + LoadMD( + Filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/mde/px_mini_NSF.nxs"), + OutputWorkspace=name, + ) + + sample_model = SampleModel(name) + + errors = [] + params = {} + + params["a"] = 5.4 + params["b"] = 6.4 + params["c"] = 4.4 + params["alpha"] = 90.0 + params["beta"] = 90.0 + params["gamma"] = 90.0 + params["u"] = "0.00,-0.00,4.40" + params["v"] = "4.12717,4.12717,-0.000" + + def error_callback(msg): + errors.append(msg) + + sample_model.connect_error_message(error_callback) + sample_model.set_ub(params) + assert len(errors) == 0 + + assert sample_model.oriented_lattice.a() == params["a"] + + # check the mde config values + mde_config = gather_mde_config_dict(name) + assert len(mde_config) != 0 + assert mde_config["SampleParameters"]["a"] == params["a"] + assert mde_config["SampleParameters"]["b"] == params["b"] + assert mde_config["SampleParameters"]["c"] == params["c"] + assert mde_config["SampleParameters"]["alpha"] == params["alpha"] + assert mde_config["SampleParameters"]["beta"] == params["beta"] + assert mde_config["SampleParameters"]["gamma"] == params["gamma"] + + mde_u_array = np.array(mde_config["SampleParameters"]["u"].split(","), dtype=float) + param_u_array = np.array(params["u"].split(","), dtype=float) + assert mde_u_array == approx(param_u_array) + + v_array = np.array(mde_config["SampleParameters"]["v"].split(","), dtype=float) + param_v_array = np.array(params["v"].split(","), dtype=float) + assert v_array == approx(param_v_array) + def test_apply_button_invalid(): """Test for pressing Apply button with invalid input""" diff --git a/tests/views/test_refine_ub_ui.py b/tests/views/test_refine_ub_ui.py index 9220ed6..58e6bf6 100644 --- a/tests/views/test_refine_ub_ui.py +++ b/tests/views/test_refine_ub_ui.py @@ -10,6 +10,7 @@ CreateSampleWorkspace, ) from shiver.presenters.refine_ub import RefineUB +from shiver.models.generate import gather_mde_config_dict def test_refine_ub_ui(qtbot): @@ -21,6 +22,8 @@ def test_refine_ub_ui(qtbot): Units="r.l.u.,r.l.u.,r.l.u.,DeltaE", Frames="QSample,QSample,QSample,General Frame", ) + expt_info = CreateSampleWorkspace() + mde.addExperimentInfo(expt_info) SetUB(mde) FakeMDEventData(mde, PeakParams="1e+05,6.283,0,0,0,0.02", RandomSeed="3873875") FakeMDEventData(mde, PeakParams="1e+05,0,6.283,0,0,0.02", RandomSeed="3873875") @@ -33,7 +36,6 @@ def test_refine_ub_ui(qtbot): Units="r.l.u.,r.l.u.,r.l.u.,DeltaE", Frames="HKL,HKL,HKL,General Frame", ) - expt_info = CreateSampleWorkspace() mdh.addExperimentInfo(expt_info) SetUB(mdh) FakeMDEventData(mdh, PeakParams="1e+05,1,0,0,0,0.02", RandomSeed="3873875") @@ -147,6 +149,10 @@ def test_refine_ub_ui(qtbot): assert refine_ub.view.peaks_table.view.model().refine_rows() == [3, 4, 5] + # check the mde config values do not exist + mde_config = gather_mde_config_dict(mde.name()) + assert len(mde_config) == 0 + qtbot.mouseClick(refine_ub.view.refine_btn, QtCore.Qt.LeftButton) qtbot.wait(100) @@ -156,3 +162,7 @@ def test_refine_ub_ui(qtbot): assert refine_ub.model.peaks.sample().getOrientedLattice().alpha() == pytest.approx(90) assert refine_ub.model.peaks.sample().getOrientedLattice().beta() == pytest.approx(89.99976212) assert refine_ub.model.peaks.sample().getOrientedLattice().gamma() == pytest.approx(90) + + # check the mde config values still do not exist + mde_config = gather_mde_config_dict(mde.name()) + assert len(mde_config) == 0