Skip to content

Commit

Permalink
Merge pull request #145 from scipp/save-parameters
Browse files Browse the repository at this point in the history
Set/get parameters input widget values.
  • Loading branch information
YooSunYoung authored Dec 11, 2024
2 parents bfd56a7 + c2ac94a commit 62bd310
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
logging
nexus
streaming
ui
uncertainty
widgets
```
89 changes: 84 additions & 5 deletions src/ess/reduce/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .parameter import Parameter
from .widgets import SwitchWidget, create_parameter_widget, default_layout
from .widgets._base import get_fields, set_fields
from .workflow import (
Key,
assign_parameter_values,
Expand Down Expand Up @@ -81,11 +82,13 @@ def _refresh_input_box(_: widgets.Button):
self._input_widgets.clear()
self._input_widgets.update(
{
node: widgets.HBox([create_parameter_widget(parameter)])
for node, parameter in registry_getter().items()
node: create_parameter_widget(parameter)
for node, parameter in new_input_parameters.items()
}
)
self._input_box.children = list(self._input_widgets.values())
self._input_box.children = [
widgets.HBox([widget]) for widget in self._input_widgets.values()
]

self.parameter_refresh_button.on_click(_refresh_input_box)

Expand All @@ -97,8 +100,7 @@ def value(self) -> dict[Key, Any]:
return {
node: widget.value
for node, widget_box in self._input_widgets.items()
if (not isinstance((widget := widget_box.children[0]), SwitchWidget))
or widget.enabled
if (not isinstance((widget := widget_box), SwitchWidget)) or widget.enabled
}


Expand Down Expand Up @@ -232,3 +234,80 @@ def refresh_workflow_box(change) -> None:
workflow_selection_box = widgets.HBox([workflow_select], layout=default_layout)
workflow_box = widgets.Box(layout=default_layout)
return widgets.VBox([workflow_selection_box, workflow_box])


def _get_parameter_box(widget: WorkflowWidget | ParameterBox) -> ParameterBox:
if isinstance(widget, WorkflowWidget):
return widget.parameter_box
elif isinstance(widget, ParameterBox):
return widget
else:
raise TypeError(
f"Expected target_widget to be a WorkflowWidget or ParameterBox, "
f"got {type(widget)}."
)


def set_parameter_widget_values(
widget: WorkflowWidget | ParameterBox, new_parameter_values: dict[type, Any]
) -> None:
"""Set the values of the input widgets in the target widget.
Nodes that don't exist in the input widgets will be ignored.
Example
-------
{
'WavelengthBins': {'start': 1.0, 'stop': 14.0, 'nbins': 500}
}
Parameters
----------
widget:
The widget containing the input widgets.
new_parameter_values:
A dictionary of values/state to set each fields/state or value of input widgets.
Raises
------
TypeError:
If the widget is not a WorkflowWidget or a ParameterBox.
"""
parameter_box = _get_parameter_box(widget)
# Walk through the existing input widgets and set the values
# ``node`s that don't exist in the input widgets will be ignored.
for node, widget in parameter_box._input_widgets.items():
if node in new_parameter_values:
# We shouldn't use `get` here because ``None`` is a valid value.
set_fields(widget, new_parameter_values[node])


def get_parameter_widget_values(
widget: WorkflowWidget | ParameterBox,
) -> dict[type, Any]:
"""Return the current values of the input widgets in the target widget.
The result of this function can be used to set the values of the input widgets
using the :py:func:`~set_parameter_widget_values` function.
Parameters
----------
widget:
The widget containing the input widgets.
Returns
-------
:
A dictionary of the current values/state of each input widget.
Raises
------
TypeError:
If the widget is not a WorkflowWidget or a ParameterBox.
"""
return {
node: get_fields(widget)
for node, widget in _get_parameter_box(widget)._input_widgets.items()
}
63 changes: 63 additions & 0 deletions src/ess/reduce/widgets/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
import warnings
from typing import Any, Protocol, runtime_checkable

from ipywidgets import Widget


@runtime_checkable
class WidgetWithFieldsProtocol(Protocol):
def set_fields(self, new_values: dict[str, Any]) -> None: ...

