Skip to content

Commit 8d2846e

Browse files
committed
feat: connect lvn with iree
1 parent c6bca28 commit 8d2846e

File tree

8 files changed

+152
-18
lines changed

8 files changed

+152
-18
lines changed

liveview_native/live_nx_iree/config/runtime.exs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ if System.get_env("PHX_SERVER") do
2020
config :live_nx_iree, LiveNxIREEWeb.Endpoint, server: true
2121
end
2222

23+
config :nx, :default_backend, NxIREE.Tensor
24+
2325
if config_env() == :prod do
2426
database_url =
2527
System.get_env("DATABASE_URL") ||

liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ defmodule LiveNxIREEWeb.HomeLive do
1515

1616
dbg(self())
1717

18-
socket = assign(socket, bytecode: nil, function_signature: nil, device: nil)
18+
socket =
19+
assign(socket,
20+
bytecode: nil,
21+
function_signature: nil,
22+
device: nil,
23+
inputs: nil,
24+
num_outputs: nil
25+
)
1926

2027
{:ok, socket}
2128
end
@@ -43,7 +50,7 @@ defmodule LiveNxIREEWeb.HomeLive do
4350
# end
4451

4552
@impl true
46-
def handle_info({:nx, :execute, function, input_templates, target_device, reply_to_pid}, socket) do
53+
def handle_info({:nx, :execute, function, inputs, target_device, reply_to_pid}, socket) do
4754
fun =
4855
case function do
4956
{m, f, a} ->
@@ -71,19 +78,36 @@ defmodule LiveNxIREEWeb.HomeLive do
7178
]
7279

7380
{:ok, %{bytecode: %NxIREE.Module{bytecode: bytecode}, output_container: output_container}} =
74-
NxIREE.Compiler.to_bytecode(fun, input_templates, iree_compiler_flags: compiler_flags)
81+
NxIREE.Compiler.to_bytecode(fun, inputs, iree_compiler_flags: compiler_flags)
82+
83+
{_, num_outputs} =
84+
Nx.Defn.Composite.traverse(output_container, 0, fn node, acc -> {node, acc + 1} end)
7585

7686
socket =
7787
socket
7888
|> assign(:bytecode, Base.encode64(bytecode))
7989
|> assign(:output_container, output_container)
80-
|> assign(:function_signature, get_signature(function, input_templates, output_container))
90+
|> assign(:function_signature, get_signature(function, inputs, output_container))
8191
|> assign(:device, runtime_device)
8292
|> assign(:reply_to_pid, reply_to_pid)
93+
|> assign(:inputs, serialize_inputs(inputs))
94+
|> assign(:num_outputs, num_outputs)
8395

8496
{:noreply, socket}
8597
end
8698

99+
defp serialize_inputs(inputs) do
100+
List.wrap(inputs)
101+
|> Nx.Defn.Composite.flatten_list()
102+
|> Enum.map(fn tensor ->
103+
tensor = Nx.to_tensor(tensor)
104+
105+
{:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)
106+
107+
serialized
108+
end)
109+
end
110+
87111
@impl true
88112
def handle_event("nx-executed", params, socket) do
89113
send(socket.assigns.reply_to_pid, {:nx, :executed, params})
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
<NxFunction on-execution="nx-executed" signature={@function_signature} bytecode={@bytecode} device={@device} />
1+
<NxFunction on-execution="nx-executed" signature={@function_signature} bytecode={@bytecode} device={@device} inputs={@inputs} num-outputs={@num_outputs} />

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func nx_iree_initialize(
2424
@_silgen_name("nx_iree_create_device")
2525
func nx_iree_create_device(
2626
_ driver_registry: UnsafeMutablePointer<iree_hal_driver_registry_t>,
27-
_ name: UnsafePointer<Int8>) -> UnsafeMutablePointer<iree_hal_device_t>
27+
_ name: UnsafePointer<CChar>) -> UnsafeMutablePointer<iree_hal_device_t>
2828

