Skip to content

Commit 60e3d16

Browse files
committed
FIXED: clipping for numerical stability in SoftmaxLayer was not equivalent to that of PyBrain.
1 parent 97848fb commit 60e3d16

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

src/cpp/structure/modules/softmax.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ SoftmaxLayer::_forward()
1818
double* output_p = output()[timestep()];
1919
for(int i = 0; i < _insize; i++)
2020
{
21-
double item = exp(input_p[i]);
22-
item = item < -500 ? -500 : item;
23-
item = item > 500 ? 500 : item;
21+
// Clip of input argument if its to extreme to avoid NaNs and inf as a
22+
// result of exp().
23+
double inpt;
24+
inpt = input_p[i] < -500 ? -500 : input_p[i];
25+
inpt = inpt > 500 ? 500 : inpt;
26+
double item = exp(inpt);
27+
2428
sum += item;
2529
output_p[i] = item;
2630
}
27-
for(int i = 0; i < _outsize; i++)
31+
for(int i = 0; i < _insize; i++)
2832
{
2933
output_p[i] /= sum;
3034
}
@@ -38,4 +42,4 @@ SoftmaxLayer::_backward()
3842
void* sourcebuffer_p = (void*) outerror()[timestep() - 1];
3943
void* sinkbuffer_p = (void*) inerror()[timestep() - 1];
4044
memcpy(sinkbuffer_p, sourcebuffer_p, size);
41-
}
45+
}

src/cpp/tests/test_structure.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,25 +417,31 @@ TEST(TestModules, TanhLayer) {
417417

418418

419419
TEST(TestModules, SoftmaxLayer) {
420-
SoftmaxLayer* layer_p = new SoftmaxLayer(2);
420+
SoftmaxLayer* layer_p = new SoftmaxLayer(3);
421421

422422
double* input_p = new double[2];
423-
input_p[0] = 2.;
424-
input_p[1] = 4.;
423+
input_p[0] = 4.6296992222786457;
424+
input_p[1] = -0.36272901550781184;
425+
input_p[2] = 15.440919648395607;
425426

426427
layer_p->add_to_input(input_p);
427428

428-
ASSERT_DOUBLE_EQ(2, layer_p->input()[0][0])
429+
ASSERT_DOUBLE_EQ(input_p[0], layer_p->input()[0][0])
429430
<< "add_to_input not working.";
430-
ASSERT_DOUBLE_EQ(4, layer_p->input()[0][1])
431+
ASSERT_DOUBLE_EQ(input_p[1], layer_p->input()[0][1])
432+
<< "add_to_input not working.";
433+
ASSERT_DOUBLE_EQ(input_p[2], layer_p->input()[0][2])
431434
<< "add_to_input not working.";
432435

433436
layer_p->forward();
434437

435-
ASSERT_DOUBLE_EQ(0.11920292202211756, layer_p->output()[0][0])
438+
ASSERT_DOUBLE_EQ(2.0171481969464377e-05, layer_p->output()[0][0])
436439
<< "Forward pass incorrect.";
437440

438-
ASSERT_DOUBLE_EQ(0.88079707797788243, layer_p->output()[0][1])
441+
ASSERT_DOUBLE_EQ(1.3694739368803625e-07, layer_p->output()[0][1])
442+
<< "Forward pass incorrect.";
443+
444+
ASSERT_DOUBLE_EQ(0.99997969157063693, layer_p->output()[0][2])
439445
<< "Forward pass incorrect.";
440446

441447
double* outerror_p = new double[2];

0 commit comments

Comments
 (0)