Skip to content

Commit 3671219

Browse files
mpcnet adaptation
1 parent 66e8d7a commit 3671219

File tree

5 files changed

+29
-31
lines changed

5 files changed

+29
-31
lines changed

ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetDummyNode.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,19 @@ int main(int argc, char** argv) {
6969

7070
// policy (MPC-Net controller)
7171
auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
72-
std::shared_ptr<ocs2::mpcnet::MpcnetDefinitionBase> mpcnetDefinitionPtr(new BallbotMpcnetDefinition);
73-
std::unique_ptr<ocs2::mpcnet::MpcnetControllerBase> mpcnetControllerPtr(
74-
new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr));
72+
auto mpcnetDefinitionPtr = std::make_shared<BallbotMpcnetDefinition>();
73+
auto mpcnetControllerPtr =
74+
std::make_unique<ocs2::mpcnet::MpcnetOnnxController>(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr);
7575
mpcnetControllerPtr->loadPolicyModel(policyFilePath);
7676

7777
// rollout
7878
std::unique_ptr<RolloutBase> rolloutPtr(ballbotInterface.getRollout().clone());
7979

8080
// observer
81-
std::shared_ptr<ocs2::mpcnet::MpcnetDummyObserverRos> mpcnetDummyObserverRosPtr(
82-
new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName));
81+
auto mpcnetDummyObserverRosPtr = std::make_shared<ocs2::mpcnet::MpcnetDummyObserverRos>(nodeHandle, robotName);
8382

8483
// visualization
85-
std::shared_ptr<BallbotDummyVisualization> ballbotDummyVisualization(new BallbotDummyVisualization(nodeHandle));
84+
auto ballbotDummyVisualization = std::make_shared<BallbotDummyVisualization>(nodeHandle);
8685

8786
// MPC-Net dummy loop ROS
8887
const scalar_t controlFrequency = ballbotInterface.mpcSettings().mrtDesiredFrequency_;

ocs2_mpcnet/ocs2_ballbot_mpcnet/src/BallbotMpcnetInterface.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ BallbotMpcnetInterface::BallbotMpcnetInterface(size_t nDataGenerationThreads, si
6161
referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
6262
for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) {
6363
BallbotInterface ballbotInterface(taskFile, libraryFolder);
64-
std::shared_ptr<ocs2::mpcnet::MpcnetDefinitionBase> mpcnetDefinitionPtr(new BallbotMpcnetDefinition);
64+
auto mpcnetDefinitionPtr = std::make_shared<BallbotMpcnetDefinition>();
6565
mpcPtrs.push_back(getMpc(ballbotInterface));
66-
mpcnetPtrs.push_back(std::unique_ptr<ocs2::mpcnet::MpcnetControllerBase>(
67-
new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, ballbotInterface.getReferenceManagerPtr(), onnxEnvironmentPtr)));
66+
mpcnetPtrs.push_back(std::make_unique<ocs2::mpcnet::MpcnetOnnxController>(
67+
mpcnetDefinitionPtr, ballbotInterface.getReferenceManagerPtr(), onnxEnvironmentPtr));
6868
if (raisim) {
6969
throw std::runtime_error("[BallbotMpcnetInterface::BallbotMpcnetInterface] raisim rollout not yet implemented for ballbot.");
7070
} else {
@@ -101,8 +101,8 @@ std::unique_ptr<MPC_BASE> BallbotMpcnetInterface::getMpc(BallbotInterface& ballb
101101
return settings;
102102
}();
103103
// create one MPC instance
104-
std::unique_ptr<MPC_BASE> mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, ballbotInterface.getRollout(),
105-
ballbotInterface.getOptimalControlProblem(), ballbotInterface.getInitializer()));
104+
auto mpcPtr = std::make_unique<GaussNewtonDDP_MPC>(mpcSettings, ddpSettings, ballbotInterface.getRollout(),
105+
ballbotInterface.getOptimalControlProblem(), ballbotInterface.getInitializer());
106106
mpcPtr->getSolverPtr()->setReferenceManager(ballbotInterface.getReferenceManagerPtr());
107107
return mpcPtr;
108108
}

ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetDummyNode.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ int main(int argc, char** argv) {
7979

8080
// policy (MPC-Net controller)
8181
auto onnxEnvironmentPtr = ocs2::mpcnet::createOnnxEnvironment();
82-
std::shared_ptr<ocs2::mpcnet::MpcnetDefinitionBase> mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(leggedRobotInterface));
83-
std::unique_ptr<ocs2::mpcnet::MpcnetControllerBase> mpcnetControllerPtr(
84-
new ocs2::mpcnet::MpcnetOnnxController(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr));
82+
auto mpcnetDefinitionPtr = std::make_shared<LeggedRobotMpcnetDefinition>(leggedRobotInterface);
83+
auto mpcnetControllerPtr =
84+
std::make_unique<ocs2::mpcnet::MpcnetOnnxController>(mpcnetDefinitionPtr, rosReferenceManagerPtr, onnxEnvironmentPtr);
8585
mpcnetControllerPtr->loadPolicyModel(policyFile);
8686