def get_fields(self) -> dict[str, Any]: ...


class WidgetWithFieldsMixin:
def set_fields(self, new_values: dict[str, Any]) -> None:
# Extract valid fields
new_field_names = set(new_values.keys())
valid_field_names = new_field_names & set(self.fields.keys())
# Warn for invalid fields
invalid_field_names = new_field_names - valid_field_names
for field_name in invalid_field_names:
warning_msg = f"Cannot set field '{field_name}'."
" The field does not exist in the widget."
"The field value will be ignored."
warnings.warn(warning_msg, UserWarning, stacklevel=1)
# Set valid fields
for field_name in valid_field_names:
self.fields[field_name].value = new_values[field_name]

def get_fields(self) -> dict[str, Any]:
return {
field_name: field_sub_widget.value
for field_name, field_sub_widget in self.fields.items()
}


def _has_widget_value_setter(widget: Widget) -> bool:
widget_type = type(widget)
return (
widget_property := getattr(widget_type, 'value', None)
) is not None and getattr(widget_property, 'fset', None) is not None


def set_fields(widget: Widget, new_values: Any) -> None:
if isinstance(widget, WidgetWithFieldsProtocol) and isinstance(new_values, dict):
widget.set_fields(new_values)
elif _has_widget_value_setter(widget):
widget.value = new_values
else:
warnings.warn(
f"Cannot set value or fields for widget of type {type(widget)}."
" The new_value(s) will be ignored.",
UserWarning,
stacklevel=1,
)


def get_fields(widget: Widget) -> Any:
if isinstance(widget, WidgetWithFieldsProtocol):
return widget.get_fields()
return widget.value
4 changes: 3 additions & 1 deletion src/ess/reduce/widgets/_binedges_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ipywidgets as ipw
import scipp as sc

from ._base import WidgetWithFieldsMixin

UNITS_LIBRARY = {
"wavelength": {"options": ("angstrom", "nm")},
"Q": {"options": ("1/angstrom", "1/nm")},
Expand All @@ -19,7 +21,7 @@
}


class BinEdgesWidget(ipw.HBox, ipw.ValueWidget):
class BinEdgesWidget(ipw.HBox, ipw.ValueWidget, WidgetWithFieldsMixin):
def __init__(
self,
name: str,
Expand Down
3 changes: 2 additions & 1 deletion src/ess/reduce/widgets/_bounds_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from ipywidgets import FloatText, GridBox, Label, Text, ValueWidget

from ..parameter import ParamWithBounds
from ._base import WidgetWithFieldsMixin


class BoundsWidget(GridBox, ValueWidget):
class BoundsWidget(GridBox, ValueWidget, WidgetWithFieldsMixin):
def __init__(self):
super().__init__()

Expand Down
4 changes: 3 additions & 1 deletion src/ess/reduce/widgets/_linspace_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import scipp as sc
from ipywidgets import FloatText, GridBox, IntText, Label, ValueWidget

from ._base import WidgetWithFieldsMixin

class LinspaceWidget(GridBox, ValueWidget):

class LinspaceWidget(GridBox, ValueWidget, WidgetWithFieldsMixin):
def __init__(self, dim: str, unit: str):
super().__init__()

Expand Down
4 changes: 3 additions & 1 deletion src/ess/reduce/widgets/_vector_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import scipp as sc
from ipywidgets import FloatText, HBox, Label, Text, ValueWidget

from ._base import WidgetWithFieldsMixin

class VectorWidget(HBox, ValueWidget):

class VectorWidget(HBox, ValueWidget, WidgetWithFieldsMixin):
def __init__(self, name: str, variable: sc.Variable, components: str):
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion tests/widget_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def provider_with_optional(a: OptionalInt, b: OptionalFloat) -> str:


def _get_param_widget(widget: WorkflowWidget, param_type: type) -> Any:
return widget.parameter_box._input_widgets[param_type].children[0]
return widget.parameter_box._input_widgets[param_type]


def test_parameter_default_value_test() -> None:
Expand Down

0 comments on commit 62bd310

Please sign in to comment.