Skip to content

Commit 9bd2037

Browse files
authored
Merge pull request #352 from google-deepmind/deepmind
Merge from `deepmind` to `main`.
2 parents b844928 + dff75d8 commit 9bd2037

File tree

12 files changed

+172
-40
lines changed

12 files changed

+172
-40
lines changed

mjpc/agent.cc

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ void Agent::Initialize(const mjModel* model) {
153153
planner_threads_ =
154154
std::max(1, NumAvailableHardwareThreads() - 3 - 2 * estimator_threads_);
155155

156+
// differentiable planning model
157+
// by default gradient-based planners use a differentiable model
158+
int gradient_planner = false;
159+
if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
160+
planner_ == kILQSPlanner) {
161+
gradient_planner = true;
162+
}
163+
differentiable_ =
164+
GetNumberOrDefault(gradient_planner, model, "agent_differentiable");
165+
jnt_solimp_.resize(model->njnt);
166+
geom_solimp_.resize(model->ngeom);
167+
pair_solimp_.resize(model->npair);
168+
156169
// delete the previous model after all the planners have been updated to use
157170
// the new one.
158171
if (old_model) {
@@ -279,6 +292,22 @@ void Agent::PlanIteration(ThreadPool* pool) {
279292
steps_ =
280293
mju_max(mju_min(horizon_ / timestep_ + 1, kMaxTrajectoryHorizon), 1);
281294

295+
// make model differentiable
296+
int differentiable = differentiable_;
297+
if (differentiable) {
298+
// cache solimp defaults
299+
for (int i = 0; i < model_->njnt; i++) {
300+
jnt_solimp_[i] = model_->jnt_solimp[mjNIMP * i];
301+
}
302+
for (int i = 0; i < model_->ngeom; i++) {
303+
geom_solimp_[i] = model_->geom_solimp[mjNIMP * i];
304+
}
305+
for (int i = 0; i < model_->npair; i++) {
306+
pair_solimp_[i] = model_->pair_solimp[mjNIMP * i];
307+
}
308+
MakeDifferentiable(model_);
309+
}
310+
282311
// plan
283312
if (!allocate_enabled) {
284313
// set state
@@ -312,6 +341,19 @@ void Agent::PlanIteration(ThreadPool* pool) {
312341
// release the planning residual function
313342
residual_fn_.reset();
314343
}
344+
345+
// restore solimp defaults
346+
if (differentiable) {
347+
for (int i = 0; i < model_->njnt; i++) {
348+
model_->jnt_solimp[mjNIMP * i] = jnt_solimp_[i];
349+
}
350+
for (int i = 0; i < model_->ngeom; i++) {
351+
model_->geom_solimp[mjNIMP * i] = geom_solimp_[i];
352+
}
353+
for (int i = 0; i < model_->npair; i++) {
354+
model_->pair_solimp[mjNIMP * i] = pair_solimp_[i];
355+
}
356+
}
315357
}
316358

317359
// call planner to update nominal policy
@@ -644,21 +686,23 @@ void Agent::GUI(mjUI& ui) {
644686
}
645687

646688
// ----- agent ----- //
647-
mjuiDef defAgent[] = {{mjITEM_SECTION, "Agent", 1, nullptr, "AP"},
648-
{mjITEM_BUTTON, "Reset", 2, nullptr, " #459"},
649-
{mjITEM_SELECT, "Planner", 2, &planner_, ""},
650-
{mjITEM_SELECT, "Estimator", 2, &estimator_, ""},
651-
{mjITEM_CHECKINT, "Plan", 2, &plan_enabled, ""},
652-
{mjITEM_CHECKINT, "Action", 2, &action_enabled, ""},
653-
{mjITEM_CHECKINT, "Plots", 2, &plot_enabled, ""},
654-
{mjITEM_CHECKINT, "Traces", 2, &visualize_enabled, ""},
655-
{mjITEM_SEPARATOR, "Agent Settings", 1},
656-
{mjITEM_SLIDERNUM, "Horizon", 2, &horizon_, "0 1"},
657-
{mjITEM_SLIDERNUM, "Timestep", 2, &timestep_, "0 1"},
658-
{mjITEM_SELECT, "Integrator", 2, &integrator_,
659-
"Euler\nRK4\nImplicit\nImplicitFast"},
660-
{mjITEM_SEPARATOR, "Planner Settings", 1},
661-
{mjITEM_END}};
689+
mjuiDef defAgent[] = {
690+
{mjITEM_SECTION, "Agent", 1, nullptr, "AP"},
691+
{mjITEM_BUTTON, "Reset", 2, nullptr, " #459"},
692+
{mjITEM_SELECT, "Planner", 2, &planner_, ""},
693+
{mjITEM_SELECT, "Estimator", 2, &estimator_, ""},
694+
{mjITEM_CHECKINT, "Plan", 2, &plan_enabled, ""},
695+
{mjITEM_CHECKINT, "Action", 2, &action_enabled, ""},
696+
{mjITEM_CHECKINT, "Plots", 2, &plot_enabled, ""},
697+
{mjITEM_CHECKINT, "Traces", 2, &visualize_enabled, ""},
698+
{mjITEM_SEPARATOR, "Agent Settings", 1},
699+
{mjITEM_SLIDERNUM, "Horizon", 2, &horizon_, "0 1"},
700+
{mjITEM_SLIDERNUM, "Timestep", 2, &timestep_, "0 1"},
701+
{mjITEM_SELECT, "Integrator", 2, &integrator_,
702+
"Euler\nRK4\nImplicit\nImplicitFast"},
703+
{mjITEM_CHECKINT, "Differentiable", 2, &differentiable_, ""},
704+
{mjITEM_SEPARATOR, "Planner Settings", 1},
705+
{mjITEM_END}};
662706

663707
// planner names
664708
mju::strcpy_arr(defAgent[2].other, planner_names_);
@@ -730,6 +774,14 @@ void Agent::AgentEvent(mjuiItem* it, mjData* data,
730774
this->PlotInitialize();
731775
this->PlotReset();
732776

777+
// by default gradient-based planners use a differentiable model
778+
if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
779+
planner_ == kILQSPlanner) {
780+
differentiable_ = true;
781+
} else {
782+
differentiable_ = false;
783+
}
784+
733785
// reset agent
734786
uiloadrequest.fetch_sub(1);
735787
}

mjpc/agent.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ class Agent {
247247

248248
// max threads for estimation
249249
int estimator_threads_;
250+
251+
// differentiable planning model
252+
bool differentiable_;
253+
std::vector<double> jnt_solimp_;
254+
std::vector<double> geom_solimp_;
255+
std::vector<double> pair_solimp_;
250256
};
251257

252258
} // namespace mjpc

