Skip to content

Commit

Permalink
Torch eq and ne ops supports bool type. (apple#1501)
Browse files Browse the repository at this point in the history
* Torch eq and ne ops supports bool type.

* Addressed review comment
  • Loading branch information
fukatani authored Jun 16, 2022
1 parent 161898d commit aeb3e8e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
16 changes: 14 additions & 2 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,26 @@ def listconstruct(context, node):
@register_torch_op
def eq(context, node):
inputs = _get_inputs(context, node, expected=2)
equal_to = mb.equal(x=inputs[0], y=inputs[1], name=node.name)
x = inputs[0]
y = inputs[1]
if is_bool(x.dtype):
x = mb.cast(x=x, dtype='int32')
if is_bool(y.dtype):
y = mb.cast(x=y, dtype='int32')
equal_to = mb.equal(x=x, y=y, name=node.name)
context.add(equal_to)


@register_torch_op
def ne(context, node):
inputs = _get_inputs(context, node, expected=2)
equal_to = mb.not_equal(x=inputs[0], y=inputs[1], name=node.name)
x = inputs[0]
y = inputs[1]
if is_bool(x.dtype):
x = mb.cast(x=x, dtype='int32')
if is_bool(y.dtype):
y = mb.cast(x=y, dtype='int32')
equal_to = mb.not_equal(x=x, y=y, name=node.name)
context.add(equal_to)


Expand Down
28 changes: 28 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3370,6 +3370,34 @@ def forward(self, x):
)


class TestBitWiseLogical(TorchBaseTest):
@pytest.mark.parametrize(
"backend, x_y, op_string",
itertools.product(
backends,
[
([True, False, True, False], [True, True, False, False]),
([[True, False], [True, False]], [[True, True], [False, False]]),
([[True, False], [True, False]], [[1, 0], [2, 1]]),
([-1.5, 0.0, 1.0, 0.0], [0.1, 2.5, 0.0, 0.0]),
([2, 0, -1, 0, 5], [1, 1, 0, 0, -5]),
],
[
"eq",
"ne",
],
),
)
def test_bitwise_logical(self, backend, x_y, op_string):
if not contains_op(torch, op_string):
return
op_func = getattr(torch, op_string)
model = ModuleWrapper(function=op_func)
x = torch.tensor(x_y[0])
y = torch.tensor(x_y[1])
self.run_compare_torch([x, y], model, backend=backend, input_as_shape=False)


class TestLogicalAnd(TorchBaseTest):
@pytest.mark.parametrize(
"backend, x_y",
Expand Down

0 comments on commit aeb3e8e

Please sign in to comment.