Skip to content

Commit

Permalink
feat: add serde for embedding nx iree tensors (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Aug 19, 2024
1 parent fe2f093 commit 0392d0e
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 5 deletions.
42 changes: 42 additions & 0 deletions c_src/nx_iree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,46 @@ DECLARE_NIF(allocate_buffer) {
return ok(env, make<iree::runtime::IREETensor*>(env, input));
}

DECLARE_NIF(serialize_tensor) {
if (argc != 1) {
return error(env, "invalid number of arguments");
}

iree::runtime::IREETensor** input;

if (!get<iree::runtime::IREETensor*>(env, argv[0], input)) {
return error(env, "invalid input");
}

std::vector<char>* serialized = (*input)->serialize();

ErlNifBinary binary;

if (!enif_alloc_binary(serialized->size(), &binary)) {
return error(env, "unable to allocate binary");
}

std::memcpy(binary.data, serialized->data(), serialized->size());

return ok(env, enif_make_binary(env, &binary));
}

DECLARE_NIF(deserialize_tensor) {
if (argc != 1) {
return error(env, "invalid number of arguments");
}

ErlNifBinary input;

if (!enif_inspect_binary(env, argv[0], &input)) {
return error(env, "invalid input");
}

auto tensor = new iree::runtime::IREETensor(reinterpret_cast<char*>(input.data));

return ok(env, make<iree::runtime::IREETensor*>(env, tensor));
}

DECLARE_NIF(call_nif) {
iree_vm_instance_t** instance;
iree_hal_device_t** device;
Expand Down Expand Up @@ -463,6 +503,8 @@ static ErlNifFunc funcs[] = {
{"list_devices", 2, list_devices},
{"list_drivers", 1, list_drivers},
{"allocate_buffer", 4, allocate_buffer},
{"serialize_tensor", 1, serialize_tensor},
{"deserialize_tensor", 1, deserialize_tensor},
{"read_buffer", 3, read_buffer_nif},
{"call_io", 5, call_nif, ERL_NIF_DIRTY_JOB_IO_BOUND},
{"call_cpu", 5, call_nif, ERL_NIF_DIRTY_JOB_CPU_BOUND}};
Expand Down
50 changes: 50 additions & 0 deletions cmake/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree::runtime::Device::~Device() {

iree::runtime::IREETensor::IREETensor(iree_hal_buffer_view_t *buffer_view, iree_hal_element_type_t type) : buffer_view(buffer_view), type(type) {
size = iree_hal_buffer_view_byte_length(buffer_view);
// TODO: fill in dim metadata
}

iree::runtime::IREETensor::IREETensor(void *data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type) : size(size), type(type) {
Expand All @@ -43,6 +44,55 @@ iree::runtime::IREETensor::IREETensor(void *data, size_t size, std::vector<int64
this->buffer_view = nullptr;
}

iree::runtime::IREETensor::IREETensor(char *buffer) {
size_t offset = 0;

// Deserialize 'type'
std::memcpy(&type, buffer + offset, sizeof(type));
offset += sizeof(type);

// Deserialize 'size'
std::memcpy(&size, buffer + offset, sizeof(size));
offset += sizeof(size);

// Allocate memory and deserialize 'data'
data = operator new(size); // Allocate raw memory
std::memcpy(data, buffer + offset, size);
offset += size;

// Deserialize 'dims'
size_t num_dims;
std::memcpy(&num_dims, buffer + offset, sizeof(num_dims));
offset += sizeof(num_dims);
dims.resize(num_dims);
std::memcpy(dims.data(), buffer + offset, num_dims * sizeof(iree_hal_dim_t));

this->buffer_view = nullptr;
}

std::vector<char> *iree::runtime::IREETensor::serialize() {
auto buffer = new std::vector<char>();

// Serialize 'type'
size_t type_size = sizeof(type);
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&type), reinterpret_cast<const char *>(&type) + type_size);

// Serialize 'size'
size_t size_size = sizeof(size);
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&size), reinterpret_cast<const char *>(&size) + size_size);

// Serialize 'data'
buffer->insert(buffer->end(), reinterpret_cast<const char *>(data), reinterpret_cast<const char *>(data) + size);

// Serialize 'dims'
size_t dims_size = sizeof(iree_hal_dim_t) * dims.size();
size_t num_dims = dims.size();
buffer->insert(buffer->end(), reinterpret_cast<const char *>(&num_dims), reinterpret_cast<const char *>(&num_dims) + sizeof(num_dims));
buffer->insert(buffer->end(), reinterpret_cast<const char *>(dims.data()), reinterpret_cast<const char *>(dims.data()) + dims_size);

return buffer;
}

