Skip to content

Commit 2967e6e

Browse files
authored
Merge pull request #298 from gAldeia/bug_fixes
Bug fixes
2 parents a40a8a4 + 4b11485 commit 2967e6e

File tree

11 files changed

+86
-16
lines changed

11 files changed

+86
-16
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
# cache-env: true
2424
-
2525
name: add docs environment dependencies
26-
uses: mamba-org/provision-with-micromamba@main
26+
uses: mamba-org/setup-micromamba@v1
2727
with:
2828
environment-file: environment.yml
2929
cache-env: true

ci/ci-environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies:
2222
- sphinx-material
2323
- recommonmark
2424
- nbsphinx
25+
- lxml_html_clean
2526
- matplotlib
2627
- jupyter
2728
- seaborn

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
'sphinx_math_dollar',
5757
# 'recommonmark',
5858
'nbsphinx',
59+
'lxml_html_clean',
5960
# "sphinx.ext.viewcode",
6061
# External stuff
6162
]

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ myst-parser
1010
nbsphinx
1111
sphinx-material
1212
sphinx-math-dollar
13+
lxml_html_clean

feat/feat.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__(self,
184184
softmax_norm=False,
185185
save_pop=0,
186186
normalize=True,
187-
val_from_arch=True,
187+
val_from_arch=False,
188188
corr_delete_mutate=False,
189189
simplify=0.0,
190190
protected_groups="",
@@ -316,8 +316,8 @@ def predict_archive(self,X,Z=None,front=False):
316316
archive = self.cfeat_.get_archive(front)
317317
preds = []
318318
for ind in archive:
319-
if ind['id'] == 9234:
320-
print('individual:',json.dumps(ind,indent=2))
319+
# if ind['id'] == 9234:
320+
# print('individual:',json.dumps(ind,indent=2))
321321
tmp = {}
322322
tmp['id'] = ind['id']
323323
tmp['y_pred'] = self.cfeat_.predict_archive(ind['id'], X)
@@ -399,6 +399,7 @@ def get_representation(self): return self.cfeat_.get_representation()
399399
def get_model(self, sort=True): return self.cfeat_.get_model(sort)
400400
def get_coefs(self): return self.cfeat_.get_coefs()
401401
def get_n_params(self): return self.cfeat_.get_n_params()
402+
def get_complexity(self): return self.cfeat_.get_complexity()
402403
def get_dim(self): return self.cfeat_.get_dim()
403404
def get_n_nodes(self): return self.cfeat_.get_n_nodes()
404405

@@ -432,7 +433,7 @@ def fit(self,X,y,zfile=None,zids=None):
432433
])):
433434
raise ValueError('y must be a contiguous set of labels from ',
434435
'0 to n_classes. y contains the values {}'.format(
435-
np.unique(np.asarray(y)))
436+
self.classes_)
436437
)
437438

438439
super().fit(X,y)

src/eval/metrics.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ namespace FT
362362

363363
return loss;
364364
}
365+
365366
/// 1 - balanced accuracy
366367
float bal_zero_one_loss(const VectorXf& y, const VectorXf& yhat,
367368
VectorXf& loss, const vector<float>& class_weights)
@@ -406,6 +407,7 @@ namespace FT
406407
// set loss vectors if third argument supplied
407408
loss = (yhat.cast<int>().array() != y.cast<int>().array()).cast<float>();
408409

410+
// 1 - accuracy (so it becomes a minimization problem)
409411
return 1.0 - class_accuracies.mean();
410412
}
411413

@@ -435,7 +437,11 @@ namespace FT
435437
float zero_one_loss(const VectorXf& y, const VectorXf& yhat, VectorXf& loss,
436438
const vector<float>& class_weights)
437439
{
440+
// Feat's update_best and sel/surv steps always handles scores as
441+
// minimization problems, so we need to invert the loss here. That's
442+
// why we account for mismatches instead of correct classifications:
438443
loss = (yhat.cast<int>().array() != y.cast<int>().array()).cast<float>();
444+
439445
//TODO: weight loss by sample weights
440446
return loss.mean();
441447
}

src/feat.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,12 @@ int Feat::get_n_params(){ return best_ind.get_n_params(); }
567567
int Feat::get_dim(){ return best_ind.get_dim(); }
568568

569569
///get dimensionality of best
570-
int Feat::get_complexity(){ return best_ind.get_complexity(); }
571-
570+
int Feat::get_complexity(){
571+
// Making sure it is calculated before returning it
572+
if (best_ind.get_complexity()==0)
573+
best_ind.set_complexity();
574+
return best_ind.get_complexity();
575+
}
572576

