Skip to content

Commit

Permalink
feat: connect lvn with iree
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Aug 23, 2024
1 parent c6bca28 commit 8d2846e
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 18 deletions.
2 changes: 2 additions & 0 deletions liveview_native/live_nx_iree/config/runtime.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ if System.get_env("PHX_SERVER") do
config :live_nx_iree, LiveNxIREEWeb.Endpoint, server: true
end

config :nx, :default_backend, NxIREE.Tensor

if config_env() == :prod do
database_url =
System.get_env("DATABASE_URL") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ defmodule LiveNxIREEWeb.HomeLive do

dbg(self())

socket = assign(socket, bytecode: nil, function_signature: nil, device: nil)
socket =
assign(socket,
bytecode: nil,
function_signature: nil,
device: nil,
inputs: nil,
num_outputs: nil
)

{:ok, socket}
end
Expand Down Expand Up @@ -43,7 +50,7 @@ defmodule LiveNxIREEWeb.HomeLive do
# end

@impl true
def handle_info({:nx, :execute, function, input_templates, target_device, reply_to_pid}, socket) do
def handle_info({:nx, :execute, function, inputs, target_device, reply_to_pid}, socket) do
fun =
case function do
{m, f, a} ->
Expand Down Expand Up @@ -71,19 +78,36 @@ defmodule LiveNxIREEWeb.HomeLive do
]

{:ok, %{bytecode: %NxIREE.Module{bytecode: bytecode}, output_container: output_container}} =
NxIREE.Compiler.to_bytecode(fun, input_templates, iree_compiler_flags: compiler_flags)
NxIREE.Compiler.to_bytecode(fun, inputs, iree_compiler_flags: compiler_flags)

{_, num_outputs} =
Nx.Defn.Composite.traverse(output_container, 0, fn node, acc -> {node, acc + 1} end)

socket =
socket
|> assign(:bytecode, Base.encode64(bytecode))
|> assign(:output_container, output_container)
|> assign(:function_signature, get_signature(function, input_templates, output_container))
|> assign(:function_signature, get_signature(function, inputs, output_container))
|> assign(:device, runtime_device)
|> assign(:reply_to_pid, reply_to_pid)
|> assign(:inputs, serialize_inputs(inputs))
|> assign(:num_outputs, num_outputs)

{:noreply, socket}
end

defp serialize_inputs(inputs) do
List.wrap(inputs)
|> Nx.Defn.Composite.flatten_list()
|> Enum.map(fn tensor ->
tensor = Nx.to_tensor(tensor)

{:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)

serialized
end)
end

@impl true
def handle_event("nx-executed", params, socket) do
send(socket.assigns.reply_to_pid, {:nx, :executed, params})
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
<NxFunction on-execution="nx-executed" signature={@function_signature} bytecode={@bytecode} device={@device} />
<NxFunction on-execution="nx-executed" signature={@function_signature} bytecode={@bytecode} device={@device} inputs={@inputs} num-outputs={@num_outputs} />
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func nx_iree_initialize(
@_silgen_name("nx_iree_create_device")
func nx_iree_create_device(
_ driver_registry: UnsafeMutablePointer<iree_hal_driver_registry_t>,
_ name: UnsafePointer<Int8>) -> UnsafeMutablePointer<iree_hal_device_t>
_ name: UnsafePointer<CChar>) -> UnsafeMutablePointer<iree_hal_device_t>

