Skip to content

Commit

Permalink
Adding Cosine Similarity (apple#1531)
Browse files Browse the repository at this point in the history
Adding Cosine Similarity (apple#1531)
  • Loading branch information
makiit authored Jun 23, 2022
1 parent aeb3e8e commit eaa1b66
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 18 deletions.
21 changes: 21 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,27 @@ def constant(context, node):
context.add(const, torch_name=name)


@register_torch_op
def cosine_similarity(context, node):
inputs = _get_inputs(context, node, expected=4)
dim = inputs[-2].val
eps = inputs[-1].val
xy = mb.mul(x=inputs[0], y=inputs[1])
sum_xy = mb.reduce_sum(x=xy, axes=[dim])

xx = mb.mul(x=inputs[0], y=inputs[0])
sum_xx = mb.reduce_sum(x=xx, axes=[dim])
yy = mb.mul(x=inputs[1], y=inputs[1])
sum_yy = mb.reduce_sum(x=yy, axes=[dim])

mul_sum_xy = mb.mul(x=sum_xx, y=sum_yy)
div_12 = mb.maximum(x=mul_sum_xy, y=eps * eps)
div_sqrt = mb.sqrt(x=div_12)

cs = mb.real_div(x=sum_xy, y=div_sqrt, name=node.name)
context.add(cs)


@register_torch_op
def selu(context, node):
ALPHA = 1.6732632423543772
Expand Down
19 changes: 2 additions & 17 deletions coremltools/converters/mil/frontend/torch/test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_TORCH_OPS_REGISTRY as _TORCH_OPS_REG,
)
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.ops import cosine_similarity as cosine_similarity_main
from coremltools.converters.mil.mil import (
Builder as mb,
Operation,
Expand All @@ -38,23 +39,7 @@

@register_torch_op(override=True)
def cosine_similarity(context, node):
inputs = _get_inputs(context, node, expected=4)
dim = inputs[-2].val
eps = inputs[-1].val
xy = mb.mul(x=inputs[0], y=inputs[1])
sum_xy = mb.reduce_sum(x=xy, axes=[dim])

xx = mb.mul(x=inputs[0], y=inputs[0])
sum_xx = mb.reduce_sum(x=xx, axes=[dim])
yy = mb.mul(x=inputs[1], y=inputs[1])
sum_yy = mb.reduce_sum(x=yy, axes=[dim])

mul_sum_xy = mb.mul(x=sum_xx, y=sum_yy)
div_12 = mb.maximum(x=mul_sum_xy, y=eps * eps)
div_sqrt = mb.sqrt(x=div_12)

cs = mb.real_div(x=sum_xy, y=div_sqrt, name=node.name)
context.add(cs)
cosine_similarity_main(context, node)


# Log custom Cosine Similarity conversion function
Expand Down
32 changes: 31 additions & 1 deletion coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def forward(self, x):

self.run_compare_torch(model.input_size, model, backend=backend, use_scripting=True)


@pytest.mark.parametrize("backend", backends)
def test_linear(self, backend):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -396,6 +397,35 @@ def test_mv(self, matrix_shape, backend):
TorchBaseTest.run_compare_torch((matrix, vector), model, backend=backend, input_as_shape=False)


class TestCosineSimilarity(TorchBaseTest):
@pytest.mark.parametrize("dim, eps, shape, backend",
itertools.product([0, 1, -1], [0.1, 1e-5, 1e-8], COMMON_SHAPES, backends)
)
@pytest.mark.xfail(backend = ("mlprogram", "fp16"),
reason = "Known precision error with mlprogram fp16 backend"
)
def test_cosine_similarity(self, backend, dim, eps, shape):
class CosineSimilarity(nn.Module):
def __init__(self, dim, eps):
super(CosineSimilarity, self).__init__()
self.cossim = torch.nn.CosineSimilarity(dim=dim, eps=eps)

def forward(self, x, y):
out = self.cossim(x, y)
return out

model = CosineSimilarity(dim, eps)
input1 = generate_input_data(shape)
input2 = generate_input_data(shape)

TorchBaseTest.run_compare_torch(
[input1, input2],
model,
input_as_shape=False,
backend=backend,
)


class TestDot(TorchBaseTest):
@pytest.mark.parametrize("vector_length, backend",
itertools.product([1, 5, 11], backends)
Expand All @@ -407,7 +437,7 @@ def test_dot(self, vector_length, backend):
vector2 = generate_input_data((vector_length, ))

TorchBaseTest.run_compare_torch((vector1, vector2), model, backend=backend, input_as_shape=False)


class TestOuter(TorchBaseTest):
@pytest.mark.parametrize("x_vector_length, y_vector_length, backend",
Expand Down

0 comments on commit eaa1b66

Please sign in to comment.