diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 5b00bfe2d..75e5fad56 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -361,6 +361,27 @@ def constant(context, node): context.add(const, torch_name=name) +@register_torch_op +def cosine_similarity(context, node): + inputs = _get_inputs(context, node, expected=4) + dim = inputs[-2].val + eps = inputs[-1].val + xy = mb.mul(x=inputs[0], y=inputs[1]) + sum_xy = mb.reduce_sum(x=xy, axes=[dim]) + + xx = mb.mul(x=inputs[0], y=inputs[0]) + sum_xx = mb.reduce_sum(x=xx, axes=[dim]) + yy = mb.mul(x=inputs[1], y=inputs[1]) + sum_yy = mb.reduce_sum(x=yy, axes=[dim]) + + mul_sum_xy = mb.mul(x=sum_xx, y=sum_yy) + div_12 = mb.maximum(x=mul_sum_xy, y=eps * eps) + div_sqrt = mb.sqrt(x=div_12) + + cs = mb.real_div(x=sum_xy, y=div_sqrt, name=node.name) + context.add(cs) + + @register_torch_op def selu(context, node): ALPHA = 1.6732632423543772 diff --git a/coremltools/converters/mil/frontend/torch/test/test_custom_ops.py b/coremltools/converters/mil/frontend/torch/test/test_custom_ops.py index 36f70e15c..468488abe 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_custom_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_custom_ops.py @@ -19,6 +19,7 @@ _TORCH_OPS_REGISTRY as _TORCH_OPS_REG, ) from coremltools.converters.mil.frontend.torch.ops import _get_inputs +from coremltools.converters.mil.frontend.torch.ops import cosine_similarity as cosine_similarity_main from coremltools.converters.mil.mil import ( Builder as mb, Operation, @@ -38,23 +39,7 @@ @register_torch_op(override=True) def cosine_similarity(context, node): - inputs = _get_inputs(context, node, expected=4) - dim = inputs[-2].val - eps = inputs[-1].val - xy = mb.mul(x=inputs[0], y=inputs[1]) - sum_xy = mb.reduce_sum(x=xy, axes=[dim]) - - xx = mb.mul(x=inputs[0], y=inputs[0]) - sum_xx = mb.reduce_sum(x=xx, axes=[dim]) - yy = mb.mul(x=inputs[1], y=inputs[1]) - sum_yy = mb.reduce_sum(x=yy, axes=[dim]) - - mul_sum_xy = mb.mul(x=sum_xx, y=sum_yy) - div_12 = mb.maximum(x=mul_sum_xy, y=eps * eps) - div_sqrt = mb.sqrt(x=div_12) - - cs = mb.real_div(x=sum_xy, y=div_sqrt, name=node.name) - context.add(cs) + cosine_similarity_main(context, node) # Log custom Cosine Similarity conversion function diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 95e6690b6..41394c1ea 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -154,6 +154,7 @@ def forward(self, x): self.run_compare_torch(model.input_size, model, backend=backend, use_scripting=True) + @pytest.mark.parametrize("backend", backends) def test_linear(self, backend): class Model(torch.nn.Module): @@ -396,6 +397,35 @@ def test_mv(self, matrix_shape, backend): TorchBaseTest.run_compare_torch((matrix, vector), model, backend=backend, input_as_shape=False) +class TestCosineSimilarity(TorchBaseTest): + @pytest.mark.parametrize("dim, eps, shape, backend", + itertools.product([0, 1, -1], [0.1, 1e-5, 1e-8], COMMON_SHAPES, backends) + ) + @pytest.mark.xfail(backend = ("mlprogram", "fp16"), + reason = "Known precision error with mlprogram fp16 backend" + ) + def test_cosine_similarity(self, backend, dim, eps, shape): + class CosineSimilarity(nn.Module): + def __init__(self, dim, eps): + super(CosineSimilarity, self).__init__() + self.cossim = torch.nn.CosineSimilarity(dim=dim, eps=eps) + + def forward(self, x, y): + out = self.cossim(x, y) + return out + + model = CosineSimilarity(dim, eps) + input1 = generate_input_data(shape) + input2 = generate_input_data(shape) + + TorchBaseTest.run_compare_torch( + [input1, input2], + model, + input_as_shape=False, + backend=backend, + ) + + class TestDot(TorchBaseTest): @pytest.mark.parametrize("vector_length, backend", itertools.product([1, 5, 11], backends) @@ -407,7 +437,7 @@ def test_dot(self, vector_length, backend): vector2 = generate_input_data((vector_length, )) TorchBaseTest.run_compare_torch((vector1, vector2), model, backend=backend, input_as_shape=False) - + class TestOuter(TorchBaseTest): @pytest.mark.parametrize("x_vector_length, y_vector_length, backend",