Skip to content

Commit

Permalink
Fixed errors with python3 (apple#723)
Browse files Browse the repository at this point in the history
* Convert tf signature values to idexable list object

* Fixed python3 comparison of bytes and strings for recurrent layers

Co-authored-by: sacha <[email protected]>
  • Loading branch information
sachatt and sacha authored Jun 4, 2020
1 parent 4134744 commit 1afb9cb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion coremltools/converters/tensorflow/_tf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _graph_def_from_saved_model_or_keras_model(filename):
raise ValueError('Unable to load a model with no signatures provided.')
if len(signatures) >= 2:
raise ValueError('Unable to load a model with multiple signatures')
concrete_func = signatures.values()[0]
concrete_func = list(signatures.values())[0]
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(concrete_func)
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
except ImportError as e:
Expand Down
4 changes: 4 additions & 0 deletions coremltools/models/neural_network/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@


def _set_recurrent_activation(param, activation):

if isinstance(activation, bytes):
activation = activation.decode("utf8")

activation = activation.upper() if isinstance(activation, str) else activation
if activation == 'SIGMOID':
param.sigmoid.MergeFromString(b'')
Expand Down

0 comments on commit 1afb9cb

Please sign in to comment.