-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add wikilinks NN method for generating embeddings #47
Changes from 4 commits
ca666ae
96ca044
6094e23
a206e36
2f25b12
21e914a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,15 @@ | |
|
||
from wikirec import utils | ||
|
||
import os | ||
import json | ||
import random | ||
from itertools import chain | ||
from collections import Counter, OrderedDict | ||
from keras.models import load_model | ||
from keras.layers import Input, Embedding, Dot, Reshape, Dense | ||
from keras.models import Model | ||
|
||
|
||
def gen_embeddings( | ||
method="lda", | ||
|
@@ -64,6 +73,10 @@ def gen_embeddings( | |
- Word importance increases proportionally to the number of times a word appears in the document while being offset by the number of documents in the corpus that contain the word | ||
|
||
- These importances are then vectorized and used to relate documents | ||
|
||
Wikilinks | ||
|
||
- Generate an embedding using a neural network trained on the connections between articles and the internal wikilinks | ||
|
||
corpus : list of lists (default=None) | ||
The text corpus over which analysis should be done | ||
|
@@ -81,7 +94,7 @@ def gen_embeddings( | |
""" | ||
method = method.lower() | ||
|
||
valid_methods = ["bert", "doc2vec", "lda", "tfidf"] | ||
valid_methods = ["bert", "doc2vec", "lda", "tfidf", "wikilinks"] | ||
|
||
if method not in valid_methods: | ||
raise ValueError( | ||
|
@@ -141,6 +154,17 @@ def gen_embeddings( | |
embeddings = tfidfvectoriser.transform(corpus) | ||
|
||
return embeddings | ||
|
||
elif method == "wikilinks": | ||
if os.path.isfile("./wikilinks_embedding_model.h5"): | ||
model = load_model("./wikilinks_embedding_model.h5") | ||
layer = model.get_layer('book_embedding') | ||
weights = layer.get_weights()[0] | ||
embeddings = weights / np.linalg.norm(weights, axis = 1).reshape((-1, 1)) | ||
return embeddings | ||
else: | ||
embeddings = _wikilinks_nn('./enwiki_books.ndjson', 50) | ||
return embeddings | ||
|
||
|
||
def gen_sim_matrix( | ||
|
@@ -327,3 +351,134 @@ def recommend( | |
recommendations = [r for r in recommendations if r[0] not in inputs][:n] | ||
|
||
return recommendations | ||
|
||
def _wikilinks_nn(path_to_json = None, embedding_size = 50): | ||
""" | ||
Generates embeddings of wikilinks and articles by training a neural network. Currently only trained on books. | ||
|
||
Parameters | ||
---------- | ||
path_to_json : str (default=None) | ||
The path to the parsed json file. | ||
|
||
embedding_size : int (default = 50) | ||
The length of the embedding vectors between the articles and the links. | ||
|
||
Returns | ||
------- | ||
book_weights : np.array | ||
The normalized embedding vectors for each of the articles. | ||
|
||
Shape of book_weights is (len(books), embedding_size) | ||
|
||
""" | ||
if os.path.isfile(path_to_json): | ||
with open(path_to_json, "r") as fin: | ||
books = [json.loads(l) for l in fin] | ||
else: | ||
raise Exception("Need to parse json for books.") | ||
|
||
# Find set of wikilinks for each book and convert to a flattened list | ||
unique_wikilinks = list(chain(*[list(set(book[2])) for book in books])) | ||
wikilinks = [link.lower() for link in unique_wikilinks] | ||
to_remove = ['hardcover', 'paperback', 'hardback', 'e-book', 'wikipedia:wikiproject books', 'wikipedia:wikiproject novels'] | ||
wikilinks = [item for item in wikilinks if item not in to_remove] | ||
|
||
# Limit to wikilinks that occur more than 4 times | ||
wikilinks_counts = Counter(wikilinks) | ||
wikilinks_counts = sorted(wikilinks_counts.items(), key = lambda x: x[1], reverse = True) | ||
wikilinks_counts = OrderedDict(wikilinks_counts) | ||
links = [t[0] for t in wikilinks_counts.items() if t[1] >= 4] | ||
|
||
# map books to their indices, and map links to indices as well | ||
book_index = {book[0]: idx for idx, book in enumerate(books)} | ||
link_index = {link: idx for idx, link in enumerate(links)} | ||
|
||
#Create data from pairs of (book, wikilink) for training the neural network embedding | ||
pairs = [] | ||
for book in books: | ||
title = book[0] | ||
book_links = book[2] | ||
# Iterate through wikilinks in book article | ||
for link in book_links: | ||
# Add index of book and index of link to pairs | ||
if link.lower() in links: | ||
pairs.append((book_index[title], link_index[link.lower()])) | ||
pairs_set = set(pairs) | ||
|
||
# Neural network architecture | ||
# Both inputs are 1-dimensional | ||
book_input = Input(name = 'book', shape = [1]) | ||
link_input = Input(name = 'link', shape = [1]) | ||
|
||
# Embedding the book (shape will be (None, 1, 50)) | ||
book_embedding = Embedding(name = 'book_embedding', | ||
input_dim = len(book_index), | ||
output_dim = embedding_size)(book_input) | ||
|
||
# Embedding the link (shape will be (None, 1, 50)) | ||
link_embedding = Embedding(name = 'link_embedding', | ||
input_dim = len(link_index), | ||
output_dim = embedding_size)(link_input) | ||
|
||
# Merge the layers with a dot product along the second axis | ||
# (shape will be (None, 1, 1)) | ||
merged = Dot(name = 'dot_product', normalize = True, | ||
axes = 2)([book_embedding, link_embedding]) | ||
|
||
# Reshape to be a single number (shape will be (None, 1)) | ||
merged = Reshape(target_shape = [1])(merged) | ||
|
||
model = Model(inputs = [book, link], outputs = merged) | ||
model.compile(optimizer = 'Adam', loss = 'mse') | ||
|
||
# Function that creates a generator for training data | ||
def _generate_batch(pairs, n_positive = 50, negative_ratio = 1.0): | ||
"""Generate batches of samples for training. | ||
Random select positive samples | ||
from pairs and randomly select negatives.""" | ||
|
||
# Create empty array to hold batch | ||
batch_size = n_positive * (1 + negative_ratio) | ||
batch = np.zeros((batch_size, 3)) | ||
|
||
# Continue to yield samples | ||
while True: | ||
# Randomly choose positive examples | ||
for idx, (book_id, link_id) in enumerate(random.sample(pairs, n_positive)): | ||
batch[idx, :] = (book_id, link_id, 1) | ||
idx += 1 | ||
|
||
# Add negative examples until reach batch size | ||
while idx < batch_size: | ||
|
||
# Random selection | ||
random_book = random.randrange(len(book_index)) | ||
random_link = random.randrange(len(link_index)) | ||
|
||
# Check to make sure this is not a positive example | ||
if (random_book, random_link) not in pairs_set: | ||
|
||
# Add to batch and increment index | ||
batch[idx, :] = (random_book, random_link, 0) | ||
idx += 1 | ||
|
||
# Make sure to shuffle order | ||
np.random.shuffle(batch) | ||
yield {'book': batch[:, 0], 'link': batch[:, 1]}, batch[:, 2] | ||
|
||
n_positive = 1024 | ||
gen = _generate_batch(pairs, n_positive, negative_ratio = 2) | ||
h = model.fit_generator(gen, epochs = 15, steps_per_epoch = len(pairs) // n_positive) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @victle Just checking here, h is not used as it's just a placeholder for the generator being fit? I'll add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @andrewtavis Yep! h is just a placeholder. |
||
|
||
# Save the model and extract embeddings | ||
model.save('./wikilinks_embedding_model.h5') | ||
|
||
# Extract embeddings | ||
book_layer = model.get_layer('book_embedding') | ||
book_weights = book_layer.get_weights()[0] | ||
|
||
# Normalize the weights to have norm of 1 | ||
book_weights = book_weights / np.linalg.norm(book_weights, axis = 1).reshape((-1, 1)) | ||
|
||
return book_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@victle, Figured out that this here is supposed to be
book_input
andlink_input
based on comparing it all to the original blogpost and your changes :)Am making good progress on this and will try to have it done by tonight/early tomorrow. Thanks again for all this!