Skip to content

Commit 5fec04e

Browse files
committed
Enum
1 parent 28c9dbf commit 5fec04e

File tree

1 file changed

+78
-94
lines changed

1 file changed

+78
-94
lines changed

nglview/widget.py

Lines changed: 78 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
from enum import Enum
23
import json
34
import re
45
import threading
@@ -63,6 +64,18 @@ def wrap_2(*args, **kwargs):
6364
return wrap_1
6465

6566

67+
class MessageType(Enum):
68+
REQUEST_FRAME = 'request_frame'
69+
UPDATE_IDS = 'updateIDs'
70+
REMOVE_COMPONENT = 'removeComponent'
71+
REPR_PARAMETERS = 'repr_parameters'
72+
REQUEST_LOADED = 'request_loaded'
73+
REQUEST_REPR_DICT = 'request_repr_dict'
74+
STAGE_PARAMETERS = 'stage_parameters'
75+
ASYNC_MESSAGE = 'async_message'
76+
IMAGE_DATA = 'image_data'
77+
78+
6679
def write_html(fp, views, frame_range=None):
6780
"""EXPERIMENTAL. Likely will be changed.
6881
@@ -304,6 +317,7 @@ def _encode_trajectory(self, traj, frame_range):
304317
else:
305318
encoded_traj.append(encode_base64(np.empty((0), dtype='f4')))
306319
return encoded_traj
320+
307321
def _create_player(self):
308322
player = Play(max=self.max_frame, interval=100)
309323
slider = IntSlider(max=self.max_frame)
@@ -587,35 +601,27 @@ def _set_size(self, w, h):
587601
'''
588602
self._remote_call('setSize', target='Widget', args=[w, h])
589603

590-
def _set_sync_repr(self, other_views):
604+
def _set_model_ids(self, other_views, attribute, remote_call):
591605
model_ids = {v._model_id for v in other_views}
592-
self._synced_repr_model_ids = sorted(
593-
set(self._synced_repr_model_ids) | model_ids)
594-
self._remote_call("setSyncRepr",
595-
target="Widget",
596-
args=[self._synced_repr_model_ids])
606+
setattr(self, attribute, sorted(set(getattr(self, attribute)) | model_ids))
607+
self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)])
597608

598-
def _set_unsync_repr(self, other_views):
609+
def _unset_model_ids(self, other_views, attribute, remote_call):
599610
model_ids = {v._model_id for v in other_views}
600-
self._synced_repr_model_ids = list(set(self._synced_repr_model_ids) - model_ids)
601-
self._remote_call("setSyncRepr",
602-
target="Widget",
603-
args=[self._synced_repr_model_ids])
611+
setattr(self, attribute, list(set(getattr(self, attribute)) - model_ids))
612+
self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)])
613+
614+
def _set_sync_repr(self, other_views):
615+
self._set_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr")
616+
617+
def _set_unsync_repr(self, other_views):
618+
self._unset_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr")
604619

605620
def _set_sync_camera(self, other_views):
606-
model_ids = {v._model_id for v in other_views}
607-
self._synced_model_ids = sorted(
608-
set(self._synced_model_ids) | model_ids)
609-
self._remote_call("setSyncCamera",
610-
target="Widget",
611-
args=[self._synced_model_ids])
621+
self._set_model_ids(other_views, "_synced_model_ids", "setSyncCamera")
612622

613623
def _set_unsync_camera(self, other_views):
614-
model_ids = {v._model_id for v in other_views}
615-
self._synced_model_ids = list(set(self._synced_model_ids) - model_ids)
616-
self._remote_call("setSyncCamera",
617-
target="Widget",
618-
args=[self._synced_model_ids])
624+
self._unset_model_ids(other_views, "_synced_model_ids", "setSyncCamera")
619625

620626
def _set_spin(self, axis, angle):
621627
self._remote_call('setSpin', target='Stage', args=[axis, angle])
@@ -780,7 +786,6 @@ def _get_trajectory_coordinates(self, trajectory, index, traj_index):
780786

781787
def set_coordinates(self, arr_dict, movie_making=False,
782788
render_params=None):
783-
# type: (Dict[int, np.ndarray]) -> None
784789
"""Used for update coordinates of a given trajectory
785790
>>> # arr: numpy array, ndim=2
786791
>>> # update coordinates of 1st trajectory
@@ -1090,26 +1095,26 @@ def _handle_image_data(self):
10901095
def _handle_nglview_custom_msg(self, _, msg, buffers):
10911096
self._ngl_msg = msg
10921097

