|
| 1 | +#include "faster-rnnlm/layers/gru_layer.h" |
| 2 | + |
| 3 | +#include "faster-rnnlm/layers/activation_functions.h" |
| 4 | + |
| 5 | + |
| 6 | +class GRULayer::Updater : public IRecUpdater, public TruncatedBPTTMixin<GRULayer::Updater> { |
| 7 | + public: |
| 8 | + explicit Updater(GRULayer* layer) |
| 9 | + : IRecUpdater(layer->weights_.layer_size_) |
| 10 | + , layer_(*layer) |
| 11 | + , reset_(MAX_SENTENCE_WORDS, size_), reset_g_(reset_) |
| 12 | + , update_(MAX_SENTENCE_WORDS, size_), update_g_(update_) |
| 13 | + , partialhidden_(MAX_SENTENCE_WORDS, size_), partialhidden_g_(partialhidden_) |
| 14 | + , quasihidden_(MAX_SENTENCE_WORDS, size_), quasihidden_g_(quasihidden_) |
| 15 | + |
| 16 | + , syn_reset_in_(&layer_.weights_.syn_reset_in_) |
| 17 | + , syn_reset_out_(&layer_.weights_.syn_reset_out_) |
| 18 | + , syn_update_in_(&layer_.weights_.syn_update_in_) |
| 19 | + , syn_update_out_(&layer_.weights_.syn_update_out_) |
| 20 | + , syn_quasihidden_in_(&layer_.weights_.syn_quasihidden_in_) |
| 21 | + , syn_quasihidden_out_(&layer_.weights_.syn_quasihidden_out_) |
| 22 | + |
| 23 | + , bias_reset_(&layer_.weights_.bias_reset_) |
| 24 | + , bias_update_(&layer_.weights_.bias_update_) |
| 25 | + { |
| 26 | + partialhidden_.row(0).setZero(); |
| 27 | + reset_g_.row(0).setZero(); |
| 28 | + } |
| 29 | + |
| 30 | + void ForwardSubSequence(int start, int steps); |
| 31 | + |
| 32 | + void BackwardSequence(int steps, uint32_t truncation_seed, int bptt_period, int bptt); |
| 33 | + |
| 34 | + void UpdateWeights(int steps, Real lrate, Real l2reg, Real rmsprop, Real gradient_clipping); |
| 35 | + |
| 36 | + void BackwardStep(int step); |
| 37 | + |
| 38 | + void BackwardStepThroughTime(int step); |
| 39 | + |
| 40 | + private: |
| 41 | + GRULayer& layer_; |
| 42 | + |
| 43 | + RowMatrix reset_, reset_g_; |
| 44 | + RowMatrix update_, update_g_; |
| 45 | + RowMatrix partialhidden_, partialhidden_g_; |
| 46 | + RowMatrix quasihidden_, quasihidden_g_; |
| 47 | + |
| 48 | + WeightMatrixUpdater<RowMatrix> syn_reset_in_; |
| 49 | + WeightMatrixUpdater<RowMatrix> syn_reset_out_; |
| 50 | + WeightMatrixUpdater<RowMatrix> syn_update_in_; |
| 51 | + WeightMatrixUpdater<RowMatrix> syn_update_out_; |
| 52 | + WeightMatrixUpdater<RowMatrix> syn_quasihidden_in_; |
| 53 | + WeightMatrixUpdater<RowMatrix> syn_quasihidden_out_; |
| 54 | + |
| 55 | + WeightMatrixUpdater<RowVector> bias_reset_; |
| 56 | + WeightMatrixUpdater<RowVector> bias_update_; |
| 57 | +}; |
| 58 | + |
| 59 | +void GRULayer::Updater::ForwardSubSequence(int start, int steps) { |
| 60 | + reset_.middleRows(start, steps) = input_.middleRows(start, steps); |
| 61 | + update_.middleRows(start, steps) = input_.middleRows(start, steps); |
| 62 | + quasihidden_.middleRows(start, steps) = input_.middleRows(start, steps); |
| 63 | + |
| 64 | + if (layer_.use_input_weights_) { |
| 65 | + reset_.middleRows(start, steps) *= syn_reset_in_.W().transpose(); |
| 66 | + update_.middleRows(start, steps) *= syn_update_in_.W().transpose(); |
| 67 | + quasihidden_.middleRows(start, steps) *= syn_quasihidden_in_.W().transpose(); |
| 68 | + } |
| 69 | + |
| 70 | + for (int step = start; step < start + steps; ++step) { |
| 71 | + if (layer_.use_bias_) { |
| 72 | + reset_.row(step) += bias_reset_.W(); |
| 73 | + update_.row(step) += bias_update_.W(); |
| 74 | + } |
| 75 | + |
| 76 | + if (step != 0) { |
| 77 | + reset_.row(step).noalias() += output_.row(step - 1) * syn_reset_out_.W().transpose(); |
| 78 | + update_.row(step).noalias() += output_.row(step - 1) * syn_update_out_.W().transpose(); |
| 79 | + } |
| 80 | + SigmoidActivation().Forward(reset_.row(step).data(), size_); |
| 81 | + SigmoidActivation().Forward(update_.row(step).data(), size_); |
| 82 | + |
| 83 | + if (step != 0) { |
| 84 | + partialhidden_.row(step).noalias() = output_.row(step - 1).cwiseProduct(reset_.row(step)); |
| 85 | + quasihidden_.row(step).noalias() += |
| 86 | + partialhidden_.row(step) * syn_quasihidden_out_.W().transpose(); |
| 87 | + } |
| 88 | + TanhActivation().Forward(quasihidden_.row(step).data(), size_); |
| 89 | + |
| 90 | + if (step == 0) { |
| 91 | + output_.row(step).row(step).noalias() |
| 92 | + = quasihidden_.row(step).cwiseProduct(update_.row(step)); |
| 93 | + } else { |
| 94 | + // these 3 lines means: |
| 95 | + // output_t = (quasihidden_t - output_{t - 1}) * update_t + output_{t - 1} |
| 96 | + output_.row(step).noalias() = quasihidden_.row(step) - output_.row(step - 1); |
| 97 | + output_.row(step).array() *= update_.row(step).array(); |
| 98 | + output_.row(step) += output_.row(step - 1); |
| 99 | + } |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +void GRULayer::Updater::BackwardSequence( |
| 104 | + int steps, uint32_t truncation_seed, int bptt_period, int bptt) { |
| 105 | + BackwardSequenceTruncated(steps, truncation_seed, bptt_period, bptt); |
| 106 | + |
| 107 | + input_g_.topRows(steps).setZero(); |
| 108 | + if (layer_.use_input_weights_) { |
| 109 | + input_g_.topRows(steps).noalias() += quasihidden_g_.topRows(steps) * syn_update_in_.W(); |
| 110 | + input_g_.topRows(steps).noalias() += reset_g_.topRows(steps) * syn_reset_in_.W(); |
| 111 | + input_g_.topRows(steps).noalias() += update_g_.topRows(steps) * syn_update_in_.W(); |
| 112 | + } else { |
| 113 | + input_g_.topRows(steps) += quasihidden_g_.topRows(steps); |
| 114 | + input_g_.topRows(steps) += reset_g_.topRows(steps); |
| 115 | + input_g_.topRows(steps) += update_g_.topRows(steps); |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +void GRULayer::Updater::BackwardStep(int step) { |
| 120 | + update_g_.row(step) = quasihidden_.row(step); |
| 121 | + if (step != 0) { |
| 122 | + update_g_.row(step) -= output_.row(step - 1); |
| 123 | + } |
| 124 | + update_g_.row(step).array() *= output_g_.row(step).array(); |
| 125 | + SigmoidActivation().Backward(update_.row(step).data(), size_, update_g_.row(step).data()); |
| 126 | + |
| 127 | + quasihidden_g_.row(step) = output_g_.row(step).cwiseProduct(update_.row(step)); |
| 128 | + TanhActivation().Backward(quasihidden_.row(step).data(), size_, quasihidden_g_.row(step).data()); |
| 129 | + |
| 130 | + partialhidden_g_.row(step).noalias() = quasihidden_g_.row(step) * syn_quasihidden_out_.W(); |
| 131 | + |
| 132 | + if (step != 0) { |
| 133 | + reset_g_.row(step) = partialhidden_g_.row(step).cwiseProduct(output_.row(step - 1)); |
| 134 | + } |
| 135 | + SigmoidActivation().Backward(reset_.row(step).data(), size_, reset_g_.row(step).data()); |
| 136 | +} |
| 137 | + |
| 138 | + |
| 139 | +void GRULayer::Updater::BackwardStepThroughTime(int step) { |
| 140 | + // these 2 lines means: h'_{t - 1} += (1 - u_t) * h'_t |
| 141 | + output_g_.row(step - 1) += output_g_.row(step); |
| 142 | + output_g_.row(step - 1) -= update_.row(step).cwiseProduct(output_g_.row(step)); |
| 143 | + |
| 144 | + output_g_.row(step - 1) += reset_.row(step).cwiseProduct(partialhidden_g_.row(step)); |
| 145 | + output_g_.row(step - 1).noalias() += reset_g_.row(step) * syn_reset_out_.W(); |
| 146 | + output_g_.row(step - 1).noalias() += update_g_.row(step) * syn_update_out_.W(); |
| 147 | +} |
| 148 | + |
| 149 | + |
| 150 | +void GRULayer::Updater::UpdateWeights( |
| 151 | + int steps, Real lrate, Real l2reg, Real rmsprop, Real gradient_clipping) { |
| 152 | + if (steps <= 1 || size_ == 0) { |
| 153 | + return; |
| 154 | + } |
| 155 | + |
| 156 | + UpdateRecurrentSynWeights(steps - 1, lrate, l2reg, rmsprop, gradient_clipping, |
| 157 | + partialhidden_.middleRows(1, steps - 1), quasihidden_g_.middleRows(1, steps - 1), |
| 158 | + &syn_quasihidden_out_); |
| 159 | + UpdateRecurrentSynWeights(steps - 1, lrate, l2reg, rmsprop, gradient_clipping, |
| 160 | + output_, reset_g_.bottomRows(reset_g_.rows() - 1), |
| 161 | + &syn_reset_out_); |
| 162 | + UpdateRecurrentSynWeights(steps - 1, lrate, l2reg, rmsprop, gradient_clipping, |
| 163 | + output_, update_g_.bottomRows(update_g_.rows() - 1), |
| 164 | + &syn_update_out_); |
| 165 | + |
| 166 | + if (layer_.use_input_weights_) { |
| 167 | + UpdateRecurrentSynWeights(steps, lrate, l2reg, rmsprop, gradient_clipping, |
| 168 | + input_, quasihidden_g_, |
| 169 | + &syn_quasihidden_in_); |
| 170 | + UpdateRecurrentSynWeights(steps, lrate, l2reg, rmsprop, gradient_clipping, |
| 171 | + input_, reset_g_, |
| 172 | + &syn_reset_in_); |
| 173 | + UpdateRecurrentSynWeights(steps, lrate, l2reg, rmsprop, gradient_clipping, |
| 174 | + input_, update_g_, |
| 175 | + &syn_update_in_); |
| 176 | + } |
| 177 | + |
| 178 | + if (layer_.use_bias_) { |
| 179 | + *bias_reset_.GetGradients() = reset_g_.middleRows(1, steps - 1).colwise().mean(); |
| 180 | + bias_reset_.ApplyGradients(lrate, l2reg, rmsprop, gradient_clipping); |
| 181 | + |
| 182 | + *bias_update_.GetGradients() = update_g_.middleRows(1, steps - 1).colwise().mean(); |
| 183 | + bias_update_.ApplyGradients(lrate, l2reg, rmsprop, gradient_clipping); |
| 184 | + } |
| 185 | +} |
| 186 | + |
| 187 | + |
| 188 | +IRecUpdater* GRULayer::CreateUpdater() { |
| 189 | + return new Updater(this); |
| 190 | +} |
| 191 | + |
0 commit comments