Skip to content

Commit

Permalink
draft symbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Mar 10, 2024
1 parent fdb803e commit 8ac7961
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 15 deletions.
11 changes: 6 additions & 5 deletions examples/simple/export.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import slope


x = slope.tensor([[1, 2], [3, 4]], dtype=slope.float32)
c = x
x = slope.tensor([[1, 2], [3, 4], [5, 6]], dtype=slope.float32)
# c = x.ones_like()


@slope.jit
def f(x):
y = (x + c).sum()
y = x.sum(1)
return y


# print(f(x,))
f_jitobj = f.lower(x)
f_jitobj.export("/tmp/f", x)
# f_jitobj = f.lower(x)
# f.export("/tmp/f", (x,), input_names=['x'], output_names=['y'], dynamic_axes=dict(x=[0], y=[0]))
f.export("/tmp/f", (x,), input_names=['x'], output_names=['y'], dynamic_axes=dict(x={0:'batch'}), y={0:'batch'})
2 changes: 2 additions & 0 deletions src/slope/shapes.py → experimental/onnx_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,7 @@ def _infer_ReduceSum(self, node): # noqa: N802
output_shape,
)
)
breakpoint()

def _infer_ReduceProd(self, node): # noqa: N802
axes = get_attribute(node, "axes")
Expand Down Expand Up @@ -2595,6 +2596,7 @@ def get_prereq(node):
break

if self.verbose_ > 2:
breakpoint()
logger.debug(node.op_type + ": " + node.name)
for i, name in enumerate(node.input):
logger.debug(" Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else ""))
Expand Down
2 changes: 2 additions & 0 deletions src/slope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
from slope import core
import numpy as np

np.set_printoptions(precision=5, threshold=1000, edgeitems=5, linewidth=120)
SLOPE_BACKEND = os.environ.get("SLOPE_BACKEND", "onnxruntime")
core.set_backend(SLOPE_BACKEND)


def __getattr__(attr):
if attr in (globals_dict := globals()):
core.dblog(
Expand Down
44 changes: 36 additions & 8 deletions src/slope/backends/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,44 @@ def fn(*args):

return fn, code

def export(self, jit_output: slope.core.JitOutput, output_path, export_params, input_names, output_names, **kwargs):
def export(
self, jit_output: slope.core.JitOutput, output_path, export_params, input_names=None, output_names=None, dynamic_axes=None, **kwargs
):
code = jit_output.code
# code = code.replace("float[3","float[batch")
# print(code)
# breakpoint()
# if dynamic_axes:
# if dims := dynamic_axes.get(input_name, None):
# if isinstance(dims, list):
# assert all(isinstance(d, int) for d in dims)
# dims = {d: f"{input_name}_dim{d}" for d in dims}
# else:
# assert isinstance(dims, dict)
# input_shape = list(input_shape)
# for dim, dim_name in dims.items():
# input_shape[dim] = dim_name
# input_shape = tuple(input_shape)

model = onnx.parser.parse_model(code)
os.makedirs(output_path, exist_ok=True)
in_binders = jit_output.codegen_output["in_binders"]
outs = jit_output.codegen_output["outs"]
in_binders = jit_output.codegen_output.in_binders
outs = jit_output.codegen_output.outs
num_consts = jit_output.program.num_consts

if input_names is None:
input_names = [inb.name for inb in in_binders]
else:
input_names = [inb.name for inb in in_binders][:num_consts] + list(input_names)
assert len(input_names) == len(in_binders)
if output_names is None:
output_names = [out.name for out in outs]
else:
assert len(output_names) == len(outs)

for i in range(num_consts):
const_array = in_binders[i]["type"].numpy()
const_name = in_binders[i].name
const_array = in_binders[i].symval.numpy()
const_name = input_names[i]
const = onnx.numpy_helper.from_array(const_array, name=const_name)
model.graph.initializer.append(const)

Expand All @@ -280,9 +308,9 @@ def export(self, jit_output: slope.core.JitOutput, output_path, export_params, i

test_input_code = ""
for i in range(num_consts, len(in_binders)):
input_name = in_binders[i].name
input_shape = in_binders[i]["type"].shape
dtype = in_binders[i]["type"].dtype
input_name = input_names[i]
input_shape = in_binders[i].symval.shape
dtype = in_binders[i].symval.dtype
input_dtype = ("np." + dtype.numpy.__name__) if dtype is not dtypes.bool else "bool"
test_input_code += f""" {input_name} = np.ones({input_shape}, dtype={input_dtype})\n"""

Expand Down
3 changes: 2 additions & 1 deletion src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,7 @@ def set_backend(name, where="slope.backends"):
global backend
backend = importlib.import_module(f"{where}.{name}").backend
import slope.nn as nn

# backend.register_node(nn.Module, nn.Module.flatten, nn.Module.unflatten, "Module")

dblog(f"slope backend is {backend}", enable=backend.LOG_INIT)
Expand Down Expand Up @@ -2770,7 +2771,7 @@ def lower(self, *args, **static_args):
jit_output = backend.jit_program(hashed_program, hashed_consts)
return jit_output

def export(self, args, output_path, export_params=True, input_names=None, output_names=None, **kwargs):
def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs):
if isinstance(args, Tensor):
args, static_args = (args,), dict()
elif not isinstance(args[-1], dict):
Expand Down
1 change: 0 additions & 1 deletion src/slope/nn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import slope
from slope.core import Tensor, SymbolicTensor, TreeDef
from typing import Tuple, List, Optional
Expand Down
Loading

0 comments on commit 8ac7961

Please sign in to comment.