Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization for System #673

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ class METATENSOR_TORCH_EXPORT SystemHolder final: public torch::CustomClassHolde
/// Implementation of `__str__` and `__repr__` for Python
std::string str() const;

/// Serialize a `System` to a JSON string.
std::string to_json() const;

/// Load a serialized `System` from a JSON string.
static System from_json(std::string_view json);

private:
struct nl_options_compare {
bool operator()(const NeighborListOptions& a, const NeighborListOptions& b) const {
Expand Down
177 changes: 172 additions & 5 deletions metatensor-torch/src/atomistic/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::string NeighborListOptionsHolder::str() const {
", full_list=" + (full_list_ ? "True" : "False") + ")";
}

std::string NeighborListOptionsHolder::to_json() const {
static nlohmann::json neighbor_list_options_to_json(const NeighborListOptionsHolder& self) {
nlohmann::json result;

result["class"] = "NeighborListOptions";
Expand All @@ -86,12 +86,17 @@ std::string NeighborListOptionsHolder::to_json() const {
// round-tripping of the data
static_assert(sizeof(double) == sizeof(int64_t));
int64_t int_cutoff = 0;
std::memcpy(&int_cutoff, &this->cutoff_, sizeof(double));
double cutoff = self.cutoff();
std::memcpy(&int_cutoff, &cutoff, sizeof(double));
result["cutoff"] = int_cutoff;
result["full_list"] = this->full_list_;
result["length_unit"] = this->length_unit_;
result["full_list"] = self.full_list();
result["length_unit"] = self.length_unit();

return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
return result;
}

std::string NeighborListOptionsHolder::to_json() const {
return neighbor_list_options_to_json(*this).dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
}

NeighborListOptions NeighborListOptionsHolder::from_json(const std::string& json) {
Expand Down Expand Up @@ -788,3 +793,165 @@ std::string SystemHolder::str() const {

return result.str();
}

template<typename scalar_t>
nlohmann::json tensor_to_vector_string(const torch::Tensor& tensor) {
torch::Tensor contiguous_cpu_tensor = tensor.cpu().contiguous();
scalar_t* pointer = contiguous_cpu_tensor.data<scalar_t>();
size_t size = static_cast<size_t>(contiguous_cpu_tensor.numel());
nlohmann::json result = std::vector<scalar_t>(pointer, pointer + size);
return result;
}

template<typename scalar_t>
torch::Tensor tensor_from_vector_string(nlohmann::json data) {
if (!data.contains("dtype") || !data["dtype"].is_string()) {
throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it");
}
auto dtype = data["dtype"].get<std::string>();
auto torch_dtype = scalar_type_from_name(dtype);

if (!data.contains("sizes")) {
throw std::runtime_error("expected 'sizes' in JSON for tensor, did not find it");
}
auto sizes = data["sizes"].get<std::vector<int64_t>>();

if (!data.contains("values")) {
throw std::runtime_error("expected 'values' in JSON for tensor, did not find it");
}
auto values = data["values"].get<std::vector<scalar_t>>();

auto options = torch::TensorOptions().dtype(torch_dtype);
auto tensor = torch::empty(sizes, options);
auto* tensor_data = tensor.data<scalar_t>();
std::copy(values.begin(), values.end(), tensor_data);
return tensor;
}

nlohmann::json tensor_to_json(const torch::Tensor& tensor) {
nlohmann::json result;
result["dtype"] = scalar_type_name(tensor.scalar_type());
result["sizes"] = tensor.sizes().vec();
result["values"] = AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "tensor_to_vector", ([&] {
return tensor_to_vector_string<scalar_t>(tensor);
}));
return result;
}

torch::Tensor tensor_from_json(const nlohmann::json& data) {
if (!data.contains("dtype") || !data["dtype"].is_string()) {
throw std::runtime_error("expected 'dtype' in JSON for tensor, did not find it");
}
auto dtype = data["dtype"].get<std::string>();
auto torch_dtype = scalar_type_from_name(dtype);

return AT_DISPATCH_ALL_TYPES(torch_dtype, "tensor_from_vector", ([&] {
return tensor_from_vector_string<scalar_t>(data);
}));
}

nlohmann::json neighbor_list_block_to_json(const TensorBlockHolder& block) {
// Specific to NLs. This means that we don't have to do gradients and
// worry save/load metadata outside of the samples.
nlohmann::json result;
result["samples"] = tensor_to_json(block.samples()->values());
result["values"] = tensor_to_json(block.values());
return result;
}

TensorBlockHolder neighbor_list_block_from_json(const nlohmann::json& data) {
if (!data.contains("samples")) {
throw std::runtime_error("expected 'samples' in JSON for neighbor list block, did not find it");
}
auto samples_values = tensor_from_json(data["samples"]);

auto names = std::vector<std::string>({"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"});
auto samples = torch::make_intrusive<LabelsHolder>(
/*names=*/std::move(names),
/*values=*/std::move(samples_values)
);
auto components = LabelsHolder::create({"xyz"}, {{0}, {1}, {2}});
auto properties = LabelsHolder::create({"distance"}, {{0}});

if (!data.contains("values")) {
throw std::runtime_error("expected 'values' in JSON for neighbor list block, did not find it");
}
auto values = tensor_from_json(data["values"]);

return TensorBlockHolder(
/*values=*/values,
/*samples=*/samples,
/*components=*/{components},
/*properties=*/properties
);
}

std::string SystemHolder::to_json() const {
nlohmann::json result;

result["class"] = "System";
result["positions"] = tensor_to_json(this->positions());
result["cell"] = tensor_to_json(this->cell());
result["types"] = tensor_to_json(this->types());

result["neighbor_lists"] = std::vector<nlohmann::json>();
for (auto nl_option: this->known_neighbor_lists()) {
auto nl_data = this->get_neighbor_list(nl_option);

auto nl_json = nlohmann::json();
nl_json["options"] = neighbor_list_options_to_json(*nl_option);
nl_json["data"] = neighbor_list_block_to_json(*nl_data);
result["neighbor_lists"].emplace_back(nl_json);
}

return result.dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
}

System SystemHolder::from_json(const std::string_view json) {
auto data = nlohmann::json::parse(json);

if (!data.is_object()) {
throw std::runtime_error("invalid JSON data for System, expected an object");
}

if (!data.contains("class") || !data["class"].is_string()) {
throw std::runtime_error("expected 'class' in JSON for System, did not find it");
}

if (data["class"] != "System") {
throw std::runtime_error("'class' in JSON for System must be 'System'");
}

if (!data.contains("positions")) {
throw std::runtime_error("expected 'positions' in JSON for System, did not find it");
}
auto positions = tensor_from_json(data["positions"]);

if (!data.contains("cell")) {
throw std::runtime_error("expected 'cell' in JSON for System, did not find it");
}
auto cell = tensor_from_json(data["cell"]);

if (!data.contains("types")) {
throw std::runtime_error("expected 'types' in JSON for System, did not find it");
}
auto types = tensor_from_json(data["types"]);

auto system = torch::make_intrusive<SystemHolder>(types, positions, cell);

for (const auto& nl_data: data["neighbor_lists"]) {
if (!nl_data.contains("options") || !nl_data["options"].is_object()) {
throw std::runtime_error("expected 'options' in JSON for neighbor list, did not find it");
}
auto options = NeighborListOptionsHolder::from_json(nl_data["options"].dump());

if (!nl_data.contains("data") || !nl_data["data"].is_object()) {
throw std::runtime_error("expected 'data' in JSON for neighbor list, did not find it");
}
auto block = neighbor_list_block_from_json(nl_data["data"]);

system->add_neighbor_list(options, torch::make_intrusive<TensorBlockHolder>(std::move(block)));
}

return system;
}
32 changes: 32 additions & 0 deletions metatensor-torch/src/internal/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,38 @@ inline std::string scalar_type_name(torch::ScalarType scalar_type) {
}
}

/// Convert a string from `scalar_type_name` to a `torch::ScalarType`
inline torch::ScalarType scalar_type_from_name(std::string name) {
if (name == "torch.int8") {
return torch::ScalarType::Char;
} else if (name == "torch.int16") {
return torch::ScalarType::Short;
} else if (name == "torch.int32") {
return torch::ScalarType::Int;
} else if (name == "torch.int64") {
return torch::ScalarType::Long;
} else if (name == "torch.float16") {
return torch::ScalarType::Half;
} else if (name == "torch.float32") {
return torch::ScalarType::Float;
} else if (name == "torch.float64") {
return torch::ScalarType::Double;
} else if (name == "torch.complex32") {
return torch::ScalarType::ComplexHalf;
} else if (name == "torch.complex64") {
return torch::ScalarType::ComplexFloat;
} else if (name == "torch.complex128") {
return torch::ScalarType::ComplexDouble;
} else if (name == "torch.bool") {
return torch::ScalarType::Bool;
} else {
throw std::runtime_error(
"Found unknown scalar type name `" + name + "` while " +
"converting it to a torch scalar type."
);
}
}

/// Parse the arguments to the `to` function
inline std::tuple<torch::optional<torch::Dtype>, torch::optional<torch::Device>>
to_arguments_parse(
Expand Down
9 changes: 8 additions & 1 deletion metatensor-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,14 @@ TORCH_LIBRARY(metatensor, m) {
{torch::arg("name")}
)
.def("known_data", &SystemHolder::known_data)
;
.def_pickle(
[](const System& self) -> std::string {
return self->to_json();
},
[](const std::string& data) -> System {
return SystemHolder::from_json(data);
}
);


m.class_<ModelMetadataHolder>("ModelMetadata")
Expand Down
43 changes: 43 additions & 0 deletions python/metatensor-torch/tests/atomistic/systems.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
import torch
from packaging import version
Expand Down Expand Up @@ -477,3 +479,44 @@ def test_to(system, neighbors):
@torch.jit.script
def check_dtype(system: System, dtype: torch.dtype):
assert system.dtype == dtype


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("nl_size", [10, 100, 1000])
def test_save_load(tmpdir, dtype, nl_size):
system = System(
types=torch.tensor([1, 2, 3, 4]),
positions=torch.rand((4, 3), dtype=dtype),
cell=torch.rand((3, 3), dtype=dtype),
)
system.add_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False),
TensorBlock(
values=torch.rand(nl_size, 3, 1, dtype=dtype),
samples=Labels(
[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
torch.arange(nl_size * 5, dtype=torch.int64).reshape(nl_size, 5),
),
components=[Labels.range("xyz", 3)],
properties=Labels.range("distance", 1),
),
)

torch.save(system, os.path.join(tmpdir, "system.pt"))
system_loaded = torch.load(os.path.join(tmpdir, "system.pt"))
assert torch.equal(system.types, system_loaded.types)
assert torch.equal(system.positions, system_loaded.positions)
assert torch.equal(system.cell, system_loaded.cell)
neigbor_list = system.get_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False)
)
neighbor_list_loaded = system_loaded.get_neighbor_list(
NeighborListOptions(cutoff=3.5, full_list=False)
)
assert metatensor.torch.equal_block(neigbor_list, neighbor_list_loaded)
Loading