Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding BERT for MS-MARCO Passage re-ranking #277

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions texar/torch/data/tokenizers/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class BERTTokenizer(PretrainedBERTMixin, TokenizerBase):
'scibert-scivocab-cased': 512,
'scibert-basevocab-uncased': 512,
'scibert-basevocab-cased': 512,

# BERT for MS-MARCO
'bert-msmarco-nogueira19-base': 512,
'bert-msmarco-nogueira19-large': 512,
}
_VOCAB_FILE_NAMES = {'vocab_file': 'vocab.txt'}
_VOCAB_FILE_MAP = {
Expand All @@ -98,6 +102,10 @@ class BERTTokenizer(PretrainedBERTMixin, TokenizerBase):
'scibert-scivocab-cased': 'vocab.txt',
'scibert-basevocab-uncased': 'vocab.txt',
'scibert-basevocab-cased': 'vocab.txt',

# BERT for MS-MARCO
'bert-msmarco-nogueira19-base': 'vocab.txt',
'bert-msmarco-nogueira19-large': 'vocab.txt',
}
}

Expand Down
7 changes: 6 additions & 1 deletion texar/torch/modules/classifiers/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ def __init__(self,

super().__init__(hparams=hparams)

self.load_pretrained_config(pretrained_model_name, cache_dir)
Copy link
Collaborator

@gpengzhi gpengzhi Dec 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will load_pretrained_config and init_pretrained_weights be called twice (once in BERTClassifier, and once in BERTEncoder)?

Copy link
Collaborator

@gpengzhi gpengzhi Dec 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that is the case, we probably should not load the pre-trained weights in self._encoder (BERTEncoder).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline.
We can pass pretrained_model_name as None while instantiating the encoder in BERTClassifier.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you set pretrained_model_name and pretrained_model_name in hparams to be None, BERTEncoder won't load the pre-trained weights.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made both the changes.


# Create the underlying encoder
encoder_hparams = dict_fetch(hparams,
self._ENCODER_CLASS.default_hparams())
encoder_hparams['pretrained_model_name'] = None

self._encoder = self._ENCODER_CLASS(
pretrained_model_name=pretrained_model_name,
pretrained_model_name=None,
cache_dir=cache_dir,
hparams=encoder_hparams)

Expand Down Expand Up @@ -120,6 +123,8 @@ def __init__(self,
(self.num_classes <= 0 and
self._hparams.encoder.dim == 1)

self.init_pretrained_weights(class_type='classifier')

@staticmethod
def default_hparams():
r"""Returns a dictionary of hyperparameters with default values.
Expand Down
54 changes: 49 additions & 5 deletions texar/torch/modules/pretrained/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_BIOBERT_PATH = "https://github.com/naver/biobert-pretrained/releases/download/"
_SCIBERT_PATH = "https://s3-us-west-2.amazonaws.com/ai2-s2-research/" \
"scibert/tensorflow_models/"
_BERT_MSMARCO_NOGUEIRA19_PATH = "https://drive.google.com/file/d/"


class PretrainedBERTMixin(PretrainedMixin, ABC):
Expand Down Expand Up @@ -97,6 +98,16 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
* ``scibert-basevocab-cased``: Cased version of the model trained on
the original BERT vocabulary.

* **BERT for MS-MARCO**: proposed in (`Nogueira et al`. 2019)
`Passage Re-ranking with BERT`_. A BERT model fine-tuned on MS-MARCO
(`Nguyen et al`., 2016) dataset. It's the best performing model (on Jan
8th 2019) on MS-MARCO Passage re-ranking task. Two models are included:

* ``bert-msmarco-nogueira19-base``: Original BERT base model fine-tuned
on MS-MARCO.
* ``bert-msmarco-nogueira19-large``: Original BERT large model
fine-tuned on MS-MARCO.

We provide the following BERT classes:

* :class:`~texar.torch.modules.BERTEncoder` for text encoding.
Expand All @@ -111,6 +122,9 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):

.. _`SciBERT: A Pretrained Language Model for Scientific Text`:
https://arxiv.org/abs/1903.10676

.. _`Passage Re-ranking with BERT`:
https://arxiv.org/abs/1901.04085
"""

_MODEL_NAME = "BERT"
Expand Down Expand Up @@ -150,6 +164,12 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
_SCIBERT_PATH + 'scibert_basevocab_uncased.tar.gz',
'scibert-basevocab-cased':
_SCIBERT_PATH + 'scibert_basevocab_cased.tar.gz',

# BERT for MS-MARCO
'bert-msmarco-nogueira19-base':
_BERT_MSMARCO_NOGUEIRA19_PATH + '1cyUrhs7JaCJTTu-DjFUqP6Bs4f8a6JTX',
'bert-msmarco-nogueira19-large':
_BERT_MSMARCO_NOGUEIRA19_PATH + '1crlASTMlsihALlkabAQP6JTYIZwC1Wm8/'
}
_MODEL2CKPT = {
# Standard BERT
Expand All @@ -172,6 +192,10 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
'scibert-scivocab-cased': 'bert_model.ckpt',
'scibert-basevocab-uncased': 'bert_model.ckpt',
'scibert-basevocab-cased': 'bert_model.ckpt',

# BERT for MSMARCO
'bert-msmarco-nogueira19-base': 'model.ckpt-100000',
'bert-msmarco-nogueira19-large': 'model.ckpt-100000',
}

@classmethod
Expand Down Expand Up @@ -301,11 +325,21 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
}
pooler_map = {
'bert/pooler/dense/bias': 'pooler.0.bias',
'bert/pooler/dense/kernel': 'pooler.0.weight'
'bert/pooler/dense/kernel': 'pooler.0.weight',
}
classifier_map = {
'output_bias': '_logits_layer.bias',
'output_weights': '_logits_layer.weight',
}
global_prefix_map = {
'classifier': '_encoder.'
}
tf_path = os.path.abspath(os.path.join(
cache_dir, self._MODEL2CKPT[pretrained_model_name]))

class_type = kwargs.get('class_type', 'encoder')
global_prefix = global_prefix_map.get(class_type, '')

# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tfnames, arrays = [], []
Expand All @@ -325,13 +359,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
continue

if name in global_tensor_map:
v_name = global_tensor_map[name]
v_name = global_prefix + global_tensor_map[name]
pointer = self._name_to_variable(v_name)
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
idx += 1
elif name in pooler_map:
pointer = self._name_to_variable(pooler_map[name])
pointer = self._name_to_variable(global_prefix +
pooler_map[name])
if name.endswith('bias'):
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
Expand All @@ -341,6 +376,13 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
assert pointer.shape == array_t.shape
pointer.data = torch.from_numpy(array_t)
idx += 1
elif name in classifier_map:
if class_type != 'classifier':
continue
pointer = self._name_to_variable(classifier_map[name])
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
idx += 1
else:
# here name is the TensorFlow variable name
name_tmp = name.split("/")
Expand All @@ -349,12 +391,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
name_tmp = "/".join(name_tmp[3:])
if name_tmp in layer_tensor_map:
v_name = layer_tensor_map[name_tmp].format(layer_no)
pointer = self._name_to_variable(py_prefix + v_name)
pointer = self._name_to_variable(global_prefix +
py_prefix + v_name)
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
elif name_tmp in layer_transpose_map:
v_name = layer_transpose_map[name_tmp].format(layer_no)
pointer = self._name_to_variable(py_prefix + v_name)
pointer = self._name_to_variable(global_prefix +
py_prefix + v_name)
array_t = np.transpose(array)
assert pointer.shape == array_t.shape
pointer.data = torch.from_numpy(array_t)
Expand Down