iree_vm_instance_t *create_instance() {
iree_vm_instance_t *instance = nullptr;
iree_status_t status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance);
Expand Down
5 changes: 5 additions & 0 deletions cmake/src/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class IREETensor {
iree_hal_element_type_t type;
iree_hal_buffer_view_t* buffer_view;

IREETensor(char* serialized_data);
IREETensor(iree_hal_buffer_view_t* buffer_view, iree_hal_element_type_t type);
IREETensor(void* data, size_t size, std::vector<int64_t> in_dims, iree_hal_element_type_t type);

Expand All @@ -62,6 +63,10 @@ class IREETensor {
iree_const_byte_span_t data_byte_span() const {
return iree_make_const_byte_span(static_cast<uint8_t*>(data), size);
}

// Serializes the tensor to a buffer that can be transmitted over the wire.
// Fields in order: type, rank, dims, data
std::vector<char>* serialize();
};

} // namespace runtime
Expand Down
3 changes: 3 additions & 0 deletions lib/nx_iree/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ defmodule NxIREE.Native do

def call_cpu(_instance_ref, _device_ref, _driver_name, _bytecode, _inputs),
do: :erlang.nif_error(:undef)

def serialize_tensor(_reference), do: :erlang.nif_error(:undef)
def deserialize_tensor(_binary), do: :erlang.nif_error(:undef)
end
2 changes: 1 addition & 1 deletion lib/nx_iree/tensor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ defmodule NxIREE.Tensor do

@impl true
def from_binary(out, binary, opts) do
device_uri = opts[:device]
device_uri = opts[:device] || "local-sync://default"

device_ref =
case NxIREE.Device.get(device_uri) do
Expand Down
5 changes: 1 addition & 4 deletions lib/nx_iree/vm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ defmodule NxIREE.VM do
def allocate_buffer(binary, device_ref, shape, type) when is_binary(binary) do
element_type = to_iree_type(type)

{:ok, buffer_ref} =
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)

buffer_ref
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
end

def read_buffer(%NxIREE.Tensor{} = t) do
Expand Down
37 changes: 37 additions & 0 deletions test/nx_iree/native_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defmodule NxIREE.NativeTest do
use ExUnit.Case, async: true

test "serializes and deserializes a tensor" do
tensor = Nx.tensor([[[1, 2], [3, 4], [5, 6]]], type: :s32, backend: NxIREE.Tensor)

{:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)

assert <<
type::unsigned-integer-native-size(32),
num_bytes::unsigned-integer-native-size(64),
data::binary-size(num_bytes),
num_dims::unsigned-integer-native-size(64),
dims_bin::bitstring
>> = serialized

dims =
for <<x::signed-integer-native-size(64) <- dims_bin>> do
x
end

# the type assertion is really an internal type to iree,
# but we assert on it as a sanity check.
# This can be skipped if needed in the future.
assert Bitwise.band(type, 0xFF) == 32
assert Bitwise.band(Bitwise.bsr(type, 24), 0xFF) == 0x10

assert num_bytes == Nx.byte_size(tensor)
assert data == Nx.to_binary(tensor)
assert num_dims == 3
assert dims == [1, 3, 2]

{:ok, deserialized_ref} = NxIREE.Native.deserialize_tensor(serialized)

assert Nx.to_binary(tensor) == Nx.to_binary(put_in(tensor.data.ref, deserialized_ref))
end
end

0 comments on commit 0392d0e

Please sign in to comment.