Skip to content

Commit 1b3a0ea

Browse files
authored
Merge pull request #5414 from rcclay/remove_history_method
Remove history method from RotatedSPOs
2 parents c302e86 + 212f95b commit 1b3a0ea

File tree

6 files changed

+21
-168
lines changed

6 files changed

+21
-168
lines changed

docs/intro_wavefunction.rst

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -951,8 +951,6 @@ Attribute:
951951
+=================+==========+================+=========+====================================+
952952
| ``name`` | Text | | | Name of rotated SPOSet |
953953
+-----------------+----------+----------------+---------+------------------------------------+
954-
| ``method`` | Text | global/history | global | Rotation matrix composition method |
955-
+-----------------+----------+----------------+---------+------------------------------------+
956954

957955
.. code-block::
958956
:caption: Orbital Rotation XML element.
@@ -989,14 +987,7 @@ These parameters are a subset of the full number of parameters in the kappa matr
989987
When rotations are combined, the entries corresponding to zero parameter derivatives can
990988
take on a non-zero value (i.e. the kappa matrix gets 'filled-in').
991989

992-
There are two ways to handle this.
993-
One way is to store a list of applied rotations.
994-
This method applies a new rotation to the coefficient matrix, and updates the coefficient matrix at each optimization step.
995-
This is the "history" method.
996-
997-
.. math:: C' = \exp(\kappa_n) \dots \exp(\kappa_1) \exp(\kappa_0) C
998-
999-
The other way is to track the full set of kappa values separately.
990+
QMCPACK handles this problem by tracking the full set of kappa values separately.
1000991
After the matrix multiplication to compose the rotations, the matrix log recovers the new kappa matrix entries.
1001992
This is the "global" method.
1002993
This method keeps a separate copy of the coefficient matrix and updates it using the global rotation matrix at each optimization step.

src/QMCWaveFunctions/RotatedSPOs.cpp

Lines changed: 18 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -125,50 +125,22 @@ void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active)
125125
myVars[i] = active[loc];
126126
}
127127

128-
if (use_global_rot_)
129-
{
130-
std::vector<ValueType> old_param(m_full_rot_inds_.size());
131-
std::copy_n(myVarsFull_.data(), myVarsFull_.size(), old_param.data());
132-
133-
applyDeltaRotation(delta_param, old_param, myVarsFull_);
134-
}
135-
else
136-
{
137-
apply_rotation(delta_param, false);
128+
std::vector<ValueType> old_param(m_full_rot_inds_.size());
129+
std::copy_n(myVarsFull_.data(), myVarsFull_.size(), old_param.data());
138130

139-
// Save the parameters in the history list
140-
history_params_.push_back(delta_param);
141-
}
131+
applyDeltaRotation(delta_param, old_param, myVarsFull_);
142132
}
143133

