Skip to content

Commit

Permalink
Add support for ROIs generated from View's
Browse files Browse the repository at this point in the history
This is a very simple implementation of something that can hopefully be improved
and cleaned up in the future: support for context-defined annotation widgets.

Each ROI is defined as a parameter, they are added to a View with an argument to
the Image_WithRoi type, which is read by the Correlator and thus creates an
ROI. When the ROI is changed, an updated parameter is sent to metropc and the
Correlator modifies the context file with the new ROI parameters (position,
height, etc).
  • Loading branch information
JamesWrigley committed Mar 4, 2022
1 parent 72d0bc4 commit 2759386
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 13 deletions.
1 change: 1 addition & 0 deletions .github/dependabot/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.2
karabo-bridge==0.6.1
kiwisolver==1.3.2
libcst==0.4.1
locket==0.2.1
MarkupSafe==2.1.0
matplotlib==3.5.1
Expand Down
32 changes: 27 additions & 5 deletions extra_foam/special_suite/correlator_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def extract(self, data):
return operator.attrgetter(self.strip_type(self._full_path))(data)


class ViewEntry(IndexViewEntry):
__slots__ = ["annotations"]

def __init__(self, annotations, entry):
super().__init__(entry.counts, entry.rate, entry.output, entry.stage)

self.annotations = annotations


