1
- # coding: utf-8
2
-
3
1
# Licensed to the Apache Software Foundation (ASF) under one
4
2
# or more contributor license agreements. See the NOTICE file
5
3
# distributed with this work for additional information
29
27
from mxnet .gluon .data import DataLoader
30
28
31
29
import gluonnlp
32
- from gluonnlp .data import BERTTokenizer , BERTSentenceTransform
30
+ from gluonnlp .data import BERTTokenizer , BERTSentenceTransform , BERTSPTokenizer
33
31
from gluonnlp .base import get_home_dir
34
32
35
33
try :
36
34
from data .embedding import BertEmbeddingDataset
37
35
except ImportError :
38
36
from .data .embedding import BertEmbeddingDataset
39
37
40
- try :
41
- unicode
42
- except NameError :
43
- # Define `unicode` for Python3
44
- def unicode (s , * _ ):
45
- return s
46
-
47
-
48
- def to_unicode (s ):
49
- return unicode (s , 'utf-8' )
50
-
51
38
52
39
__all__ = ['BertEmbedding' ]
53
40
@@ -75,36 +62,50 @@ class BertEmbedding:
75
62
max length of each sequence
76
63
batch_size : int, default 256
77
64
batch size
65
+ sentencepiece : str, default None
66
+ Path to the sentencepiece .model file for both tokenization and vocab
78
67
root : str, default '$MXNET_HOME/models' with MXNET_HOME defaults to '~/.mxnet'
79
68
Location for keeping the model parameters.
80
69
"""
81
70
def __init__ (self , ctx = mx .cpu (), dtype = 'float32' , model = 'bert_12_768_12' ,
82
71
dataset_name = 'book_corpus_wiki_en_uncased' , params_path = None ,
83
- max_seq_length = 25 , batch_size = 256 ,
72
+ max_seq_length = 25 , batch_size = 256 , sentencepiece = None ,
84
73
root = os .path .join (get_home_dir (), 'models' )):
85
74
self .ctx = ctx
86
75
self .dtype = dtype
87
76
self .max_seq_length = max_seq_length
88
77
self .batch_size = batch_size
89
78
self .dataset_name = dataset_name
90
79
91
- # Don't download the pretrained models if we have a parameter path
80
+ # use sentencepiece vocab and a checkpoint
81
+ # we need to set dataset_name to None, otherwise it uses the downloaded vocab
82
+ if params_path and sentencepiece :
83
+ dataset_name = None
84
+ else :
85
+ dataset_name = self .dataset_name
86
+ if sentencepiece :
87
+ vocab = gluonnlp .vocab .BERTVocab .from_sentencepiece (sentencepiece )
88
+ else :
89
+ vocab = None
92
90
self .bert , self .vocab = gluonnlp .model .get_model (model ,
93
- dataset_name = self . dataset_name ,
91
+ dataset_name = dataset_name ,
94
92
pretrained = params_path is None ,
95
93
ctx = self .ctx ,
96
94
use_pooler = False ,
97
95
use_decoder = False ,
98
96
use_classifier = False ,
99
- root = root )
100
- self .bert .cast (self .dtype )
97
+ root = root , vocab = vocab )
101
98
99
+ self .bert .cast (self .dtype )
102
100
if params_path :
103
101
logger .info ('Loading params from %s' , params_path )
104
- self .bert .load_parameters (params_path , ctx = ctx , ignore_extra = True )
102
+ self .bert .load_parameters (params_path , ctx = ctx , ignore_extra = True , cast_dtype = True )
105
103
106
104
lower = 'uncased' in self .dataset_name
107
- self .tokenizer = BERTTokenizer (self .vocab , lower = lower )
105
+ if sentencepiece :
106
+ self .tokenizer = BERTSPTokenizer (sentencepiece , self .vocab , lower = lower )
107
+ else :
108
+ self .tokenizer = BERTTokenizer (self .vocab , lower = lower )
108
109
self .transform = BERTSentenceTransform (tokenizer = self .tokenizer ,
109
110
max_seq_length = self .max_seq_length ,
110
111
pair = False )
@@ -153,12 +154,9 @@ def oov(self, batches, oov_way='avg'):
153
154
154
155
Parameters
155
156
----------
156
- batches : List[(tokens_id,
157
- sequence_outputs,
158
- pooled_output].
159
- batch token_ids (max_seq_length, ),
160
- sequence_outputs (max_seq_length, dim, ),
161
- pooled_output (dim, )
157
+ batches : List[(tokens_id, sequence_outputs)].
158
+ batch token_ids shape is (max_seq_length,),
159
+ sequence_outputs shape is (max_seq_length, dim)
162
160
oov_way : str
163
161
use **avg**, **sum** or **last** to get token embedding for those out of
164
162
vocabulary words
@@ -169,21 +167,29 @@ def oov(self, batches, oov_way='avg'):
169
167
List of tokens, and tokens embedding
170
168
"""
171
169
sentences = []
170
+ padding_idx , cls_idx , sep_idx = None , None , None
171
+ if self .vocab .padding_token :
172
+ padding_idx = self .vocab [self .vocab .padding_token ]
173
+ if self .vocab .cls_token :
174
+ cls_idx = self .vocab [self .vocab .cls_token ]
175
+ if self .vocab .sep_token :
176
+ sep_idx = self .vocab [self .vocab .sep_token ]
172
177
for token_ids , sequence_outputs in batches :
173
178
tokens = []
174
179
tensors = []
175
180
oov_len = 1
176
181
for token_id , sequence_output in zip (token_ids , sequence_outputs ):
177
- if token_id == 1 :
178
- # [PAD] token, sequence is finished.
182
+ # [PAD] token, sequence is finished.
183
+ if padding_idx and token_id == padding_idx :
179
184
break
180
- if token_id in (2 , 3 ):
181
- # [CLS], [SEP]
185
+ # [CLS], [SEP]
186
+ if cls_idx and token_id == cls_idx :
187
+ continue
188
+ if sep_idx and token_id == sep_idx :
182
189
continue
183
190
token = self .vocab .idx_to_token [token_id ]
184
- if token .startswith ('##' ):
185
- token = token [2 :]
186
- tokens [- 1 ] += token
191
+ if not self .tokenizer .is_first_subword (token ):
192
+ tokens .append (token )
187
193
if oov_way == 'last' :
188
194
tensors [- 1 ] = sequence_output
189
195
else :
@@ -212,19 +218,21 @@ def oov(self, batches, oov_way='avg'):
212
218
parser .add_argument ('--model' , type = str , default = 'bert_12_768_12' ,
213
219
help = 'pre-trained model' )
214
220
parser .add_argument ('--dataset_name' , type = str , default = 'book_corpus_wiki_en_uncased' ,
215
- help = 'dataset' )
221
+ help = 'name of the dataset used for pre-training ' )
216
222
parser .add_argument ('--params_path' , type = str , default = None ,
217
223
help = 'path to a params file to load instead of the pretrained model.' )
218
- parser .add_argument ('--max_seq_length' , type = int , default = 25 ,
224
+ parser .add_argument ('--sentencepiece' , type = str , default = None ,
225
+ help = 'Path to the sentencepiece .model file for tokenization and vocab.' )
226
+ parser .add_argument ('--max_seq_length' , type = int , default = 128 ,
219
227
help = 'max length of each sequence' )
220
228
parser .add_argument ('--batch_size' , type = int , default = 256 ,
221
229
help = 'batch size' )
222
230
parser .add_argument ('--oov_way' , type = str , default = 'avg' ,
223
- help = 'how to handle oov \n '
224
- 'avg: average all oov embeddings to represent the original token\n '
225
- 'sum: sum all oov embeddings to represent the original token\n '
226
- 'last: use last oov embeddings to represent the original token\n ' )
227
- parser .add_argument ('--sentences' , type = to_unicode , nargs = '+' , default = None ,
231
+ help = 'how to handle subword embeddings \n '
232
+ 'avg: average all subword embeddings to represent the original token\n '
233
+ 'sum: sum all subword embeddings to represent the original token\n '
234
+ 'last: use last subword embeddings to represent the original token\n ' )
235
+ parser .add_argument ('--sentences' , type = str , nargs = '+' , default = None ,
228
236
help = 'sentence for encoding' )
229
237
parser .add_argument ('--file' , type = str , default = None ,
230
238
help = 'file for encoding' )
@@ -240,7 +248,8 @@ def oov(self, batches, oov_way='avg'):
240
248
else :
241
249
context = mx .cpu ()
242
250
bert_embedding = BertEmbedding (ctx = context , model = args .model , dataset_name = args .dataset_name ,
243
- max_seq_length = args .max_seq_length , batch_size = args .batch_size )
251
+ max_seq_length = args .max_seq_length , batch_size = args .batch_size ,
252
+ params_path = args .params_path , sentencepiece = args .sentencepiece )
244
253
result = []
245
254
sents = []
246
255
if args .sentences :
@@ -255,7 +264,7 @@ def oov(self, batches, oov_way='avg'):
255
264
logger .error ('Please specify --sentence or --file' )
256
265
257
266
if result :
258
- for sent , embeddings in zip (sents , result ):
259
- print ( 'Text: {}' . format ( sent ))
260
- _ , tokens_embedding = embeddings
267
+ for _ , embeddings in zip (sents , result ):
268
+ sent , tokens_embedding = embeddings
269
+ print ( 'Text: {}' . format ( ' ' . join ( sent )))
261
270
print ('Tokens embedding: {}' .format (tokens_embedding ))
0 commit comments