144134
void RotatedSPOs::writeVariationalParameters(hdf_archive& hout)
145135
{
146136
hout.push("RotatedSPOs");
147-
if (use_global_rot_)
148-
{
149-
hout.push("rotation_global");
150-
const std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();
151137

152-
hout.write(myVarsFull_, rot_global_name);
153-
hout.pop();
154-
}
155-
else
156-
{
157-
hout.push("rotation_history");
158-
size_t rows = history_params_.size();
159-
size_t cols = 0;
160-
if (rows > 0)
161-
cols = history_params_[0].size();
162-
163-
Matrix<ValueType> tmp(rows, cols);
164-
for (size_t i = 0; i < rows; i++)
165-
for (size_t j = 0; j < cols; j++)
166-
tmp[i][j] = history_params_[i][j];
167-
std::string rot_hist_name = std::string("rotation_history_") + SPOSet::getName();
168-
hout.write(tmp, rot_hist_name);
169-
hout.pop();
170-
}
138+
hout.push("rotation_global");
139+
const std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();
171140

141+
hout.write(myVarsFull_, rot_global_name);
142+
hout.pop();
143+
172144
// Save myVars in order to restore object state exactly
173145
// The values aren't meaningful, but they need to match those saved in VariableSet
174146
hout.push("rotation_params");
@@ -189,11 +161,7 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
189161
{
190162
hin.push("RotatedSPOs", false);
191163

192-
bool grp_hist_exists = hin.is_group("rotation_history");
193164
bool grp_global_exists = hin.is_group("rotation_global");
194-
if (!grp_hist_exists && !grp_global_exists)
195-
app_warning() << "Rotation parameters not found in VP file";
196-
197165

198166
if (grp_global_exists)
199167
{
@@ -217,29 +185,9 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
217185

218186
applyFullRotation(myVarsFull_, true);
219187
}
220-
else if (grp_hist_exists)
188+
else
221189
{
222-
hin.push("rotation_history", false);
223-
std::string rot_hist_name = std::string("rotation_history_") + SPOSet::getName();
224-
std::vector<int> sizes(2);
225-
if (!hin.getShape<ValueType>(rot_hist_name, sizes))
226-
throw std::runtime_error("Failed to read rotation history in VP file");
227-
228-
int rows = sizes[0];
229-
int cols = sizes[1];
230-
history_params_.resize(rows);
231-
Matrix<ValueType> tmp(rows, cols);
232-
hin.read(tmp, rot_hist_name);
233-
for (size_t i = 0; i < rows; i++)
234-
{
235-
history_params_[i].resize(cols);
236-
for (size_t j = 0; j < cols; j++)
237-
history_params_[i][j] = tmp(i, j);
238-
}
239-
240-
hin.pop();
241-
242-
applyRotationHistory();
190+
throw std::runtime_error("Error. No global rotation group in h5. Abort.");
243191
}
244192

245193
hin.push("rotation_params", false);
@@ -290,8 +238,8 @@ void RotatedSPOs::buildOptVariables(const size_t nel)
290238
RotationIndices created_m_act_rot_inds;
291239

292240
RotationIndices created_full_rot_inds;
293-
if (use_global_rot_)
294-
createRotationIndicesFull(nel, nmo, created_full_rot_inds);
241+
242+
createRotationIndicesFull(nel, nmo, created_full_rot_inds);
295243

296244
createRotationIndices(nel, nmo, created_m_act_rot_inds);
297245

@@ -306,13 +254,9 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota
306254
// create active rotations
307255
m_act_rot_inds_ = rotations;
308256

309-
if (use_global_rot_)
310-
m_full_rot_inds_ = full_rotations;
257+
m_full_rot_inds_ = full_rotations;
311258

312-
if (use_global_rot_)
313-
app_log() << "Orbital rotation using global rotation" << std::endl;
314-
else
315-
app_log() << "Orbital rotation using history" << std::endl;
259+
app_log() << "Orbital rotation using global rotation" << std::endl;
316260

317261
// This will add the orbital rotation parameters to myVars
318262
// and will also read in initial parameter values supplied in input file
@@ -349,13 +293,10 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota
349293
registerParameter(i, p, q, myVars, params_, false);
350294
}
351295

352-
if (use_global_rot_)
353-
{
354-
const size_t nfull_rot = m_full_rot_inds_.size();
355-
myVarsFull_.resize(nfull_rot);
356-
for (int i = 0; i < nfull_rot; i++)
357-
myVarsFull_[i] = (params_supplied_ && i < m_act_rot_inds_.size()) ? params_[i] : 0.0;
358-
}
296+
const size_t nfull_rot = m_full_rot_inds_.size();
297+
myVarsFull_.resize(nfull_rot);
298+
for (int i = 0; i < nfull_rot; i++)
299+
myVarsFull_[i] = (params_supplied_ && i < m_act_rot_inds_.size()) ? params_[i] : 0.0;
359300

360301
//Printing the parameters
361302
if (true)
@@ -464,14 +405,6 @@ void RotatedSPOs::applyFullRotation(const std::vector<ValueType>& full_param, bo
464405
Phi_->applyRotation(rot_mat, use_stored_copy);
465406
}
466407

467-
void RotatedSPOs::applyRotationHistory()
468-
{
469-
for (auto delta_param : history_params_)
470-
{
471-
apply_rotation(delta_param, false);
472-
}
473-
}
474-
475408
// compute exponential of a real, antisymmetric matrix by diagonalizing and exponentiating eigenvalues
476409
void RotatedSPOs::exponentiate_antisym_matrix(ValueMatrix& mat)
477410
{
@@ -1640,8 +1573,6 @@ std::unique_ptr<SPOSet> RotatedSPOs::makeClone() const
16401573
myclone->m_full_rot_inds_ = this->m_full_rot_inds_;
16411574
myclone->myVars = this->myVars;
16421575
myclone->myVarsFull_ = this->myVarsFull_;
1643-
myclone->history_params_ = this->history_params_;
1644-
myclone->use_global_rot_ = this->use_global_rot_;
16451576
return myclone;
16461577
}
16471578

src/QMCWaveFunctions/RotatedSPOs.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ namespace testing
2323
{
2424
const opt_variables_type& getMyVars(RotatedSPOs& rot);
2525
const std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot);
26-
const std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot);
2726
} // namespace testing
2827

