-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
22 lines (16 loc) · 784 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
import logging
import numpy as np
def get_chunks(l, batch_size):
return [l[offs:offs + batch_size] for offs in range(0, len(l), batch_size)]
def tf_argchoice_element(sequence, element):
random_uniform = tf.random.uniform(shape=tf.shape(sequence))
y = tf.ones(shape=tf.shape(sequence)) * -1
random_uniform = tf.where(tf.equal(sequence, element), x=random_uniform, y=y)
idx = tf.argmax(random_uniform, axis=-1, output_type=tf.int32)
return idx
def get_embeddings(s2v_file):
embeddings = np.load(s2v_file)
logging.info('loaded {} embeddings with dimension {} from npy file'.format(embeddings.shape[0],
embeddings.shape[1]))
return embeddings