diff --git a/coremltools/converters/mil/frontend/tensorflow/ops.py b/coremltools/converters/mil/frontend/tensorflow/ops.py index 1d61b0463..570b74613 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/ops.py @@ -729,9 +729,16 @@ def DepthwiseConv2dNative(context, node): pad_type = pad_type.lower() x = context[node.inputs[0]] - C_in = x.shape[-1] + if data_format == "NHWC": x = _transpose_NHWC_to_NCHW(x) + C_in = x.shape[-1] + elif data_format == "NCHW": + C_in = x.shape[1] + + if not isinstance(C_in, int): + raise ValueError("Channel number of input node must be an integer, instead got: {}".format(C_in)) + # Only the last op should have the same name as node.name conv_name = node.name + "x" if data_format == "NHWC" else node.name x = mb.conv(