Skip to content

Commit

Permalink
Update the State class (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers authored Jun 19, 2024
1 parent 7b78b55 commit 2943010
Show file tree
Hide file tree
Showing 21 changed files with 532 additions and 621 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import numpy as np
from turnkeyml.run.basert import BaseRT
import turnkeyml.common.exceptions as exp
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
from turnkeyml.run.onnxrt.within_conda import dummy_inputs
from turnkeyml.common.performance import MeasuredPerformance
from turnkeyml.common.filesystem import Stats


combined_rt_name = "example-combined-rt"
Expand All @@ -19,7 +18,7 @@ def __init__(
self,
cache_dir: str,
build_name: str,
stats: Stats,
stats: fs.Stats,
iterations: int,
device_type: str,
runtime: str = combined_rt_name,
Expand Down Expand Up @@ -57,13 +56,13 @@ def _setup(self):
pass

def benchmark(self):
state = build.load_state(self.cache_dir, self.build_name)
state = fs.load_state(self.cache_dir, self.build_name)
per_iteration_latency = []
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
onnx_session = ort.InferenceSession(state.results[0], sess_options)
onnx_session = ort.InferenceSession(state.results, sess_options)
sess_input = onnx_session.get_inputs()
input_feed = dummy_inputs(sess_input)
output_name = onnx_session.get_outputs()[0].name
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from turnkeyml.build.stage import Sequence, Stage
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
import turnkeyml.build.export as export

combined_seq_name = "example-combined-seq"
Expand All @@ -17,7 +17,7 @@ def __init__(self):
monitor_message="Special step expected by CombinedExampleRT",
)

def fire(self, state: build.State):
def fire(self, state: fs.State):
return state


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from turnkeyml.build.stage import Sequence, Stage
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
import turnkeyml.build.export as export


Expand All @@ -33,7 +33,7 @@ def __init__(self):
monitor_message="Teaching by example",
)

def fire(self, state: build.State):
def fire(self, state: fs.State):
return state


Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
"pandas>=1.5.3",
"fasteners",
"GitPython>=3.1.40",
# Necessary until upstream packages account for the breaking
# change to numpy
"numpy<2.0.0",
"psutil",
],
classifiers=[],
Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .files_api import benchmark_files
from .cli.cli import main as turnkeycli
from .build_api import build_model
from .common.build import load_state
from .common.filesystem import load_state
24 changes: 10 additions & 14 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _store_traceback(invocation_info: status.UniqueInvocationInfo):

def set_status_on_exception(
build_required: bool,
build_state: build.State,
build_state: fs.State,
stats: fs.Stats,
benchmark_logfile_path: str,
):
Expand All @@ -181,16 +181,12 @@ def set_status_on_exception(
# whether the exception was thrown during build or benchmark
# We also take into account whether a build was requested
if build_required and not build_state:
stats.save_model_eval_stat(
fs.Keys.BUILD_STATUS, build.FunctionStatus.ERROR.value
)
stats.save_model_eval_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.ERROR)

# NOTE: The log file for the failed build stage should have
# already been saved to stats
else:
stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.ERROR.value
)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.ERROR)

# Also save the benchmark log file to the stats
stats.save_eval_error_log(benchmark_logfile_path)
Expand Down Expand Up @@ -376,12 +372,12 @@ def explore_invocation(
# that action is part of the evaluation
if runtime_info["build_required"]:
stats.save_model_eval_stat(
fs.Keys.BUILD_STATUS, build.FunctionStatus.NOT_STARTED.value
fs.Keys.BUILD_STATUS, build.FunctionStatus.NOT_STARTED
)

if Action.BENCHMARK in tracer_args.actions:
stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.NOT_STARTED.value
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.NOT_STARTED
)

