Skip to content

Commit 290f719

Browse files
Add PredictionLoggerEvaluator
1 parent 73cb667 commit 290f719

16 files changed

+231
-249
lines changed

moa/src/main/java/moa/evaluation/ClassificationPerformanceEvaluator.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
package moa.evaluation;
2121

2222
import com.yahoo.labs.samoa.instances.Instance;
23-
import moa.MOAObject;
2423
import moa.core.Example;
25-
import moa.core.Measurement;
2624

2725
public interface ClassificationPerformanceEvaluator extends LearningPerformanceEvaluator<Example<Instance>> {
28-
2926
}

moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
* @author Richard Kirkby ([email protected])
3636
* @version $Revision: 7 $
3737
*/
38-
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler {
38+
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject, CapabilitiesHandler, AutoCloseable {
3939

4040
/**
4141
* Resets this evaluator. It must be similar to
@@ -66,4 +66,8 @@ default ImmutableCapabilities defineImmutableCapabilities() {
6666
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
6767
}
6868

69+
@Override
70+
default void close() throws Exception {
71+
// By default an evaluator does nothing when closed.
72+
}
6973
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package moa.evaluation;
2+
3+
import java.io.BufferedOutputStream;
4+
import java.io.File;
5+
import java.io.FileOutputStream;
6+
import java.io.IOException;
7+
import java.io.OutputStreamWriter;
8+
import java.util.Arrays;
9+
import java.util.zip.GZIPOutputStream;
10+
11+
import com.github.javacliparser.FileOption;
12+
import com.github.javacliparser.FlagOption;
13+
import com.yahoo.labs.samoa.instances.Instance;
14+
import com.yahoo.labs.samoa.instances.Prediction;
15+
16+
import moa.capabilities.Capability;
17+
import moa.capabilities.ImmutableCapabilities;
18+
import moa.core.Example;
19+
import moa.core.Measurement;
20+
import moa.core.ObjectRepository;
21+
import moa.core.Utils;
22+
import moa.options.AbstractOptionHandler;
23+
import moa.options.ClassOption;
24+
import moa.tasks.TaskMonitor;
25+
26+
public class PredictionLoggerEvaluator extends AbstractOptionHandler
27+
implements ClassificationPerformanceEvaluator {
28+
29+
private static final long serialVersionUID = 1L;
30+
31+
private OutputStreamWriter writer;
32+
private int index = 0;
33+
34+
public FileOption csvFileOption = new FileOption("predictionLog", 'o',
35+
"A file to write comma separated values to.", null, "csv.gzip", true);
36+
37+
public FlagOption overwrite = new FlagOption("overwrite", 'f', "Overwrite existing file.");
38+
39+
public ClassOption wrappedEvaluatorOption = new ClassOption("evaluator", 'e',
40+
"Classification performance evaluation method.", ClassificationPerformanceEvaluator.class,
41+
"BasicClassificationPerformanceEvaluator");
42+
43+
public FlagOption probabilities = new FlagOption("probabilities", 'p',
44+
"Log probabilities instead of raw predictions.");
45+
46+
public FlagOption uncompressed = new FlagOption("uncompressed", 'u',
47+
"The output file should be saved uncompressed.");
48+
49+
private ClassificationPerformanceEvaluator wrappedEvaluator;
50+
51+
@Override
52+
public String getPurposeString() {
53+
return "Log raw predictions and probabilities to a CSV file, and evaluate using a wrapped evaluator.";
54+
}
55+
56+
@Override
57+
public void addResult(Example<Instance> example, double[] classVotes) {
58+
Instance instance = example.getData();
59+
int predictedClass = Utils.maxIndex(classVotes);
60+
double normalizingFactor = Arrays.stream(classVotes).sum();
61+
int numClasses = instance.numClasses();
62+
63+
if (normalizingFactor == 0) {
64+
normalizingFactor = 1;
65+
}
66+
try {
67+
// If this is the first result, write the header to the top of the file
68+
if (index == 0)
69+
writeHeader(numClasses);
70+
71+
72+
// Add row to CSV file
73+
if (instance.classIsMissing() == true)
74+
{
75+
writer.write(String.format("?,%d,", predictedClass));
76+
}
77+
else
78+
{
79+
int trueClass = (int) instance.classValue();
80+
writer.write(String.format("%d,%d,", trueClass, predictedClass));
81+
}
82+
83+
if (probabilities.isSet()) {
84+
for (int i = 0; i < numClasses; i++) {
85+
double probability = 0.0;
86+
if (i < classVotes.length){
87+
probability = classVotes[i] / normalizingFactor;
88+
}
89+
writer.write(String.format("%.2f,", probability));
90+
}
91+
}
92+
93+
writer.write("\n");
94+
} catch (Exception e) {
95+
throw new RuntimeException(e);
96+
}
97+
98+
// Pass result to wrapped evaluator
99+
wrappedEvaluator.addResult(example, classVotes);
100+
index ++;
101+
}
102+
103+
@Override
104+
public void addResult(Example<Instance> testInst, Prediction prediction) {
105+
// addResult(testInst, prediction.getVotes());
106+
throw new RuntimeException("Not implemented");
107+
}
108+
109+
@Override
110+
protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
111+
wrappedEvaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(wrappedEvaluatorOption);
112+
try {
113+
File file = csvFileOption.getFile();
114+
if (file.exists() && !overwrite.isSet()) {
115+
throw new RuntimeException(
116+
"File already exists: " + file.getAbsolutePath()
117+
+ ". MOA doesn't want to overwrite it.");
118+
}
119+
if (uncompressed.isSet())
120+
writer = new OutputStreamWriter(new FileOutputStream(file));
121+
else
122+
writer = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(file)));
123+
} catch (Exception e) {
124+
throw new RuntimeException(e);
125+
}
126+
}
127+
128+
private void writeHeader(int numClasses) throws IOException {
129+
writer.write("true_class,class_prediction,");
130+
if (probabilities.isSet()) {
131+
for (int i = 0; i < numClasses; i++) {
132+
writer.write(String.format("class_probability_%d,", i));
133+
}
134+
}
135+
writer.write("\n");
136+
}
137+
138+
@Override
139+
public void close() throws Exception {
140+
writer.close();
141+
}
142+
143+
@Override
144+
public void reset() {
145+
wrappedEvaluator.reset();
146+
}
147+
148+
@Override
149+
public Measurement[] getPerformanceMeasurements() {
150+
return wrappedEvaluator.getPerformanceMeasurements();
151+
}
152+
153+
@Override
154+
public void getDescription(StringBuilder sb, int indent) {
155+
sb.append(getPurposeString());
156+
}
157+
158+
@Override
159+
public ImmutableCapabilities defineImmutableCapabilities() {
160+
return new ImmutableCapabilities(Capability.VIEW_STANDARD);
161+
}
162+
}

moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
287287
if (immediateResultStream != null) {
288288
immediateResultStream.close();
289289
}
290+
try {
291+
evaluator.close();
292+
} catch (Exception ex) {
293+
throw new RuntimeException("Exception closing evaluator", ex);
294+
}
290295
return learningCurve;
291296
}
292297

moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import moa.capabilities.Capability;
2727
import moa.capabilities.ImmutableCapabilities;
28-
import moa.classifiers.Classifier;
2928
import moa.classifiers.MultiClassClassifier;
3029
import moa.core.Example;
3130
import moa.core.Measurement;
@@ -40,7 +39,6 @@
4039
import com.github.javacliparser.IntOption;
4140
import moa.streams.ExampleStream;
4241
import moa.streams.InstanceStream;
43-
import com.yahoo.labs.samoa.instances.Instance;
4442

4543
/**
4644
* Task for evaluating a classifier on a stream by testing then training with each example in sequence.
@@ -217,6 +215,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
217215
if (immediateResultStream != null) {
218216
immediateResultStream.close();
219217
}
218+
try {
219+
evaluator.close();
220+
} catch (Exception ex) {
221+
throw new RuntimeException("Exception closing evaluator", ex);
222+
}
220223
return learningCurve;
221224
}
222225

moa/src/main/java/moa/tasks/EvaluateModel.java

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
*/
2020
package moa.tasks;
2121

22-
import java.io.File;
23-
import java.io.FileOutputStream;
24-
import java.io.PrintStream;
2522
import com.github.javacliparser.FileOption;
2623
import com.github.javacliparser.IntOption;
2724
import moa.capabilities.CapabilitiesHandler;
@@ -32,15 +29,13 @@
3229
import moa.core.Example;
3330
import moa.core.Measurement;
3431
import moa.core.ObjectRepository;
35-
import moa.core.Utils;
3632
import moa.evaluation.LearningEvaluation;
3733
import moa.evaluation.LearningPerformanceEvaluator;
3834
import moa.evaluation.preview.LearningCurve;
3935
import moa.learners.Learner;
4036
import moa.options.ClassOption;
4137
import moa.streams.ExampleStream;
4238
import moa.streams.InstanceStream;
43-
import com.yahoo.labs.samoa.instances.Instance;
4439

4540
/**
4641
* Task for evaluating a static model on a stream.
@@ -107,35 +102,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
107102
long instancesProcessed = 0;
108103
monitor.setCurrentActivity("Evaluating model...", -1.0);
109104

110-
//File for output predictions
111-
File outputPredictionFile = this.outputPredictionFileOption.getFile();
112-
PrintStream outputPredictionResultStream = null;
113-
if (outputPredictionFile != null) {
114-
try {
115-
if (outputPredictionFile.exists()) {
116-
outputPredictionResultStream = new PrintStream(
117-
new FileOutputStream(outputPredictionFile, true), true);
118-
} else {
119-
outputPredictionResultStream = new PrintStream(
120-
new FileOutputStream(outputPredictionFile), true);
121-
}
122-
} catch (Exception ex) {
123-
throw new RuntimeException(
124-
"Unable to open prediction result file: " + outputPredictionFile, ex);
125-
}
126-
}
127105
while (stream.hasMoreInstances()
128106
&& ((maxInstances < 0) || (instancesProcessed < maxInstances))) {
129107
Example testInst = (Example) stream.nextInstance();//.copy();
130-
int trueClass = (int) ((Instance) testInst.getData()).classValue();
131-
//testInst.setClassMissing();
132108
double[] prediction = model.getVotesForInstance(testInst);
133-
//evaluator.addClassificationAttempt(trueClass, prediction, testInst
134-
// .weight());
135-
if (outputPredictionFile != null) {
136-
outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," +(
137-
((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass));
138-
}
139109
evaluator.addResult(testInst, prediction);
140110
instancesProcessed++;
141111

@@ -169,8 +139,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
169139
}
170140
}
171141
}
172-
if (outputPredictionResultStream != null) {
173-
outputPredictionResultStream.close();
142+
try {
143+
evaluator.close();
144+
} catch (Exception ex) {
145+
throw new RuntimeException("Exception closing evaluator", ex);
174146
}
175147
return learningCurve;
176148
}

moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import com.github.javacliparser.IntOption;
3131
import moa.capabilities.Capability;
3232
import moa.capabilities.ImmutableCapabilities;
33-
import moa.classifiers.Classifier;
3433
import moa.classifiers.MultiClassClassifier;
3534
import moa.core.Example;
3635
import moa.core.Measurement;
@@ -140,12 +139,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
140139
}
141140
testStream = new CachedInstancesStream(testInstances);
142141
} else {
143-
//testStream = (InstanceStream) stream.copy();
144142
testStream = stream;
145-
/*monitor.setCurrentActivity("Skipping test examples...", -1.0);
146-
for (int i = 0; i < testSize; i++) {
147-
stream.nextInstance();
148-
}*/
149143
}
150144
instancesProcessed = 0;
151145
TimingUtils.enablePreciseTiming();
@@ -191,10 +185,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
191185
break;
192186
}
193187
Example testInst = (Example) testStream.nextInstance(); //.copy();
194-
double trueClass = ((Instance) testInst.getData()).classValue();
195-
//testInst.setClassMissing();
196188
double[] prediction = learner.getVotesForInstance(testInst);
197-
//testInst.setClassValue(trueClass);
198189
evaluator.addResult(testInst, prediction);
199190
testInstancesProcessed++;
200191
if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
@@ -242,49 +233,15 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
242233
if (monitor.resultPreviewRequested()) {
243234
monitor.setLatestResultPreview(learningCurve.copy());
244235
}
245-
// if (learner instanceof HoeffdingTree
246-
// || learner instanceof HoeffdingOptionTree) {
247-
// int numActiveNodes = (int) Measurement.getMeasurementNamed(
248-
// "active learning leaves",
249-
// modelMeasurements).getValue();
250-
// // exit if tree frozen
251-
// if (numActiveNodes < 1) {
252-
// break;
253-
// }
254-
// int numNodes = (int) Measurement.getMeasurementNamed(
255-
// "tree size (nodes)", modelMeasurements)
256-
// .getValue();
257-
// if (numNodes == lastNumNodes) {
258-
// noGrowthCount++;
259-
// } else {
260-
// noGrowthCount = 0;
261-
// }
262-
// lastNumNodes = numNodes;
263-
// } else if (learner instanceof OzaBoost || learner instanceof
264-
// OzaBag) {
265-
// double numActiveNodes = Measurement.getMeasurementNamed(
266-
// "[avg] active learning leaves",
267-
// modelMeasurements).getValue();
268-
// // exit if all trees frozen
269-
// if (numActiveNodes == 0.0) {
270-
// break;
271-
// }
272-
// int numNodes = (int) (Measurement.getMeasurementNamed(
273-
// "[avg] tree size (nodes)",
274-
// learner.getModelMeasurements()).getValue() * Measurement
275-
// .getMeasurementNamed("ensemble size",
276-
// modelMeasurements).getValue());
277-
// if (numNodes == lastNumNodes) {
278-
// noGrowthCount++;
279-
// } else {
280-
// noGrowthCount = 0;
281-
// }
282-
// lastNumNodes = numNodes;
283-
// }
284236
}
285237
if (immediateResultStream != null) {
286238
immediateResultStream.close();
287239
}
240+
try {
241+
evaluator.close();
242+
} catch (Exception ex) {
243+
throw new RuntimeException("Exception closing evaluator", ex);
244+
}
288245
return learningCurve;
289246
}
290247

0 commit comments

Comments
 (0)