-
Notifications
You must be signed in to change notification settings - Fork 0
/
semantic_sim.py
114 lines (94 loc) · 4.09 KB
/
semantic_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import tf_sentencepiece
USE_QA_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual-qa/1"
USE_EMBED_DIM = 512
def embed_use(module):
with tf.Graph().as_default():
sentences = tf.compat.v1.placeholder(tf.string)
embed = hub.Module(module)
embeddings = embed(sentences)
session = tf.train.MonitoredSession()
return lambda x: session.run(embeddings, {sentences: x})
def embed_use_multilingual(module):
# Graph set up.
g = tf.Graph()
with g.as_default():
text_input = tf.placeholder(dtype=tf.string, shape=[None])
embed = hub.Module(module)
embedded_text = embed(text_input)
init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
g.finalize()
# Initialize session.
session = tf.Session(graph=g)
session.run(init_op)
return lambda x: session.run(embedded_text, feed_dict={text_input: x})
class SimServer:
# models
USE_MULTILINGUAL = 0
USE_WITH_DAN = 1
UNIV_SENT_ENCODER = 2
USE_QA = 3
def __init__(self, model=0, win_size=None):
self.model = model # which model to use
self.win_size = win_size
self._context = ""
self.context_emb = None
if self.model == self.UNIV_SENT_ENCODER: # UNIVERSAL SENTENCE ENCODER
self.use_embed_fn = embed_use("https://tfhub.dev/google/universal-sentence-encoder-large/3")
elif self.model == self.USE_WITH_DAN:
self.use_embed_fn = embed_use("https://tfhub.dev/google/universal-sentence-encoder/2")
elif self.model == self.USE_MULTILINGUAL:
self.use_embed_fn = embed_use_multilingual("https://tfhub.dev/google/"
"universal-sentence-encoder-multilingual-large/1")
elif self.model == self.USE_QA:
self.use_embed_fn, self.embed_answers_batch = self.embed_use_qa()
print("SimilarityServer set up.")
@staticmethod
def embed_use_qa(tfhub_path=USE_QA_URL):
def question_embed(question_text):
return session.run(question_embedding, {question: question_text})['outputs']
def answers_embed(answers, contexts):
return session.run(response_embedding, {
response: answers,
response_context: contexts
})['outputs']
g = tf.Graph()
with g.as_default():
module = hub.Module(tfhub_path)
question = tf.compat.v1.placeholder(dtype=tf.string, shape=[None])
response = tf.compat.v1.placeholder(dtype=tf.string, shape=[None])
response_context = tf.compat.v1.placeholder(dtype=tf.string, shape=[None])
question_embedding = module(question, signature="question_encoder", as_dict=True)
response_embedding = module(
inputs={
"input": response,
"context": response_context
},
signature="response_encoder", as_dict=True)
init_op = tf.group([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
g.finalize()
# Initialize session.
session = tf.compat.v1.Session(graph=g)
session.run(init_op)
return question_embed, answers_embed
def set_context(self, context_sentence: str):
self._context = context_sentence
self.context_emb = self.use_embed_fn([self._context])[0]
def _dot_prod_sim(self, b):
# b is a sentence embedding could be a batch
sim = np.inner(self.context_emb / np.linalg.norm(self.context_emb),
b / np.linalg.norm(b))
return sim
def _euclidean_sim(self, b):
# b is a sentence embedding could be a batch
dist = np.linalg.norm(self.context_emb, b)
return 1 / (1 + dist)
def embed(self, batch):
return self.use_embed_fn(batch)
def similarity(self, b):
# b is a string could be a batch...
score = self._dot_prod_sim(self.use_embed_fn([b])[0])
return score