diff --git a/metatensor-torch/include/metatensor/torch/atomistic/system.hpp b/metatensor-torch/include/metatensor/torch/atomistic/system.hpp index 0add9344c..ef490a94b 100644 --- a/metatensor-torch/include/metatensor/torch/atomistic/system.hpp +++ b/metatensor-torch/include/metatensor/torch/atomistic/system.hpp @@ -263,6 +263,12 @@ class METATENSOR_TORCH_EXPORT SystemHolder final: public torch::CustomClassHolde /// Implementation of `__str__` and `__repr__` for Python std::string str() const; + /// Serialize a `System` to a JSON string. + std::string to_json() const; + + /// Load a serialized `System` from a JSON string. + static System from_json(std::string_view json); + private: struct nl_options_compare { bool operator()(const NeighborListOptions& a, const NeighborListOptions& b) const { diff --git a/metatensor-torch/src/atomistic/system.cpp b/metatensor-torch/src/atomistic/system.cpp index 8f46d8ed8..6f08f31f0 100644 --- a/metatensor-torch/src/atomistic/system.cpp +++ b/metatensor-torch/src/atomistic/system.cpp @@ -77,7 +77,7 @@ std::string NeighborListOptionsHolder::str() const { ", full_list=" + (full_list_ ? "True" : "False") + ")"; } -std::string NeighborListOptionsHolder::to_json() const { +static nlohmann::json neighbor_list_options_to_json(const NeighborListOptionsHolder& self) { nlohmann::json result; result["class"] = "NeighborListOptions"; @@ -86,12 +86,17 @@ std::string NeighborListOptionsHolder::to_json() const { // round-tripping of the data static_assert(sizeof(double) == sizeof(int64_t)); int64_t int_cutoff = 0; - std::memcpy(&int_cutoff, &this->cutoff_, sizeof(double)); + double cutoff = self.cutoff(); + std::memcpy(&int_cutoff, &cutoff, sizeof(double)); result["cutoff"] = int_cutoff; - result["full_list"] = this->full_list_; - result["length_unit"] = this->length_unit_; + result["full_list"] = self.full_list(); + result["length_unit"] = self.length_unit(); - return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); + return result; +} + +std::string NeighborListOptionsHolder::to_json() const { + return neighbor_list_options_to_json(*this).dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); } NeighborListOptions NeighborListOptionsHolder::from_json(const std::string& json) { @@ -874,3 +879,174 @@ std::string SystemHolder::str() const { return result.str(); } + +template +nlohmann::json tensor_to_vector_string(const torch::Tensor& tensor) { + torch::Tensor contiguous_cpu_tensor = tensor.cpu().contiguous(); + scalar_t* pointer = contiguous_cpu_tensor.data_ptr(); + size_t size = static_cast(contiguous_cpu_tensor.numel()); + nlohmann::json result = std::vector(pointer, pointer + size); + return result; +} + +template +torch::Tensor tensor_from_vector_string(nlohmann::json data) { + if (!data.contains("dtype") || !data["dtype"].is_string()) { + throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it"); + } + auto dtype = data["dtype"].get(); + auto torch_dtype = scalar_type_from_name(dtype); + + if (!data.contains("sizes")) { + throw std::runtime_error("expected 'sizes' in JSON for tensor, did not find it"); + } + auto sizes = data["sizes"].get>(); + + if (!data.contains("values")) { + throw std::runtime_error("expected 'values' in JSON for tensor, did not find it"); + } + auto values = data["values"].get>(); + + auto options = torch::TensorOptions().dtype(torch_dtype); + auto tensor = torch::empty(sizes, options); + auto* tensor_data = tensor.data_ptr(); + std::copy(values.begin(), values.end(), tensor_data); + return tensor; +} + +nlohmann::json tensor_to_json(const torch::Tensor& tensor) { + nlohmann::json result; + result["dtype"] = scalar_type_name(tensor.scalar_type()); + result["sizes"] = tensor.sizes().vec(); + result["values"] = AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "tensor_to_vector", ([&] { + return tensor_to_vector_string(tensor); + })); + return result; +} + +torch::Tensor tensor_from_json(const nlohmann::json& data) { + if (!data.contains("dtype") || !data["dtype"].is_string()) { + throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it"); + } + auto dtype = data["dtype"].get(); + auto torch_dtype = scalar_type_from_name(dtype); + + return AT_DISPATCH_ALL_TYPES(torch_dtype, "tensor_from_vector", ([&] { + return tensor_from_vector_string(data); + })); +} + +nlohmann::json neighbor_list_block_to_json(const TensorBlockHolder& block) { + // Specific to NLs. This means that we don't have to do gradients and + // worry save/load metadata outside of the samples. + nlohmann::json result; + result["samples"] = tensor_to_json(block.samples()->values()); + result["values"] = tensor_to_json(block.values()); + return result; +} + +TensorBlockHolder neighbor_list_block_from_json(const nlohmann::json& data) { + if (!data.contains("samples")) { + throw std::runtime_error("expected 'samples' in JSON for neighbor list block, did not find it"); + } + auto samples_values = tensor_from_json(data["samples"]); + + auto names = std::vector({"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"}); + auto samples = torch::make_intrusive( + /*names=*/std::move(names), + /*values=*/std::move(samples_values) + ); + auto components = LabelsHolder::create({"xyz"}, {{0}, {1}, {2}}); + auto properties = LabelsHolder::create({"distance"}, {{0}}); + + if (!data.contains("values")) { + throw std::runtime_error("expected 'values' in JSON for neighbor list block, did not find it"); + } + auto values = tensor_from_json(data["values"]); + + return TensorBlockHolder( + /*values=*/values, + /*samples=*/samples, + /*components=*/{components}, + /*properties=*/properties + ); +} + +std::string SystemHolder::to_json() const { + nlohmann::json result; + + result["class"] = "System"; + result["positions"] = tensor_to_json(this->positions()); + result["cell"] = tensor_to_json(this->cell()); + result["types"] = tensor_to_json(this->types()); + + // torch doesn't dispatch bools with AT_DISPATCH_ALL_TYPES, use trick + result["pbc"] = tensor_to_json(this->pbc().to(torch::kInt32)); + + result["neighbor_lists"] = std::vector(); + for (auto nl_option: this->known_neighbor_lists()) { + auto nl_data = this->get_neighbor_list(nl_option); + + auto nl_json = nlohmann::json(); + nl_json["options"] = neighbor_list_options_to_json(*nl_option); + nl_json["data"] = neighbor_list_block_to_json(*nl_data); + result["neighbor_lists"].emplace_back(nl_json); + } + + return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); +} + +System SystemHolder::from_json(const std::string_view json) { + auto data = nlohmann::json::parse(json); + + if (!data.is_object()) { + throw std::runtime_error("invalid JSON data for System, expected an object"); + } + + if (!data.contains("class") || !data["class"].is_string()) { + throw std::runtime_error("expected 'class' in JSON for System, did not find it"); + } + + if (data["class"] != "System") { + throw std::runtime_error("'class' in JSON for System must be 'System'"); + } + + if (!data.contains("positions")) { + throw std::runtime_error("expected 'positions' in JSON for System, did not find it"); + } + auto positions = tensor_from_json(data["positions"]); + + if (!data.contains("cell")) { + throw std::runtime_error("expected 'cell' in JSON for System, did not find it"); + } + auto cell = tensor_from_json(data["cell"]); + + if (!data.contains("types")) { + throw std::runtime_error("expected 'types' in JSON for System, did not find it"); + } + auto types = tensor_from_json(data["types"]); + + if (!data.contains("pbc")) { + throw std::runtime_error("expected 'pbc' in JSON for System, did not find it"); + } + // undo the bool->int trick (see to_json) + auto pbc = tensor_from_json(data["pbc"]).to(torch::kBool); + + auto system = torch::make_intrusive(types, positions, cell, pbc); + + for (const auto& nl_data: data["neighbor_lists"]) { + if (!nl_data.contains("options") || !nl_data["options"].is_object()) { + throw std::runtime_error("expected 'options' in JSON for neighbor list, did not find it"); + } + auto options = NeighborListOptionsHolder::from_json(nl_data["options"].dump()); + + if (!nl_data.contains("data") || !nl_data["data"].is_object()) { + throw std::runtime_error("expected 'data' in JSON for neighbor list, did not find it"); + } + auto block = neighbor_list_block_from_json(nl_data["data"]); + + system->add_neighbor_list(options, torch::make_intrusive(std::move(block))); + } + + return system; +} diff --git a/metatensor-torch/src/internal/utils.hpp b/metatensor-torch/src/internal/utils.hpp index 7afb36e98..1137fa750 100644 --- a/metatensor-torch/src/internal/utils.hpp +++ b/metatensor-torch/src/internal/utils.hpp @@ -39,6 +39,38 @@ inline std::string scalar_type_name(torch::ScalarType scalar_type) { } } +/// Convert a string from `scalar_type_name` to a `torch::ScalarType` +inline torch::ScalarType scalar_type_from_name(std::string name) { + if (name == "torch.int8") { + return torch::ScalarType::Char; + } else if (name == "torch.int16") { + return torch::ScalarType::Short; + } else if (name == "torch.int32") { + return torch::ScalarType::Int; + } else if (name == "torch.int64") { + return torch::ScalarType::Long; + } else if (name == "torch.float16") { + return torch::ScalarType::Half; + } else if (name == "torch.float32") { + return torch::ScalarType::Float; + } else if (name == "torch.float64") { + return torch::ScalarType::Double; + } else if (name == "torch.complex32") { + return torch::ScalarType::ComplexHalf; + } else if (name == "torch.complex64") { + return torch::ScalarType::ComplexFloat; + } else if (name == "torch.complex128") { + return torch::ScalarType::ComplexDouble; + } else if (name == "torch.bool") { + return torch::ScalarType::Bool; + } else { + throw std::runtime_error( + "Found unknown scalar type name `" + name + "` while " + + "converting it to a torch scalar type." + ); + } +} + /// Parse the arguments to the `to` function inline std::tuple, torch::optional> to_arguments_parse( diff --git a/metatensor-torch/src/register.cpp b/metatensor-torch/src/register.cpp index 4a848ed49..792845296 100644 --- a/metatensor-torch/src/register.cpp +++ b/metatensor-torch/src/register.cpp @@ -388,7 +388,14 @@ TORCH_LIBRARY(metatensor, m) { {torch::arg("name")} ) .def("known_data", &SystemHolder::known_data) - ; + .def_pickle( + [](const System& self) -> std::string { + return self->to_json(); + }, + [](const std::string& data) -> System { + return SystemHolder::from_json(data); + } + ); m.class_("ModelMetadata") diff --git a/python/metatensor-torch/tests/atomistic/systems.py b/python/metatensor-torch/tests/atomistic/systems.py index 573dd6401..a912bf1ba 100644 --- a/python/metatensor-torch/tests/atomistic/systems.py +++ b/python/metatensor-torch/tests/atomistic/systems.py @@ -1,3 +1,5 @@ +import os + import pytest import torch from packaging import version @@ -518,6 +520,55 @@ def check_dtype(system: System, dtype: torch.dtype): assert system.dtype == dtype +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("nl_size", [10, 100, 1000]) +@pytest.mark.parametrize( + "pbc", [[True, True, True], [True, False, True], [False, False, False]] +) +def test_save_load(tmpdir, dtype, nl_size, pbc): + cell = torch.rand((3, 3), dtype=dtype) + cell[[not periodic for periodic in pbc]] = 0.0 + + system = System( + types=torch.tensor([1, 2, 3, 4]), + positions=torch.rand((4, 3), dtype=dtype), + cell=cell, + pbc=torch.tensor(pbc, dtype=torch.bool), + ) + system.add_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False), + TensorBlock( + values=torch.rand(nl_size, 3, 1, dtype=dtype), + samples=Labels( + [ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + torch.arange(nl_size * 5, dtype=torch.int64).reshape(nl_size, 5), + ), + components=[Labels.range("xyz", 3)], + properties=Labels.range("distance", 1), + ), + ) + + torch.save(system, os.path.join(tmpdir, "system.pt")) + system_loaded = torch.load(os.path.join(tmpdir, "system.pt")) + assert torch.equal(system.types, system_loaded.types) + assert torch.equal(system.positions, system_loaded.positions) + assert torch.equal(system.cell, system_loaded.cell) + assert torch.equal(system.pbc, system_loaded.pbc) + neigbor_list = system.get_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False) + ) + neighbor_list_loaded = system_loaded.get_neighbor_list( + NeighborListOptions(cutoff=3.5, full_list=False) + ) + assert metatensor.torch.equal_block(neigbor_list, neighbor_list_loaded) + + def test_partial_pbc(): system = System( torch.tensor([1]),