From f41f4e3395da44cfd0520b99f50ed7d315b3cf94 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Jul 2024 11:25:29 +0200 Subject: [PATCH] Serialization for `System` (without neighborlist) --- .../metatensor/torch/atomistic/system.hpp | 6 ++ metatensor-torch/src/atomistic/system.cpp | 101 ++++++++++++++++++ metatensor-torch/src/internal/utils.hpp | 32 ++++++ metatensor-torch/src/register.cpp | 9 +- .../tests/atomistic/systems.py | 16 +++ 5 files changed, 163 insertions(+), 1 deletion(-) diff --git a/metatensor-torch/include/metatensor/torch/atomistic/system.hpp b/metatensor-torch/include/metatensor/torch/atomistic/system.hpp index 04ba891ba..fe302b8ac 100644 --- a/metatensor-torch/include/metatensor/torch/atomistic/system.hpp +++ b/metatensor-torch/include/metatensor/torch/atomistic/system.hpp @@ -255,6 +255,12 @@ class METATENSOR_TORCH_EXPORT SystemHolder final: public torch::CustomClassHolde /// Implementation of `__str__` and `__repr__` for Python std::string str() const; + /// Serialize a `System` to a JSON string. + std::string to_json() const; + + /// Load a serialized `System` from a JSON string. + static System from_json(std::string_view json); + private: struct nl_options_compare { bool operator()(const NeighborListOptions& a, const NeighborListOptions& b) const { diff --git a/metatensor-torch/src/atomistic/system.cpp b/metatensor-torch/src/atomistic/system.cpp index 04523c1cc..5bd415bce 100644 --- a/metatensor-torch/src/atomistic/system.cpp +++ b/metatensor-torch/src/atomistic/system.cpp @@ -789,3 +789,104 @@ std::string SystemHolder::str() const { return result.str(); } + +template +nlohmann::json tensor_to_vector_string(const torch::Tensor& tensor) { + torch::Tensor contiguous_cpu_tensor = tensor.cpu().contiguous(); + scalar_t* pointer = contiguous_cpu_tensor.data(); + size_t size = static_cast(contiguous_cpu_tensor.numel()); + nlohmann::json result = std::vector(pointer, pointer + size); + return result; +} + +template +torch::Tensor tensor_from_vector_string(nlohmann::json data) { + if (!data.contains("dtype") || !data["dtype"].is_string()) { + throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it"); + } + auto dtype = data["dtype"].get(); + auto torch_dtype = scalar_type_from_name(dtype); + + if (!data.contains("sizes")) { + throw std::runtime_error("expected 'sizes' in JSON for tensor, did not find it"); + } + auto sizes = data["sizes"].get>(); + + if (!data.contains("values")) { + throw std::runtime_error("expected 'values' in JSON for tensor, did not find it"); + } + auto values = data["values"].get>(); + + auto options = torch::TensorOptions().dtype(torch_dtype); + auto tensor = torch::empty(sizes, options); + auto* tensor_data = tensor.data(); + std::copy(values.begin(), values.end(), tensor_data); + return tensor; +} + +nlohmann::json tensor_to_json(const torch::Tensor& tensor) { + nlohmann::json result; + result["dtype"] = scalar_type_name(tensor.scalar_type()); + result["sizes"] = tensor.sizes().vec(); + result["values"] = AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "tensor_to_vector", ([&] { + return tensor_to_vector_string(tensor); + })); + return result; +} + +torch::Tensor tensor_from_json(const nlohmann::json& data) { + if (!data.contains("dtype") || !data["dtype"].is_string()) { + throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it"); + } + auto dtype = data["dtype"].get(); + auto torch_dtype = scalar_type_from_name(dtype); + + return AT_DISPATCH_ALL_TYPES(torch_dtype, "tensor_from_vector", ([&] { + return tensor_from_vector_string(data); + })); +} + +std::string SystemHolder::to_json() const { + nlohmann::json result; + + result["class"] = "System"; + result["positions"] = tensor_to_json(this->positions()); + result["cell"] = tensor_to_json(this->cell()); + result["types"] = tensor_to_json(this->types()); + + return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true); +} + +System SystemHolder::from_json(const std::string_view json) { + auto data = nlohmann::json::parse(json); + + if (!data.is_object()) { + throw std::runtime_error("invalid JSON data for System, expected an object"); + } + + if (!data.contains("class") || !data["class"].is_string()) { + throw std::runtime_error("expected 'class' in JSON for System, did not find it"); + } + + if (data["class"] != "System") { + throw std::runtime_error("'class' in JSON for System must be 'System'"); + } + + if (!data.contains("positions")) { + throw std::runtime_error("expected 'positions' in JSON for System, did not find it"); + } + auto positions = tensor_from_json(data["positions"]); + + if (!data.contains("cell")) { + throw std::runtime_error("expected 'cell' in JSON for System, did not find it"); + } + auto cell = tensor_from_json(data["cell"]); + + if (!data.contains("types")) { + throw std::runtime_error("expected 'types' in JSON for System, did not find it"); + } + auto types = tensor_from_json(data["types"]); + + auto system = torch::make_intrusive(types, positions, cell); + return system; +} diff --git a/metatensor-torch/src/internal/utils.hpp b/metatensor-torch/src/internal/utils.hpp index 7afb36e98..1137fa750 100644 --- a/metatensor-torch/src/internal/utils.hpp +++ b/metatensor-torch/src/internal/utils.hpp @@ -39,6 +39,38 @@ inline std::string scalar_type_name(torch::ScalarType scalar_type) { } } +/// Convert a string from `scalar_type_name` to a `torch::ScalarType` +inline torch::ScalarType scalar_type_from_name(std::string name) { + if (name == "torch.int8") { + return torch::ScalarType::Char; + } else if (name == "torch.int16") { + return torch::ScalarType::Short; + } else if (name == "torch.int32") { + return torch::ScalarType::Int; + } else if (name == "torch.int64") { + return torch::ScalarType::Long; + } else if (name == "torch.float16") { + return torch::ScalarType::Half; + } else if (name == "torch.float32") { + return torch::ScalarType::Float; + } else if (name == "torch.float64") { + return torch::ScalarType::Double; + } else if (name == "torch.complex32") { + return torch::ScalarType::ComplexHalf; + } else if (name == "torch.complex64") { + return torch::ScalarType::ComplexFloat; + } else if (name == "torch.complex128") { + return torch::ScalarType::ComplexDouble; + } else if (name == "torch.bool") { + return torch::ScalarType::Bool; + } else { + throw std::runtime_error( + "Found unknown scalar type name `" + name + "` while " + + "converting it to a torch scalar type." + ); + } +} + /// Parse the arguments to the `to` function inline std::tuple, torch::optional> to_arguments_parse( diff --git a/metatensor-torch/src/register.cpp b/metatensor-torch/src/register.cpp index 461dbda40..11996d4a8 100644 --- a/metatensor-torch/src/register.cpp +++ b/metatensor-torch/src/register.cpp @@ -362,7 +362,14 @@ TORCH_LIBRARY(metatensor, m) { {torch::arg("name")} ) .def("known_data", &SystemHolder::known_data) - ; + .def_pickle( + [](const System& self) -> std::string { + return self->to_json(); + }, + [](const std::string& data) -> System { + return SystemHolder::from_json(data); + } + ); m.class_("ModelMetadata") diff --git a/python/metatensor-torch/tests/atomistic/systems.py b/python/metatensor-torch/tests/atomistic/systems.py index c5155c92b..4ab0f3900 100644 --- a/python/metatensor-torch/tests/atomistic/systems.py +++ b/python/metatensor-torch/tests/atomistic/systems.py @@ -1,3 +1,5 @@ +import os + import pytest import torch from packaging import version @@ -477,3 +479,17 @@ def test_to(system, neighbors): @torch.jit.script def check_dtype(system: System, dtype: torch.dtype): assert system.dtype == dtype + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_save_load(tmpdir, dtype): + system = System( + types=torch.tensor([1, 2, 3, 4]), + positions=torch.rand((4, 3), dtype=dtype), + cell=torch.rand((3, 3), dtype=dtype), + ) + torch.save(system, os.path.join(tmpdir, "system.pt")) + system_loaded = torch.load(os.path.join(tmpdir, "system.pt")) + assert torch.equal(system.types, system_loaded.types) + assert torch.equal(system.positions, system_loaded.positions) + assert torch.equal(system.cell, system_loaded.cell)