Skip to content

Commit dc5a8d3

Browse files
committed
Implement selected_atoms
1 parent 86ca791 commit dc5a8d3

File tree

3 files changed

+154
-45
lines changed

3 files changed

+154
-45
lines changed

regtest/metatensor/rt-soap/plumed.dat

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,23 @@ soap: METATENSOR ...
88
SPECIES_TO_TYPES=6,1,8
99
...
1010

11+
soap_selected: METATENSOR ...
12+
MODEL=soap_cv.pt
13+
EXTENSIONS_DIRECTORY=extensions
14+
15+
SPECIES1=1-26
16+
SPECIES2=27-62
17+
SPECIES3=63-76
18+
SPECIES_TO_TYPES=6,1,8
19+
20+
# select out of order to make sure this is respected in the output
21+
SELECTED_ATOMS=2,3,1
22+
...
23+
1124

1225
scalar: SUM ARG=soap PERIODIC=NO
1326
BIASVALUE ARG=scalar
1427

1528

1629
PRINT ARG=soap FILE=soap_data STRIDE=1 FMT=%8.4f
30+
PRINT ARG=soap_selected FILE=soap_selected_data STRIDE=1 FMT=%8.4f
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#! FIELDS time soap_selected.1.1 soap_selected.1.2 soap_selected.1.3 soap_selected.2.1 soap_selected.2.2 soap_selected.2.3 soap_selected.3.1 soap_selected.3.2 soap_selected.3.3
2+
0.000000 6.0785 6.3903 6.9409 5.2246 4.6212 5.9061 5.3739 5.3189 6.4924

src/metatensor/metatensor.cpp

Lines changed: 138 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,11 @@ directory defines a custom machine learning CV that can be used with PLUMED.
3535
\par Examples
3636
3737
The following input shows how you can call metatensor and evaluate the model
38-
that is described in the file soap_cv.pt from PLUMED. To evaluate this model
39-
plumed is required to use code that is included in the directory extensions
40-
which has been specified using the `EXTENSIONS_DIRECTORY` flag. Numbered
41-
`SPECIES` labels are used to indicate the list of indices that belong to each
42-
atomic species in the model. The `SPECIES_TO_TYPE` keyword then provides
43-
information on the atom type for each species. The first number here is the
44-
atomic number of the atoms that have been specified using the `SPECIES1` flag,
45-
the second number is the atomic number of the atoms that have been specified
46-
using the `SPECIES2` flag and so on.
38+
that is described in the file `custom_cv.pt` from PLUMED.
4739
48-
\plumedfile soap: METATENSOR ... MODEL=soap_cv.pt
49-
EXTENSIONS_DIRECTORY=extensions
40+
\plumedfile
41+
metatensor_cv: METATENSOR ...
42+
MODEL=custom_cv.pt
5043
5144
SPECIES1=1-26
5245
SPECIES2=27-62
@@ -55,6 +48,47 @@ EXTENSIONS_DIRECTORY=extensions
5548
...
5649
\endplumedfile
5750
51+
The numbered `SPECIES` labels are used to indicate the list of atoms that belong
52+
to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
53+
provides information on the atom type for each species. The first number here is
54+
the atomic type of the atoms that have been specified using the `SPECIES1` flag,
55+
the second number is the atomic number of the atoms that have been specified
56+
using the `SPECIES2` flag and so on.
57+
58+
`METATENSOR` action also accepts the following options:
59+
60+
- `EXTENSIONS_DIRECTORY` should be the path to a directory containing
61+
TorchScript extensions (as shared libraries) that are required to load and
62+
execute the model. This matches the `collect_extensions` argument to
63+
`MetatensorAtomisticModel.export` in Python.
64+
- `NO_CONSISTENCY_CHECK` can be used to disable internal consistency checks;
65+
- `SELECTED_ATOMS` can be used to signal the metatensor models that it should
66+
only run its calculation for the selected subset of atoms. The model still
67+
need to know about all the atoms in the system (through the `SPECIES`
68+
keyword); but this can be used to reduce the calculation cost. Note that the
69+
indices of the selected atoms should start at 1 in the PLUMED input file, but
70+
they will be translated to start at 0 when given to the model (i.e. in
71+
Python/TorchScript, the `forward` method will receive a `selected_atoms` which
72+
starts at 0)
73+
74+
Here is another example with all the possible keywords:
75+
76+
\plumedfile
77+
soap: METATENSOR ...
78+
MODEL=soap.pt
79+
EXTENSION_DIRECTORY=extensions
80+
NO_CONSISTENCY_CHECK
81+
82+
SPECIES1=1-10
83+
SPECIES2=11-20
84+
SPECIES_TO_TYPES=8,13
85+
86+
# only run the calculation for the Aluminium (type 13) atoms, but
87+
# include the Oxygen (type 8) as potential neighbors.
88+
SELECTED_ATOMS=11-20
89+
...
90+
\endplumedfile
91+
5892
\par Collective variables and metatensor models
5993
6094
Collective variables are not yet part of the [known outputs][mts_outputs] for
@@ -309,8 +343,6 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
309343
}
310344