1093-
msg_type = self._ngl_msg.get('type')
1094-
if msg_type == 'request_frame':
1098+
msg_type = MessageType(self._ngl_msg.get('type'))
1099+
1100+
if msg_type == MessageType.REQUEST_FRAME:
10951101
self._handle_request_frame()
1096-
elif msg_type == 'updateIDs':
1102+
elif msg_type == MessageType.UPDATE_IDS:
10971103
self._handle_update_ids()
1098-
elif msg_type == 'removeComponent':
1104+
elif msg_type == MessageType.REMOVE_COMPONENT:
10991105
self._handle_remove_component()
1100-
elif msg_type == 'repr_parameters':
1106+
elif msg_type == MessageType.REPR_PARAMETERS:
11011107
self._handle_repr_parameters()
1102-
elif msg_type == 'request_loaded':
1108+
elif msg_type == MessageType.REQUEST_LOADED:
11031109
self._handle_request_loaded()
1104-
elif msg_type == 'request_repr_dict':
1110+
elif msg_type == MessageType.REQUEST_REPR_DICT:
11051111
self._handle_request_repr_dict()
1106-
elif msg_type == 'stage_parameters':
1112+
elif msg_type == MessageType.STAGE_PARAMETERS:
11071113
self._handle_stage_parameters()
1108-
elif msg_type == 'async_message':
1114+
elif msg_type == MessageType.ASYNC_MESSAGE:
11091115
self._handle_async_message()
1110-
elif msg_type == 'image_data':
1116+
elif msg_type == MessageType.IMAGE_DATA:
11111117
self._handle_image_data()
1112-
11131118
def _request_repr_parameters(self, component=0, repr_index=0):
11141119
if self.n_components > 0:
11151120
self._remote_call('requestReprParameters',
@@ -1362,64 +1367,57 @@ def _get_remote_call_msg(self,
13621367
args=['*', 200],
13631368
kwargs={'component_index': 1})
13641369
"""
1365-
# NOTE: _camelize_dict here?
1366-
args = [] if args is None else args
1367-
kwargs = {} if kwargs is None else kwargs
1370+
args = args or []
1371+
kwargs = kwargs or {}
13681372

1369-
msg = {}
1370-
1371-
if 'component_index' in kwargs:
1372-
msg['component_index'] = kwargs.pop('component_index')
1373-
if 'repr_index' in kwargs:
1374-
msg['repr_index'] = kwargs.pop('repr_index')
1375-
if 'default' in kwargs:
1376-
kwargs['defaultRepresentation'] = kwargs.pop('default')
1373+
msg = {
1374+
'component_index': kwargs.pop('component_index', None),
1375+
'repr_index': kwargs.pop('repr_index', None),
1376+
'defaultRepresentation': kwargs.pop('default', None),
1377+
'target': target,
1378+
'type': 'call_method',
1379+
'methodName': method_name,
1380+
'reconstruc_color_scheme': False,
1381+
'args': args,
1382+
'kwargs': kwargs
1383+
}
13771384

13781385
# Color handling
1379-
reconstruc_color_scheme = False
1380-
if 'color' in kwargs and isinstance(kwargs['color'],
1381-
color._ColorScheme):
1382-
kwargs['color_label'] = kwargs['color'].data['label']
1383-
# overite `color`
1384-
kwargs['color'] = kwargs['color'].data['data']
1385-
reconstruc_color_scheme = True
1386+
color = kwargs.get('color')
1387+
if isinstance(color, color._ColorScheme):
1388+
kwargs['color_label'] = color.data['label']
1389+
kwargs['color'] = color.data['data']
1390+
msg['reconstruc_color_scheme'] = True
1391+
13861392
if kwargs.get('colorScheme') == 'volume' and kwargs.get('colorVolume'):
13871393
assert isinstance(kwargs['colorVolume'], ComponentViewer)
13881394
kwargs['colorVolume'] = kwargs['colorVolume']._index
13891395

1390-
msg['target'] = target
1391-
msg['type'] = 'call_method'
1392-
msg['methodName'] = method_name
1393-
msg['reconstruc_color_scheme'] = reconstruc_color_scheme
1394-
msg['args'] = args
1395-
msg['kwargs'] = kwargs
13961396
if other_kwargs:
13971397
msg.update(other_kwargs)
1398+
13981399
return msg
13991400

14001401
def _trim_message(self, messages):
1401-
messages = messages[:]
1402+
"""
1403+
This function trims the messages based on certain conditions.
1404+
"""
14021405

1403-
remove_comps = [(index, msg['args'][0])
1404-
for index, msg in enumerate(messages)
1405-
if msg['methodName'] == 'removeComponent']
1406+
# Create a list of tuples containing the index and the first argument of the message
1407+
# for messages where the method name is 'removeComponent'
1408+
remove_comps = [(i, msg['args'][0]) for i, msg in enumerate(messages) if msg['methodName'] == 'removeComponent']
14061409

14071410
if not remove_comps:
14081411
return messages
14091412

1410-
load_comps = [
1411-
index for index, msg in enumerate(messages)
1412-
if msg['methodName'] in ('loadFile', 'addShape')
1413-
]
1413+
# Create a list of indices for messages where the method name is either 'loadFile' or 'addShape'
1414+
load_comps = [i for i, msg in enumerate(messages) if msg['methodName'] in ('loadFile', 'addShape')]
14141415

1415-
messages_rm = [r[0] for r in remove_comps]
1416-
messages_rm += [load_comps[r[1]] for r in remove_comps]
1417-
messages_rm = set(messages_rm)
1416+
# Create a set of indices to remove from the messages
1417+
messages_rm = set(i for r in remove_comps for i in (r[0], load_comps[r[1]]))
14181418

1419-
return [
1420-
msg for i, msg in enumerate(messages)
1421-
if i not in messages_rm
1422-
]
1419+
# Return a new list of messages that excludes the messages with the indices in messages_rm
1420+
return [msg for i, msg in enumerate(messages) if i not in messages_rm]
14231421

14241422
def _remote_call(self,
14251423
method_name,
@@ -1492,32 +1490,18 @@ def show_only(self, indices='all', **kwargs):
14921490
"""
14931491
traj_ids = {traj.id for traj in self._trajlist}
14941492

1495-
if indices == 'all':
1496-
indices_ = set(range(self.n_components))
1497-
else:
1498-
indices_ = set(indices)
1493+
indices_ = set(range(self.n_components)) if indices == 'all' else set(indices)
14991494

15001495
for index, comp_id in enumerate(self._ngl_component_ids):
1501-
if comp_id in traj_ids:
1502-
traj = self._get_traj_by_id(comp_id)
1503-
else:
1504-
traj = None
1505-
if index in indices_:
1506-
args = [
1507-
True,
1508-
]
1509-
if traj is not None:
1510-
traj.shown = True
1511-
else:
1512-
args = [
1513-
False,
1514-
]
1515-
if traj is not None:
1516-
traj.shown = False
1496+
traj = self._get_traj_by_id(comp_id) if comp_id in traj_ids else None
1497+
is_visible = index in indices_
1498+
1499+
if traj is not None:
1500+
traj.shown = is_visible
15171501

15181502
self._remote_call("setVisibility",
15191503
target='compList',
1520-
args=args,
1504+
args=[is_visible],
15211505
kwargs={'component_index': index},
15221506
**kwargs)
15231507

0 commit comments

Comments
 (0)