|
23 | 23 | from keras import activations
|
24 | 24 | from keras import constraints
|
25 | 25 | import warnings
|
26 |
| -import keras |
| 26 | +import os |
27 | 27 | # crf_loss
|
28 | 28 | from keras.losses import sparse_categorical_crossentropy
|
29 | 29 | from keras.losses import categorical_crossentropy
|
@@ -220,7 +220,7 @@ def to_tuple(shape):
|
220 | 220 | with keras-team/keras. So we must apply this function to
|
221 | 221 | all input_shapes of the build methods in custom layers.
|
222 | 222 | """
|
223 |
| - if is_tf_keras: |
| 223 | + if os.environ.get("TF_KERAS")==1: |
224 | 224 | import tensorflow as tf
|
225 | 225 | return tuple(tf.TensorShape(shape).as_list())
|
226 | 226 | else:
|
@@ -530,36 +530,36 @@ def get_config(self):
|
530 | 530 | base_config = super(CRF, self).get_config()
|
531 | 531 | return dict(list(base_config.items()) + list(config.items()))
|
532 | 532 |
|
533 |
| - # @property |
534 |
| - # def loss_function(self): |
535 |
| - # warnings.warn('CRF.loss_function is deprecated ' |
536 |
| - # 'and it might be removed in the future. Please ' |
537 |
| - # 'use losses.crf_loss instead.') |
538 |
| - # return crf_loss |
539 |
| - # |
540 |
| - # @property |
541 |
| - # def accuracy(self): |
542 |
| - # warnings.warn('CRF.accuracy is deprecated and it ' |
543 |
| - # 'might be removed in the future. Please ' |
544 |
| - # 'use metrics.crf_accuracy') |
545 |
| - # if self.test_mode == 'viterbi': |
546 |
| - # return crf_viterbi_accuracy |
547 |
| - # else: |
548 |
| - # return crf_marginal_accuracy |
549 |
| - # |
550 |
| - # @property |
551 |
| - # def viterbi_acc(self): |
552 |
| - # warnings.warn('CRF.viterbi_acc is deprecated and it might ' |
553 |
| - # 'be removed in the future. Please ' |
554 |
| - # 'use metrics.viterbi_acc instead.') |
555 |
| - # return crf_viterbi_accuracy |
556 |
| - # |
557 |
| - # @property |
558 |
| - # def marginal_acc(self): |
559 |
| - # warnings.warn('CRF.moarginal_acc is deprecated and it ' |
560 |
| - # 'might be removed in the future. Please ' |
561 |
| - # 'use metrics.marginal_acc instead.') |
562 |
| - # return crf_marginal_accuracy |
| 533 | + @property |
| 534 | + def loss_function(self): |
| 535 | + warnings.warn('CRF.loss_function is deprecated ' |
| 536 | + 'and it might be removed in the future. Please ' |
| 537 | + 'use losses.crf_loss instead.') |
| 538 | + return crf_loss |
| 539 | + |
| 540 | + @property |
| 541 | + def accuracy(self): |
| 542 | + warnings.warn('CRF.accuracy is deprecated and it ' |
| 543 | + 'might be removed in the future. Please ' |
| 544 | + 'use metrics.crf_accuracy') |
| 545 | + if self.test_mode == 'viterbi': |
| 546 | + return crf_viterbi_accuracy |
| 547 | + else: |
| 548 | + return crf_marginal_accuracy |
| 549 | + |
| 550 | + @property |
| 551 | + def viterbi_acc(self): |
| 552 | + warnings.warn('CRF.viterbi_acc is deprecated and it might ' |
| 553 | + 'be removed in the future. Please ' |
| 554 | + 'use metrics.viterbi_acc instead.') |
| 555 | + return crf_viterbi_accuracy |
| 556 | + |
| 557 | + @property |
| 558 | + def marginal_acc(self): |
| 559 | + warnings.warn('CRF.moarginal_acc is deprecated and it ' |
| 560 | + 'might be removed in the future. Please ' |
| 561 | + 'use metrics.marginal_acc instead.') |
| 562 | + return crf_marginal_accuracy |
563 | 563 |
|
564 | 564 | @staticmethod
|
565 | 565 | def softmaxNd(x, axis=-1):
|
@@ -655,9 +655,22 @@ def step(self, input_energy_t, states, return_logZ=True):
|
655 | 655 | chain_energy = chain_energy * K.expand_dims(
|
656 | 656 | K.expand_dims(m[:, 0] * m[:, 1]))
|
657 | 657 | if return_logZ:
|
658 |
| - # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) |
| 658 | + # # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) |
| 659 | + # energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) |
| 660 | + # new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) |
| 661 | + # return new_target_val, [new_target_val, i + 1] |
| 662 | + |
659 | 663 | energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
660 |
| - new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) |
| 664 | + new_target_val = K.logsumexp(-energy, 1) |
| 665 | + # added from here |
| 666 | + if len(states) > 3: |
| 667 | + if K.backend() == 'theano': |
| 668 | + m = states[3][:, t:(t + 2)] |
| 669 | + else: |
| 670 | + m = K.slice(states[3], [0, t], [-1, 2]) |
| 671 | + is_valid = K.expand_dims(m[:, 0]) |
| 672 | + new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val |
| 673 | + # added until here |
661 | 674 | return new_target_val, [new_target_val, i + 1]
|
662 | 675 | else:
|
663 | 676 | energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)
|
|
0 commit comments