@@ -83,17 +83,18 @@ namespace smt {
8383 return r;
8484 }
8585
86- std::pair<parallel::param_generator::param_values, bool > parallel::param_generator::replay_proof_prefixes (unsigned max_conflicts_epsilon=200 ) {
86+ void parallel::param_generator::replay_proof_prefixes (unsigned max_conflicts_epsilon=200 ) {
8787 unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
8888 param_values best_param_state;
89- double best_score;
89+ double best_score = 0 ;
9090 bool found_better_params = false ;
9191
92- for (unsigned i = 0 ; i < N; ++i) {
92+ for (unsigned i = 0 ; i <= N; ++i) {
9393 IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER: replaying proof prefix in param probe context " << i << " \n " );
9494
9595 // copy prefix solver context to a new probe_ctx for next replay with candidate mutation
96- scoped_ptr<context> probe_ctx = alloc (context, m, ctx->get_fparams (), m_p);
96+ smt_params smtp (m_p);
97+ scoped_ptr<context> probe_ctx = alloc (context, m, smtp, m_p);
9798 context::copy (*ctx, *probe_ctx, true );
9899
99100 // apply a candidate (mutated) param state to probe_ctx
@@ -110,24 +111,37 @@ namespace smt {
110111
111112 // replay the cube (negation of the clause)
112113 for (expr_ref_vector const & cube : probe_ctx->m_recorded_cubes ) {
113- lbool r = probe_ctx->check (cube.size (), cube.data ());
114-
114+ lbool r = probe_ctx->check (cube.size (), cube.data ());
115115 unsigned conflicts = probe_ctx->m_stats .m_num_conflicts ;
116116 unsigned decisions = probe_ctx->m_stats .m_num_decisions ;
117-
117+ IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER " << i << " : cube replay result " << r <<
118+ " , conflicts = " << conflicts << " , decisions = " << decisions << " \n " );
118119 score += conflicts + decisions;
119120 }
120121
121- if (i > 0 && score < best_score) {
122- found_better_params = true ;
123- best_param_state = mutated_param_state;
122+ if (i == 0 ) {
124123 best_score = score;
125- } else {
124+ IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER: baseline score = " << best_score << " \n " );
125+ }
126+ else if (score < best_score) {
127+ found_better_params = true ;
128+ best_param_state = mutated_param_state;
126129 best_score = score;
127130 }
128131 }
129-
130- return {best_param_state, found_better_params};
132+ // NOTE: we either need to apply the best params found that are better than base line
133+ // or, we have to implement a comparison operator for param_values (what would this do?)
134+ // or, we update the param state every single time even if it hasn't changed (what would this do?)
135+ // for now, I went with option 1
136+ if (found_better_params) {
137+ m_param_state = best_param_state;
138+ auto p = apply_param_values (m_param_state);
139+ b.set_param_state (p);
140+ IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER found better param state\n " );
141+ }
142+ else {
143+ IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER retained current param state\n " );
144+ }
131145 }
132146
133147 void parallel::param_generator::init_param_state () {
@@ -144,9 +158,23 @@ namespace smt {
144158 m_param_state.push_back (
145159 {symbol (" smt.arith.nl.propagate_linear_monomials" ), smtp.arith_nl_propagate_linear_monomials ()});
146160 m_param_state.push_back ({symbol (" smt.arith.nl.tangents" ), smtp.arith_nl_tangents ()});
147-
148161 };
149162
163+ params_ref parallel::param_generator::apply_param_values (param_values const &pv) {
164+ params_ref p = m_p;
165+ for (auto const &[k, v] : pv) {
166+ if (std::holds_alternative<unsigned_value>(v)) {
167+ unsigned_value uv = std::get<unsigned_value>(v);
168+ p.set_uint (k, uv.value );
169+ }
170+ else if (std::holds_alternative<bool >(v)) {
171+ bool bv = std::get<bool >(v);
172+ p.set_bool (k, bv);
173+ }
174+ }
175+ return p;
176+ }
177+
150178 parallel::param_generator::param_values parallel::param_generator::mutate_param_state () {
151179 param_values new_param_values (m_param_state);
152180 unsigned index = ctx->get_random_value () % new_param_values.size ();
@@ -161,7 +189,7 @@ namespace smt {
161189 while (new_value == value) {
162190 new_value = lo + ctx->get_random_value () % (hi - lo + 1 );
163191 }
164- std::get <unsigned_value>(param.second ). value = new_value;
192+ std::get_if <unsigned_value>(& param.second )-> value = new_value;
165193 }
166194 return new_param_values;
167195 }
@@ -174,20 +202,7 @@ namespace smt {
174202
175203 switch (r) {
176204 case l_undef: {
177- auto [best_param_state, found_better_params] = replay_proof_prefixes ();
178-
179- // NOTE: we either need to return a pair from replay_proof_prefixes so we can return a boolean flag indicating whether better params were found.
180- // or, we have to implement a comparison operator for param_values
181- // or, we update the param state every single time even if it hasn't changed
182- // for now, I went with option 1
183- if (found_better_params) {
184- m_param_state = best_param_state;
185- auto p = apply_param_values (m_param_state);
186- b.set_param_state (p);
187- IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER found better param state\n " );
188- } else {
189- IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER retained current param state\n " );
190- }
205+ replay_proof_prefixes ();
191206 }
192207 case l_true: {
193208 IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER found formula sat\n " );
@@ -222,9 +237,8 @@ namespace smt {
222237 LOG_WORKER (1 , " CUBE SIZE IN MAIN LOOP: " << cube.size () << " \n " );
223238
224239 // apply current best param state from batch manager
225- smt_params params = b.get_best_param_state ();
226240 params_ref p;
227- params. updt_params (p);
241+ b. get_param_state (p);
228242 ctx->updt_params (p);
229243
230244 lbool r = check_cube (cube);
@@ -454,10 +468,9 @@ namespace smt {
454468 }
455469 }
456470
457- // todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure.
458- params_ref parallel::batch_manager::get_best_param_state () {
471+ void parallel::batch_manager::get_param_state (params_ref& p) {
459472 std::scoped_lock lock (mux);
460- return m_param_state;
473+ p. copy ( m_param_state) ;
461474 }
462475
463476 void parallel::worker::collect_shared_clauses () {
0 commit comments