2929
@_silgen_name("nx_iree_call")
3030
func nx_iree_call(

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ struct NxFunctionView<Root: RootRegistry>: View {
1414
@LiveAttribute("bytecode") private var bytecode: String? = nil
1515
@LiveAttribute("signature") private var signature: String? = nil
1616
@LiveAttribute("device") private var deviceURI: String? = nil
17-
@LiveAttribute("trigger") private var trigger: Bool = false
17+
@LiveAttribute("inputs") private var serializedInputs: [String]? = nil
18+
@LiveAttribute("num-outputs") private var numOutputs: Int? = nil
1819
@Event("on-execution", type: "change") private var change
1920

2021
var body: some View {
@@ -32,18 +33,125 @@ struct NxFunctionView<Root: RootRegistry>: View {
3233
}
3334
}
3435

35-
private func run() {
36-
if bytecode == nil {
37-
return
36+
private func convertBase64StringToBytecode(_ base64String: String) -> (bytecodeSize: UInt64, bytecodePointer: UnsafePointer<CUnsignedChar>?)? {
37+
// Step 1: Decode the Base64 string into Data
38+
guard let decodedData = Data(base64Encoded: base64String) else {
39+
print("Failed to decode base64 string.")
40+
return nil
41+
}
42+
43+
// Step 2: Get the size of the data
44+
let bytecodeSize = UInt64(decodedData.count)
45+
46+
// Step 3: Convert Data to UnsafePointer<CUnsignedChar>
47+
// We use `withUnsafeBytes` to get a pointer to the data
48+
let bytecodePointer = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer<CUnsignedChar>? in
49+
return pointer.bindMemory(to: CUnsignedChar.self).baseAddress
50+
}
51+
52+
return (bytecodeSize, bytecodePointer)
53+
}
54+
55+
private func convertToCStringArray(from strings: [String]) -> UnsafePointer<UnsafePointer<CChar>>? {
56+
// Array to hold the C strings (UnsafePointer<CChar>)
57+
var cStrings: [UnsafePointer<CChar>] = []
58+
59+
for string in strings {
60+
// Decode the base64 string to Data
61+
guard let decodedData = Data(base64Encoded: string) else {
62+
print("Failed to decode base64 string: \(string)")
63+
return nil
64+
}
65+
66+
// Convert Data to a C string (null-terminated UTF-8)
67+
let cString = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer<CChar>? in
68+
guard let baseAddress = pointer.baseAddress else { return nil }
69+
// Allocate memory for the C string and copy the data
70+
let cStringPointer = UnsafeMutablePointer<CChar>.allocate(capacity: decodedData.count + 1)
71+
cStringPointer.initialize(from: baseAddress.assumingMemoryBound(to: CChar.self), count: decodedData.count)
72+
cStringPointer[decodedData.count] = 0 // Null-terminate the string
73+
return UnsafePointer(cStringPointer)
74+
}
75+
76+
guard let cStr = cString else {
77+
print("Failed to convert Data to C string.")
78+
return nil
79+
}
80+
81+
cStrings.append(cStr)
3882
}
3983

40-
if let vmInstance = globalVmInstance,
41-
let driverRegistry = globalDriverRegistry {
42-
print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")
43-
change(value: "Sending something back")
84+
// Allocate memory for the array of C strings
85+
let cStringsPointer = UnsafeMutablePointer<UnsafePointer<CChar>>.allocate(capacity: cStrings.count)
86+
87+
// Copy the C strings to the allocated array
88+
cStringsPointer.initialize(from: &cStrings, count: cStrings.count)
89+
90+
// Return the pointer to the array
91+
return UnsafePointer(cStringsPointer)
92+
}
93+
94+
private func base64EncodedStrings(from serializedOutputs: UnsafePointer<UnsafePointer<CChar>>, count: Int) -> [String] {
95+
// Convert UnsafePointer to a Swift array of UnsafePointer<CChar>
96+
let cStringPointers = Array(UnsafeBufferPointer(start: serializedOutputs, count: count))
97+
98+
var base64Strings: [String] = []
99+
100+
for cStringPointer in cStringPointers {
101+
// Convert each C string to a Swift String
102+
let string = String(cString: cStringPointer)
103+
104+
// Encode the string to Base64
105+
if let data = string.data(using: .utf8) {
106+
let base64String = data.base64EncodedString()
107+
base64Strings.append(base64String)
108+
}
109+
}
110+
111+
return base64Strings
112+
}
113+
114+
115+
private func run() {
116+
if bytecode != nil,
117+
deviceURI != nil,
118+
globalVmInstance != nil,
119+
globalDriverRegistry != nil,
120+
serializedInputs != nil,
121+
let (bytecodeSize, bytecodePointer) = convertBase64StringToBytecode(bytecode!),
122+
let inputs = convertToCStringArray(from: serializedInputs!) {
123+
let deviceURIcstr = strdup(deviceURI!)
124+
let device = nx_iree_create_device(globalDriverRegistry!, UnsafePointer(deviceURIcstr)!)
125+
deviceURIcstr?.deallocate()
126+
127+
let serializedOutputs: UnsafePointer<UnsafePointer<CChar>>? = nil
128+
let errorMessage = UnsafeMutablePointer<CChar>.allocate(capacity: 256)
129+
130+
print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")
131+
132+
let result = nx_iree_call(
133+
globalVmInstance!,
134+
device,
135+
bytecodeSize,
136+
bytecodePointer!,
137+
UInt64(serializedInputs!.count),
138+
inputs,
139+
UInt64(numOutputs!),
140+
serializedOutputs!,
141+
errorMessage)
142+
143+
if result != 0 {
144+
print(errorMessage)
145+
return
146+
}
147+
148+
for i in 0..<serializedInputs!.count {
149+
inputs[i].deallocate()
150+
}
151+
inputs.deallocate()
152+
153+
change(value: base64EncodedStrings(from: serializedOutputs!, count: numOutputs!))
44154
} else {
45-
print("vm instance: \(globalVmInstance)")
46-
print("driver registry: \(globalDriverRegistry)")
47155
print("IREE components are not initialized.")
48156
}
49157
}

liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/c_src/nx_iree.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ extern "C" {
1616
#endif
1717

1818
int nx_iree_initialize(iree_vm_instance_t* vm_instance, iree_hal_driver_registry_t* driver_registry, char* error_message);
19-
iree_hal_device_t* nx_iree_create_device(char* device_uri);
19+
iree_hal_device_t* nx_iree_create_device(iree_hal_driver_registry_t* registry, char* device_uri);
2020
int nx_iree_call(iree_vm_instance_t* vm_instance, iree_hal_device_t* device, uint64_t bytecode_size, unsigned char* bytecode, uint64_t num_inputs, char** serialized_inputs, uint64_t num_outputs, char** serialized_outputs, char* error_message);
2121

2222

priv/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.1-pre.5
1+
0.0.1-pre.6

0 commit comments

Comments
 (0)