311345
this->atomic_types_ = torch::tensor(std::move(atomic_types));
312-
313-
// Request the atoms and check we have read in everything
314346
this->requestAtoms(all_atoms);
315347

316348
bool no_consistency_check = false;
@@ -346,10 +378,6 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
346378
output->explicit_gradients = {};
347379
evaluations_options_->outputs.insert("plumed::cv", output);
348380

349-
// TODO: selected_atoms
350-
// evaluations_options_->set_selected_atoms()
351-
352-
353381
// Determine which device we should use based on user input, what the model
354382
// supports and what's available
355383
auto available_devices = std::vector<torch::Device>();
@@ -435,9 +463,8 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
435463
auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
436464
this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
437465

438-
// setup storage for the computed CV: we need to run the model once to know
439-
// the shape of the output, so we use a dummy system with one since atom for
440-
// this
466+
// determine how many properties there will be in the output by running the
467+
// model once on a dummy system
441468
auto dummy_system = torch::make_intrusive<metatensor_torch::SystemHolder>(
442469
/*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
443470
/*positions = */ torch::zeros({0, 3}, tensor_options),
@@ -461,16 +488,52 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
461488
dummy_system->add_neighbor_list(request, neighbors);
462489
}
463490

491+
this->n_properties_ = static_cast<unsigned>(
492+
this->executeModel(dummy_system)->properties()->count()
493+
);
494+
495+
// parse and handle atom sub-selection. This is done AFTER determining the
496+
// output size, since the selection might not be valid for the dummy system
497+
std::vector<int32_t> selected_atoms;
498+
this->parseVector("SELECTED_ATOMS", selected_atoms);
499+
if (!selected_atoms.empty()) {
500+
auto selection_value = torch::zeros(
501+
{static_cast<int64_t>(selected_atoms.size()), 2},
502+
torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
503+
);
504+
505+
for (unsigned i=0; i<selected_atoms.size(); i++) {
506+
auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
507+
if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
508+
this->error(
509+
"Values in metatensor's SELECTED_ATOMS should be between 1 "
510+
"and the number of atoms (" + std::to_string(n_atoms) + "), "
511+
"got " + std::to_string(selected_atoms[i]));
512+
}
513+
// PLUMED input uses 1-based indexes, but metatensor wants 0-based
514+
selection_value[i][1] = selected_atoms[i] - 1;
515+
}
516+
517+
evaluations_options_->set_selected_atoms(
518+
torch::make_intrusive<metatensor_torch::LabelsHolder>(
519+
std::vector<std::string>{"system", "atom"}, selection_value
520+
)
521+
);
522+
}
523+
524+
// Now that we now both n_samples and n_properties, we can setup the
525+
// PLUMED-side storage for the computed CV
464526
if (output->per_atom) {
465-
this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
527+
auto selected_atoms = this->evaluations_options_->get_selected_atoms();
528+
if (selected_atoms.has_value()) {
529+
this->n_samples_ = static_cast<unsigned>(selected_atoms.value()->count());
530+
} else {
531+
this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
532+
}
466533
} else {
467534
this->n_samples_ = 1;
468535
}
469536