2928
class RotatedSPOs : public SPOSet, public OptimizableObject
@@ -99,10 +98,6 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
9998
std::vector<ValueType>& new_param,
10099
ValueMatrix& new_rot_mat);
101100

102-
// When initializing the rotation from VP files
103-
// This function applies the rotation history
104-
void applyRotationHistory();
105-
106101
// This function applies the global rotation (similar to apply_rotation, but for the full
107102
// set of rotation parameters)
108103
void applyFullRotation(const std::vector<ValueType>& full_param, bool use_stored_copy);
@@ -400,9 +395,6 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
400395
// void evaluateThirdDeriv(const ParticleSet& P, int first, int last, GGGMatrix& grad_grad_grad_logdet)
401396
// {Phi->evaluateThridDeriv(P, first, last, grad_grad_grad_logdet); }
402397

403-
/// Use history list (false) or global rotation (true)
404-
void set_use_global_rotation(bool use_global_rotation) { use_global_rot_ = use_global_rotation; }
405-
406398
void mw_evaluateDetRatios(const RefVectorWithLeader<SPOSet>& spo_list,
407399
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
408400
const RefVector<ValueVector>& psi_list,
@@ -472,17 +464,10 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
472464
/// timer for apply_rotation
473465
NewTimer& apply_rotation_timer_;
474466

475-
/// List of previously applied parameters
476-
std::vector<std::vector<ValueType>> history_params_;
477-
478467
static RefVectorWithLeader<SPOSet> extractPhiRefList(const RefVectorWithLeader<SPOSet>& spo_list);
479468

480-
/// Use global rotation or history list
481-
bool use_global_rot_ = true;
482-
483469
friend const opt_variables_type& testing::getMyVars(RotatedSPOs& rot);
484470
friend const std::vector<ValueType>& testing::getMyVarsFull(RotatedSPOs& rot);
485-
friend const std::vector<std::vector<ValueType>>& testing::getHistoryParams(RotatedSPOs& rot);
486471
};
487472

488473

src/QMCWaveFunctions/SPOSetBuilder.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ std::unique_ptr<SPOSet> SPOSetBuilder::createRotatedSPOSet(xmlNodePtr cur)
130130
std::string method;
131131
OhmmsAttributeSet attrib;
132132
attrib.add(spo_object_name, "name");
133-
attrib.add(method, "method", {"global", "history"});
134133
attrib.put(cur);
135134

