Skip to content

Commit

Permalink
Include neighbor lists
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 4, 2024
1 parent f41f4e3 commit 2b688a6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 6 deletions.
76 changes: 71 additions & 5 deletions metatensor-torch/src/atomistic/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::string NeighborListOptionsHolder::str() const {
", full_list=" + (full_list_ ? "True" : "False") + ")";
}

std::string NeighborListOptionsHolder::to_json() const {
static nlohmann::json neighbor_list_options_to_json(const NeighborListOptionsHolder& self) {
nlohmann::json result;

result["class"] = "NeighborListOptions";
Expand All @@ -86,12 +86,17 @@ std::string NeighborListOptionsHolder::to_json() const {
// round-tripping of the data
static_assert(sizeof(double) == sizeof(int64_t));
int64_t int_cutoff = 0;
std::memcpy(&int_cutoff, &this->cutoff_, sizeof(double));
double cutoff = self.cutoff();
std::memcpy(&int_cutoff, &cutoff, sizeof(double));
result["cutoff"] = int_cutoff;
result["full_list"] = this->full_list_;
result["length_unit"] = this->length_unit_;
result["full_list"] = self.full_list();
result["length_unit"] = self.length_unit();

return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
return result;
}

std::string NeighborListOptionsHolder::to_json() const {
return neighbor_list_options_to_json(*this).dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
}

NeighborListOptions NeighborListOptionsHolder::from_json(const std::string& json) {
Expand Down Expand Up @@ -846,6 +851,42 @@ torch::Tensor tensor_from_json(const nlohmann::json& data) {
}));
}

nlohmann::json neighbor_list_block_to_json(const TensorBlockHolder& block) {
// Specific to NLs. This means that we don't have to do gradients and
// worry save/load metadata outside of the samples.
nlohmann::json result;
result["samples"] = tensor_to_json(block.samples()->values());
result["values"] = tensor_to_json(block.values());
return result;
}

TensorBlockHolder neighbor_list_block_from_json(const nlohmann::json& data) {
if (!data.contains("samples")) {
throw std::runtime_error("expected 'samples' in JSON for neighbor list block, did not find it");
}
auto samples_values = tensor_from_json(data["samples"]);

auto names = std::vector<std::string>({"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"});
auto samples = torch::make_intrusive<LabelsHolder>(
/*names=*/std::move(names),
/*values=*/std::move(samples_values)
);
auto components = LabelsHolder::create({"xyz"}, {{0}, {1}, {2}});
auto properties = LabelsHolder::create({"distance"}, {{0}});

if (!data.contains("values")) {
throw std::runtime_error("expected 'values' in JSON for neighbor list block, did not find it");
}
auto values = tensor_from_json(data["values"]);

return TensorBlockHolder(
/*values=*/values,
/*samples=*/samples,
/*components=*/{components},
/*properties=*/properties
);
}

std::string SystemHolder::to_json() const {
nlohmann::json result;

Expand All @@ -854,6 +895,16 @@ std::string SystemHolder::to_json() const {
result["cell"] = tensor_to_json(this->cell());
result["types"] = tensor_to_json(this->types());

result["neighbor_lists"] = std::vector<nlohmann::json>();
for (auto nl_option: this->known_neighbor_lists()) {
auto nl_data = this->get_neighbor_list(nl_option);

auto nl_json = nlohmann::json();
nl_json["options"] = neighbor_list_options_to_json(*nl_option);
nl_json["data"] = neighbor_list_block_to_json(*nl_data);
result["neighbor_lists"].emplace_back(nl_json);
}

return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
}

Expand Down Expand Up @@ -888,5 +939,20 @@ System SystemHolder::from_json(const std::string_view json) {
auto types = tensor_from_json(data["types"]);

auto system = torch::make_intrusive<SystemHolder>(types, positions, cell);

for (const auto& nl_data: data["neighbor_lists"]) {
if (!nl_data.contains("options") || !nl_data["options"].is_object()) {
throw std::runtime_error("expected 'options' in JSON for neighbor list, did not find it");
}
auto options = NeighborListOptionsHolder::from_json(nl_data["options"].dump());

if (!nl_data.contains("data") || !nl_data["data"].is_object()) {
throw std::runtime_error("expected 'data' in JSON for neighbor list, did not find it");
}
auto block = neighbor_list_block_from_json(nl_data["data"]);

system->add_neighbor_list(options, torch::make_intrusive<TensorBlockHolder>(std::move(block)));
}

return system;
}
29 changes: 28 additions & 1 deletion python/metatensor-torch/tests/atomistic/systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,41 @@ def check_dtype(system: System, dtype: torch.dtype):


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_save_load(tmpdir, dtype):
@pytest.mark.parametrize("nl_size", [10, 100, 1000])
def test_save_load(tmpdir, dtype, nl_size):
system = System(
types=torch.tensor([1, 2, 3, 4]),
positions=torch.rand((4, 3), dtype=dtype),
cell=torch.rand((3, 3), dtype=dtype),
)
system.add_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False),
TensorBlock(
values=torch.rand(nl_size, 3, 1, dtype=dtype),
samples=Labels(
[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
torch.arange(nl_size * 5, dtype=torch.int64).reshape(nl_size, 5),
),
components=[Labels.range("xyz", 3)],
properties=Labels.range("distance", 1),
),
)

torch.save(system, os.path.join(tmpdir, "system.pt"))
system_loaded = torch.load(os.path.join(tmpdir, "system.pt"))
assert torch.equal(system.types, system_loaded.types)
assert torch.equal(system.positions, system_loaded.positions)
assert torch.equal(system.cell, system_loaded.cell)
neigbor_list = system.get_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False)
)
neighbor_list_loaded = system_loaded.get_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False)
)
assert metatensor.torch.equal_block(neigbor_list, neighbor_list_loaded)

0 comments on commit 2b688a6

Please sign in to comment.