|
1 | 1 | import base64 |
| 2 | +from enum import Enum |
2 | 3 | import json |
3 | 4 | import re |
4 | 5 | import threading |
@@ -63,6 +64,18 @@ def wrap_2(*args, **kwargs): |
63 | 64 | return wrap_1 |
64 | 65 |
|
65 | 66 |
|
| 67 | +class MessageType(Enum): |
| 68 | + REQUEST_FRAME = 'request_frame' |
| 69 | + UPDATE_IDS = 'updateIDs' |
| 70 | + REMOVE_COMPONENT = 'removeComponent' |
| 71 | + REPR_PARAMETERS = 'repr_parameters' |
| 72 | + REQUEST_LOADED = 'request_loaded' |
| 73 | + REQUEST_REPR_DICT = 'request_repr_dict' |
| 74 | + STAGE_PARAMETERS = 'stage_parameters' |
| 75 | + ASYNC_MESSAGE = 'async_message' |
| 76 | + IMAGE_DATA = 'image_data' |
| 77 | + |
| 78 | + |
66 | 79 | def write_html(fp, views, frame_range=None): |
67 | 80 | """EXPERIMENTAL. Likely will be changed. |
68 | 81 |
|
@@ -304,6 +317,7 @@ def _encode_trajectory(self, traj, frame_range): |
304 | 317 | else: |
305 | 318 | encoded_traj.append(encode_base64(np.empty((0), dtype='f4'))) |
306 | 319 | return encoded_traj |
| 320 | + |
307 | 321 | def _create_player(self): |
308 | 322 | player = Play(max=self.max_frame, interval=100) |
309 | 323 | slider = IntSlider(max=self.max_frame) |
@@ -587,35 +601,27 @@ def _set_size(self, w, h): |
587 | 601 | ''' |
588 | 602 | self._remote_call('setSize', target='Widget', args=[w, h]) |
589 | 603 |
|
590 | | - def _set_sync_repr(self, other_views): |
| 604 | + def _set_model_ids(self, other_views, attribute, remote_call): |
591 | 605 | model_ids = {v._model_id for v in other_views} |
592 | | - self._synced_repr_model_ids = sorted( |
593 | | - set(self._synced_repr_model_ids) | model_ids) |
594 | | - self._remote_call("setSyncRepr", |
595 | | - target="Widget", |
596 | | - args=[self._synced_repr_model_ids]) |
| 606 | + setattr(self, attribute, sorted(set(getattr(self, attribute)) | model_ids)) |
| 607 | + self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)]) |
597 | 608 |
|
598 | | - def _set_unsync_repr(self, other_views): |
| 609 | + def _unset_model_ids(self, other_views, attribute, remote_call): |
599 | 610 | model_ids = {v._model_id for v in other_views} |
600 | | - self._synced_repr_model_ids = list(set(self._synced_repr_model_ids) - model_ids) |
601 | | - self._remote_call("setSyncRepr", |
602 | | - target="Widget", |
603 | | - args=[self._synced_repr_model_ids]) |
| 611 | + setattr(self, attribute, list(set(getattr(self, attribute)) - model_ids)) |
| 612 | + self._remote_call(remote_call, target="Widget", args=[getattr(self, attribute)]) |
| 613 | + |
| 614 | + def _set_sync_repr(self, other_views): |
| 615 | + self._set_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr") |
| 616 | + |
| 617 | + def _set_unsync_repr(self, other_views): |
| 618 | + self._unset_model_ids(other_views, "_synced_repr_model_ids", "setSyncRepr") |
604 | 619 |
|
605 | 620 | def _set_sync_camera(self, other_views): |
606 | | - model_ids = {v._model_id for v in other_views} |
607 | | - self._synced_model_ids = sorted( |
608 | | - set(self._synced_model_ids) | model_ids) |
609 | | - self._remote_call("setSyncCamera", |
610 | | - target="Widget", |
611 | | - args=[self._synced_model_ids]) |
| 621 | + self._set_model_ids(other_views, "_synced_model_ids", "setSyncCamera") |
612 | 622 |
|
613 | 623 | def _set_unsync_camera(self, other_views): |
614 | | - model_ids = {v._model_id for v in other_views} |
615 | | - self._synced_model_ids = list(set(self._synced_model_ids) - model_ids) |
616 | | - self._remote_call("setSyncCamera", |
617 | | - target="Widget", |
618 | | - args=[self._synced_model_ids]) |
| 624 | + self._unset_model_ids(other_views, "_synced_model_ids", "setSyncCamera") |
619 | 625 |
|
620 | 626 | def _set_spin(self, axis, angle): |
621 | 627 | self._remote_call('setSpin', target='Stage', args=[axis, angle]) |
@@ -780,7 +786,6 @@ def _get_trajectory_coordinates(self, trajectory, index, traj_index): |
780 | 786 |
|
781 | 787 | def set_coordinates(self, arr_dict, movie_making=False, |
782 | 788 | render_params=None): |
783 | | - # type: (Dict[int, np.ndarray]) -> None |
784 | 789 | """Used for update coordinates of a given trajectory |
785 | 790 | >>> # arr: numpy array, ndim=2 |
786 | 791 | >>> # update coordinates of 1st trajectory |
@@ -1090,26 +1095,26 @@ def _handle_image_data(self): |
1090 | 1095 | def _handle_nglview_custom_msg(self, _, msg, buffers): |
1091 | 1096 | self._ngl_msg = msg |
1092 | 1097 |
|
1093 | | - msg_type = self._ngl_msg.get('type') |
1094 | | - if msg_type == 'request_frame': |
| 1098 | + msg_type = MessageType(self._ngl_msg.get('type')) |
| 1099 | + |
| 1100 | + if msg_type == MessageType.REQUEST_FRAME: |
1095 | 1101 | self._handle_request_frame() |
1096 | | - elif msg_type == 'updateIDs': |
| 1102 | + elif msg_type == MessageType.UPDATE_IDS: |
1097 | 1103 | self._handle_update_ids() |
1098 | | - elif msg_type == 'removeComponent': |
| 1104 | + elif msg_type == MessageType.REMOVE_COMPONENT: |
1099 | 1105 | self._handle_remove_component() |
1100 | | - elif msg_type == 'repr_parameters': |
| 1106 | + elif msg_type == MessageType.REPR_PARAMETERS: |
1101 | 1107 | self._handle_repr_parameters() |
1102 | | - elif msg_type == 'request_loaded': |
| 1108 | + elif msg_type == MessageType.REQUEST_LOADED: |
1103 | 1109 | self._handle_request_loaded() |
1104 | | - elif msg_type == 'request_repr_dict': |
| 1110 | + elif msg_type == MessageType.REQUEST_REPR_DICT: |
1105 | 1111 | self._handle_request_repr_dict() |
1106 | | - elif msg_type == 'stage_parameters': |
| 1112 | + elif msg_type == MessageType.STAGE_PARAMETERS: |
1107 | 1113 | self._handle_stage_parameters() |
1108 | | - elif msg_type == 'async_message': |
| 1114 | + elif msg_type == MessageType.ASYNC_MESSAGE: |
1109 | 1115 | self._handle_async_message() |
1110 | | - elif msg_type == 'image_data': |
| 1116 | + elif msg_type == MessageType.IMAGE_DATA: |
1111 | 1117 | self._handle_image_data() |
1112 | | - |
1113 | 1118 | def _request_repr_parameters(self, component=0, repr_index=0): |
1114 | 1119 | if self.n_components > 0: |
1115 | 1120 | self._remote_call('requestReprParameters', |
@@ -1362,64 +1367,57 @@ def _get_remote_call_msg(self, |
1362 | 1367 | args=['*', 200], |
1363 | 1368 | kwargs={'component_index': 1}) |
1364 | 1369 | """ |
1365 | | - # NOTE: _camelize_dict here? |
1366 | | - args = [] if args is None else args |
1367 | | - kwargs = {} if kwargs is None else kwargs |
| 1370 | + args = args or [] |
| 1371 | + kwargs = kwargs or {} |
1368 | 1372 |
|
1369 | | - msg = {} |
1370 | | - |
1371 | | - if 'component_index' in kwargs: |
1372 | | - msg['component_index'] = kwargs.pop('component_index') |
1373 | | - if 'repr_index' in kwargs: |
1374 | | - msg['repr_index'] = kwargs.pop('repr_index') |
1375 | | - if 'default' in kwargs: |
1376 | | - kwargs['defaultRepresentation'] = kwargs.pop('default') |
| 1373 | + msg = { |
| 1374 | + 'component_index': kwargs.pop('component_index', None), |
| 1375 | + 'repr_index': kwargs.pop('repr_index', None), |
| 1376 | + 'defaultRepresentation': kwargs.pop('default', None), |
| 1377 | + 'target': target, |
| 1378 | + 'type': 'call_method', |
| 1379 | + 'methodName': method_name, |
| 1380 | + 'reconstruc_color_scheme': False, |
| 1381 | + 'args': args, |
| 1382 | + 'kwargs': kwargs |
| 1383 | + } |
1377 | 1384 |
|
1378 | 1385 | # Color handling |
1379 | | - reconstruc_color_scheme = False |
1380 | | - if 'color' in kwargs and isinstance(kwargs['color'], |
1381 | | - color._ColorScheme): |
1382 | | - kwargs['color_label'] = kwargs['color'].data['label'] |
1383 | | - # overite `color` |
1384 | | - kwargs['color'] = kwargs['color'].data['data'] |
1385 | | - reconstruc_color_scheme = True |
| 1386 | + color = kwargs.get('color') |
| 1387 | + if isinstance(color, color._ColorScheme): |
| 1388 | + kwargs['color_label'] = color.data['label'] |
| 1389 | + kwargs['color'] = color.data['data'] |
| 1390 | + msg['reconstruc_color_scheme'] = True |
| 1391 | + |
1386 | 1392 | if kwargs.get('colorScheme') == 'volume' and kwargs.get('colorVolume'): |
1387 | 1393 | assert isinstance(kwargs['colorVolume'], ComponentViewer) |
1388 | 1394 | kwargs['colorVolume'] = kwargs['colorVolume']._index |
1389 | 1395 |
|
1390 | | - msg['target'] = target |
1391 | | - msg['type'] = 'call_method' |
1392 | | - msg['methodName'] = method_name |
1393 | | - msg['reconstruc_color_scheme'] = reconstruc_color_scheme |
1394 | | - msg['args'] = args |
1395 | | - msg['kwargs'] = kwargs |
1396 | 1396 | if other_kwargs: |
1397 | 1397 | msg.update(other_kwargs) |
| 1398 | + |
1398 | 1399 | return msg |
1399 | 1400 |
|
1400 | 1401 | def _trim_message(self, messages): |
1401 | | - messages = messages[:] |
| 1402 | + """ |
| 1403 | + This function trims the messages based on certain conditions. |
| 1404 | + """ |
1402 | 1405 |
|
1403 | | - remove_comps = [(index, msg['args'][0]) |
1404 | | - for index, msg in enumerate(messages) |
1405 | | - if msg['methodName'] == 'removeComponent'] |
| 1406 | + # Create a list of tuples containing the index and the first argument of the message |
| 1407 | + # for messages where the method name is 'removeComponent' |
| 1408 | + remove_comps = [(i, msg['args'][0]) for i, msg in enumerate(messages) if msg['methodName'] == 'removeComponent'] |
1406 | 1409 |
|
1407 | 1410 | if not remove_comps: |
1408 | 1411 | return messages |
1409 | 1412 |
|
1410 | | - load_comps = [ |
1411 | | - index for index, msg in enumerate(messages) |
1412 | | - if msg['methodName'] in ('loadFile', 'addShape') |
1413 | | - ] |
| 1413 | + # Create a list of indices for messages where the method name is either 'loadFile' or 'addShape' |
| 1414 | + load_comps = [i for i, msg in enumerate(messages) if msg['methodName'] in ('loadFile', 'addShape')] |
1414 | 1415 |
|
1415 | | - messages_rm = [r[0] for r in remove_comps] |
1416 | | - messages_rm += [load_comps[r[1]] for r in remove_comps] |
1417 | | - messages_rm = set(messages_rm) |
| 1416 | + # Create a set of indices to remove from the messages |
| 1417 | + messages_rm = set(i for r in remove_comps for i in (r[0], load_comps[r[1]])) |
1418 | 1418 |
|
1419 | | - return [ |
1420 | | - msg for i, msg in enumerate(messages) |
1421 | | - if i not in messages_rm |
1422 | | - ] |
| 1419 | + # Return a new list of messages that excludes the messages with the indices in messages_rm |
| 1420 | + return [msg for i, msg in enumerate(messages) if i not in messages_rm] |
1423 | 1421 |
|
1424 | 1422 | def _remote_call(self, |
1425 | 1423 | method_name, |
@@ -1492,32 +1490,18 @@ def show_only(self, indices='all', **kwargs): |
1492 | 1490 | """ |
1493 | 1491 | traj_ids = {traj.id for traj in self._trajlist} |
1494 | 1492 |
|
1495 | | - if indices == 'all': |
1496 | | - indices_ = set(range(self.n_components)) |
1497 | | - else: |
1498 | | - indices_ = set(indices) |
| 1493 | + indices_ = set(range(self.n_components)) if indices == 'all' else set(indices) |
1499 | 1494 |
|
1500 | 1495 | for index, comp_id in enumerate(self._ngl_component_ids): |
1501 | | - if comp_id in traj_ids: |
1502 | | - traj = self._get_traj_by_id(comp_id) |
1503 | | - else: |
1504 | | - traj = None |
1505 | | - if index in indices_: |
1506 | | - args = [ |
1507 | | - True, |
1508 | | - ] |
1509 | | - if traj is not None: |
1510 | | - traj.shown = True |
1511 | | - else: |
1512 | | - args = [ |
1513 | | - False, |
1514 | | - ] |
1515 | | - if traj is not None: |
1516 | | - traj.shown = False |
| 1496 | + traj = self._get_traj_by_id(comp_id) if comp_id in traj_ids else None |
| 1497 | + is_visible = index in indices_ |
| 1498 | + |
| 1499 | + if traj is not None: |
| 1500 | + traj.shown = is_visible |
1517 | 1501 |
|
1518 | 1502 | self._remote_call("setVisibility", |
1519 | 1503 | target='compList', |
1520 | | - args=args, |
| 1504 | + args=[is_visible], |
1521 | 1505 | kwargs={'component_index': index}, |
1522 | 1506 | **kwargs) |
1523 | 1507 |
|
|
0 commit comments