136135
std::unique_ptr<SPOSet> sposet;
@@ -151,9 +150,6 @@ std::unique_ptr<SPOSet> SPOSetBuilder::createRotatedSPOSet(xmlNodePtr cur)
151150
sposet->storeParamsBeforeRotation();
152151
auto rot_spo = std::make_unique<RotatedSPOs>(spo_object_name, std::move(sposet));
153152

154-
if (method == "history")
155-
rot_spo->set_use_global_rotation(false);
156-
157153
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
158154
if (cname == "opt_vars")
159155
{

src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,6 @@ namespace testing
732732
{
733733
const opt_variables_type& getMyVars(RotatedSPOs& rot) { return rot.myVars; }
734734
const std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; }
735-
const std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
736735
} // namespace testing
737736

738737
// Test using global rotation
@@ -787,55 +786,6 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
787786
CHECK(full_var[i] == ValueApprox(vs_values[i]));
788787
}
789788

790-
// Test using history list.
791-
TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
792-
{
793-
//Problem with h5 parameter parsing for complex build. To be fixed in future PR.
794-
auto fake_spo = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
795-
fake_spo->setOrbitalSetSize(4);
796-
RotatedSPOs rot("fake_rot", std::move(fake_spo));
797-
rot.set_use_global_rotation(false);
798-
int nel = 2;
799-
rot.buildOptVariables(nel);
800-
801-
std::vector<SPOSet::ValueType> vs_values{0.1, 0.15, 0.2, 0.25};
802-
803-
optimize::VariableSet vs;
804-
rot.checkInVariablesExclusive(vs);
805-
auto* vs_values_data_real = (SPOSet::RealType*)vs_values.data();
806-
for (size_t i = 0; i < vs.size(); i++)
807-
vs[i] = vs_values_data_real[i];
808-
rot.resetParametersExclusive(vs);
809-
810-
{
811-
hdf_archive hout;
812-
vs.writeToHDF("rot_vp_hist.h5", hout);
813-
814-
rot.writeVariationalParameters(hout);
815-
}
816-
817-
auto fake_spo2 = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
818-
fake_spo2->setOrbitalSetSize(4);
819-
820-
RotatedSPOs rot2("fake_rot", std::move(fake_spo2));
821-
rot2.buildOptVariables(nel);
822-
823-
optimize::VariableSet vs2;
824-
rot2.checkInVariablesExclusive(vs2);
825-
826-
hdf_archive hin;
827-
vs2.readFromHDF("rot_vp_hist.h5", hin);
828-
rot2.readVariationalParameters(hin);
829-
830-
auto& var = testing::getMyVars(rot2);
831-
for (size_t i = 0; i < var.size(); i++)
832-
CHECK(var[i] == Approx(vs[i]));
833-
834-
const auto hist = testing::getHistoryParams(rot2);
835-
REQUIRE(hist.size() == 1);
836-
REQUIRE(hist[0].size() == 4);
837-
}
838-
839789
template<typename T>
840790
class DummySPOSetWithoutMW : public SPOSetT<T>
841791
{

src/QMCWaveFunctions/tests/test_RotatedSPOs_LCAO.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,15 @@ TEST_CASE("Rotated LCAO WF2 with jastrow", "[qmcapp]")
381381
</atomicBasisSet>
382382
</basisset>
383383
<rotated_sposet name="rot-spo-up">
384-
<sposet basisset="LCAOBSet" name="spo-up" method="history">
384+
<sposet basisset="LCAOBSet" name="spo-up">
385385
<coefficient id="updetC" type="Array" size="2">
386386
1.0 0.0
387387
0.0 1.0
388388
</coefficient>
389389
</sposet>
390390
</rotated_sposet>
391391
<rotated_sposet name="rot-spo-down">
392-
<sposet basisset="LCAOBSet" name="spo-down" method="history">
392+
<sposet basisset="LCAOBSet" name="spo-down">
393393
<coefficient id="updetC" type="Array" size="2">
394394
1.0 0.0
395395
0.0 1.0

0 commit comments

Comments
 (0)