Skip to content

Commit 839e3b4

Browse files
committed
1. DeBERTa v2
2. Add DeBERTv2 xlarge, xxlarge and MNLI xlarge-v2, xxlarge-v2 models 3. Fix GLUE data downloading issue. 4. Support plugin tasks 5. Update experiments
1 parent d9e01c6 commit 839e3b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2077
-2095
lines changed

DeBERTa/apps/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from .task_registry import tasks
1+
import os
2+
# This statement must be executed at the very beginning, i.e. before import torch
3+
os.environ["OMP_NUM_THREADS"] = "1"

DeBERTa/apps/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ner import *
2+
from .multi_choice import *
3+
from .sequence_classification import *

DeBERTa/apps/multi_choice.py renamed to DeBERTa/apps/models/multi_choice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
from torch.nn import CrossEntropyLoss
1616
import math
1717

18-
from ..deberta import *
19-
from ..utils import *
18+
from ...deberta import *
19+
from ...utils import *
2020
import pdb
2121

2222
__all__ = ['MultiChoiceModel']
2323
class MultiChoiceModel(NNModule):
2424
def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
2525
super().__init__(config)
26-
self.bert = DeBERTa(config)
26+
self.deberta = DeBERTa(config)
2727
self.num_labels = num_labels
28-
self.classifier = nn.Linear(config.hidden_size, 1)
28+
self.classifier = torch.nn.Linear(config.hidden_size, 1)
2929
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
3030
self.dropout = StableDropout(drop_out)
3131
self.apply(self.init_weights)
@@ -39,7 +39,7 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
3939
position_ids = position_ids.view([-1, position_ids.size(-1)])
4040
if input_mask is not None:
4141
input_mask = input_mask.view([-1, input_mask.size(-1)])
42-
encoder_layers = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
42+
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
4343
position_ids=position_ids, output_all_encoded_layers=True)
4444
seqout = encoder_layers[-1]
4545
cls = seqout[:,:1,:]

DeBERTa/apps/ner.py renamed to DeBERTa/apps/models/ner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import math
1616
from torch import nn
1717
from torch.nn import CrossEntropyLoss
18-
from ..deberta import DeBERTa,NNModule,ACT2FN,StableDropout
18+
from ...deberta import DeBERTa,NNModule,ACT2FN,StableDropout
1919

2020
__all__ = ['NERModel']
2121

DeBERTa/apps/sequence_classification.py renamed to DeBERTa/apps/models/sequence_classification.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,35 @@
1414
import torch
1515
from torch.nn import CrossEntropyLoss
1616
import math
17+
import pdb
1718

18-
from ..deberta import *
19-
from ..utils import *
19+
from ...deberta import *
20+
from ...utils import *
2021

2122
__all__= ['SequenceClassificationModel']
2223
class SequenceClassificationModel(NNModule):
2324
def __init__(self, config, num_labels=2, drop_out=None, pre_trained=None):
2425
super().__init__(config)
2526
self.num_labels = num_labels
26-
self.bert = DeBERTa(config, pre_trained=pre_trained)
27+
self._register_load_state_dict_pre_hook(self._pre_load_hook)
28+
self.deberta = DeBERTa(config, pre_trained=pre_trained)
2729
if pre_trained is not None:
28-
self.config = self.bert.config
30+
self.config = self.deberta.config
2931
else:
3032
self.config = config
3133
pool_config = PoolConfig(self.config)
32-
output_dim = self.bert.config.hidden_size
34+
output_dim = self.deberta.config.hidden_size
3335
self.pooler = ContextPooler(pool_config)
3436
output_dim = self.pooler.output_dim()
3537

3638
self.classifier = torch.nn.Linear(output_dim, num_labels)
3739
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
3840
self.dropout = StableDropout(drop_out)
3941
self.apply(self.init_weights)
40-
self.bert.apply_state()
42+
self.deberta.apply_state()
4143

4244
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
43-
encoder_layers = self.bert(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
45+
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
4446
position_ids=position_ids, output_all_encoded_layers=True)
4547
pooled_output = self.pooler(encoder_layers[-1])
4648
pooled_output = self.dropout(pooled_output)
@@ -69,3 +71,15 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
6971
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()
7072

7173
return (logits,loss)
74+
75+
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
76+
missing_keys, unexpected_keys, error_msgs):
77+
new_state = dict()
78+
bert_prefix = prefix + 'bert.'
79+
deberta_prefix = prefix + 'deberta.'
80+
for k in list(state_dict.keys()):
81+
if k.startswith(bert_prefix):
82+
nk = deberta_prefix + k[len(bert_prefix):]
83+
value = state_dict[k]
84+
del state_dict[k]
85+
state_dict[nk] = value

0 commit comments

Comments
 (0)