Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit ab41699

Browse files
eric-haibin-linszha
authored andcommitted
[v0.8.x][BUGFIX] Update BERT embedding script (#1045)
* update embedding * display tokens used in the batch, and remove hard coded values * fix typo * update embedding display tokens used in the batch, and remove hard coded values fix typo * Update embedding.py * fix get_model call * Update embedding.py * Update embedding.py * fix download error in test * Update test_scripts.py
1 parent b5ded8f commit ab41699

File tree

4 files changed

+61
-49
lines changed

4 files changed

+61
-49
lines changed

scripts/bert/embedding.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# coding: utf-8
2-
31
# Licensed to the Apache Software Foundation (ASF) under one
42
# or more contributor license agreements. See the NOTICE file
53
# distributed with this work for additional information
@@ -29,25 +27,14 @@
2927
from mxnet.gluon.data import DataLoader
3028

3129
import gluonnlp
32-
from gluonnlp.data import BERTTokenizer, BERTSentenceTransform
30+
from gluonnlp.data import BERTTokenizer, BERTSentenceTransform, BERTSPTokenizer
3331
from gluonnlp.base import get_home_dir
3432

3533
try:
3634
from data.embedding import BertEmbeddingDataset
3735
except ImportError:
3836
from .data.embedding import BertEmbeddingDataset
3937

40-
try:
41-
unicode
42-
except NameError:
43-
# Define `unicode` for Python3
44-
def unicode(s, *_):
45-
return s
46-
47-
48-
def to_unicode(s):
49-
return unicode(s, 'utf-8')
50-
5138

5239
__all__ = ['BertEmbedding']
5340

@@ -75,36 +62,50 @@ class BertEmbedding:
7562
max length of each sequence
7663
batch_size : int, default 256
7764
batch size
65+
sentencepiece : str, default None
66+
Path to the sentencepiece .model file for both tokenization and vocab
7867
root : str, default '$MXNET_HOME/models' with MXNET_HOME defaults to '~/.mxnet'
7968
Location for keeping the model parameters.
8069
"""
8170
def __init__(self, ctx=mx.cpu(), dtype='float32', model='bert_12_768_12',
8271
dataset_name='book_corpus_wiki_en_uncased', params_path=None,
83-
max_seq_length=25, batch_size=256,
72+
max_seq_length=25, batch_size=256, sentencepiece=None,
8473
root=os.path.join(get_home_dir(), 'models')):
8574
self.ctx = ctx
8675
self.dtype = dtype
8776
self.max_seq_length = max_seq_length
8877
self.batch_size = batch_size
8978
self.dataset_name = dataset_name
9079

91-
# Don't download the pretrained models if we have a parameter path
80+
# use sentencepiece vocab and a checkpoint
81+
# we need to set dataset_name to None, otherwise it uses the downloaded vocab
82+
if params_path and sentencepiece:
83+
dataset_name = None
84+
else:
85+
dataset_name = self.dataset_name
86+
if sentencepiece:
87+
vocab = gluonnlp.vocab.BERTVocab.from_sentencepiece(sentencepiece)
88+
else:
89+
vocab = None
9290
self.bert, self.vocab = gluonnlp.model.get_model(model,
93-
dataset_name=self.dataset_name,
91+
dataset_name=dataset_name,
9492
pretrained=params_path is None,
9593
ctx=self.ctx,
9694
use_pooler=False,
9795
use_decoder=False,
9896
use_classifier=False,
99-
root=root)
100-
self.bert.cast(self.dtype)
97+
root=root, vocab=vocab)
10198

99+
self.bert.cast(self.dtype)
102100
if params_path:
103101
logger.info('Loading params from %s', params_path)
104-
self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True)
102+
self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True, cast_dtype=True)
105103

106104
lower = 'uncased' in self.dataset_name
107-
self.tokenizer = BERTTokenizer(self.vocab, lower=lower)
105+
if sentencepiece:
106+
self.tokenizer = BERTSPTokenizer(sentencepiece, self.vocab, lower=lower)
107+
else:
108+
self.tokenizer = BERTTokenizer(self.vocab, lower=lower)
108109
self.transform = BERTSentenceTransform(tokenizer=self.tokenizer,
109110
max_seq_length=self.max_seq_length,
110111
pair=False)
@@ -153,12 +154,9 @@ def oov(self, batches, oov_way='avg'):
153154
154155
Parameters
155156
----------
156-
batches : List[(tokens_id,
157-
sequence_outputs,
158-
pooled_output].
159-
batch token_ids (max_seq_length, ),
160-
sequence_outputs (max_seq_length, dim, ),
161-
pooled_output (dim, )
157+
batches : List[(tokens_id, sequence_outputs)].
158+
batch token_ids shape is (max_seq_length,),
159+
sequence_outputs shape is (max_seq_length, dim)
162160
oov_way : str
163161
use **avg**, **sum** or **last** to get token embedding for those out of
164162
vocabulary words
@@ -169,21 +167,29 @@ def oov(self, batches, oov_way='avg'):
169167
List of tokens, and tokens embedding
170168
"""
171169
sentences = []
170+
padding_idx, cls_idx, sep_idx = None, None, None
171+
if self.vocab.padding_token:
172+
padding_idx = self.vocab[self.vocab.padding_token]
173+
if self.vocab.cls_token:
174+
cls_idx = self.vocab[self.vocab.cls_token]
175+
if self.vocab.sep_token:
176+
sep_idx = self.vocab[self.vocab.sep_token]
172177
for token_ids, sequence_outputs in batches:
173178
tokens = []
174179
tensors = []
175180
oov_len = 1
176181
for token_id, sequence_output in zip(token_ids, sequence_outputs):
177-
if token_id == 1:
178-
# [PAD] token, sequence is finished.
182+
# [PAD] token, sequence is finished.
183+
if padding_idx and token_id == padding_idx:
179184
break
180-
if token_id in (2, 3):
181-
# [CLS], [SEP]
185+
# [CLS], [SEP]
186+
if cls_idx and token_id == cls_idx:
187+
continue
188+
if sep_idx and token_id == sep_idx:
182189
continue
183190
token = self.vocab.idx_to_token[token_id]
184-
if token.startswith('##'):
185-
token = token[2:]
186-
tokens[-1] += token
191+
if not self.tokenizer.is_first_subword(token):
192+
tokens.append(token)
187193
if oov_way == 'last':
188194
tensors[-1] = sequence_output
189195
else:
@@ -212,19 +218,21 @@ def oov(self, batches, oov_way='avg'):
212218
parser.add_argument('--model', type=str, default='bert_12_768_12',
213219
help='pre-trained model')
214220
parser.add_argument('--dataset_name', type=str, default='book_corpus_wiki_en_uncased',
215-
help='dataset')
221+
help='name of the dataset used for pre-training')
216222
parser.add_argument('--params_path', type=str, default=None,
217223
help='path to a params file to load instead of the pretrained model.')
218-
parser.add_argument('--max_seq_length', type=int, default=25,
224+
parser.add_argument('--sentencepiece', type=str, default=None,
225+
help='Path to the sentencepiece .model file for tokenization and vocab.')
226+
parser.add_argument('--max_seq_length', type=int, default=128,
219227
help='max length of each sequence')
220228
parser.add_argument('--batch_size', type=int, default=256,
221229
help='batch size')
222230
parser.add_argument('--oov_way', type=str, default='avg',
223-
help='how to handle oov\n'
224-
'avg: average all oov embeddings to represent the original token\n'
225-
'sum: sum all oov embeddings to represent the original token\n'
226-
'last: use last oov embeddings to represent the original token\n')
227-
parser.add_argument('--sentences', type=to_unicode, nargs='+', default=None,
231+
help='how to handle subword embeddings\n'
232+
'avg: average all subword embeddings to represent the original token\n'
233+
'sum: sum all subword embeddings to represent the original token\n'
234+
'last: use last subword embeddings to represent the original token\n')
235+
parser.add_argument('--sentences', type=str, nargs='+', default=None,
228236
help='sentence for encoding')
229237
parser.add_argument('--file', type=str, default=None,
230238
help='file for encoding')
@@ -240,7 +248,8 @@ def oov(self, batches, oov_way='avg'):
240248
else:
241249
context = mx.cpu()
242250
bert_embedding = BertEmbedding(ctx=context, model=args.model, dataset_name=args.dataset_name,
243-
max_seq_length=args.max_seq_length, batch_size=args.batch_size)
251+
max_seq_length=args.max_seq_length, batch_size=args.batch_size,
252+
params_path=args.params_path, sentencepiece=args.sentencepiece)
244253
result = []
245254
sents = []
246255
if args.sentences:
@@ -255,7 +264,7 @@ def oov(self, batches, oov_way='avg'):
255264
logger.error('Please specify --sentence or --file')
256265

