Skip to content

Commit

Permalink
Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
hainm committed Jun 20, 2024
1 parent 28c9dbf commit 5fec04e
Showing 1 changed file with 78 additions and 94 deletions.
172 changes: 78 additions & 94 deletions nglview/widget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
from enum import Enum
import json
import re
import threading
Expand Down Expand Up @@ -63,6 +64,18 @@ def wrap_2(*args, **kwargs):
return wrap_1


class MessageType(Enum):
REQUEST_FRAME = 'request_frame'
UPDATE_IDS = 'updateIDs'
REMOVE_COMPONENT = 'removeComponent'
REPR_PARAMETERS = 'repr_parameters'
REQUEST_LOADED = 'request_loaded'
REQUEST_REPR_DICT = 'request_repr_dict'
STAGE_PARAMETERS = 'stage_parameters'
ASYNC_MESSAGE = 'async_message'
IMAGE_DATA = 'image_data'


def write_html(fp, views, frame_range=None):
"""EXPERIMENTAL. Likely will be changed.
Expand Down Expand Up @@ -304,6 +317,7 @@ def _encode_trajectory(self, traj, frame_range):
else:
encoded_traj.append(encode_base64(np.empty((0), dtype='f4')))
return encoded_traj

def _create_player(self):
player = Play(max=self.max_frame, interval=100)
slider = IntSlider(max=self.max_frame)
Expand Down Expand Up @@ -587,35 +601,27 @@ def _set_size(self, w, h):
'''
self._remote_call('setSize', target='Widget', args=[w, h])

def _set_sync_repr(self, other_views):
def _set_model_ids(self, other_views, attribute, remote_call):
model_ids = {v._model_id for v in other_views}
self._synced_repr_model_ids = sorted(
set(self._synced_repr_model_ids) | model_ids)
self._remote_call("setSyncRepr",
target="Widget",
args=[self._synced_repr_model_ids])
setattr(self, attribute, sorted(set(getattr(self, attribute)) | model_ids))
self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)])

def _set_unsync_repr(self, other_views):
def _unset_model_ids(self, other_views, attribute, remote_call):
model_ids = {v._model_id for v in other_views}
self._synced_repr_model_ids = list(set(self._synced_repr_model_ids) - model_ids)
self._remote_call("setSyncRepr",
target="Widget",
args=[self._synced_repr_model_ids])
setattr(self, attribute, list(set(getattr(self, attribute)) - model_ids))
self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)])

def _set_sync_repr(self, other_views):
self._set_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr")

def _set_unsync_repr(self, other_views):
self._unset_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr")

def _set_sync_camera(self, other_views):
model_ids = {v._model_id for v in other_views}
self._synced_model_ids = sorted(
set(self._synced_model_ids) | model_ids)
self._remote_call("setSyncCamera",
target="Widget",
args=[self._synced_model_ids])
self._set_model_ids(other_views, "_synced_model_ids", "setSyncCamera")

def _set_unsync_camera(self, other_views):
model_ids = {v._model_id for v in other_views}
self._synced_model_ids = list(set(self._synced_model_ids) - model_ids)
self._remote_call("setSyncCamera",
target="Widget",
args=[self._synced_model_ids])
self._unset_model_ids(other_views, "_synced_model_ids", "setSyncCamera")

def _set_spin(self, axis, angle):
self._remote_call('setSpin', target='Stage', args=[axis, angle])
Expand Down Expand Up @@ -780,7 +786,6 @@ def _get_trajectory_coordinates(self, trajectory, index, traj_index):

def set_coordinates(self, arr_dict, movie_making=False,
render_params=None):
# type: (Dict[int, np.ndarray]) -> None
"""Used for update coordinates of a given trajectory
>>> # arr: numpy array, ndim=2
>>> # update coordinates of 1st trajectory
Expand Down Expand Up @@ -1090,26 +1095,26 @@ def _handle_image_data(self):
def _handle_nglview_custom_msg(self, _, msg, buffers):
self._ngl_msg = msg

msg_type = self._ngl_msg.get('type')
if msg_type == 'request_frame':
msg_type = MessageType(self._ngl_msg.get('type'))

if msg_type == MessageType.REQUEST_FRAME:
self._handle_request_frame()
elif msg_type == 'updateIDs':
elif msg_type == MessageType.UPDATE_IDS:
self._handle_update_ids()
elif msg_type == 'removeComponent':
elif msg_type == MessageType.REMOVE_COMPONENT:
self._handle_remove_component()
elif msg_type == 'repr_parameters':
elif msg_type == MessageType.REPR_PARAMETERS:
self._handle_repr_parameters()
elif msg_type == 'request_loaded':
elif msg_type == MessageType.REQUEST_LOADED:
self._handle_request_loaded()
elif msg_type == 'request_repr_dict':
elif msg_type == MessageType.REQUEST_REPR_DICT:
self._handle_request_repr_dict()
elif msg_type == 'stage_parameters':
elif msg_type == MessageType.STAGE_PARAMETERS:
self._handle_stage_parameters()
elif msg_type == 'async_message':
elif msg_type == MessageType.ASYNC_MESSAGE:
self._handle_async_message()
elif msg_type == 'image_data':
elif msg_type == MessageType.IMAGE_DATA:
self._handle_image_data()

