|
1 | 1 | import base64
|
| 2 | +from enum import Enum |
2 | 3 | import json
|
3 | 4 | import re
|
4 | 5 | import threading
|
@@ -63,6 +64,18 @@ def wrap_2(*args, **kwargs):
|
63 | 64 | return wrap_1
|
64 | 65 |
|
65 | 66 |
|
| 67 | +class MessageType(Enum): |
| 68 | + REQUEST_FRAME = 'request_frame' |
| 69 | + UPDATE_IDS = 'updateIDs' |
| 70 | + REMOVE_COMPONENT = 'removeComponent' |
| 71 | + REPR_PARAMETERS = 'repr_parameters' |
| 72 | + REQUEST_LOADED = 'request_loaded' |
| 73 | + REQUEST_REPR_DICT = 'request_repr_dict' |
| 74 | + STAGE_PARAMETERS = 'stage_parameters' |
| 75 | + ASYNC_MESSAGE = 'async_message' |
| 76 | + IMAGE_DATA = 'image_data' |
| 77 | + |
| 78 | + |
66 | 79 | def write_html(fp, views, frame_range=None):
|
67 | 80 | """EXPERIMENTAL. Likely will be changed.
|
68 | 81 |
|
@@ -304,6 +317,7 @@ def _encode_trajectory(self, traj, frame_range):
|
304 | 317 | else:
|
305 | 318 | encoded_traj.append(encode_base64(np.empty((0), dtype='f4')))
|
306 | 319 | return encoded_traj
|
| 320 | + |
307 | 321 | def _create_player(self):
|
308 | 322 | player = Play(max=self.max_frame, interval=100)
|
309 | 323 | slider = IntSlider(max=self.max_frame)
|
@@ -587,35 +601,27 @@ def _set_size(self, w, h):
|
587 | 601 | '''
|
588 | 602 | self._remote_call('setSize', target='Widget', args=[w, h])
|
589 | 603 |
|
590 |
| - def _set_sync_repr(self, other_views): |
| 604 | + def _set_model_ids(self, other_views, attribute, remote_call): |
591 | 605 | model_ids = {v._model_id for v in other_views}
|
592 |
| - self._synced_repr_model_ids = sorted( |
593 |
| - set(self._synced_repr_model_ids) | model_ids) |
594 |
| - self._remote_call("setSyncRepr", |
595 |
| - target="Widget", |
596 |
| - args=[self._synced_repr_model_ids]) |
| 606 | + setattr(self, attribute, sorted(set(getattr(self, attribute)) | model_ids)) |
| 607 | + self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)]) |
597 | 608 |
|
598 |
| - def _set_unsync_repr(self, other_views): |
| 609 | + def _unset_model_ids(self, other_views, attribute, remote_call): |
599 | 610 | model_ids = {v._model_id for v in other_views}
|
600 |
| - self._synced_repr_model_ids = list(set(self._synced_repr_model_ids) - model_ids) |
601 |
| - self._remote_call("setSyncRepr", |
602 |
| - target="Widget", |
603 |
| - args=[self._synced_repr_model_ids]) |
| 611 | + setattr(self, attribute, list(set(getattr(self, attribute)) - model_ids)) |
| 612 | + self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)]) |
| 613 | + |
| 614 | + def _set_sync_repr(self, other_views): |
| 615 | + self._set_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr") |
| 616 | + |
| 617 | + def _set_unsync_repr(self, other_views): |
| 618 | + self._unset_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr") |
604 | 619 |
|
605 | 620 | def _set_sync_camera(self, other_views):
|
606 |
| - model_ids = {v._model_id for v in other_views} |
607 |
| - self._synced_model_ids = sorted( |
608 |
| - set(self._synced_model_ids) | model_ids) |
609 |
| - self._remote_call("setSyncCamera", |
610 |
| - target="Widget", |
611 |
| - args=[self._synced_model_ids]) |
| 621 | + self._set_model_ids(other_views, "_synced_model_ids", "setSyncCamera") |
612 | 622 |
|
613 | 623 | def _set_unsync_camera(self, other_views):
|
614 |
| - model_ids = {v._model_id for v in other_views} |
615 |
| - self._synced_model_ids = list(set(self._synced_model_ids) - model_ids) |
616 |
| - self._remote_call("setSyncCamera", |
617 |
| - target="Widget", |
618 |
| - args=[self._synced_model_ids]) |
| 624 | + self._unset_model_ids(other_views, "_synced_model_ids", "setSyncCamera") |
619 | 625 |
|
620 | 626 | def _set_spin(self, axis, angle):
|
621 | 627 | self._remote_call('setSpin', target='Stage', args=[axis, angle])
|
@@ -780,7 +786,6 @@ def _get_trajectory_coordinates(self, trajectory, index, traj_index):
|
780 | 786 |
|
781 | 787 | def set_coordinates(self, arr_dict, movie_making=False,
|
782 | 788 | render_params=None):
|
783 |
| - # type: (Dict[int, np.ndarray]) -> None |
784 | 789 | """Used for update coordinates of a given trajectory
|
785 | 790 | >>> # arr: numpy array, ndim=2
|
786 | 791 | >>> # update coordinates of 1st trajectory
|
@@ -1090,26 +1095,26 @@ def _handle_image_data(self):
|
1090 | 1095 | def _handle_nglview_custom_msg(self, _, msg, buffers):
|
1091 | 1096 | self._ngl_msg = msg
|
1092 | 1097 |
|
1093 |
| - msg_type = self._ngl_msg.get('type') |
1094 |
| - if msg_type == 'request_frame': |
| 1098 | + msg_type = MessageType(self._ngl_msg.get('type')) |
| 1099 | + |
| 1100 | + if msg_type == MessageType.REQUEST_FRAME: |
1095 | 1101 | self._handle_request_frame()
|
1096 |
| - elif msg_type == 'updateIDs': |
| 1102 | + elif msg_type == MessageType.UPDATE_IDS: |
1097 | 1103 | self._handle_update_ids()
|
1098 |
| - elif msg_type == 'removeComponent': |
| 1104 | + elif msg_type == MessageType.REMOVE_COMPONENT: |
1099 | 1105 | self._handle_remove_component()
|
1100 |
| - elif msg_type == 'repr_parameters': |
| 1106 | + elif msg_type == MessageType.REPR_PARAMETERS: |
1101 | 1107 | self._handle_repr_parameters()
|
1102 |
| - elif msg_type == 'request_loaded': |
| 1108 | + elif msg_type == MessageType.REQUEST_LOADED: |
1103 | 1109 | self._handle_request_loaded()
|
1104 |
| - elif msg_type == 'request_repr_dict': |
| 1110 | + elif msg_type == MessageType.REQUEST_REPR_DICT: |
1105 | 1111 | self._handle_request_repr_dict()
|
1106 |
| - elif msg_type == 'stage_parameters': |
| 1112 | + elif msg_type == MessageType.STAGE_PARAMETERS: |
1107 | 1113 | self._handle_stage_parameters()
|
1108 |
| - elif msg_type == 'async_message': |
| 1114 | + elif msg_type == MessageType.ASYNC_MESSAGE: |
1109 | 1115 | self._handle_async_message()
|
1110 |
| - elif msg_type == 'image_data': |
| 1116 | + elif msg_type == MessageType.IMAGE_DATA: |
1111 | 1117 | self._handle_image_data()
|
1112 |
| - |
1113 | 1118 | def _request_repr_parameters(self, component=0, repr_index=0):
|
1114 | 1119 | if self.n_components > 0:
|
1115 | 1120 | self._remote_call('requestReprParameters',
|
@@ -1362,64 +1367,57 @@ def _get_remote_call_msg(self,
|
1362 | 1367 | args=['*', 200],
|
1363 | 1368 | kwargs={'component_index': 1})
|
1364 | 1369 | """
|
1365 |
| - # NOTE: _camelize_dict here? |
1366 |
| - args = [] if args is None else args |
1367 |
| - kwargs = {} if kwargs is None else kwargs |
| 1370 | + args = args or [] |
| 1371 | + kwargs = kwargs or {} |
1368 | 1372 |
|
1369 |
| - msg = {} |
1370 |
| - |
1371 |
| - if 'component_index' in kwargs: |
1372 |
| - msg['component_index'] = kwargs.pop('component_index') |
1373 |
| - if 'repr_index' in kwargs: |
1374 |
| - msg['repr_index'] = kwargs.pop('repr_index') |
1375 |
| - if 'default' in kwargs: |
1376 |
| - kwargs['defaultRepresentation'] = kwargs.pop('default') |
| 1373 | + msg = { |
| 1374 | + 'component_index': kwargs.pop('component_index', None), |
| 1375 | + 'repr_index': kwargs.pop('repr_index', None), |
| 1376 | + 'defaultRepresentation': kwargs.pop('default', None), |
| 1377 | + 'target': target, |
| 1378 | + 'type': 'call_method', |
| 1379 | + 'methodName': method_name, |
| 1380 | + 'reconstruc_color_scheme': False, |
| 1381 | + 'args': args, |
| 1382 | + 'kwargs': kwargs |
| 1383 | + } |
1377 | 1384 |
|
1378 | 1385 | # Color handling
|
1379 |
| - reconstruc_color_scheme = False |
1380 |
| - if 'color' in kwargs and isinstance(kwargs['color'], |
1381 |
| - color._ColorScheme): |
1382 |
| - kwargs['color_label'] = kwargs['color'].data['label'] |
1383 |
| - # overite `color` |
1384 |
| - kwargs['color'] = kwargs['color'].data['data'] |
1385 |
| - reconstruc_color_scheme = True |
| 1386 | + color = kwargs.get('color') |
| 1387 | + if isinstance(color, color._ColorScheme): |
| 1388 | + kwargs['color_label'] = color.data['label'] |
| 1389 | + kwargs['color'] = color.data['data'] |
| 1390 | + msg['reconstruc_color_scheme'] = True |
| 1391 | + |
1386 | 1392 | if kwargs.get('colorScheme') == 'volume' and kwargs.get('colorVolume'):
|
1387 | 1393 | assert isinstance(kwargs['colorVolume'], ComponentViewer)
|
1388 | 1394 | kwargs['colorVolume'] = kwargs['colorVolume']._index
|
1389 | 1395 |
|
1390 |
| - msg['target'] = target |
1391 |
| - msg['type'] = 'call_method' |
1392 |
| - msg['methodName'] = method_name |
1393 |
| - msg['reconstruc_color_scheme'] = reconstruc_color_scheme |
1394 |
| - msg['args'] = args |
1395 |
| - msg['kwargs'] = kwargs |
1396 | 1396 | if other_kwargs:
|
1397 | 1397 | msg.update(other_kwargs)
|
| 1398 | + |
1398 | 1399 | return msg
|
1399 | 1400 |
|
1400 | 1401 | def _trim_message(self, messages):
|
1401 |
| - messages = messages[:] |
| 1402 | + """ |
| 1403 | + This function trims the messages based on certain conditions. |
| 1404 | + """ |
1402 | 1405 |
|
1403 |
| - remove_comps = [(index, msg['args'][0]) |
1404 |
| - for index, msg in enumerate(messages) |
1405 |
| - if msg['methodName'] == 'removeComponent'] |
| 1406 | + # Create a list of tuples containing the index and the first argument of the message |
| 1407 | + # for messages where the method name is 'removeComponent' |
| 1408 | + remove_comps = [(i, msg['args'][0]) for i, msg in enumerate(messages) if msg['methodName'] == 'removeComponent'] |
1406 | 1409 |
|
1407 | 1410 | if not remove_comps:
|
1408 | 1411 | return messages
|
1409 | 1412 |
|
1410 |
| - load_comps = [ |
1411 |
| - index for index, msg in enumerate(messages) |
1412 |
| - if msg['methodName'] in ('loadFile', 'addShape') |
1413 |
| - ] |
| 1413 | + # Create a list of indices for messages where the method name is either 'loadFile' or 'addShape' |
| 1414 | + load_comps = [i for i, msg in enumerate(messages) if msg['methodName'] in ('loadFile', 'addShape')] |
1414 | 1415 |
|
1415 |
| - messages_rm = [r[0] for r in remove_comps] |
1416 |
| - messages_rm += [load_comps[r[1]] for r in remove_comps] |
1417 |
| - messages_rm = set(messages_rm) |
| 1416 | + # Create a set of indices to remove from the messages |
| 1417 | + messages_rm = set(i for r in remove_comps for i in (r[0], load_comps[r[1]])) |
1418 | 1418 |
|
1419 |
| - return [ |
1420 |
| - msg for i, msg in enumerate(messages) |
1421 |
| - if i not in messages_rm |
1422 |
| - ] |
| 1419 | + # Return a new list of messages that excludes the messages with the indices in messages_rm |
| 1420 | + return [msg for i, msg in enumerate(messages) if i not in messages_rm] |
1423 | 1421 |
|
1424 | 1422 | def _remote_call(self,
|
1425 | 1423 | method_name,
|
@@ -1492,32 +1490,18 @@ def show_only(self, indices='all', **kwargs):
|
1492 | 1490 | """
|
1493 | 1491 | traj_ids = {traj.id for traj in self._trajlist}
|
1494 | 1492 |
|
1495 |
| - if indices == 'all': |
1496 |
| - indices_ = set(range(self.n_components)) |
1497 |
| - else: |
1498 |
| - indices_ = set(indices) |
| 1493 | + indices_ = set(range(self.n_components)) if indices == 'all' else set(indices) |
1499 | 1494 |
|
1500 | 1495 | for index, comp_id in enumerate(self._ngl_component_ids):
|
1501 |
| - if comp_id in traj_ids: |
1502 |
| - traj = self._get_traj_by_id(comp_id) |
1503 |
| - else: |
1504 |
| - traj = None |
1505 |
| - if index in indices_: |
1506 |
| - args = [ |
1507 |
| - True, |
1508 |
| - ] |
1509 |
| - if traj is not None: |
1510 |
| - traj.shown = True |
1511 |
| - else: |
1512 |
| - args = [ |
1513 |
| - False, |
1514 |
| - ] |
1515 |
| - if traj is not None: |
1516 |
| - traj.shown = False |
| 1496 | + traj = self._get_traj_by_id(comp_id) if comp_id in traj_ids else None |
| 1497 | + is_visible = index in indices_ |
| 1498 | + |
| 1499 | + if traj is not None: |
| 1500 | + traj.shown = is_visible |
1517 | 1501 |
|
1518 | 1502 | self._remote_call("setVisibility",
|
1519 | 1503 | target='compList',
|
1520 |
| - args=args, |
| 1504 | + args=[is_visible], |
1521 | 1505 | kwargs={'component_index': index},
|
1522 | 1506 | **kwargs)
|
1523 | 1507 |
|
|
0 commit comments