Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export casanovo to torchscript/onnx #328

Open
LLautenbacher opened this issue May 11, 2024 · 1 comment
Open

Export casanovo to torchscript/onnx #328

LLautenbacher opened this issue May 11, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@LLautenbacher
Copy link

Hi,

I want to export casanovo to Torchscript or ONNX to make it accessible via Koina.
When I follow the documentation for Lightning to do that (using method="trace"). I get a UnsupportedNodeError. I'm not familiar with Lightning or Pytorch. Can you help with creating a Torchscript/ONNX export of your model?

Here is the full traceback
---------------------------------------------------------------------------
UnsupportedNodeError                      Traceback (most recent call last)
Cell In[5], [line 1](vscode-notebook-cell:?execution_count=5&line=1)
----> [1](vscode-notebook-cell:?execution_count=5&line=1) runner.model.to_torchscript("model.pt", method="trace", example_inputs=inp)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/lightning/pytorch/core/module.py:1479, in LightningModule.to_torchscript(self, file_path, method, example_inputs, **kwargs)
1477 example_inputs = self._apply_batch_transfer_handler(example_inputs)
1478 with _jit_is_scripting():
-> 1479 torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1480 else:
1481 raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:820, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
818 else:
819 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 820 return trace_module(
821 func,
822 {"forward": example_inputs},
823 None,
824 check_trace,
825 wrap_check_inputs(check_inputs),
826 check_tolerance,
827 strict,
828 _force_outplace,
829 _module_class,
830 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
831 _store_inputs=_store_inputs,
832 )
833 if (
834 hasattr(func, "self")
835 and isinstance(func.self, torch.nn.Module)
836 and func.name == "forward"
837 ):
838 if example_inputs is None:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:1053, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1050 torch.jit._trace._trace_module_map = trace_module_map
1051 register_submods(mod, "__module")
-> 1053 module = make_module(mod, _module_class, _compilation_unit)
1055 for method_name, example_inputs in inputs.items():
1056 if method_name == "forward":
1057 # "forward" is a special case because we need to trace
1058 # Module.__call__, which sets up some extra tracing, but uses
1059 # argument names of the real Module.forward method.

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:624, in make_module(mod, _module_class, _compilation_unit)
622 elif torch._jit_internal.module_has_exports(mod):
623 infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
--> 624 return torch.jit._recursive.create_script_module(
625 mod, infer_methods_stubs_fn, share_types=False, is_tracing=True
626 )
627 else:
628 if _module_class is None:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:558, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
556 if not is_tracing:
557 AttributeTypeIsSupportedChecker().check(nn_module)
--> 558 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:631, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
628 script_module._concrete_type = concrete_type
630 # Actually create the ScriptModule, initializing it with the function we just defined
--> 631 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
633 # Compile methods if necessary
634 if concrete_type not in concrete_type_store.methods_compiled:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_script.py:647, in RecursiveScriptModule._construct(cpp_module, init_fn)
633 """
634 Construct a RecursiveScriptModule that's ready for use.
635
(...)
644 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
645 """
646 script_module = RecursiveScriptModule(cpp_module)
--> 647 init_fn(script_module)
649 # Finalize the ScriptModule: replace the nn.Module state with our
650 # custom implementations and flip the _initializing bit.
651 RecursiveScriptModule._finalize_scriptmodule(script_module)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:607, in create_script_module_impl..init_fn(script_module)
604 scripted = orig_value
605 else:
606 # always reuse the provided stubs_fn to infer the methods to compile
--> 607 scripted = create_script_module_impl(
608 orig_value, sub_concrete_type, stubs_fn
609 )
611 cpp_module.setattr(name, scripted)
612 script_module._modules[name] = scripted

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:635, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
633 # Compile methods if necessary
634 if concrete_type not in concrete_type_store.methods_compiled:
--> 635 create_methods_and_properties_from_stubs(
636 concrete_type, method_stubs, property_stubs
637 )
638 # Create hooks after methods to ensure no name collisions between hooks and methods.
639 # If done before, hooks can overshadow methods that aren't exported.
640 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/recursive.py:467, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
464 property_defs = [p.def
for p in property_stubs]
465 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 467 concrete_type._create_methods_and_properties(
468 property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
469 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:1036, in compile_unbound_method(concrete_type, fn)
1034 if _jit_internal.is_ignored_fn(fn):
1035 return None
-> 1036 stub = make_stub(fn, fn.name)
1037 with torch._jit_internal._disable_emit_hooks():
1038 # We don't want to call the hooks here since the graph that is calling
1039 # this function is not yet complete
1040 create_methods_and_properties_from_stubs(concrete_type, (stub,), ())

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:71, in make_stub(func, name)
69 def make_stub(func, name):
70 rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 71 ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
72 return ScriptMethodStub(rcb, ast, func)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:372, in get_jit_def(fn, def_name, self_name, is_classmethod)
369 qualname = get_qualified_name(fn)
370 pdt_arg_types = type_trace_db.get_args_types(qualname)
--> 372 return build_def(
373 parsed_def.ctx,
374 fn_def,
375 type_line,
376 def_name,
377 self_name=self_name,
378 pdt_arg_types=pdt_arg_types,
379 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:433, in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
430 type_comment_decl = torch._C.parse_type_comment(type_line)
431 decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
--> 433 return Def(Ident(r, def_name), decl, build_stmts(ctx, body))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in build_stmts(ctx, stmts)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in (.0)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:406, in Builder.call(self, ctx, node)
404 if method is None:
405 raise UnsupportedNodeError(ctx, node)
--> 406 return method(ctx, node)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:773, in StmtBuilder.build_For(ctx, stmt)
766 if stmt.orelse:
767 raise NotSupportedError(r, "else branches of for loops aren't supported")
769 return For(
770 r,
771 [build_expr(ctx, stmt.target)],
772 [build_expr(ctx, stmt.iter)],
--> 773 build_stmts(ctx, stmt.body),
774 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in build_stmts(ctx, stmts)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in (.0)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:406, in Builder.call(self, ctx, node)
404 if method is None:
405 raise UnsupportedNodeError(ctx, node)
--> 406 return method(ctx, node)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:676, in StmtBuilder.build_Expr(ctx, stmt)
674 return None
675 else:
--> 676 return ExprStmt(build_expr(ctx, value))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:405, in Builder.call(self, ctx, node)
403 method = getattr(self, "build_" + node.class.name, None)
404 if method is None:
--> 405 raise UnsupportedNodeError(ctx, node)
406 return method(ctx, node)

UnsupportedNodeError: Yield aren't supported:
File "/cmnfs/home/llautenbacher/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2230
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
~ <--- HERE

@wsnoble
Copy link
Contributor

wsnoble commented May 14, 2024

We are very interested in helping to make this happen, but unfortunately, we have zero familiarity with ONNX.

Separately, we've found that torch compile doesn't work with Casanovo (though I don't know the details there). Perhaps these are related issues.

If there is anything specific we can help with, please let us know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants