From 9bc3009703f4838f2b7ae0cceb4d0b9788d2a68c Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Jul 2024 14:08:10 +0200 Subject: [PATCH] Include neighbor lists --- metatensor-torch/src/atomistic/system.cpp | 76 +++++++++++++++++-- .../tests/atomistic/systems.py | 31 +++++++- 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/metatensor-torch/src/atomistic/system.cpp b/metatensor-torch/src/atomistic/system.cpp index 5bd415bce..377b25824 100644 --- a/metatensor-torch/src/atomistic/system.cpp +++ b/metatensor-torch/src/atomistic/system.cpp @@ -77,7 +77,7 @@ std::string NeighborListOptionsHolder::str() const { ", full_list=" + (full_list_ ? "True" : "False") + ")"; } -std::string NeighborListOptionsHolder::to_json() const { +static nlohmann::json neighbor_list_options_to_json(const NeighborListOptionsHolder& self) { nlohmann::json result; result["class"] = "NeighborListOptions"; @@ -86,12 +86,17 @@ std::string NeighborListOptionsHolder::to_json() const { // round-tripping of the data static_assert(sizeof(double) == sizeof(int64_t)); int64_t int_cutoff = 0; - std::memcpy(&int_cutoff, &this->cutoff_, sizeof(double)); + double cutoff = self.cutoff(); + std::memcpy(&int_cutoff, &cutoff, sizeof(double)); result["cutoff"] = int_cutoff; - result["full_list"] = this->full_list_; - result["length_unit"] = this->length_unit_; + result["full_list"] = self.full_list(); + result["length_unit"] = self.length_unit(); - return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); + return result; +} + +std::string NeighborListOptionsHolder::to_json() const { + return neighbor_list_options_to_json(*this).dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); } NeighborListOptions NeighborListOptionsHolder::from_json(const std::string& json) { @@ -846,6 +851,42 @@ torch::Tensor tensor_from_json(const nlohmann::json& data) { })); } +nlohmann::json neighbor_list_block_to_json(const TensorBlockHolder& block) { + // Specific to NLs. This means that we don't have to do gradients and + // worry save/load metadata outside of the samples. + nlohmann::json result; + result["samples"] = tensor_to_json(block.samples()->values()); + result["values"] = tensor_to_json(block.values()); + return result; +} + +TensorBlockHolder neighbor_list_block_from_json(const nlohmann::json& data) { + if (!data.contains("samples")) { + throw std::runtime_error("expected 'samples' in JSON for neighbor list block, did not find it"); + } + auto samples_values = tensor_from_json(data["samples"]); + + auto names = std::vector({"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"}); + auto samples = torch::make_intrusive( + /*names=*/std::move(names), + /*values=*/std::move(samples_values) + ); + auto components = LabelsHolder::create({"xyz"}, {{0}, {1}, {2}}); + auto properties = LabelsHolder::create({"distance"}, {{0}}); + + if (!data.contains("values")) { + throw std::runtime_error("expected 'values' in JSON for neighbor list block, did not find it"); + } + auto values = tensor_from_json(data["values"]); + + return TensorBlockHolder( + /*values=*/values, + /*samples=*/samples, + /*components=*/{components}, + /*properties=*/properties + ); +} + std::string SystemHolder::to_json() const { nlohmann::json result; @@ -854,6 +895,16 @@ std::string SystemHolder::to_json() const { result["cell"] = tensor_to_json(this->cell()); result["types"] = tensor_to_json(this->types()); + result["neighbor_lists"] = std::vector(); + for (auto nl_option: this->known_neighbor_lists()) { + auto nl_data = this->get_neighbor_list(nl_option); + + auto nl_json = nlohmann::json(); + nl_json["options"] = neighbor_list_options_to_json(*nl_option); + nl_json["data"] = neighbor_list_block_to_json(*nl_data); + result["neighbor_lists"].emplace_back(nl_json); + } + return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); } @@ -888,5 +939,20 @@ System SystemHolder::from_json(const std::string_view json) { auto types = tensor_from_json(data["types"]); auto system = torch::make_intrusive(types, positions, cell); + + for (const auto& nl_data: data["neighbor_lists"]) { + if (!nl_data.contains("options") || !nl_data["options"].is_object()) { + throw std::runtime_error("expected 'options' in JSON for neighbor list, did not find it"); + } + auto options = NeighborListOptionsHolder::from_json(nl_data["options"].dump()); + + if (!nl_data.contains("data") || !nl_data["data"].is_object()) { + throw std::runtime_error("expected 'data' in JSON for neighbor list, did not find it"); + } + auto block = neighbor_list_block_from_json(nl_data["data"]); + + system->add_neighbor_list(options, torch::make_intrusive(std::move(block))); + } + return system; } diff --git a/python/metatensor-torch/tests/atomistic/systems.py b/python/metatensor-torch/tests/atomistic/systems.py index 4ab0f3900..dca3daa26 100644 --- a/python/metatensor-torch/tests/atomistic/systems.py +++ b/python/metatensor-torch/tests/atomistic/systems.py @@ -482,14 +482,43 @@ def check_dtype(system: System, dtype: torch.dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_save_load(tmpdir, dtype): +@pytest.mark.parametrize("nl_size", [10, 100, 1000]) +def test_save_load(tmpdir, dtype, nl_size): system = System( types=torch.tensor([1, 2, 3, 4]), positions=torch.rand((4, 3), dtype=dtype), cell=torch.rand((3, 3), dtype=dtype), ) + system.add_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False), + TensorBlock( + values=torch.rand(nl_size, 3, 1, dtype=dtype), + samples=Labels( + [ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + torch.tensor(torch.arange(nl_size * 5, dtype=torch.int64)).reshape( + nl_size, 5 + ), + ), + components=[Labels.range("xyz", 3)], + properties=Labels.range("distance", 1), + ), + ) + torch.save(system, os.path.join(tmpdir, "system.pt")) system_loaded = torch.load(os.path.join(tmpdir, "system.pt")) assert torch.equal(system.types, system_loaded.types) assert torch.equal(system.positions, system_loaded.positions) assert torch.equal(system.cell, system_loaded.cell) + neigbor_list = system.get_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False) + ) + neighbor_list_loaded = system_loaded.get_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False) + ) + assert metatensor.torch.equal_block(neigbor_list, neighbor_list_loaded)