Skip to content

Commit 52e1756

Browse files
author
Anton Bakhtin
committed
Add a few more command line options
New options: --gradint-clipping (alters GC threshold) --learn-recurrent (converts RNN to ESN) --learn-embeddings (could be useful for fine tuning)
1 parent b95a279 commit 52e1756

File tree

14 files changed

+160
-170
lines changed

14 files changed

+160
-170
lines changed

AUTHORS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
The following authors have created the source code of "Faster RNNLM"
22
published and distributed by YANDEX LLC as the owner:
33
Anton Bakhtin <[email protected]>
4-
Ilia Edrenkin <[email protected]>
4+
Ilya Edrenkin <[email protected]>

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Run `./build.sh` to download Eigen library and build faster-rnnlm.
1717

1818
To train a simple model with GRU hidden unit and Noise Contrastive Estimation, use the following command:
1919

20-
`./rnnlm -rnnlm model_name -train train.txt -valid validation.txt -hidden 128 -hidden-type gru -nce 20 -alpha 0.01 -rmsprop 0.9`
20+
`./rnnlm -rnnlm model_name -train train.txt -valid validation.txt -hidden 128 -hidden-type gru -nce 20 -alpha 0.01`
2121

2222
Files train.txt and test.txt must contain one sentence per line. All distinct words that are found in the training file will be used for the nnet vocab, their counts will determine Huffman tree structure and remain fixed for this nnet. If you prefer using limited vocabulary (say, top 1 million words) you should map all other words to <unk> or another token of your choice. Limited vocabulary is usually a good idea if it helps you to have enough training examples for each word.
2323

@@ -138,6 +138,12 @@ Optimization options
138138
--rmsprop <float>
139139
RMSprop coefficient; rmsprop=1 disables rmsprop and rmsprop=0 equivalent to RMS
140140
(default: 1)
141+
--gradient-clipping <float>
142+
Clip updates above the value (default: 1)
143+
--learn-recurrent (0 | 1)
144+
Learn hidden layer weights (default: 1)
145+
--learn-embeddings (0 | 1)
146+
Learn embedding weights (default: 1)
141147
--alpha <float>
142148
Learning rate for recurrent and embedding weights (default: 0.1)
143149
--maxent-alpha <float>

faster-rnnlm/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ CC = g++
44
CFLAGS = -Wall -march=native -funroll-loops -g -D__STDC_FORMAT_MACROS
55
CFLAGS += -I../ -DEIGEN_DONT_PARALLELIZE # for Eigen
66
CFLAGS += $(shell $(CC) -dumpversion | awk '{if(NR==1 && $$1>="4.6") print "-Ofast -Wno-unused-result"; else print "-O3";}')
7+
NVCC_CFLAGS = -O3 -march=native -funroll-loops
78
LDFLAGS = -lm -lrt
89
ifeq ($(NOTHREAD), 1)
910
CFLAGS += -DNOTHREAD
@@ -51,7 +52,7 @@ nnet.o : nnet.cc nnet.h maxent.h settings.h hierarchical_softmax.h words.h recur
5152
$(CC) $< -c -o $@ $(CFLAGS)
5253

5354
cuda_softmax.o : cuda_softmax.cu cuda_softmax.h settings.h
54-
nvcc $< -c -Xcompiler "-Ofast -march=native -funroll-loops" -o $@
55+
nvcc $< -c -Xcompiler "$(NVCC_CFLAGS)" -o $@
5556

5657

5758
clean:

faster-rnnlm/cuda_softmax.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ void AssertCudaSuccessLast(const char* message) {
2626
return AssertCudaSuccess(cudaGetLastError(), message);
2727
}
2828

29-
extern "C"
3029
void InitCudaStorage(CudaStorage* cust, size_t layer_size, size_t vocab_size, size_t maxent_hash_size, Real lnz) {
3130
cust->layer_size = layer_size;
3231
cust->vocab_size = vocab_size;
@@ -88,7 +87,6 @@ void InitCudaStorage(CudaStorage* cust, size_t layer_size, size_t vocab_size, si
8887
}
8988
}
9089

