@@ -55,11 +55,16 @@ class Dataset(object):
5555 Specially, if feature is loaded from Arg ``additional_feat_suffix``, its source has type str,
5656 which is the suffix of its local file (also the suffix written in Arg ``additional_feat_suffix``).
5757
58- field2id_token (dict): Dict mapping feature name (str) to a list , which stores the original token of
59- this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
58+ field2id_token (dict): Dict mapping feature name (str) to a :class:`np.ndarray` , which stores the original token
59+ of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
6060 is remapped to 2. Then ``field2id_token['test'] = ['[PAD]', 'token_a', 'token_b']``. (Note that 0 is
6161 always PADDING for token-like features.)
6262
63+ field2token_id (dict): Dict mapping feature name (str) to a dict, which stores the token remap table
64+ of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
65+ is remapped to 2. Then ``field2token_id['test'] = {'[PAD]': 0, 'token_a': 1, 'token_b': 2}``.
66+ (Note that 0 is always PADDING for token-like features.)
67+
6368 field2seqlen (dict): Dict mapping feature name (str) to its sequence length (int).
6469 For sequence features, their length can be either set in config,
6570 or set to the max sequence length of this feature.
@@ -116,6 +121,7 @@ def _get_preset(self):
116121 self .field2type = {}
117122 self .field2source = {}
118123 self .field2id_token = {}
124+ self .field2token_id = {}
119125 self .field2seqlen = self .config ['seq_len' ] or {}
120126 self ._preloaded_weight = {}
121127 self .benchmark_filename_list = self .config ['benchmark_filename' ]
@@ -897,11 +903,13 @@ def _remap(self, remap_list):
897903 tokens , split_point = self ._concat_remaped_tokens (remap_list )
898904 new_ids_list , mp = pd .factorize (tokens )
899905 new_ids_list = np .split (new_ids_list + 1 , split_point )
900- mp = ['[PAD]' ] + list (mp )
906+ mp = np .array (['[PAD]' ] + list (mp ))
907+ token_id = {t : i for i , t in enumerate (mp )}
901908
902909 for (feat , field , ftype ), new_ids in zip (remap_list , new_ids_list ):
903- if ( field not in self .field2id_token ) :
910+ if field not in self .field2id_token :
904911 self .field2id_token [field ] = mp
912+ self .field2token_id [field ] = token_id
905913 if ftype == FeatureType .TOKEN :
906914 feat [field ] = new_ids
907915 elif ftype == FeatureType .TOKEN_SEQ :
@@ -1010,6 +1018,46 @@ def copy_field_property(self, dest_field, source_field):
10101018 self .field2source [dest_field ] = self .field2source [source_field ]
10111019 self .field2seqlen [dest_field ] = self .field2seqlen [source_field ]
10121020
1021+ @dlapi .set ()
1022+ def token2id (self , field , tokens ):
1023+ """Map external tokens to internal ids.
1024+
1025+ Args:
1026+ field (str): Field of external tokens.
1027+ tokens (str, list or np.ndarray): External tokens.
1028+
1029+ Returns:
1030+ int or np.ndarray: The internal ids of external tokens.
1031+ """
1032+ if isinstance (tokens , str ):
1033+ if tokens in self .field2token_id [field ]:
1034+ return self .field2token_id [field ][tokens ]
1035+ else :
1036+ raise ValueError ('token [{}] is not existed' )
1037+ elif isinstance (tokens , (list , np .ndarray )):
1038+ return np .array ([self .token2id (field , token ) for token in tokens ])
1039+ else :
1040+ raise TypeError ('The type of tokens [{}] is not supported' )
1041+
1042+ @dlapi .set ()
1043+ def id2token (self , field , ids ):
1044+ """Map internal ids to external tokens.
1045+
1046+ Args:
1047+ field (str): Field of internal ids.
1048+ ids (int, list, np.ndarray or torch.Tensor): Internal ids.
1049+
1050+ Returns:
1051+ str or np.ndarray: The external tokens of internal ids.
1052+ """
1053+ try :
1054+ return self .field2id_token [field ][ids ]
1055+ except IndexError :
1056+ if isinstance (ids , list ):
1057+ raise ValueError ('[{}] is not a one-dimensional list' .format (ids ))
1058+ else :
1059+ raise ValueError ('[{}] is not a valid ids' .format (ids ))
1060+
10131061 @property
10141062 @dlapi .set ()
10151063 def user_num (self ):
0 commit comments