257266
if result:
258-
for sent, embeddings in zip(sents, result):
259-
print('Text: {}'.format(sent))
260-
_, tokens_embedding = embeddings
267+
for _, embeddings in zip(sents, result):
268+
sent, tokens_embedding = embeddings
269+
print('Text: {}'.format(' '.join(sent)))
261270
print('Tokens embedding: {}'.format(tokens_embedding))

scripts/bert/finetune_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
except ImportError:
220220
# amp is not available
221221
logging.info('Mixed precision training with float16 requires MXNet >= '
222-
'1.5.0b20190627. Please consider upgrading your MXNet version.')
222+
'1.5.1. Please consider upgrading your MXNet version.')
223223
exit()
224224

225225
# model and loss

scripts/bert/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ The goal of this BERT Embedding is to obtain the token embedding from BERT's pre
278278

279279
.. code-block:: shell
280280
281-
python bert/embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research."
282-
Text: GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research.
281+
python embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research."
282+
Text: g ##lu ##on ##nl ##p is a tool ##kit that enables easy text prep ##ro ##ces ##sing , data ##set ##s loading and neural models building to help you speed up your natural language processing ( nl ##p ) research .
283283
Tokens embedding: [array([-0.11881411, -0.59530115, 0.627092 , ..., 0.00648153,
284284
-0.03886228, 0.03406909], dtype=float32), array([-0.7995638 , -0.6540758 , -0.00521846, ..., -0.42272145,
285285
-0.5787281 , 0.7021201 ], dtype=float32), array([-0.7406778 , -0.80276626, 0.3931962 , ..., -0.49068323,

scripts/tests/test_scripts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import pytest
2727
import mxnet as mx
28+
import gluonnlp as nlp
2829

2930
@pytest.mark.serial
3031
@pytest.mark.remote_required
@@ -200,8 +201,10 @@ def test_bert_embedding(use_pretrained):
200201
if use_pretrained:
201202
args.extend(['--dtype', 'float32'])
202203
else:
204+
_, _ = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased',
205+
pretrained=True, root='test_bert_embedding')
203206
args.extend(['--params_path',
204-
'~/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params'])
207+
'test_bert_embedding/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params'])
205208
process = subprocess.check_call([sys.executable, './scripts/bert/embedding.py'] + args)
206209
time.sleep(5)
207210

0 commit comments

Comments
 (0)