573577
/// return the number of nodes in the best model
574578
int Feat::get_n_nodes(){ return best_ind.program.size(); }
@@ -707,15 +711,14 @@ void Feat::run_generation(unsigned int g,
707711
pop.update(survivors);
708712
logger.log("survivors:\n" + pop.print_eqns(), 3);
709713

714+
// we need to update best, so min_loss_v is updated inside stats
710715
logger.log("update best...",2);
711716
bool updated_best = update_best(d);
712717

713-
logger.log("calculate stats...",2);
714-
calculate_stats(d);
715-
716718
if (params.max_stall > 0)
717719
update_stall_count(stall_count, updated_best);
718720

721+
logger.log("update objectives...",2);
719722
if ( (use_arch || params.verbosity>1) || !logfile.empty()) {
720723
// set objectives to make sure they are reported in log/verbose/arch
721724
#pragma omp parallel for
@@ -727,6 +730,9 @@ void Feat::run_generation(unsigned int g,
727730
if (use_arch)
728731
archive.update(pop,params);
729732

733+
logger.log("calculate stats...",2);
734+
calculate_stats(d);
735+
730736
if(params.verbosity>1)
731737
print_stats(log, fraction);
732738
else if(params.verbosity == 1)
@@ -1293,7 +1299,7 @@ ArrayXXf Feat::predict_proba(MatrixXf& X)
12931299
}
12941300

12951301

1296-
bool Feat::update_best(const DataRef& d, bool validation)
1302+
bool Feat::update_best(const DataRef& d, bool val)
12971303
{
12981304
float bs;
12991305
bs = this->min_loss_v;
@@ -1463,7 +1469,7 @@ void Feat::print_stats(std::ofstream& log, float fraction)
14631469
<< stats.min_loss.back() << " ("
14641470
<< stats.med_loss.back() << ")\n"
14651471
<< "Val Loss (Med): "
1466-
<< this->min_loss_v << " (" << stats.med_loss_v.back() << ")\n"
1472+
<< stats.min_loss_v.back() << " (" << stats.med_loss_v.back() << ")\n"
14671473
<< "Median Size (Max): "
14681474
<< stats.med_size.back() << " (" << max_size << ")\n"
14691475
<< "Time (s): " << timer << "\n";
@@ -1553,7 +1559,7 @@ void Feat::log_stats(std::ofstream& log)
15531559
log << params.current_gen << sep
15541560
<< timer.Elapsed().count() << sep
15551561
<< stats.min_loss.back() << sep
1556-
<< this->min_loss_v << sep
1562+
<< stats.min_loss_v.back() << sep
15571563
<< stats.med_loss.back() << sep
15581564
<< stats.med_loss_v.back() << sep
15591565
<< stats.med_size.back() << sep

src/feat.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Feat
9393
// string logfile="", int max_time=-1, bool residual_xo = false,
9494
// bool stagewise_xo = false, bool stagewise_tol = true,
9595
// bool softmax_norm=false, int save_pop=0, bool normalize=true,
96-
// bool val_from_arch=true, bool corr_delete_mutate=false,
96+
// bool val_from_arch=false, bool corr_delete_mutate=false,
9797
// float simplify=0.0, string protected_groups="",
9898
// bool tune_initial=false, bool tune_final=true,
9999
// string starting_pop="");
@@ -325,7 +325,7 @@ class Feat
325325
int get_n_params();
326326
///get dimensionality of best
327327
int get_dim();
328-
///get dimensionality of best
328+
///get complexity of best
329329
int get_complexity();
330330
///return population as string
331331
vector<nl::json> get_archive(bool front);

src/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ PYBIND11_MODULE(_feat, m)
150150
.def("load", &Feat::load)
151151
.def("get_representation", &Feat::get_representation)
152152
.def("get_n_params", &Feat::get_n_params)
153+
.def("get_complexity", &Feat::get_complexity)
153154
.def("get_dim", &Feat::get_dim)
154155
.def("get_n_nodes", &Feat::get_n_nodes)
155156
.def("get_model", &Feat::get_model, py::arg("sort") = true)

tests/evaluationTests.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,58 @@ TEST(Evaluation, mse)
5353
ASSERT_TRUE(score == 28.5);
5454
}
5555

56+
TEST(Evaluation, accuracy)
57+
{
58+
// test zero one loss
59+
60+
Feat ft = make_estimator(100, 10, "LinearRidgeRegression", false, 1, 666);
61+
62+
VectorXf yhat(10), y(10), res(10), loss(10);
63+
64+
y << 0.0,
65+
1.0,
66+
0.0,
67+
0.0,
68+
1.0,
69+
0.0,
70+
0.0,
71+
1.0,
72+
0.0,
73+
0.0;
74+
75+
yhat << 0.0,
76+
1.0,
77+
1.0,
78+
0.0,
79+
0.0,
80+
1.0,
81+
1.0,
82+
0.0,
83+
0.0,
84+
0.0;
85+
86+
res << 0.0,
87+
0.0,
88+
1.0,
89+
0.0,
90+
1.0,
91+
1.0,
92+
1.0,
93+
1.0,
94+
0.0,
95+
0.0;
96+
97+
float score = zero_one_loss(y, yhat, loss, ft.params.class_weights);
98+
99+
if (loss != res)
100+
{
101+
std::cout << "loss:" << loss.transpose() << "\n";
102+
std::cout << "res:" << res.transpose() << "\n";
103+
}
104+
ASSERT_TRUE(loss == res);
105+
ASSERT_EQ(((int)(score*1000000)), 500000);
106+
}
107+
56108
TEST(Evaluation, bal_accuracy)
57109
{
58110
// test balanced zero one loss

0 commit comments

Comments
 (0)