def _request_repr_parameters(self, component=0, repr_index=0):
if self.n_components > 0:
self._remote_call('requestReprParameters',
Expand Down Expand Up @@ -1362,64 +1367,57 @@ def _get_remote_call_msg(self,
args=['*', 200],
kwargs={'component_index': 1})
"""
# NOTE: _camelize_dict here?
args = [] if args is None else args
kwargs = {} if kwargs is None else kwargs
args = args or []
kwargs = kwargs or {}

msg = {}

if 'component_index' in kwargs:
msg['component_index'] = kwargs.pop('component_index')
if 'repr_index' in kwargs:
msg['repr_index'] = kwargs.pop('repr_index')
if 'default' in kwargs:
kwargs['defaultRepresentation'] = kwargs.pop('default')
msg = {
'component_index': kwargs.pop('component_index', None),
'repr_index': kwargs.pop('repr_index', None),
'defaultRepresentation': kwargs.pop('default', None),
'target': target,
'type': 'call_method',
'methodName': method_name,
'reconstruc_color_scheme': False,
'args': args,
'kwargs': kwargs
}

# Color handling
reconstruc_color_scheme = False
if 'color' in kwargs and isinstance(kwargs['color'],
color._ColorScheme):
kwargs['color_label'] = kwargs['color'].data['label']
# overite `color`
kwargs['color'] = kwargs['color'].data['data']
reconstruc_color_scheme = True
color = kwargs.get('color')
if isinstance(color, color._ColorScheme):
kwargs['color_label'] = color.data['label']
kwargs['color'] = color.data['data']
msg['reconstruc_color_scheme'] = True

if kwargs.get('colorScheme') == 'volume' and kwargs.get('colorVolume'):
assert isinstance(kwargs['colorVolume'], ComponentViewer)
kwargs['colorVolume'] = kwargs['colorVolume']._index

msg['target'] = target
msg['type'] = 'call_method'
msg['methodName'] = method_name
msg['reconstruc_color_scheme'] = reconstruc_color_scheme
msg['args'] = args
msg['kwargs'] = kwargs
if other_kwargs:
msg.update(other_kwargs)

return msg

def _trim_message(self, messages):
messages = messages[:]
"""
This function trims the messages based on certain conditions.
"""

remove_comps = [(index, msg['args'][0])
for index, msg in enumerate(messages)
if msg['methodName'] == 'removeComponent']
# Create a list of tuples containing the index and the first argument of the message
# for messages where the method name is 'removeComponent'
remove_comps = [(i, msg['args'][0]) for i, msg in enumerate(messages) if msg['methodName'] == 'removeComponent']

if not remove_comps:
return messages

load_comps = [
index for index, msg in enumerate(messages)
if msg['methodName'] in ('loadFile', 'addShape')
]
# Create a list of indices for messages where the method name is either 'loadFile' or 'addShape'
load_comps = [i for i, msg in enumerate(messages) if msg['methodName'] in ('loadFile', 'addShape')]

messages_rm = [r[0] for r in remove_comps]
messages_rm += [load_comps[r[1]] for r in remove_comps]
messages_rm = set(messages_rm)
# Create a set of indices to remove from the messages
messages_rm = set(i for r in remove_comps for i in (r[0], load_comps[r[1]]))

return [
msg for i, msg in enumerate(messages)
if i not in messages_rm
]
# Return a new list of messages that excludes the messages with the indices in messages_rm
return [msg for i, msg in enumerate(messages) if i not in messages_rm]

def _remote_call(self,
method_name,
Expand Down Expand Up @@ -1492,32 +1490,18 @@ def show_only(self, indices='all', **kwargs):
"""
traj_ids = {traj.id for traj in self._trajlist}

if indices == 'all':
indices_ = set(range(self.n_components))
else:
indices_ = set(indices)
indices_ = set(range(self.n_components)) if indices == 'all' else set(indices)

for index, comp_id in enumerate(self._ngl_component_ids):
if comp_id in traj_ids:
traj = self._get_traj_by_id(comp_id)
else:
traj = None
if index in indices_:
args = [
True,
]
if traj is not None:
traj.shown = True
else:
args = [
False,
]
if traj is not None:
traj.shown = False
traj = self._get_traj_by_id(comp_id) if comp_id in traj_ids else None
is_visible = index in indices_

if traj is not None:
traj.shown = is_visible

self._remote_call("setVisibility",
target='compList',
args=args,
args=[is_visible],
kwargs={'component_index': index},
**kwargs)

Expand Down

0 comments on commit 5fec04e

Please sign in to comment.