# This is a helper type to hold useful data about a path, to be displayed in a
# client.
PathData = namedtuple("PathData",
Expand Down Expand Up @@ -149,6 +158,7 @@ class CorrelatorProcessor(QThreadWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._ctx = None
self.client_type = None

self._index_event = Event()
Expand Down Expand Up @@ -214,15 +224,24 @@ def handleIndex(self, index):
"""
views = set(p for p, v in index.items() if isinstance(v, IndexViewEntry))
subscribed_views = set(self._subscriptions.keys())
new_views = views - subscribed_views
old_views = subscribed_views - views

# Subscribe to new views
for s in views - subscribed_views:
self._subscriber.subscribe(s.encode())
self._subscriptions[s] = index[s]
# Subscribe to new views and update existing ones
for s in views:
if s in new_views:
# If it's a new view, subscribe to it
self._subscriber.subscribe(s.encode())

# Create a custom index entry that stores annotations
view = self._ctx.views[s.split("#")[1]]
view_entry = ViewEntry(getattr(view, "annotations", []), index[s])

self._subscriptions[s] = view_entry
self.log.debug(f"Subscribed to {s}")

# Unsubscribe from old ones
for s in subscribed_views - views:
for s in old_views:
self._subscriber.unsubscribe(s.encode())
del self._subscriptions[s]
self.log.debug(f"Unsubscribed to {s}")
Expand Down Expand Up @@ -258,6 +277,9 @@ def waitUntil(self, event_type: MetroEvent):

event.wait()

def set_parameter(self, name, value):
self._pipeline.queue_to_all(b"params", {name: value})

# Helper function to inspect an object and create a PathData object
# for it.
def inspect_data(self, data):
Expand Down
130 changes: 123 additions & 7 deletions extra_foam/special_suite/correlator_w.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import textwrap
import itertools
import dataclasses
from enum import Enum
from collections import defaultdict

import libcst as cst
import libcst.matchers as m
import libcst.metadata as cstmeta
import numpy as np
import xarray as xr

Expand All @@ -17,6 +21,7 @@

from metropc.client import ViewOutput

from ..utils import RectROI as MetroRectROI
from ..algorithms import SimpleSequence
from ..pipeline.data_model import ProcessedData
from ..gui.misc_widgets import FColor
Expand Down Expand Up @@ -168,6 +173,7 @@ def __init__(self, main_window):
self._xs = SimpleSequence(max_len=self._max_points)
self._ys = defaultdict(lambda: SimpleSequence(max_len=self._max_points))
self._errors = defaultdict(lambda: SimpleSequence(max_len=self._max_points))
self._image_annotations = { }
self._current_view = None

self._main_window.registerPlotWidget(self)
Expand Down Expand Up @@ -200,7 +206,7 @@ def initUI(self):
view_selection_widget.setLayout(layout)

# Create plot widgets
self._image_view = ImageViewF()
self._image_view = ImageViewF(has_roi=True)
# Note: this QHBoxLayout is currently unused but should be kept, it's a
# placeholder for future widgets.
image_widget_layout = QHBoxLayout()
Expand Down Expand Up @@ -456,7 +462,34 @@ def handle_rich_output():
logger.error(f"Image data has wrong number of dimensions: {data.ndim} (expected 2)")
return

self._image_view.setImage(data, auto_levels=True)
# Check if there are annotations on this view
view = self._views[self._current_view]
if view.annotations is not None:
for idx, roi_name in enumerate(view.annotations):
if roi_name not in self._image_annotations:
roi = self._image_view.rois[idx + 1]
self._image_annotations[roi_name] = roi

# Update parameters from context
metro_roi = self._main_window.context.parameters[roi_name]
if all(x is not None for x in dataclasses.astuple(metro_roi)):
roi.setPos(metro_roi.x, metro_roi.y)
roi.setSize(metro_roi.width, metro_roi.height)

roi.setLocked(False)
roi.setLabel(roi_name)
roi.sigRegionChangeFinished.connect(self._main_window.onRoiChanged)
roi.show()

# Hide and reset unused annotations
for roi_name in list(self._image_annotations.keys()):
if roi_name not in view.annotations:
roi = self._image_annotations[roi_name]
roi.sigRegionChangeFinished.disconnect()
roi.hide()
del self._image_annotations[roi_name]

self._image_view.setImage(data)
else:
self._legend.clear()
for label, ys_data in self._ys.items():
Expand All @@ -481,12 +514,15 @@ def handle_rich_output():
self._plot_widget.setTitle(data.attrs["title"])

def setViews(self, views):
if views == self._views:
return

# Always update the view indexes
old_views = self._views
self._views = views
self.updateAvailableViews()
self.views_updated_sgn.emit()

# But we only update the options displayed if any views have been
# added/removed.
if self._views.keys() != old_views.keys():
self.updateAvailableViews()
self.views_updated_sgn.emit()

def updateAvailableViews(self):
self.view_picker.blockSignals(True)
Expand Down Expand Up @@ -697,6 +733,68 @@ def context_path(self, path):
self._path_label.setText(path)


class RoiTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (cstmeta.ParentNodeProvider, )

def __init__(self, roi):
self._roi_name = roi.label()
self._roi_args = [int(x) for x in [*roi.pos(), *roi.size()]]

def roi_args(self):
return self._roi_args

def gen_int_node(self, value):
int_node = cst.Integer(str(np.abs(value)))

if value < 0:
return cst.UnaryOperation(cst.Minus(), int_node)
else:
return int_node

def visit_Call(self, call):
"""
On the way down the CST, filter for the call to 'parameters()'.
"""
return m.matches(call,
m.Call(
func=m.Name("parameters"),
args=[m.AtLeastN(n=1,
matcher=m.Arg(
value=m.Call(func=m.Name("RectROI"))
))
]
))

def leave_Call(self, original_node, updated_node):
"""
On the way back up, modify the 'RectROI' calls
"""
if m.matches(updated_node.func, m.Name("RectROI")):
parent = self.get_metadata(cstmeta.ParentNodeProvider, updated_node)

# If this object is the value for a keyword argument of the same
# name, it's the one we're looking for and we can update it.
if m.matches(parent, m.Arg(keyword=m.Name(self._roi_name))):
# Create the new int nodes
new_ints = [self.gen_int_node(x) for x in self._roi_args]

# If the constructor already has all arguments, then we just
# update each arguments value individually. This preserves any
# existing formatting/comments.
if len(updated_node.args) == 4:
new_args = []
for arg, x in zip(updated_node.args, new_ints):
new_args.append(arg.with_changes(value=x))

return updated_node.with_changes(args=new_args)
else:
# But if it doesn't have all the arguments, replace the
# argument list entirely.
return updated_node.with_changes(args=[cst.Arg(x) for x in new_ints])

return updated_node


@create_special(CorrelatorCtrlWidget, CorrelatorProcessor)
class CorrelatorWindow(_SpecialAnalysisBase):
icon = "cam_view.png"
Expand Down Expand Up @@ -820,6 +918,10 @@ def initConnections(self):
worker.updated_data_paths_sgn.connect(self._completer.setDataPaths)
worker.updated_data_paths_sgn.connect(self._ctrl_widget_st.setDataPaths)

@property
def context(self):
return self._worker_st._ctx

@pyqtSlot()
def _onContextModified(self):
self._markContextSaved(False)
Expand All @@ -832,6 +934,20 @@ def _onViewsUpdated(self, views):
self._completer.setViewPaths(views)
self._ctrl_widget_st.setViewPaths(views)

def onRoiChanged(self, roi):
# Get current source
ctx = self._editor.text()

# Modify the source as needed
module = cstmeta.MetadataWrapper(cst.parse_module(ctx))
transformer = RoiTransformer(roi)
new_source = module.visit(transformer)

# Set the new source and update the parameter
self._editor.setText(new_source.code)
self._worker_st.set_parameter(roi.label(),
MetroRectROI(*transformer.roi_args()))

@pyqtSlot()
def _addTab(self, splitter=None):
index = self._tab_widget.count() - 1
Expand Down
2 changes: 1 addition & 1 deletion extra_foam/special_suite/tests/test_correlator.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def testImagePlotting(self, win, initial_context, caplog):
output = np.random.rand(10, 10)
with patch.object(image_view, "setImage") as set_image:
widget.updateF({ view_name: [output] })
set_image.assert_called_with(output, auto_levels=ANY)
set_image.assert_called_with(output)


class TestCorrelatorProcessor(_TestDataMixin, _SpecialSuiteProcessorTestBase):
Expand Down
32 changes: 32 additions & 0 deletions extra_foam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
import multiprocessing as mp
import functools
import subprocess
from dataclasses import dataclass
from collections import namedtuple
from threading import RLock, Thread
import time

import numpy as np
import xarray as xr

from metropc.core import View
from metropc.client import ViewOutput
from metropc.viewdef import ViewDecorator

from .logger import logger


Expand Down Expand Up @@ -374,3 +380,29 @@ def rich_output(x, xlabel="x", ylabel="y", title=None, max_points=None, **kwargs
full_data.append(data)

return xr.DataArray(full_data, attrs=xr_attrs)

@dataclass
class RectROI():
x: int = None
y: int = None
width: int = None
height: int = None

def of(self, data):
if not isinstance(data, np.ndarray):
raise ValueError(f"ROI input is a {type(data)} instead of an np.ndarray")
elif data.ndim != 2:
raise ValueError(f"ROI input must be 2D, but is actually: {data.ndim}D")

return data[self.y:self.y + self.height, self.x:self.x + self.width]

class AnnotatedImageView(View, abstract=True):
def __init__(self, *args, annotations, **kwargs):
super().__init__(*args, **kwargs)
self.annotations = annotations

class ImageWithRoisView(AnnotatedImageView, output=ViewOutput.IMAGE):
def __init__(self, *args, rois=None, **kwargs):
super().__init__(*args, annotations=rois, **kwargs)

ViewDecorator.kwargs_symbols.update(WithRois=dict(view_impl=ImageWithRoisView))
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def has_ext_modules(self):
'pyyaml',
"metropc @ git+ssh://[email protected]:10022/karaboDevices/metropc.git@high_high_water_mark",
"qscintilla",
"libcst",
# These dependencies are not directly used, but are needed to satisfy
# pip's resolver:
'pygments',
Expand Down

0 comments on commit 2759386

Please sign in to comment.