Skip to content

Commit da4cba7

Browse files
committed
fix syn_words and CRF
1 parent d4a9414 commit da4cba7

File tree

2 files changed

+53
-40
lines changed

2 files changed

+53
-40
lines changed

AugmentText/augment_eda/enhance_eda_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
KEY_WORDS = ["macropodus"] # 不替换同义词的词语
15-
ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
15+
ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
1616

1717

1818
def is_english(text):
@@ -22,7 +22,7 @@ def is_english(text):
2222
:return: boolean, True or False
2323
"""
2424
try:
25-
text_r = text.replace(' ', '').strip()
25+
text_r = text.replace(" ", "").strip()
2626
for tr in text_r:
2727
if tr in ENGLISH:
2828
continue
@@ -39,7 +39,7 @@ def is_number(text):
3939
:return: boolean, True or False
4040
"""
4141
try:
42-
text_r = text.replace(' ', '').strip()
42+
text_r = text.replace(" ", "").strip()
4343
for tr in text_r:
4444
if tr.isdigit():
4545
continue
@@ -57,7 +57,7 @@ def get_syn_word(word):
5757
"""
5858
if not is_number(word.strip()) or not is_english(word.strip()):
5959
word_syn = synonyms.nearby(word)
60-
word_syn = word_syn if not word_syn else [word]
60+
word_syn = word_syn[0] if len(word_syn[0]) else [word]
6161
return word_syn
6262
else:
6363
return [word]
@@ -124,7 +124,7 @@ def word_swap(words, n=1):
124124
while count < n:
125125
idx_select = random.sample(idxs, 2)
126126
temp = words[idx_select[0]]
127-
words[idx_select[0]] = words[idx_select[1]]
127+
words[idx_select[0]] = words[idx_select[1]]
128128
words[idx_select[1]] = temp
129129
count += 1
130130
return words
@@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True):
182182
return sens_4
183183

184184

185-
if __name__ == '__main__':
185+
if __name__ == "__main__":
186186
sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
187187
"只需做好这四点,就能让你养的天竺葵全年花开不断!"])
188188
print(eda(sens))

Ner/bert/keras_bert_layer.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from keras import activations
2424
from keras import constraints
2525
import warnings
26-
import keras
26+
import os
2727
# crf_loss
2828
from keras.losses import sparse_categorical_crossentropy
2929
from keras.losses import categorical_crossentropy
@@ -220,7 +220,7 @@ def to_tuple(shape):
220220
with keras-team/keras. So we must apply this function to
221221
all input_shapes of the build methods in custom layers.
222222
"""
223-
if is_tf_keras:
223+
if os.environ.get("TF_KERAS")==1:
224224
import tensorflow as tf
225225
return tuple(tf.TensorShape(shape).as_list())
226226
else:
@@ -530,36 +530,36 @@ def get_config(self):
530530
base_config = super(CRF, self).get_config()
531531
return dict(list(base_config.items()) + list(config.items()))
532532

533-
# @property
534-
# def loss_function(self):
535-
# warnings.warn('CRF.loss_function is deprecated '
536-
# 'and it might be removed in the future. Please '
537-
# 'use losses.crf_loss instead.')
538-
# return crf_loss
539-
#
540-
# @property
541-
# def accuracy(self):
542-
# warnings.warn('CRF.accuracy is deprecated and it '
543-
# 'might be removed in the future. Please '
544-
# 'use metrics.crf_accuracy')
545-
# if self.test_mode == 'viterbi':
546-
# return crf_viterbi_accuracy
547-
# else:
548-
# return crf_marginal_accuracy
549-
#
550-
# @property
551-
# def viterbi_acc(self):
552-
# warnings.warn('CRF.viterbi_acc is deprecated and it might '
553-
# 'be removed in the future. Please '
554-
# 'use metrics.viterbi_acc instead.')
555-
# return crf_viterbi_accuracy
556-
#
557-
# @property
558-
# def marginal_acc(self):
559-
# warnings.warn('CRF.moarginal_acc is deprecated and it '
560-
# 'might be removed in the future. Please '
561-
# 'use metrics.marginal_acc instead.')
562-
# return crf_marginal_accuracy
533+
@property
534+
def loss_function(self):
535+
warnings.warn('CRF.loss_function is deprecated '
536+
'and it might be removed in the future. Please '
537+
'use losses.crf_loss instead.')
538+
return crf_loss
539+
540+
@property
541+
def accuracy(self):
542+
warnings.warn('CRF.accuracy is deprecated and it '
543+
'might be removed in the future. Please '
544+
'use metrics.crf_accuracy')
545+
if self.test_mode == 'viterbi':
546+
return crf_viterbi_accuracy
547+
else:
548+
return crf_marginal_accuracy
549+
550+
@property
551+
def viterbi_acc(self):
552+
warnings.warn('CRF.viterbi_acc is deprecated and it might '
553+
'be removed in the future. Please '
554+
'use metrics.viterbi_acc instead.')
555+
return crf_viterbi_accuracy
556+
557+
@property
558+
def marginal_acc(self):
559+
warnings.warn('CRF.moarginal_acc is deprecated and it '
560+
'might be removed in the future. Please '
561+
'use metrics.marginal_acc instead.')
562+
return crf_marginal_accuracy
563563

564564
@staticmethod
565565
def softmaxNd(x, axis=-1):
@@ -655,9 +655,22 @@ def step(self, input_energy_t, states, return_logZ=True):
655655
chain_energy = chain_energy * K.expand_dims(
656656
K.expand_dims(m[:, 0] * m[:, 1]))
657657
if return_logZ:
658-
# shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
658+
# # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
659+
# energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
660+
# new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
661+
# return new_target_val, [new_target_val, i + 1]
662+
659663
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
660-
new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
664+
new_target_val = K.logsumexp(-energy, 1)
665+
# added from here
666+
if len(states) > 3:
667+
if K.backend() == 'theano':
668+
m = states[3][:, t:(t + 2)]
669+
else:
670+
m = K.slice(states[3], [0, t], [-1, 2])
671+
is_valid = K.expand_dims(m[:, 0])
672+
new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val
673+
# added until here
661674
return new_target_val, [new_target_val, i + 1]
662675
else:
663676
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)

0 commit comments

Comments
 (0)