Skip to content

Commit

Permalink
Serialization for System (without neighborlist)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 4, 2024
1 parent 2783dd8 commit f41f4e3
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ class METATENSOR_TORCH_EXPORT SystemHolder final: public torch::CustomClassHolde
/// Implementation of `__str__` and `__repr__` for Python
std::string str() const;

/// Serialize a `System` to a JSON string.
std::string to_json() const;

/// Load a serialized `System` from a JSON string.
static System from_json(std::string_view json);

private:
struct nl_options_compare {
bool operator()(const NeighborListOptions& a, const NeighborListOptions& b) const {
Expand Down
101 changes: 101 additions & 0 deletions metatensor-torch/src/atomistic/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,104 @@ std::string SystemHolder::str() const {

return result.str();
}

template<typename scalar_t>
nlohmann::json tensor_to_vector_string(const torch::Tensor& tensor) {
torch::Tensor contiguous_cpu_tensor = tensor.cpu().contiguous();
scalar_t* pointer = contiguous_cpu_tensor.data<scalar_t>();
size_t size = static_cast<size_t>(contiguous_cpu_tensor.numel());
nlohmann::json result = std::vector<scalar_t>(pointer, pointer + size);
return result;
}

template<typename scalar_t>
torch::Tensor tensor_from_vector_string(nlohmann::json data) {
if (!data.contains("dtype") || !data["dtype"].is_string()) {
throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it");
}
auto dtype = data["dtype"].get<std::string>();
auto torch_dtype = scalar_type_from_name(dtype);

if (!data.contains("sizes")) {
throw std::runtime_error("expected 'sizes' in JSON for tensor, did not find it");
}
auto sizes = data["sizes"].get<std::vector<int64_t>>();

if (!data.contains("values")) {
throw std::runtime_error("expected 'values' in JSON for tensor, did not find it");
}
auto values = data["values"].get<std::vector<scalar_t>>();

auto options = torch::TensorOptions().dtype(torch_dtype);
auto tensor = torch::empty(sizes, options);
auto* tensor_data = tensor.data<scalar_t>();
std::copy(values.begin(), values.end(), tensor_data);
return tensor;
}

nlohmann::json tensor_to_json(const torch::Tensor& tensor) {
nlohmann::json result;
result["dtype"] = scalar_type_name(tensor.scalar_type());
result["sizes"] = tensor.sizes().vec();
result["values"] = AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "tensor_to_vector", ([&] {
return tensor_to_vector_string<scalar_t>(tensor);
}));
return result;
}

torch::Tensor tensor_from_json(const nlohmann::json& data) {
if (!data.contains("dtype") || !data["dtype"].is_string()) {
throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it");
}
auto dtype = data["dtype"].get<std::string>();
auto torch_dtype = scalar_type_from_name(dtype);

return AT_DISPATCH_ALL_TYPES(torch_dtype, "tensor_from_vector", ([&] {
return tensor_from_vector_string<scalar_t>(data);
}));
}

std::string SystemHolder::to_json() const {
nlohmann::json result;

result["class"] = "System";
result["positions"] = tensor_to_json(this->positions());
result["cell"] = tensor_to_json(this->cell());
result["types"] = tensor_to_json(this->types());

return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
}

System SystemHolder::from_json(const std::string_view json) {
auto data = nlohmann::json::parse(json);

if (!data.is_object()) {
throw std::runtime_error("invalid JSON data for System, expected an object");
}

if (!data.contains("class") || !data["class"].is_string()) {
throw std::runtime_error("expected 'class' in JSON for System, did not find it");
}

if (data["class"] != "System") {
throw std::runtime_error("'class' in JSON for System must be 'System'");
}

if (!data.contains("positions")) {
throw std::runtime_error("expected 'positions' in JSON for System, did not find it");
}
auto positions = tensor_from_json(data["positions"]);

if (!data.contains("cell")) {
throw std::runtime_error("expected 'cell' in JSON for System, did not find it");
}
auto cell = tensor_from_json(data["cell"]);

if (!data.contains("types")) {
throw std::runtime_error("expected 'types' in JSON for System, did not find it");
}
auto types = tensor_from_json(data["types"]);

auto system = torch::make_intrusive<SystemHolder>(types, positions, cell);
return system;
}
32 changes: 32 additions & 0 deletions metatensor-torch/src/internal/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,38 @@ inline std::string scalar_type_name(torch::ScalarType scalar_type) {
}
}

/// Convert a string from `scalar_type_name` to a `torch::ScalarType`
inline torch::ScalarType scalar_type_from_name(std::string name) {
if (name == "torch.int8") {
return torch::ScalarType::Char;
} else if (name == "torch.int16") {
return torch::ScalarType::Short;
} else if (name == "torch.int32") {
return torch::ScalarType::Int;
} else if (name == "torch.int64") {
return torch::ScalarType::Long;
} else if (name == "torch.float16") {
return torch::ScalarType::Half;
} else if (name == "torch.float32") {
return torch::ScalarType::Float;
} else if (name == "torch.float64") {
return torch::ScalarType::Double;
} else if (name == "torch.complex32") {
return torch::ScalarType::ComplexHalf;
} else if (name == "torch.complex64") {
return torch::ScalarType::ComplexFloat;
} else if (name == "torch.complex128") {
return torch::ScalarType::ComplexDouble;
} else if (name == "torch.bool") {
return torch::ScalarType::Bool;
} else {
throw std::runtime_error(
"Found unknown scalar type name `" + name + "` while " +
"converting it to a torch scalar type."
);
}
}

/// Parse the arguments to the `to` function
inline std::tuple<torch::optional<torch::Dtype>, torch::optional<torch::Device>>
to_arguments_parse(
Expand Down
9 changes: 8 additions & 1 deletion metatensor-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,14 @@ TORCH_LIBRARY(metatensor, m) {
{torch::arg("name")}
)
.def("known_data", &SystemHolder::known_data)
;
.def_pickle(
[](const System& self) -> std::string {
return self->to_json();
},
[](const std::string& data) -> System {
return SystemHolder::from_json(data);
}
);


m.class_<ModelMetadataHolder>("ModelMetadata")
Expand Down
16 changes: 16 additions & 0 deletions python/metatensor-torch/tests/atomistic/systems.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
import torch
from packaging import version
Expand Down Expand Up @@ -477,3 +479,17 @@ def test_to(system, neighbors):
@torch.jit.script
def check_dtype(system: System, dtype: torch.dtype):
assert system.dtype == dtype


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_save_load(tmpdir, dtype):
system = System(
types=torch.tensor([1, 2, 3, 4]),
positions=torch.rand((4, 3), dtype=dtype),
cell=torch.rand((3, 3), dtype=dtype),
)
torch.save(system, os.path.join(tmpdir, "system.pt"))
system_loaded = torch.load(os.path.join(tmpdir, "system.pt"))
assert torch.equal(system.types, system_loaded.types)
assert torch.equal(system.positions, system_loaded.positions)
assert torch.equal(system.cell, system_loaded.cell)

0 comments on commit f41f4e3

Please sign in to comment.