@_silgen_name("nx_iree_call")
func nx_iree_call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ struct NxFunctionView<Root: RootRegistry>: View {
@LiveAttribute("bytecode") private var bytecode: String? = nil
@LiveAttribute("signature") private var signature: String? = nil
@LiveAttribute("device") private var deviceURI: String? = nil
@LiveAttribute("trigger") private var trigger: Bool = false
@LiveAttribute("inputs") private var serializedInputs: [String]? = nil
@LiveAttribute("num-outputs") private var numOutputs: Int? = nil
@Event("on-execution", type: "change") private var change

var body: some View {
Expand All @@ -32,18 +33,125 @@ struct NxFunctionView<Root: RootRegistry>: View {
}
}

private func run() {
if bytecode == nil {
return
private func convertBase64StringToBytecode(_ base64String: String) -> (bytecodeSize: UInt64, bytecodePointer: UnsafePointer<CUnsignedChar>?)? {
// Step 1: Decode the Base64 string into Data
guard let decodedData = Data(base64Encoded: base64String) else {
print("Failed to decode base64 string.")
return nil
}

// Step 2: Get the size of the data
let bytecodeSize = UInt64(decodedData.count)

// Step 3: Convert Data to UnsafePointer<CUnsignedChar>
// We use `withUnsafeBytes` to get a pointer to the data
let bytecodePointer = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer<CUnsignedChar>? in
return pointer.bindMemory(to: CUnsignedChar.self).baseAddress
}

return (bytecodeSize, bytecodePointer)
}

private func convertToCStringArray(from strings: [String]) -> UnsafePointer<UnsafePointer<CChar>>? {
// Array to hold the C strings (UnsafePointer<CChar>)
var cStrings: [UnsafePointer<CChar>] = []

for string in strings {
// Decode the base64 string to Data
guard let decodedData = Data(base64Encoded: string) else {
print("Failed to decode base64 string: \(string)")
return nil
}

// Convert Data to a C string (null-terminated UTF-8)
let cString = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer<CChar>? in
guard let baseAddress = pointer.baseAddress else { return nil }
// Allocate memory for the C string and copy the data
let cStringPointer = UnsafeMutablePointer<CChar>.allocate(capacity: decodedData.count + 1)
cStringPointer.initialize(from: baseAddress.assumingMemoryBound(to: CChar.self), count: decodedData.count)
cStringPointer[decodedData.count] = 0 // Null-terminate the string
return UnsafePointer(cStringPointer)
}

guard let cStr = cString else {
print("Failed to convert Data to C string.")
return nil
}

cStrings.append(cStr)
}

if let vmInstance = globalVmInstance,
let driverRegistry = globalDriverRegistry {
print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")
change(value: "Sending something back")
// Allocate memory for the array of C strings
let cStringsPointer = UnsafeMutablePointer<UnsafePointer<CChar>>.allocate(capacity: cStrings.count)

// Copy the C strings to the allocated array
cStringsPointer.initialize(from: &cStrings, count: cStrings.count)

// Return the pointer to the array
return UnsafePointer(cStringsPointer)
}

private func base64EncodedStrings(from serializedOutputs: UnsafePointer<UnsafePointer<CChar>>, count: Int) -> [String] {
// Convert UnsafePointer to a Swift array of UnsafePointer<CChar>
let cStringPointers = Array(UnsafeBufferPointer(start: serializedOutputs, count: count))

var base64Strings: [String] = []

for cStringPointer in cStringPointers {
// Convert each C string to a Swift String
let string = String(cString: cStringPointer)

// Encode the string to Base64
if let data = string.data(using: .utf8) {
let base64String = data.base64EncodedString()
base64Strings.append(base64String)
}
}

return base64Strings
}


private func run() {
if bytecode != nil,
deviceURI != nil,
globalVmInstance != nil,
globalDriverRegistry != nil,
serializedInputs != nil,
let (bytecodeSize, bytecodePointer) = convertBase64StringToBytecode(bytecode!),
let inputs = convertToCStringArray(from: serializedInputs!) {
let deviceURIcstr = strdup(deviceURI!)
let device = nx_iree_create_device(globalDriverRegistry!, UnsafePointer(deviceURIcstr)!)
deviceURIcstr?.deallocate()

let serializedOutputs: UnsafePointer<UnsafePointer<CChar>>? = nil
let errorMessage = UnsafeMutablePointer<CChar>.allocate(capacity: 256)

print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")

let result = nx_iree_call(
globalVmInstance!,
device,
bytecodeSize,
bytecodePointer!,
UInt64(serializedInputs!.count),
inputs,
UInt64(numOutputs!),
serializedOutputs!,
errorMessage)

if result != 0 {
print(errorMessage)
return
}

for i in 0..<serializedInputs!.count {
inputs[i].deallocate()
}
inputs.deallocate()

change(value: base64EncodedStrings(from: serializedOutputs!, count: numOutputs!))
} else {
print("vm instance: \(globalVmInstance)")
print("driver registry: \(globalDriverRegistry)")
print("IREE components are not initialized.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ extern "C" {
#endif

int nx_iree_initialize(iree_vm_instance_t* vm_instance, iree_hal_driver_registry_t* driver_registry, char* error_message);
iree_hal_device_t* nx_iree_create_device(char* device_uri);
iree_hal_device_t* nx_iree_create_device(iree_hal_driver_registry_t* registry, char* device_uri);
int nx_iree_call(iree_vm_instance_t* vm_instance, iree_hal_device_t* device, uint64_t bytecode_size, unsigned char* bytecode, uint64_t num_inputs, char** serialized_inputs, uint64_t num_outputs, char** serialized_outputs, char* error_message);


Expand Down
2 changes: 1 addition & 1 deletion priv/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1-pre.5
0.0.1-pre.6

0 comments on commit 8d2846e

Please sign in to comment.