8787
// rollout
@@ -121,8 +121,7 @@ int main(int argc, char** argv) {
121121
}
122122

123123
// observer
124-
std::shared_ptr<ocs2::mpcnet::MpcnetDummyObserverRos> mpcnetDummyObserverRosPtr(
125-
new ocs2::mpcnet::MpcnetDummyObserverRos(nodeHandle, robotName));
124+
auto mpcnetDummyObserverRosPtr = std::make_shared<ocs2::mpcnet::MpcnetDummyObserverRos>(nodeHandle, robotName);
126125

127126
// visualization
128127
CentroidalModelPinocchioMapping pinocchioMapping(leggedRobotInterface.getCentroidalModelInfo());

ocs2_mpcnet/ocs2_legged_robot_mpcnet/src/LeggedRobotMpcnetInterface.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ LeggedRobotMpcnetInterface::LeggedRobotMpcnetInterface(size_t nDataGenerationThr
6464
mpcnetDefinitionPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
6565
referenceManagerPtrs.reserve(nDataGenerationThreads + nPolicyEvaluationThreads);
6666
for (int i = 0; i < (nDataGenerationThreads + nPolicyEvaluationThreads); i++) {
67-
leggedRobotInterfacePtrs_.push_back(std::unique_ptr<LeggedRobotInterface>(new LeggedRobotInterface(taskFile, urdfFile, referenceFile)));
68-
std::shared_ptr<ocs2::mpcnet::MpcnetDefinitionBase> mpcnetDefinitionPtr(new LeggedRobotMpcnetDefinition(*leggedRobotInterfacePtrs_[i]));
67+
leggedRobotInterfacePtrs_.push_back(std::make_unique<LeggedRobotInterface>(taskFile, urdfFile, referenceFile));
68+
auto mpcnetDefinitionPtr = std::make_shared<LeggedRobotMpcnetDefinition>(*leggedRobotInterfacePtrs_[i]);
6969
mpcPtrs.push_back(getMpc(*leggedRobotInterfacePtrs_[i]));
7070
mpcnetPtrs.push_back(std::unique_ptr<ocs2::mpcnet::MpcnetControllerBase>(new ocs2::mpcnet::MpcnetOnnxController(
7171
mpcnetDefinitionPtr, leggedRobotInterfacePtrs_[i]->getReferenceManagerPtr(), onnxEnvironmentPtr)));
7272
if (raisim) {
7373
RaisimRolloutSettings raisimRolloutSettings(raisimFile, "rollout");
7474
raisimRolloutSettings.portNumber_ += i;
75-
leggedRobotRaisimConversionsPtrs_.push_back(std::unique_ptr<LeggedRobotRaisimConversions>(new LeggedRobotRaisimConversions(
75+
leggedRobotRaisimConversionsPtrs_.push_back(std::make_unique<LeggedRobotRaisimConversions>(
7676
leggedRobotInterfacePtrs_[i]->getPinocchioInterface(), leggedRobotInterfacePtrs_[i]->getCentroidalModelInfo(),
77-
leggedRobotInterfacePtrs_[i]->getInitialState())));
77+
leggedRobotInterfacePtrs_[i]->getInitialState()));
7878
leggedRobotRaisimConversionsPtrs_[i]->loadSettings(raisimFile, "rollout", true);
79-
rolloutPtrs.push_back(std::unique_ptr<RolloutBase>(new RaisimRollout(
79+
rolloutPtrs.push_back(std::make_unique<RaisimRollout>(
8080
urdfFile, resourcePath,
8181
[&, i](const vector_t& state, const vector_t& input) {
8282
return leggedRobotRaisimConversionsPtrs_[i]->stateToRaisimGenCoordGenVel(state, input);
@@ -90,7 +90,7 @@ LeggedRobotMpcnetInterface::LeggedRobotMpcnetInterface(size_t nDataGenerationThr
9090
nullptr, raisimRolloutSettings,
9191
[&, i](double time, const vector_t& input, const vector_t& state, const Eigen::VectorXd& q, const Eigen::VectorXd& dq) {
9292
return leggedRobotRaisimConversionsPtrs_[i]->inputToRaisimPdTargets(time, input, state, q, dq);
93-
})));
93+
}));
9494
if (raisimRolloutSettings.generateTerrain_) {
9595
raisim::TerrainProperties terrainProperties;
9696
terrainProperties.zScale = raisimRolloutSettings.terrainRoughness_;
@@ -132,9 +132,9 @@ std::unique_ptr<MPC_BASE> LeggedRobotMpcnetInterface::getMpc(LeggedRobotInterfac
132132
return settings;
133133
}();
134134
// create one MPC instance
135-
std::unique_ptr<MPC_BASE> mpcPtr(new GaussNewtonDDP_MPC(mpcSettings, ddpSettings, leggedRobotInterface.getRollout(),
136-
leggedRobotInterface.getOptimalControlProblem(),
137-
leggedRobotInterface.getInitializer()));
135+
auto mpcPtr =
136+
std::make_unique<GaussNewtonDDP_MPC>(mpcSettings, ddpSettings, leggedRobotInterface.getRollout(),
137+
leggedRobotInterface.getOptimalControlProblem(), leggedRobotInterface.getInitializer());
138138
mpcPtr->getSolverPtr()->setReferenceManager(leggedRobotInterface.getReferenceManagerPtr());
139139
return mpcPtr;
140140
}

ocs2_mpcnet/ocs2_mpcnet_core/src/rollout/MpcnetRolloutManager.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ MpcnetRolloutManager::MpcnetRolloutManager(size_t nDataGenerationThreads, size_t
4747
dataGenerationThreadPoolPtr_.reset(new ThreadPool(nDataGenerationThreads_));
4848
dataGenerationPtrs_.reserve(nDataGenerationThreads);
4949
for (int i = 0; i < nDataGenerationThreads; i++) {
50-
dataGenerationPtrs_.push_back(std::unique_ptr<MpcnetDataGeneration>(
51-
new MpcnetDataGeneration(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)),
52-
std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i))));
50+
dataGenerationPtrs_.push_back(
51+
std::make_unique<MpcnetDataGeneration>(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)),
52+
std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i)));
5353
}
5454
}
5555

@@ -59,9 +59,9 @@ MpcnetRolloutManager::MpcnetRolloutManager(size_t nDataGenerationThreads, size_t
5959
policyEvaluationThreadPoolPtr_.reset(new ThreadPool(nPolicyEvaluationThreads_));
6060
policyEvaluationPtrs_.reserve(nPolicyEvaluationThreads_);
6161
for (int i = nDataGenerationThreads_; i < (nDataGenerationThreads_ + nPolicyEvaluationThreads_); i++) {
62-
policyEvaluationPtrs_.push_back(std::unique_ptr<MpcnetPolicyEvaluation>(
63-
new MpcnetPolicyEvaluation(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)),
64-
std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i))));
62+
policyEvaluationPtrs_.push_back(
63+
std::make_unique<MpcnetPolicyEvaluation>(std::move(mpcPtrs.at(i)), std::move(mpcnetPtrs.at(i)), std::move(rolloutPtrs.at(i)),
64+
std::move(mpcnetDefinitionPtrs.at(i)), referenceManagerPtrs.at(i)));
6565
}
6666
}
6767
}

0 commit comments

Comments
 (0)