# Save the device name that will be used for the benchmark
Expand All @@ -400,7 +396,7 @@ def explore_invocation(
# If a concluded build still has a status of "running", this means
# there was an uncaught exception.
stats.save_model_eval_stat(
fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE.value
fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE
)

build_state = build_model(
Expand All @@ -416,10 +412,10 @@ def explore_invocation(
)

stats.save_model_eval_stat(
fs.Keys.BUILD_STATUS, build.FunctionStatus.SUCCESSFUL.value
fs.Keys.BUILD_STATUS, build.FunctionStatus.SUCCESSFUL
)

model_to_benchmark = build_state.results[0]
model_to_benchmark = build_state.results

# Analyze the onnx file (if any) and save statistics
analyze_model.analyze_onnx(
Expand All @@ -438,7 +434,7 @@ def explore_invocation(
rt_args_to_use = tracer_args.rt_args

stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.INCOMPLETE.value
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.INCOMPLETE
)

runtime_handle = runtime_info["RuntimeClass"](
Expand All @@ -462,7 +458,7 @@ def explore_invocation(
)

stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.SUCCESSFUL.value
fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.SUCCESSFUL
)

invocation_info.status_message = "Model successfully benchmarked!"
Expand Down
72 changes: 36 additions & 36 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,32 @@ def get_output_names(
return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member


def onnx_dir(state: build.State):
return os.path.join(
build.output_dir(state.cache_dir, state.config.build_name), "onnx"
)
def original_inputs_file(cache_dir: str, build_name: str):
return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")


def onnx_dir(state: fs.State):
return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")


def base_onnx_file(state: build.State):
def base_onnx_file(state: fs.State):
return os.path.join(
onnx_dir(state),
f"{state.config.build_name}-op{state.config.onnx_opset}-base.onnx",
f"{state.build_name}-op{state.onnx_opset}-base.onnx",
)


def opt_onnx_file(state: build.State):
def opt_onnx_file(state: fs.State):
return os.path.join(
onnx_dir(state),
f"{state.config.build_name}-op{state.config.onnx_opset}-opt.onnx",
f"{state.build_name}-op{state.onnx_opset}-opt.onnx",
)


def converted_onnx_file(state: build.State):
def converted_onnx_file(state: fs.State):
return os.path.join(
onnx_dir(state),
f"{state.config.build_name}-op{state.config.onnx_opset}-opt-f16.onnx",
f"{state.build_name}-op{state.onnx_opset}-opt-f16.onnx",
)


Expand All @@ -89,7 +91,7 @@ def __init__(self):
monitor_message="Placeholder for an Export Stage",
)

def fire(self, _: build.State):
def fire(self, _: fs.State):
raise exp.StageError(
"This Sequence includes an ExportPlaceholder Stage that should have "
"been replaced with an export Stage."
Expand All @@ -114,7 +116,7 @@ def __init__(self):
monitor_message="Receiving ONNX Model",
)

def fire(self, state: build.State):
def fire(self, state: fs.State):
if not isinstance(state.model, str):
msg = f"""
The current stage (ReceiveOnnxModel) is only compatible with
Expand Down Expand Up @@ -172,19 +174,19 @@ def fire(self, state: build.State):
shutil.copy(state.model, output_path)

tensor_helpers.save_inputs(
[state.inputs], state.original_inputs_file, downcast=False
[state.inputs],
original_inputs_file(state.cache_dir, state.build_name),
downcast=False,
)

# Check the if the base mode has been exported successfully
success_msg = "\tSuccess receiving ONNX Model"
fail_msg = "\tFailed receiving ONNX Model"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]
state.intermediate_results = output_path

stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats = fs.Stats(state.cache_dir, state.build_name, state.evaluation_id)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
Expand Down Expand Up @@ -220,7 +222,7 @@ def __init__(self):
monitor_message="Exporting PyTorch to ONNX",
)

def fire(self, state: build.State):
def fire(self, state: fs.State):
if not isinstance(state.model, (torch.nn.Module, torch.jit.ScriptModule)):
msg = f"""
The current stage (ExportPytorchModel) is only compatible with
Expand Down Expand Up @@ -280,7 +282,7 @@ def fire(self, state: build.State):
default_warnings = warnings.showwarning
warnings.showwarning = _warn_to_stdout

stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id)
stats = fs.Stats(state.cache_dir, state.build_name, state.evaluation_id)

# Verify if the exported model matches the input torch model
try:
Expand All @@ -295,7 +297,7 @@ def fire(self, state: build.State):
export_verification = torch.onnx.verification.find_mismatch(
state.model,
tuple(state.inputs.values()),
opset_version=state.config.onnx_opset,
opset_version=state.onnx_opset,
options=fp32_tolerance,
)

Expand Down Expand Up @@ -331,7 +333,7 @@ def fire(self, state: build.State):
output_path,
input_names=dummy_input_names,
do_constant_folding=True,
opset_version=state.config.onnx_opset,
opset_version=state.onnx_opset,
verbose=False,
)

Expand All @@ -342,15 +344,17 @@ def fire(self, state: build.State):
warnings.showwarning = default_warnings

tensor_helpers.save_inputs(
[state.inputs], state.original_inputs_file, downcast=False
[state.inputs],
original_inputs_file(state.cache_dir, state.build_name),
downcast=False,
)

# Check the if the base mode has been exported successfully
success_msg = "\tSuccess exporting model to ONNX"
fail_msg = "\tFailed exporting model to ONNX"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]
state.intermediate_results = output_path

stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
Expand Down Expand Up @@ -387,8 +391,8 @@ def __init__(self):
monitor_message="Optimizing ONNX file",
)

def fire(self, state: build.State):
input_onnx = state.intermediate_results[0]
def fire(self, state: fs.State):
input_onnx = state.intermediate_results
output_path = opt_onnx_file(state)

# Perform some basic optimizations on the model to remove shape related
Expand All @@ -413,11 +417,9 @@ def fire(self, state: build.State):
fail_msg = "\tFailed optimizing ONNX model"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]
state.intermediate_results = output_path

stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats = fs.Stats(state.cache_dir, state.build_name, state.evaluation_id)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
Expand Down Expand Up @@ -452,8 +454,8 @@ def __init__(self):
monitor_message="Converting to FP16",
)

def fire(self, state: build.State):
input_onnx = state.intermediate_results[0]
def fire(self, state: fs.State):
input_onnx = state.intermediate_results

# Convert the model to FP16
# Some ops will not be converted to fp16 because they are in a block list
Expand Down Expand Up @@ -485,7 +487,7 @@ def fire(self, state: build.State):
)

# Load inputs and convert to fp16
inputs_file = state.original_inputs_file
inputs_file = original_inputs_file(state.cache_dir, state.build_name)
if os.path.isfile(inputs_file):
inputs = np.load(inputs_file, allow_pickle=True)
inputs_converted = tensor_helpers.save_inputs(
Expand Down Expand Up @@ -519,11 +521,9 @@ def fire(self, state: build.State):
fail_msg = "\tFailed converting ONNX model to fp16"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]
state.intermediate_results = output_path

stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats = fs.Stats(state.cache_dir, state.build_name, state.evaluation_id)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
Expand Down
Loading

0 comments on commit 2943010

Please sign in to comment.