Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: draft autowidget #54

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/ess/reduce/autowidget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import ipywidgets as ipw
from IPython.display import display


class AutoWidget(ipw.ValueWidget, ipw.AppLayout):
def __init__(
self,
tp,
tp_args,
tp_kwargs=None,
*,
display_output=False,
child_layout=ipw.HBox,
**kwargs,
):
'''Creates a widget from a constructor `tp` and a range of argument widgets

Example usage:
bound_widget = AutoWidget(
lambda a, b: (sc.scalar(a), sc.scalar(b)),
ipw.FloatText(description="left"),
ipw.FloatText(description="right")
)
range_widget = AutoWidget(
sc.linspace,
dim=ipywidgets.Text(value='m'),
start=ipywidgets.FloatText(),
stop=ipywidgets.FloatText(),
num=ipywidgets.IntText(value=50),
unit=ipywidgets.Text(value='angstrom')
)
'''
if tp_kwargs is None:
tp_kwargs = {}
self.tp = tp
self.args = tp_args
self.kwargs = tp_kwargs
self.out = ipw.Output()

form = child_layout([*tp_args, *tp_kwargs.values()])
layout = (
dict(left_sidebar=form, center=self.out) # noqa: C408
if display_output
else dict(center=form) # noqa: C408
)
if 'description' in kwargs:
layout["header"] = ipw.Label(value=kwargs['description'])

super().__init__(
**kwargs,
**layout,
)
for a in (*tp_args, *tp_kwargs.values()):
a.observe(self._recompute, names='value')

self.observe(self._set_output, names='value')
self._recompute()

def _set_output(self, _=None):
self.out.clear_output()
with self.out:
display(self.value)

def _recompute(self, _=None):
self.value = self.tp(
*(v.value for v in self.args),
**{k: v.value for k, v in self.kwargs.items()},
)


def extract_children(widget):
'Extracts underlying data from nested widget'
if not hasattr(widget, 'children') or widget.children == ():
if hasattr(widget, 'value'):
return widget.value
else:
return ()
return tuple(extract_children(child) for child in widget.children)


def insert_children(widget, children):
'Inserts underlying data into nested widget'
if not hasattr(widget, 'children') or widget.children == ():
if hasattr(widget, 'value'):
widget.value = children
return
else:
return
for wid, child in zip(widget.children, children, strict=True):
insert_children(wid, child)
3 changes: 1 addition & 2 deletions src/ess/reduce/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

workflows = workflow.workflow_registry


_style = {
'description_width': 'auto',
'value_width': 'auto',
Expand Down Expand Up @@ -187,5 +186,5 @@ def run_workflow(b):
top_left=workflow_box,
bottom_left=widgets.VBox([run_button, output]),
# bottom_left=run_button,
bottom_right=parameter_box,
top_right=parameter_box,
)
144 changes: 54 additions & 90 deletions src/ess/reduce/widget.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import singledispatch
from functools import partial, singledispatch

import ipywidgets as widgets
import scipp as sc
from ess.reduce import parameter
from ess.reduce.autowidget import AutoWidget

_layout = widgets.Layout(width='80%')
_style = {
Expand All @@ -12,99 +13,45 @@
}


class LinspaceWidget(widgets.GridBox, widgets.ValueWidget):
def __init__(self, dim, unit):
super().__init__()

self.fields = {
'dim': widgets.Label(value=dim, description='dim'),
'start': widgets.FloatText(description='start'),
'end': widgets.FloatText(description='end'),
'num': widgets.IntText(description='num'),
'unit': widgets.Label(description='unit', value=unit),
}
self.children = [
widgets.Label(value="Select range:"),
self.fields['dim'],
self.fields['unit'],
self.fields['start'],
self.fields['end'],
self.fields['num'],
]

@property
def value(self):
return sc.linspace(
self.fields['dim'].value,
self.fields['start'].value,
self.fields['end'].value,
self.fields['num'].value,
unit=self.fields['unit'].value,
)


class BoundsWidget(widgets.GridBox, widgets.ValueWidget):
def __init__(self):
super().__init__()

self.fields = {
'start': widgets.FloatText(description='start'),
'end': widgets.FloatText(description='end'),
'unit': widgets.Text(description='unit'),
}
self.children = [
widgets.Label(value="Select bound:"),
self.fields['unit'],
self.fields['start'],
self.fields['end'],
]

@property
def value(self):
return (
sc.scalar(self.fields['start'].value, unit=self.fields['unit']),
sc.scalar(self.fields['end'].value, unit=self.fields['unit']),
)


class VectorWidget(widgets.GridBox, widgets.ValueWidget):
def __init__(self, variable):
super().__init__()

self.fields = {
"x": widgets.FloatText(description="x", value=variable.fields.x.value),
"y": widgets.FloatText(description="y", value=variable.fields.y.value),
"z": widgets.FloatText(description="z", value=variable.fields.z.value),
"unit": widgets.Text(description="unit", value=str(variable.unit)),
}
self.children = [
widgets.Label(value="(x, y, z) ="),
self.fields['x'],
self.fields['y'],
self.fields['z'],
self.fields['unit'],
]

@property
def value(self):
return sc.vector(
value=[
self.fields['x'].value,
self.fields['y'].value,
self.fields['z'].value,
],
unit=self.fields['unit'].value,
)


@singledispatch
def create_parameter_widget(param):
return widgets.Text('', layout=_layout, style=_style)


@create_parameter_widget.register(parameter.VectorParameter)
def _(param):
return VectorWidget(param.default)
entry = partial(
widgets.FloatText,
layout=widgets.Layout(width='5em'),
)
tri_tuple = AutoWidget(
lambda _, x: x,
(
widgets.Label(value="(x, y, z) = "),
AutoWidget(
lambda x, y, z: (x, y, z),
(
entry(value=param.default.fields.x.value),
entry(value=param.default.fields.y.value),
entry(value=param.default.fields.z.value),
),
),
),
)
return AutoWidget(
sc.vector,
(tri_tuple,),
dict( # noqa: C408
unit=widgets.Text(
description='unit of vector',
value=str('m'),
layout=widgets.Layout(width='10em'),
)
),
description=param.name,
layout=widgets.Layout(border='0.5px solid'),
child_layout=widgets.VBox,
)


@create_parameter_widget.register(parameter.BooleanParameter)
Expand Down Expand Up @@ -147,9 +94,26 @@ def _(param):

@create_parameter_widget.register(parameter.BinEdgesParameter)
def _(param):
dim = param.dim
unit = param.unit
return LinspaceWidget(dim, unit)
return AutoWidget(
lambda space, **kwargs: space(**kwargs),
(),
dict( # noqa: C408
space=widgets.Dropdown(
default=sc.linspace,
options=[sc.linspace, sc.geomspace],
description='bin space',
),
dim=widgets.Text(value=param.dim, description='dim'),
start=widgets.FloatText(description='left edge'),
stop=widgets.FloatText(description='right edge'),
num=widgets.IntText(value=1, description='num. edges'),
unit=widgets.Text(value=str(param.unit), description='unit'),
),
description=param.name,
child_layout=widgets.VBox,
layout=widgets.Layout(border='0.5px solid'),
style=_style,
)


@create_parameter_widget.register(parameter.FilenameParameter)
Expand Down