-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtestDist.cpp
156 lines (145 loc) · 3.62 KB
/
testDist.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include "CDist.h"
using namespace std;
int testType(const string distType);
int testDist(CDist* dist, CDist* dist2, const string fileName);
int main()
{
int fail=0;
try
{
fail += testType("gaussian");
fail += testType("gamma");
fail += testType("wang");
cout << "Number of failures: " << fail << "." << endl;
}
catch(ndlexceptions::FileFormatError err)
{
cerr << err.getMessage();
exit(1);
}
catch(ndlexceptions::FileReadError err)
{
cerr << err.getMessage();
exit(1);
}
catch(ndlexceptions::FileWriteError err)
{
cerr << err.getMessage();
exit(1);
}
catch(ndlexceptions::FileError err)
{
cerr << err.getMessage();
exit(1);
}
catch(ndlexceptions::Error err)
{
cerr << err.getMessage();
exit(1);
}
catch(std::bad_alloc err)
{
cerr << "Out of memory.";
exit(1);
}
catch(std::exception err)
{
cerr << "Unhandled exception.";
exit(1);
}
}
int testType(const string distType)
{
string fileName = "matfiles" + ndlstrutil::dirSep() + distType + "DistTest.mat";
CMatrix X;
X.readMatlabFile(fileName, "X");
CDist* dist;
CDist* dist2;
if(distType=="gaussian")
{
dist = new CGaussianDist();
dist2 = new CGaussianDist();
}
if(distType=="gamma")
{
dist = new CGammaDist();
dist2 = new CGammaDist();
}
else if(distType=="wang")
{
dist = new CWangDist();
dist2 = new CWangDist();
}
int fail = testDist(dist, dist2, fileName);
delete dist;
delete dist2;
return fail;
}
int testDist(CDist* dist, CDist* dist2, const string fileName)
{
int fail = 0;
CMatrix params;
params.readMatlabFile(fileName, "params");
CMatrix X;
X.readMatlabFile(fileName, "X");
CMatrix g;
g.readMatlabFile(fileName, "g");
CMatrix ll;
ll.readMatlabFile(fileName, "ll");
dist->setTransParams(params);
dist2->readMatlabFile(fileName, "dist");
if(dist2->equals(*dist))
cout << dist->getName() << " Initial Dist matches." << endl;
else
{
cout << "FAILURE: " << dist->getName() << " Initial Dist." << endl;
fail++;
}
CMatrix g1(X.getRows(), X.getCols());
dist->getGradInputs(g1, X);
if(g1.equals(g))
cout << dist->getName() << " parameter gradient matches." << endl;
else
{
cout << "FAILURE: " << dist->getName() << " parameter gradient." << endl;
cout << "Matlab gradient: " << endl;
cout << g << endl;
cout << "C++ gradient: " << endl;
cout << g1 << endl;
fail++;
}
CMatrix ll2(1, 1);
ll2.setVal(dist->logProb(X), 0);
if(ll2.equals(ll))
cout << dist->getName() << " log likelihood matches." << endl;
else
{
cout << "FAILURE: " << dist->getName() << " log likelihood." << endl;
cout << "Matlab log likelihood: " << endl;
cout << ll << endl;
cout << "C++ log likelihood: " << endl;
cout << ll2 << endl;
fail++;
}
// Matlab read and Read and write test
dist->writeMatlabFile("crap.mat", "writtenDist");
dist2->readMatlabFile("crap.mat", "writtenDist");
if(dist->equals(*dist2))
cout << "MATLAB written " << dist->getName() << " matches read in dist. Read and write to matlab passes." << endl;
else
{
cout << "FAILURE: MATLAB read in " << dist->getName() << " does not match written out dist." << endl;
fail++;
}
// Matlab read and Read and write test
dist->toFile("crap_dist");
dist2->fromFile("crap_dist");
if(dist->equals(*dist2))
cout << "Text written " << dist->getName() << " matches read in dist. Read and write to text passes." << endl;
else
{
cout << "FAILURE: Text read in " << dist->getName() << " does not match written out dist." << endl;
fail++;
}
return fail;
}