mjpc/planners/cross_entropy/planner.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <algorithm>
1818
#include <chrono>
1919
#include <cmath>
20+
#include <mutex>
2021
#include <shared_mutex>
2122

2223
#include <absl/random/random.h>
@@ -54,8 +55,11 @@ void CrossEntropyPlanner::Initialize(mjModel* model, const Task& task) {
5455
// sampling noise
5556
std_initial_ =
5657
GetNumberOrDefault(0.1, model,
57-
"sampling_exploration"); // initial variance
58-
std_min_ = GetNumberOrDefault(0.1, model, "std_min"); // minimum variance
58+
"sampling_exploration"); // initial variance
59+
std_min_ = GetNumberOrDefault(0.01, model, "std_min"); // minimum variance
60+
// fraction of the trajectories that will use full exploration noise
61+
explore_fraction_ =
62+
GetNumberOrDefault(0.0, model, "explore_fraction");
5963

6064
// set number of trajectories to rollout
6165
num_trajectory_ = GetNumberOrDefault(10, model, "sampling_trajectories");
@@ -227,12 +231,13 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
227231
for (int i = 0; i < n_elite; i++) {
228232
// ordered trajectory index
229233
int idx = trajectory_order[i];
234+
const TimeSpline& elite_plan = candidate_policy[idx].plan;
230235

231236
// add parameters
232-
for (int i = 0; i < num_spline_points; i++) {
233-
TimeSpline::Node n = candidate_policy[idx].plan.NodeAt(i);
237+
for (int t = 0; t < num_spline_points; t++) {
238+
TimeSpline::ConstNode n = elite_plan.NodeAt(t);
234239
for (int j = 0; j < model->nu; j++) {
235-
parameters_scratch[i * model->nu + j] += n.values()[j];
240+
parameters_scratch[t * model->nu + j] += n.values()[j];
236241
}
237242
}
238243

@@ -247,12 +252,15 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
247252

248253
// loop over elites to compute variance
249254
std::fill(variance.begin(), variance.end(), 0.0); // reset variance to zero
250-
for (int t = 0; t < num_spline_points; t++) {
251-
TimeSpline::Node n = candidate_policy[trajectory_order[0]].plan.NodeAt(t);
252-
for (int j = 0; j < model->nu; j++) {
253-
// average
254-
double p_avg = parameters_scratch[t * model->nu + j];
255-
for (int i = 0; i < n_elite; i++) {
255+
for (int i = 0; i < n_elite; i++) {
256+
int idx = trajectory_order[i];
257+
const TimeSpline& elite_plan = candidate_policy[idx].plan;
258+
for (int t = 0; t < num_spline_points; t++) {
259+
TimeSpline::ConstNode n = elite_plan.NodeAt(t);
260+
for (int j = 0; j < model->nu; j++) {
261+
// average
262+
double p_avg = parameters_scratch[t * model->nu + j];
263+
256264
// candidate parameter
257265
double pi = n.values()[j];
258266
double diff = pi - p_avg;
@@ -263,7 +271,7 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
263271

264272
// update
265273
{
266-
const std::shared_lock<std::shared_mutex> lock(mtx_);
274+
const std::unique_lock<std::shared_mutex> lock(mtx_);
267275
policy.plan.Clear();
268276
policy.plan.SetInterpolation(interpolation_);
269277
for (int t = 0; t < num_spline_points; t++) {
@@ -384,14 +392,21 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
384392

385393
// lock std_min
386394
double std_min = std_min_;
395+
double std_initial = std_initial_;
387396

388397
// random search
389398
int count_before = pool.GetCount();
390399
for (int i = 0; i < num_trajectory; i++) {
400+
double std;
401+
if (i < num_trajectory * explore_fraction_) {
402+
std = std_initial;
403+
} else {
404+
std = std_min;
405+
}
391406
pool.Schedule([&s = *this, &model = this->model, &task = this->task,
392407
&state = this->state, &time = this->time,
393408
&mocap = this->mocap, &userdata = this->userdata, horizon,
394-
std_min, i]() {
409+
std, i]() {
395410
// copy nominal policy and sample noise
396411
{
397412
const std::shared_lock<std::shared_mutex> lock(s.mtx_);
@@ -401,7 +416,7 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
401416
s.resampled_policy.plan.Interpolation());
402417

403418
// sample noise
404-
s.AddNoiseToPolicy(i, std_min);
419+
s.AddNoiseToPolicy(i, std);
405420
}
406421

407422
// ----- rollout sample policy ----- //
@@ -486,6 +501,7 @@ void CrossEntropyPlanner::GUI(mjUI& ui) {
486501
{mjITEM_SLIDERINT, "Spline Pts", 2, &policy.num_spline_points, "0 1"},
487502
{mjITEM_SLIDERNUM, "Init. Std", 2, &std_initial_, "0 1"},
488503
{mjITEM_SLIDERNUM, "Min. Std", 2, &std_min_, "0.01 0.5"},
504+
{mjITEM_SLIDERNUM, "Explore", 2, &explore_fraction_, "0.0 1.0"},
489505
{mjITEM_SLIDERINT, "Elite", 2, &n_elite_, "2 128"},
490506
{mjITEM_END}};
491507

mjpc/planners/cross_entropy/planner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class CrossEntropyPlanner : public Planner {
122122
double std_initial_; // standard deviation for sampling normal: N(0,
123123
// std)
124124
double std_min_; // the minimum allowable std
125+
double explore_fraction_ = 0; // fraction of trajectories that will use
126+
// std_initial instead of the variance from CEM
125127
std::vector<double> noise;
126128
std::vector<double> variance;
127129

mjpc/planners/ilqg/planner.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,19 @@ void iLQGPlanner::Iteration(int horizon, ThreadPool& pool) {
584584
std::cout << " dV: " << expected << '\n';
585585
std::cout << " dV[0]: " << backward_pass.dV[0] << '\n';
586586
std::cout << " dV[1]: " << backward_pass.dV[1] << '\n';
587-
std::cout << std::endl;
587+
588+
std::cout << "\niLQG Timing (ms)\n" << '\n';
589+
std::cout << " nominal: " << nominal_compute_time * 1.0e-3 << '\n';
590+
std::cout << " model derivative: "
591+
<< model_derivative_compute_time * 1.0e-3 << '\n';
592+
std::cout << " cost derivative: " << cost_derivative_compute_time * 1.0e-3
593+
<< '\n';
594+
std::cout << " backward pass: " << backward_pass_compute_time * 1.0e-3
595+
<< '\n';
596+
std::cout << " rollouts: " << rollouts_compute_time * 1.0e-3 << '\n';
597+
std::cout << " policy update: " << policy_update_compute_time * 1.0e-3
598+
<< '\n';
599+
std::cout << "\n\n";
588600
}
589601

590602
// stop timer

mjpc/planners/include.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222

2323
namespace mjpc {
2424

25+
// planner types
26+
enum PlannerType : int {
27+
kSamplingPlanner = 0,
28+
kGradientPlanner,
29+
kILQGPlanner,
30+
kILQSPlanner,
31+
kRobustPlanner,
32+
kCrossEntropyPlanner,
33+
kSampleGradientPlanner,
34+
};
35+
2536
// Planner names, separated by '\n'.
2637
extern const char kPlannerNames[];
2738

mjpc/tasks/quadruped/quadruped.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,21 @@ void QuadrupedFlat::ResidualFn::Residual(const mjModel* model,
182182
if (current_mode_ == kModeBiped) {
183183
// loosen the "hands" in Biped mode
184184
bool handstand = ReinterpretAsInt(parameters_[biped_type_param_id_]);
185+
double arm_posture = parameters_[arm_posture_param_id_];
185186
if (handstand) {
186-
residual[counter + 4] *= 0.03;
187-
residual[counter + 5] *= 0.03;
188-
residual[counter + 10] *= 0.03;
189-
residual[counter + 11] *= 0.03;
187+
residual[counter + 6] *= arm_posture;
188+
residual[counter + 7] *= arm_posture;
189+
residual[counter + 8] *= arm_posture;
190+
residual[counter + 9] *= arm_posture;
191+
residual[counter + 10] *= arm_posture;
192+
residual[counter + 11] *= arm_posture;
190193
} else {
191-
residual[counter + 1] *= 0.03;
192-
residual[counter + 2] *= 0.03;
193-
residual[counter + 7] *= 0.03;
194-
residual[counter + 8] *= 0.03;
194+
residual[counter + 0] *= arm_posture;
195+
residual[counter + 1] *= arm_posture;
196+
residual[counter + 2] *= arm_posture;
197+
residual[counter + 3] *= arm_posture;
198+
residual[counter + 4] *= arm_posture;
199+
residual[counter + 5] *= arm_posture;
195200
}
196201
}
197202
counter += model->nu;
@@ -521,6 +526,7 @@ void QuadrupedFlat::ResetLocked(const mjModel* model) {
521526
residual_.cadence_param_id_ = ParameterIndex(model, "Cadence");
522527
residual_.amplitude_param_id_ = ParameterIndex(model, "Amplitude");
523528
residual_.duty_param_id_ = ParameterIndex(model, "Duty ratio");
529+
residual_.arm_posture_param_id_ = ParameterIndex(model, "Arm posture");
524530
residual_.balance_cost_id_ = CostTermByName(model, "Balance");
525531
residual_.upright_cost_id_ = CostTermByName(model, "Upright");
526532
residual_.height_cost_id_ = CostTermByName(model, "Height");

mjpc/tasks/quadruped/quadruped.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class QuadrupedFlat : public Task {
203203
int cadence_param_id_ = -1;
204204
int amplitude_param_id_ = -1;
205205
int duty_param_id_ = -1;
206+
int arm_posture_param_id_ = -1;
206207
int upright_cost_id_ = -1;
207208
int balance_cost_id_ = -1;
208209
int height_cost_id_ = -1;

mjpc/tasks/quadruped/task_flat.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
<numeric name="residual_select_Biped type" data="0"/>
3030
<text name="residual_list_Biped type" data="Foot Stand|Hand Stand"/>
3131
<numeric name="residual_Heading" data="0 -3.14 3.14" />
32+
<numeric name="residual_Arm posture" data=".03 0 1"/>
3233

3334
<!-- estimator -->
3435
<numeric name="estimator" data="0" />

mjpc/tasks/shadow_reorient/task.xml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
<size memory="1M"/>
55

66
<custom>
7-
<numeric name="agent_planner" data="0" />
7+
<numeric name="agent_planner" data="5" />
88
<numeric name="agent_horizon" data="0.25" />
99
<numeric name="agent_timestep" data="0.01" />
1010
<numeric name="agent_policy_width" data="0.0035" />
1111
<numeric name="sampling_spline_points" data="5" />
12-
<numeric name="sampling_exploration" data="0.1" />
12+
<numeric name="sampling_exploration" data="0.2" />
1313
<numeric name="sampling_representation" data="0" />
14+
<numeric name="sampling_trajectories" data="60" />
15+
<numeric name="n_elite" data="8" />
16+
<numeric name="explore_fraction" data="0.5" />
17+
1418
<numeric name="robust_xfrc" data="0.004" />
1519
</custom>
1620

0 commit comments

Comments
 (0)