forked from deeppavlov/DeepPavlov
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: multilingual and english BERT-based models for classification a…
…nd SQuAD (deeppavlov#742) * feat: add multi_squad_retr dataset * feat: add config file * feat: add multi_squad_ru_retr dataset * fix: paths and squad datareaders * fix: squad preprocess * feat: add context from squad rate sampling ratio * fix: dataset reader * fix: squad rate attr in squad_iterator * feat: add link to trained noans_ru model * feat: create odqa retr noans config, update top n for ru tfidf ranker * fix: upd metrics docs * feat: add bert for classification tasks * fix: remove squad model from registry * fix: add bert preprocessor * fix: upd names in configs * fix: set ru tfidf ranker top_n to 5 * fix: set paragraphs false in odqa * fix: multilabel classification with bert * fix: bert dropout config names * feat: add AdamWeighDecay optimizer for bert * config for rusentiment bert 1m ckpt * chore: bert configs * feat: add bert model for squad * fix: squad bert config * feat: add loading pretrained bert for squad * feat: add multilingual bert for sbersquad * feat: add url arg to squad dataset reader * fix: bert answer postprocessor * feat&fix: add bert squad infer model and config fixes * chore: add squad bert infer config * feat: add batching to bert inference model * feat: add squad with rubert and odqa configs * fix: squad rubert config * chore: add ruodqa evaluation config * fix: squad iterator empty batch * chore: add one more bert based squad config * feat: add noans score and model trained on retr dump * chore: add inference config for bert on noans * feat: add bert_ner (with bugs for now) * fix: attribute assign in classifier * fix: Adam inited twice * fix: ner bert preprocessor * refactor&docs: add docs for bert preprocessor * feat: add bert_ner for inference * docs: add bert preprocessor call method docs * docs: bert squad model * feat: add parametrization of taken encoder layers * fix: rubert download link * docs: train_on_batch and call for bert squad model * docs: add docstring for BertSQuADInferModel * fix: sigmoid -> softmax output * docs: upd docstring * fix: list -> tuple default value * docs: add docstrings for bert classifier * chore: rm configs * feat: paraphraser config, model and scores * feat: rusentiment with multi-lingual bert * feat: rusentiment with multi-lingual bert * feat: insults kaggle on English BERT * chore: remove configs * docs: add ru squad results * feat: add bert models to tests * docs: upd docs * feat: upd configs and tests * feat: rm bert configs * refactor: rm debug logging * fix: paraphraser dataset reader * docs: update bert_ner docstring * fix: restore tests for squad noans model * fix: paraphraser tests * fix: add bert_dp requirement to squad_bert config * fix: download path for model files * fix: rusentiment tests * fix: import order * fix: remove extra tabulation * refactor: rename squad metrics * refactor: url argument in squad_dataset_reader has higher priority than dataset * refactor: remove process_event method from bert models * feat: add bert-base model for EN SQuAD * feat: add bert dstc2 model * fix: squad bert infer config * docs: add docs for NerPreprocessor * feat: give higher priority to url arg in multi_squad_dataset_reader * fix: bert NER preprocessor
- Loading branch information
1 parent
b57151e
commit 174dec6
Showing
35 changed files
with
2,788 additions
and
49 deletions.
There are no files selected for viewing
144 changes: 144 additions & 0 deletions
144
deeppavlov/configs/classifiers/insults_kaggle_bert.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
{ | ||
"dataset_reader": { | ||
"class_name": "basic_classification_reader", | ||
"x": "Comment", | ||
"y": "Class", | ||
"data_path": "{DOWNLOADS_PATH}/insults_data" | ||
}, | ||
"dataset_iterator": { | ||
"class_name": "basic_classification_iterator", | ||
"seed": 42 | ||
}, | ||
"chainer": { | ||
"in": [ | ||
"x" | ||
], | ||
"in_y": [ | ||
"y" | ||
], | ||
"pipe": [ | ||
{ | ||
"class_name": "bert_preprocessor", | ||
"vocab_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/vocab.txt", | ||
"do_lower_case": false, | ||
"max_seq_length": 64, | ||
"in": [ | ||
"x" | ||
], | ||
"out": [ | ||
"bert_features" | ||
] | ||
}, | ||
{ | ||
"id": "classes_vocab", | ||
"class_name": "simple_vocab", | ||
"fit_on": [ | ||
"y" | ||
], | ||
"save_path": "{MODELS_PATH}/classes.dict", | ||
"load_path": "{MODELS_PATH}/classes.dict", | ||
"in": "y", | ||
"out": "y_ids" | ||
}, | ||
{ | ||
"in": "y_ids", | ||
"out": "y_onehot", | ||
"class_name": "one_hotter", | ||
"depth": "#classes_vocab.len", | ||
"single_vector": true | ||
}, | ||
{ | ||
"class_name": "bert_classifier", | ||
"n_classes": "#classes_vocab.len", | ||
"return_probas": true, | ||
"one_hot_labels": true, | ||
"bert_config_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_config.json", | ||
"pretrained_bert": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_model.ckpt", | ||
"save_path": "{MODELS_PATH}/model", | ||
"load_path": "{MODELS_PATH}/model", | ||
"keep_prob": 0.5, | ||
"learning_rate": 1e-05, | ||
"learning_rate_drop_patience": 5, | ||
"learning_rate_drop_div": 2.0, | ||
"in": [ | ||
"bert_features" | ||
], | ||
"in_y": [ | ||
"y_onehot" | ||
], | ||
"out": [ | ||
"y_pred_probas" | ||
] | ||
}, | ||
{ | ||
"in": "y_pred_probas", | ||
"out": "y_pred_ids", | ||
"class_name": "proba2labels", | ||
"max_proba": true | ||
}, | ||
{ | ||
"in": "y_pred_ids", | ||
"out": "y_pred_labels", | ||
"ref": "classes_vocab" | ||
} | ||
], | ||
"out": [ | ||
"y_pred_labels" | ||
] | ||
}, | ||
"train": { | ||
"epochs": 100, | ||
"batch_size": 64, | ||
"metrics": [ | ||
{ | ||
"name": "roc_auc", | ||
"inputs": [ | ||
"y_onehot", | ||
"y_pred_probas" | ||
] | ||
}, | ||
"sets_accuracy", | ||
"f1_macro" | ||
], | ||
"validation_patience": 5, | ||
"val_every_n_epochs": 1, | ||
"log_every_n_epochs": 1, | ||
"show_examples": false, | ||
"evaluation_targets": [ | ||
"train", | ||
"valid", | ||
"test" | ||
], | ||
"class_name": "nn_trainer", | ||
"tensorboard_log_dir": "{MODELS_PATH}/" | ||
}, | ||
"metadata": { | ||
"variables": { | ||
"ROOT_PATH": "~/.deeppavlov", | ||
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads", | ||
"MODELS_PATH": "{ROOT_PATH}/models/classifiers/insults_kaggle_v3" | ||
}, | ||
"requirements": [ | ||
"{DEEPPAVLOV_PATH}/requirements/tf.txt", | ||
"{DEEPPAVLOV_PATH}/requirements/bert_dp.txt" | ||
], | ||
"labels": { | ||
"telegram_utils": "IntentModel", | ||
"server_utils": "KerasIntentModel" | ||
}, | ||
"download": [ | ||
{ | ||
"url": "http://files.deeppavlov.ai/datasets/insults_data.tar.gz", | ||
"subdir": "{DOWNLOADS_PATH}" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/bert/cased_L-12_H-768_A-12.zip", | ||
"subdir": "{DOWNLOADS_PATH}/bert_models" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/classifiers/insults_kaggle_v3.tar.gz", | ||
"subdir": "{ROOT_PATH}/models/classifiers" | ||
} | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
{ | ||
"dataset_reader": { | ||
"class_name": "dstc2_reader", | ||
"data_path": "{DOWNLOADS_PATH}/dstc2" | ||
}, | ||
"dataset_iterator": { | ||
"class_name": "dstc2_intents_iterator", | ||
"seed": 42 | ||
}, | ||
"chainer": { | ||
"in": ["x"], | ||
"in_y": ["y"], | ||
"pipe": [ | ||
{ | ||
"id": "classes_vocab", | ||
"class_name": "simple_vocab", | ||
"fit_on": ["y"], | ||
"save_path": "{MODELS_PATH}/classes.dict", | ||
"load_path": "{MODELS_PATH}/classes.dict", | ||
"in": "y", | ||
"out": "y_ids", | ||
"special_tokens": ["<UNK>"] | ||
}, | ||
{ | ||
"class_name": "bert_preprocessor", | ||
"vocab_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/vocab.txt", | ||
"do_lower_case": false, | ||
"max_seq_length": 64, | ||
"in": ["x"], | ||
"out": ["bert_features"] | ||
}, | ||
{ | ||
"in": "y_ids", | ||
"out": "y_onehot", | ||
"class_name": "one_hotter", | ||
"id": "my_one_hotter", | ||
"depth": "#classes_vocab.len", | ||
"single_vector": true | ||
}, | ||
{ | ||
"class_name": "bert_classifier", | ||
"n_classes": "#classes_vocab.len", | ||
"return_probas": true, | ||
"one_hot_labels": true, | ||
"multilabel": true, | ||
"bert_config_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_config.json", | ||
"pretrained_bert": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_model.ckpt", | ||
"save_path": "{MODELS_PATH}/model", | ||
"load_path": "{MODELS_PATH}/model", | ||
"keep_prob": 0.5, | ||
"learning_rate": 2e-05, | ||
"learning_rate_drop_patience": 3, | ||
"learning_rate_drop_div": 2.0, | ||
"in": ["bert_features"], | ||
"in_y": ["y_onehot"], | ||
"out": ["y_pred_probas"] | ||
}, | ||
{ | ||
"in": "y_pred_probas", | ||
"out": "y_pred_ids", | ||
"class_name": "proba2labels", | ||
"confident_threshold": 0.5 | ||
}, | ||
{ | ||
"in": "y_pred_ids", | ||
"out": "y_pred_labels", | ||
"ref": "classes_vocab" | ||
}, | ||
{ | ||
"ref": "my_one_hotter", | ||
"in": "y_pred_ids", | ||
"out": "y_pred_onehot" | ||
} | ||
], | ||
"out": ["y_pred_probas", "y_pred_labels"] | ||
}, | ||
"train": { | ||
"metrics": [ | ||
{ | ||
"name": "sets_accuracy", | ||
"inputs": ["y", "y_pred_labels"] | ||
}, | ||
{ | ||
"name": "roc_auc", | ||
"inputs": ["y_onehot", "y_pred_probas"] | ||
} | ||
], | ||
"show_examples": false, | ||
"batch_size": 32, | ||
"pytest_max_batches": 2, | ||
"validation_patience": 10, | ||
"val_every_n_batches": 100, | ||
"log_every_n_batches": 100, | ||
"validate_best": true, | ||
"test_best": true, | ||
"tensorboard_log_dir": "{MODELS_PATH}/logs" | ||
}, | ||
"metadata": { | ||
"variables": { | ||
"ROOT_PATH": "~/.deeppavlov", | ||
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads", | ||
"MODELS_PATH": "{ROOT_PATH}/models/classifiers/intents_dstc2_bert_v0" | ||
}, | ||
"requirements": [ | ||
"{DEEPPAVLOV_PATH}/requirements/tf.txt", | ||
"{DEEPPAVLOV_PATH}/requirements/bert_dp.txt" | ||
], | ||
"labels": { | ||
"telegram_utils": "IntentModel", | ||
"server_utils": "KerasIntentModel" | ||
}, | ||
"download": [ | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/bert/cased_L-12_H-768_A-12.zip", | ||
"subdir": "{DOWNLOADS_PATH}/bert_models" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/datasets/dstc2_v2.tar.gz", | ||
"subdir": "{DOWNLOADS_PATH}/dstc2" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/classifiers/intents_dstc2_bert_v0.tar.gz", | ||
"subdir": "{ROOT_PATH}/models/classifiers" | ||
} | ||
|
||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
{ | ||
"dataset_reader": { | ||
"class_name": "paraphraser_reader", | ||
"data_path": "{DOWNLOADS_PATH}/paraphraser_data", | ||
"do_lower_case": false | ||
}, | ||
"dataset_iterator": { | ||
"class_name": "siamese_iterator", | ||
"seed": 243, | ||
"len_valid": 500 | ||
}, | ||
"chainer": { | ||
"in": [ | ||
"text_a", | ||
"text_b" | ||
], | ||
"in_y": [ | ||
"y" | ||
], | ||
"pipe": [ | ||
{ | ||
"class_name": "bert_preprocessor", | ||
"vocab_file": "{DOWNLOADS_PATH}/bert_models/multi_cased_L-12_H-768_A-12/vocab.txt", | ||
"do_lower_case": false, | ||
"max_seq_length": 64, | ||
"in": [ | ||
"text_a", | ||
"text_b" | ||
], | ||
"out": [ | ||
"bert_features" | ||
] | ||
}, | ||
{ | ||
"class_name": "bert_classifier", | ||
"n_classes": 2, | ||
"one_hot_labels": false, | ||
"bert_config_file": "{DOWNLOADS_PATH}/bert_models/multi_cased_L-12_H-768_A-12/bert_config.json", | ||
"pretrained_bert": "{DOWNLOADS_PATH}/bert_models/multi_cased_L-12_H-768_A-12/bert_model.ckpt", | ||
"save_path": "{MODELS_PATH}/model_multi", | ||
"load_path": "{MODELS_PATH}/model_multi", | ||
"keep_prob": 0.5, | ||
"learning_rate": 2e-05, | ||
"learning_rate_drop_patience": 2, | ||
"learning_rate_drop_div": 2.0, | ||
"in": [ | ||
"bert_features" | ||
], | ||
"in_y": [ | ||
"y" | ||
], | ||
"out": [ | ||
"predictions" | ||
] | ||
} | ||
], | ||
"out": [ | ||
"predictions" | ||
] | ||
}, | ||
"train": { | ||
"batch_size": 32, | ||
"pytest_max_batches": 2, | ||
"metrics": [ | ||
"f1", | ||
"acc" | ||
], | ||
"validation_patience": 10, | ||
"val_every_n_batches": 100, | ||
"log_every_n_batches": 100, | ||
"evaluation_targets": [ | ||
"train", | ||
"valid", | ||
"test" | ||
], | ||
"tensorboard_log_dir": "{MODELS_PATH}/" | ||
}, | ||
"metadata": { | ||
"variables": { | ||
"ROOT_PATH": "~/.deeppavlov", | ||
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads", | ||
"MODELS_PATH": "{ROOT_PATH}/models/paraphraser_bert_v0" | ||
}, | ||
"requirements": [ | ||
"{DEEPPAVLOV_PATH}/requirements/tf.txt", | ||
"{DEEPPAVLOV_PATH}/requirements/bert_dp.txt" | ||
], | ||
"download": [ | ||
{ | ||
"url": "http://files.deeppavlov.ai/datasets/paraphraser.zip", | ||
"subdir": "{DOWNLOADS_PATH}/paraphraser_data" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/datasets/paraphraser_gold.zip", | ||
"subdir": "{DOWNLOADS_PATH}/paraphraser_data" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/bert/multi_cased_L-12_H-768_A-12.zip", | ||
"subdir": "{DOWNLOADS_PATH}/bert_models" | ||
}, | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/classifiers/paraphraser_bert_v0.tar.gz", | ||
"subdir": "{ROOT_PATH}/models" | ||
} | ||
] | ||
} | ||
} |
Oops, something went wrong.