Skip to content

Commit aeb3e8e

Browse files
authored
Torch eq and ne ops supports bool type. (apple#1501)
* Torch eq and ne ops supports bool type. * Addressed review comment
1 parent 161898d commit aeb3e8e

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,26 @@ def listconstruct(context, node):
472472
@register_torch_op
473473
def eq(context, node):
474474
inputs = _get_inputs(context, node, expected=2)
475-
equal_to = mb.equal(x=inputs[0], y=inputs[1], name=node.name)
475+
x = inputs[0]
476+
y = inputs[1]
477+
if is_bool(x.dtype):
478+
x = mb.cast(x=x, dtype='int32')
479+
if is_bool(y.dtype):
480+
y = mb.cast(x=y, dtype='int32')
481+
equal_to = mb.equal(x=x, y=y, name=node.name)
476482
context.add(equal_to)
477483

478484

479485
@register_torch_op
480486
def ne(context, node):
481487
inputs = _get_inputs(context, node, expected=2)
482-
equal_to = mb.not_equal(x=inputs[0], y=inputs[1], name=node.name)
488+
x = inputs[0]
489+
y = inputs[1]
490+
if is_bool(x.dtype):
491+
x = mb.cast(x=x, dtype='int32')
492+
if is_bool(y.dtype):
493+
y = mb.cast(x=y, dtype='int32')
494+
equal_to = mb.not_equal(x=x, y=y, name=node.name)
483495
context.add(equal_to)
484496

485497

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3370,6 +3370,34 @@ def forward(self, x):
33703370
)
33713371

33723372

3373+
class TestBitWiseLogical(TorchBaseTest):
3374+
@pytest.mark.parametrize(
3375+
"backend, x_y, op_string",
3376+
itertools.product(
3377+
backends,
3378+
[
3379+
([True, False, True, False], [True, True, False, False]),
3380+
([[True, False], [True, False]], [[True, True], [False, False]]),
3381+
([[True, False], [True, False]], [[1, 0], [2, 1]]),
3382+
([-1.5, 0.0, 1.0, 0.0], [0.1, 2.5, 0.0, 0.0]),
3383+
([2, 0, -1, 0, 5], [1, 1, 0, 0, -5]),
3384+
],
3385+
[
3386+
"eq",
3387+
"ne",
3388+
],
3389+
),
3390+
)
3391+
def test_bitwise_logical(self, backend, x_y, op_string):
3392+
if not contains_op(torch, op_string):
3393+
return
3394+
op_func = getattr(torch, op_string)
3395+
model = ModuleWrapper(function=op_func)
3396+
x = torch.tensor(x_y[0])
3397+
y = torch.tensor(x_y[1])
3398+
self.run_compare_torch([x, y], model, backend=backend, input_as_shape=False)
3399+
3400+
33733401
class TestLogicalAnd(TorchBaseTest):
33743402
@pytest.mark.parametrize(
33753403
"backend, x_y",

0 commit comments

Comments
 (0)