91-
extern "C"
9290
void FreeCudaStorage(CudaStorage* cust) {
9391
cudaFree(cust->sm_embedding);
9492
if (cust->maxent_hash_size != 0) {
@@ -115,7 +113,6 @@ void FreeCudaStorage(CudaStorage* cust) {
115113
delete cust->inner;
116114
}
117115

118-
extern "C"
119116
void UploadNetWeights(CudaStorage* cust, const Real* sm_embedding_cpu, const Real* maxent_cpu) {
120117
cudaMemcpy(cust->sm_embedding, sm_embedding_cpu, cust->layer_size * cust->vocab_size * sizeof(Real), cudaMemcpyHostToDevice);
121118
cudaMemcpy(cust->maxent, maxent_cpu, cust->maxent_hash_size * sizeof(Real), cudaMemcpyHostToDevice);
@@ -174,7 +171,6 @@ void CublasMultiply_A_BT(cublasHandle_t* handle, float beta, int rows_a, int row
174171
dev_c, rows_b);
175172
}
176173

177-
extern "C"
178174
void CalculateSoftMax(
179175
CudaStorage* cust, const Real* hidden_layers,
180176
const uint64_t* maxent_indices_all, const int* maxent_indices_count_all,

faster-rnnlm/cuda_softmax.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,12 @@ struct CudaStorage {
3333
CudaStorageInner* inner;
3434
};
3535

36-
extern "C"
3736
void InitCudaStorage(CudaStorage* cust, size_t layer_size, size_t vocab_size, size_t maxent_hash_size, Real lnz);
3837

39-
extern "C"
4038
void FreeCudaStorage(CudaStorage* cust);
4139

42-
extern "C"
4340
void UploadNetWeights(CudaStorage* cust, const Real* sm_embedding_cpu, const Real* maxent);
4441

45-
extern "C"
4642
void CalculateSoftMax(CudaStorage* cust, const Real* hidden_layers, const uint64_t* maxent_indices_all, const int* maxent_indices_count_all, size_t sentence_length, const WordIndex* sen, Real* logprobs);
4743

4844
#endif // FASTER_RNNLM_CUDA_SOFTMAX_H_

faster-rnnlm/hierarchical_softmax.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ inline void PropagateNodeForwardByDepth(
346346
inline void PropagateNodeBackward(
347347
HSTree* hs, WordIndex target_word, int depth,
348348
const uint64_t* feature_hashes, int maxent_order,
349-
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg,
349+
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg, Real gradient_clipping,
350350
const double* state,
351351
const Real* hidden,
352352
Real* hidden_grad, MaxEnt* maxent
@@ -357,33 +357,32 @@ inline void PropagateNodeBackward(
357357
Real branch_gradient[ARITY - 1];
358358
for (int branch = 0; branch < ARITY - 1; ++branch) {
359359
const int match = (branch == selected_branch);
360-
// gradient of softmax
361-
// g = branch_softmax_prob[selected_branch] * (match - branch_softmax_prob[branch]);
362360
// gradient of logsoftmax
363-
// g /= branch_softmax_prob[selected_branch];
364-
Real g = (match - (Real) state[branch]);
365-
// gradient clipping
366-
g = Clip(g, GRAD_CLIPPING);
367-
branch_gradient[branch] = g;
361+
branch_gradient[branch] = (match - static_cast<Real>(state[branch]));
368362
}
369363

370364
for (int branch = 0; branch < ARITY - 1; ++branch) {
371-
Real g = branch_gradient[branch];
365+
Real grad = branch_gradient[branch];
372366
int child_offset = hs->tree_->GetChildOffsetByDepth(target_word, depth, branch);
367+
Real* sm_embedding = hs->weights_.row(child_offset).data();
373368

374369
// Propagate errors output -> hidden
375370
for (int i = 0; i < hs->layer_size; ++i) {
376-
hidden_grad[i] += g * hs->weights_(child_offset, i);
371+
hidden_grad[i] += grad * sm_embedding[i];
377372
}
378373

379374
// Learn weights hidden -> output
380375
for (int i = 0; i < hs->layer_size; ++i) {
381-
Real update = g * lrate * hidden[i] - l2reg * hs->weights_(child_offset, i);
382-
hs->weights_(child_offset, i) += Clip(update, GRAD_CLIPPING);
376+
Real update = grad * hidden[i];
377+
sm_embedding[i] *= (1 - l2reg);
378+
sm_embedding[i] += lrate * Clip(update, gradient_clipping);
383379
}
380+
381+
// update maxent weights
382+
Real maxent_grad = Clip(grad, gradient_clipping);
384383
for (int order = 0; order < maxent_order; ++order) {
385384
uint64_t maxent_index = feature_hashes[order] + child_offset;
386-
maxent->UpdateValue(maxent_index, maxent_lrate, g, maxent_l2reg);
385+
maxent->UpdateValue(maxent_index, maxent_lrate, maxent_grad, maxent_l2reg);
387386
}
388387
}
389388
}
@@ -392,7 +391,7 @@ inline void PropagateNodeBackward(
392391
Real HSTree::PropagateForwardAndBackward(
393392
bool calculate_probability, WordIndex target_word,
394393
const uint64_t* feature_hashes, int maxent_order,
395-
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg,
394+
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg, Real gradient_clipping,
396395
const Real* hidden,
397396
Real* hidden_grad, MaxEnt* maxent
398397
) {
@@ -406,7 +405,7 @@ Real HSTree::PropagateForwardAndBackward(
406405

407406
PropagateNodeBackward(
408407
this, target_word, depth, feature_hashes, maxent_order,
409-
lrate, maxent_lrate, l2reg, maxent_l2reg,
408+
lrate, maxent_lrate, l2reg, maxent_l2reg, gradient_clipping,
410409
softmax_state,
411410
hidden, hidden_grad, maxent);
412411

faster-rnnlm/hierarchical_softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class HSTree {
5050
Real PropagateForwardAndBackward(
5151
bool calculate_probability, WordIndex target_word,
5252
const uint64_t* feature_hashes, int maxent_order,
53-
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg,
53+
Real lrate, Real maxent_lrate, Real l2reg, Real maxent_l2reg, Real gradient_clipping,
5454
const Real* hidden,
5555
Real* hidden_grad, MaxEnt* maxent);
5656

faster-rnnlm/nce.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ void NCE::Updater::PropagateForwardAndBackward(
9292
const Ref<const RowVector> hidden, WordIndex target_word,
9393
const uint64_t* maxent_indices, size_t maxent_size,
9494
const NoiseSample& sample, Real lrate, Real l2reg,
95-
Real maxent_lrate, Real maxent_l2reg,
95+
Real maxent_lrate, Real maxent_l2reg, Real gradient_clipping,
9696
Ref<RowVector> hidden_grad, MaxEnt* maxent) {
9797

9898
Real ln_sample_size = log(sample.size);
@@ -116,14 +116,15 @@ void NCE::Updater::PropagateForwardAndBackward(
116116

117117
// update softmax weights
118118
embedding_grad_.noalias() = grad * hidden;
119-
ClipMatrix(embedding_grad_);
119+
ClipMatrix(embedding_grad_, gradient_clipping);
120120
nce_->sm_embedding_.row(word) *= (1 - l2reg);
121121
nce_->sm_embedding_.row(word) += embedding_grad_ * lrate;
122122

123-
// update softmax weights
123+
// update maxent weights
124+
Real maxent_grad = Clip(grad, gradient_clipping);
124125
for (size_t i = 0; i < maxent_size; i++) {
125126
uint64_t maxent_index = get_maxent_index(maxent_indices[i], nce_->maxent_hash_size_, word);
126-
maxent->UpdateValue(maxent_index, maxent_lrate, grad, maxent_l2reg);
127+
maxent->UpdateValue(maxent_index, maxent_lrate, maxent_grad, maxent_l2reg);
127128
}
128129
}
129130
}

faster-rnnlm/nce.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class NCE {
4141
const Ref<const RowVector> hidden, WordIndex target_word,
4242
const uint64_t* maxent_indices, size_t maxent_size,
4343
const NoiseSample& sample, Real lrate, Real l2reg,
44-
Real maxent_lrate, Real maxent_l2reg,
44+
Real maxent_lrate, Real maxent_l2reg, Real gradient_clipping,
4545
Ref<RowVector> hidden_grad, MaxEnt* maxent);
4646

4747
private:

0 commit comments

Comments
 (0)