470-
this->n_properties_ = static_cast<unsigned>(
471-
this->executeModel(dummy_system)->properties()->count()
472-
);
473-
474537
if (n_samples_ == 1 && n_properties_ == 1) {
475538
log.printf(" the output of this model is a scalar\n");
476539

@@ -511,6 +574,8 @@ void MetatensorPlumedAction::createSystem() {
511574
plumed_merror(oss.str());
512575
}
513576

577+
// this->getTotAtoms()
578+
514579
const auto& cell = this->getPbc().getBox();
515580

516581
auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
@@ -714,25 +779,39 @@ void MetatensorPlumedAction::calculate() {
714779
}
715780
}
716781
} else {
717-
auto samples = block->samples()->as_metatensor();
718-
plumed_assert(samples.names().size() == 2);
719-
plumed_assert(samples.names()[0] == std::string("system"));
720-
plumed_assert(samples.names()[1] == std::string("atom"));
721-
722-
auto& samples_values = samples.values();
782+
auto samples = block->samples();
783+
plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
784+
785+
auto samples_values = samples->values().to(torch::kCPU);
786+
auto selected_atoms = this->evaluations_options_->get_selected_atoms();
787+
788+
// handle the possibility that samples are returned in
789+
// a non-sorted order.
790+
auto get_output_location = [&](unsigned i) {
791+
if (selected_atoms.has_value()) {
792+
// If the users picked some selected atoms, then we store the
793+
// output in the same order as the selection was given
794+
auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
795+
auto position = selected_atoms.value()->position(sample);
796+
plumed_assert(position.has_value());
797+
return static_cast<unsigned>(position.value());
798+
} else {
799+
return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
800+
}
801+
};
723802

724803
if (n_properties_ == 1) {
725804
// we have a single CV describing multiple things (i.e. atoms)
726805
for (unsigned i=0; i<n_samples_; i++) {
727-
auto atom_i = static_cast<size_t>(samples_values(i, 1));
728-
value->set(atom_i, torch_values[i][0].item<double>());
806+
auto output_i = get_output_location(i);
807+
value->set(output_i, torch_values[i][0].item<double>());
729808
}
730809
} else {
731810
// the CV is a matrix
732811
for (unsigned i=0; i<n_samples_; i++) {
733-
auto atom_i = static_cast<size_t>(samples_values(i, 1));
812+
auto output_i = get_output_location(i);
734813
for (unsigned j=0; j<n_properties_; j++) {
735-
value->set(atom_i * n_properties_ + j, torch_values[i][j].item<double>());
814+
value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
736815
}
737816
}
738817
}
@@ -759,23 +838,34 @@ void MetatensorPlumedAction::apply() {
759838
}
760839
}
761840
} else {
762-
auto samples = block->samples()->as_metatensor();
763-
plumed_assert(samples.names().size() == 2);
764-
plumed_assert(samples.names()[0] == std::string("system"));
765-
plumed_assert(samples.names()[1] == std::string("atom"));
766-
767-
auto& samples_values = samples.values();
841+
auto samples = block->samples();
842+
plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
843+
844+
auto samples_values = samples->values().to(torch::kCPU);
845+
auto selected_atoms = this->evaluations_options_->get_selected_atoms();
846+
847+
// see above for an explanation of why we use this function
848+
auto get_output_location = [&](unsigned i) {
849+
if (selected_atoms.has_value()) {
850+
auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
851+
auto position = selected_atoms.value()->position(sample);
852+
plumed_assert(position.has_value());
853+
return static_cast<unsigned>(position.value());
854+
} else {
855+
return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
856+
}
857+
};
768858

769859
if (n_properties_ == 1) {
770860
for (unsigned i=0; i<n_samples_; i++) {
771-
auto atom_i = static_cast<size_t>(samples_values(i, 1));
772-
output_grad[i][0] = value->getForce(atom_i);
861+
auto output_i = get_output_location(i);
862+
output_grad[i][0] = value->getForce(output_i);
773863
}
774864
} else {
775865
for (unsigned i=0; i<n_samples_; i++) {
776-
auto atom_i = static_cast<size_t>(samples_values(i, 1));
866+
auto output_i = get_output_location(i);
777867
for (unsigned j=0; j<n_properties_; j++) {
778-
output_grad[i][j] = value->getForce(atom_i * n_properties_ + j);
868+
output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
779869
}
780870
}
781871
}
@@ -842,6 +932,9 @@ namespace PLMD { namespace metatensor {
842932
keys.add("numbered", "SPECIES", "the atoms in each PLUMED species");
843933
keys.reset_style("SPECIES", "atoms");
844934

935+
keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
936+
keys.reset_style("SELECTED_ATOMS", "atoms");
937+
845938
keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatensor's atomic types");
846939
}
847940

0 commit comments

Comments
 (0)