Skip to content

Commit f0d03e9

Browse files
neatify
Signed-off-by: Nikolaj Bjorner <[email protected]>
1 parent 605a474 commit f0d03e9

File tree

2 files changed

+60
-63
lines changed

2 files changed

+60
-63
lines changed

src/smt/smt_parallel.cpp

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,18 @@ namespace smt {
8383
return r;
8484
}
8585

86-
std::pair<parallel::param_generator::param_values, bool> parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) {
86+
void parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) {
8787
unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
8888
param_values best_param_state;
89-
double best_score;
89+
double best_score = 0;
9090
bool found_better_params = false;
9191

92-
for (unsigned i = 0; i < N; ++i) {
92+
for (unsigned i = 0; i <= N; ++i) {
9393
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n");
9494

9595
// copy prefix solver context to a new probe_ctx for next replay with candidate mutation
96-
scoped_ptr<context> probe_ctx = alloc(context, m, ctx->get_fparams(), m_p);
96+
smt_params smtp(m_p);
97+
scoped_ptr<context> probe_ctx = alloc(context, m, smtp, m_p);
9798
context::copy(*ctx, *probe_ctx, true);
9899

99100
// apply a candidate (mutated) param state to probe_ctx
@@ -110,24 +111,37 @@ namespace smt {
110111

111112
// replay the cube (negation of the clause)
112113
for (expr_ref_vector const& cube : probe_ctx->m_recorded_cubes) {
113-
lbool r = probe_ctx->check(cube.size(), cube.data());
114-
114+
lbool r = probe_ctx->check(cube.size(), cube.data());
115115
unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;
116116
unsigned decisions = probe_ctx->m_stats.m_num_decisions;
117-
117+
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER " << i << ": cube replay result " << r <<
118+
", conflicts = " << conflicts << ", decisions = " << decisions << "\n");
118119
score += conflicts + decisions;
119120
}
120121

121-
if (i > 0 && score < best_score) {
122-
found_better_params = true;
123-
best_param_state = mutated_param_state;
122+
if (i == 0) {
124123
best_score = score;
125-
} else {
124+
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: baseline score = " << best_score << "\n");
125+
}
126+
else if (score < best_score) {
127+
found_better_params = true;
128+
best_param_state = mutated_param_state;
126129
best_score = score;
127130
}
128131
}
129-
130-
return {best_param_state, found_better_params};
132+
// NOTE: we either need to apply the best params found that are better than base line
133+
// or, we have to implement a comparison operator for param_values (what would this do?)
134+
// or, we update the param state every single time even if it hasn't changed (what would this do?)
135+
// for now, I went with option 1
136+
if (found_better_params) {
137+
m_param_state = best_param_state;
138+
auto p = apply_param_values(m_param_state);
139+
b.set_param_state(p);
140+
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state\n");
141+
}
142+
else {
143+
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
144+
}
131145
}
132146

133147
void parallel::param_generator::init_param_state() {
@@ -144,9 +158,23 @@ namespace smt {
144158
m_param_state.push_back(
145159
{symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()});
146160
m_param_state.push_back({symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()});
147-
148161
};
149162

163+
params_ref parallel::param_generator::apply_param_values(param_values const &pv) {
164+
params_ref p = m_p;
165+
for (auto const &[k, v] : pv) {
166+
if (std::holds_alternative<unsigned_value>(v)) {
167+
unsigned_value uv = std::get<unsigned_value>(v);
168+
p.set_uint(k, uv.value);
169+
}
170+
else if (std::holds_alternative<bool>(v)) {
171+
bool bv = std::get<bool>(v);
172+
p.set_bool(k, bv);
173+
}
174+
}
175+
return p;
176+
}
177+
150178
parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
151179
param_values new_param_values(m_param_state);
152180
unsigned index = ctx->get_random_value() % new_param_values.size();
@@ -161,7 +189,7 @@ namespace smt {
161189
while (new_value == value) {
162190
new_value = lo + ctx->get_random_value() % (hi - lo + 1);
163191
}
164-
std::get<unsigned_value>(param.second).value = new_value;
192+
std::get_if<unsigned_value>(&param.second)->value = new_value;
165193
}
166194
return new_param_values;
167195
}
@@ -174,20 +202,7 @@ namespace smt {
174202

175203
switch (r) {
176204
case l_undef: {
177-
auto [best_param_state, found_better_params] = replay_proof_prefixes();
178-
179-
// NOTE: we either need to return a pair from replay_proof_prefixes so we can return a boolean flag indicating whether better params were found.
180-
// or, we have to implement a comparison operator for param_values
181-
// or, we update the param state every single time even if it hasn't changed
182-
// for now, I went with option 1
183-
if (found_better_params) {
184-
m_param_state = best_param_state;
185-
auto p = apply_param_values(m_param_state);
186-
b.set_param_state(p);
187-
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state\n");
188-
} else {
189-
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
190-
}
205+
replay_proof_prefixes();
191206
}
192207
case l_true: {
193208
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found formula sat\n");
@@ -222,9 +237,8 @@ namespace smt {
222237
LOG_WORKER(1, " CUBE SIZE IN MAIN LOOP: " << cube.size() << "\n");
223238

224239
// apply current best param state from batch manager
225-
smt_params params = b.get_best_param_state();
226240
params_ref p;
227-
params.updt_params(p);
241+
b.get_param_state(p);
228242
ctx->updt_params(p);
229243

230244
lbool r = check_cube(cube);
@@ -454,10 +468,9 @@ namespace smt {
454468
}
455469
}
456470

457-
// todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure.
458-
params_ref parallel::batch_manager::get_best_param_state() {
471+
void parallel::batch_manager::get_param_state(params_ref& p) {
459472
std::scoped_lock lock(mux);
460-
return m_param_state;
473+
p.copy(m_param_state);
461474
}
462475

463476
void parallel::worker::collect_shared_clauses() {

src/smt/smt_parallel.h

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ namespace smt {
8989
void set_exception(std::string const& msg);
9090
void set_exception(unsigned error_code);
9191
void set_param_state(params_ref const& p) { m_param_state.copy(p); }
92-
void collect_statistics(::statistics& st) const;
93-
94-
params_ref get_best_param_state();
92+
void get_param_state(params_ref &p);
93+
void collect_statistics(::statistics& st) const;
94+
9595
bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n);
9696
void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n);
9797
void split(ast_translation& l2g, unsigned id, node* n, expr* atom);
@@ -110,6 +110,14 @@ namespace smt {
110110
// 4. update current configuration with the winner
111111

112112
class param_generator {
113+
struct unsigned_value {
114+
unsigned value;
115+
unsigned min_value;
116+
unsigned max_value;
117+
};
118+
using param_value = std::variant<unsigned_value, bool>;
119+
using param_values = vector<std::pair<symbol, param_value>>;
120+
113121
parallel &p;
114122
batch_manager &b;
115123
ast_manager m;
@@ -120,42 +128,18 @@ namespace smt {
120128
unsigned m_max_prefix_conflicts = 1000;
121129

122130
scoped_ptr<context> m_prefix_solver;
123-
scoped_ptr_vector<context> m_param_probe_contexts;
124131
params_ref m_p;
125-
126-
struct unsigned_value {
127-
unsigned value;
128-
unsigned min_value;
129-
unsigned max_value;
130-
};
131-
using param_value = std::variant<unsigned_value, bool>;
132-
using param_values = vector<std::pair<symbol, param_value>>;
133132
param_values m_param_state;
134133

135-
params_ref apply_param_values(param_values const &pv) {
136-
params_ref p = m_p;
137-
for (auto const& [k, v] : pv) {
138-
if (std::holds_alternative<unsigned_value>(v)) {
139-
unsigned_value uv = std::get<unsigned_value>(v);
140-
p.set_uint(k, uv.value);
141-
} else if (std::holds_alternative<bool>(v)) {
142-
bool bv = std::get<bool>(v);
143-
p.set_bool(k, bv);
144-
}
145-
}
146-
return p;
147-
}
148-
149-
private:
134+
params_ref apply_param_values(param_values const &pv);
150135
void init_param_state();
151-
152136
param_values mutate_param_state();
153137

154138
public:
155139
param_generator(parallel &p);
156140
lbool run_prefix_step();
157141
void protocol_iteration();
158-
std::pair<parallel::param_generator::param_values, bool> replay_proof_prefixes(unsigned max_conflicts_epsilon);
142+
void replay_proof_prefixes(unsigned max_conflicts_epsilon);
159143

160144
reslimit &limit() {
161145
return m.limit();

0 commit comments

Comments
 (0)