Skip to content

Commit

Permalink
can actually train now
Browse files Browse the repository at this point in the history
  • Loading branch information
cooliotonyio committed May 22, 2019
1 parent e0223d2 commit 13633ee
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 30 deletions.
23 changes: 13 additions & 10 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class NUS_WIDE(BaseCMRetrievalDataset):
"""

def __init__(self, root, transform, train=True, feature_mode='resnet152', word_embeddings=None):
def __init__(self, root, transform, train=True, feature_mode='vanilla', word_embeddings=None):
primary_tags = pickle.load(open("pickles/nuswide_metadata/tag_matrix.p", "rb"))

super(NUS_WIDE, self).__init__(root, transform, primary_tags=None)
Expand All @@ -183,6 +183,7 @@ def __init__(self, root, transform, train=True, feature_mode='resnet152', word_e
elif feature_mode == 'resnet18':
self.features = pickle.load(open("pickles/nuswide_features/resnet18_nuswide_feats_dict.p", "rb"))
else:
print("WARNING: NUS_WIDE dataset feature_mode is None")
self.features, self.feature_mode = None, 'vanilla'

self.word_embeddings = word_embeddings
Expand Down Expand Up @@ -217,9 +218,9 @@ def _make_secondary_tags(self):

def _make_image_paths(self, dir, train=True):
if train:
file_paths_fname = "./nuswide_metadata/Imagelist/TrainImagelist.txt"
file_paths_fname = "data/nuswide_metadata/ImageList/TrainImagelist.txt"
else:
file_paths_fname = "./nuswide_metadata/Imagelist/TestImagelist.txt"
file_paths_fname = "data/nuswide_metadata/ImageList/TestImagelist.txt"

image_paths = []

Expand All @@ -232,13 +233,13 @@ def _make_image_paths(self, dir, train=True):
return image_paths

def _make_idx_to_concept():
fname = "nuswide_metadata/Concepts81.txt"
fname = "data/nuswide_metadata/Concepts81.txt"
idx_to_concept = idx_maker(fname)
return idx_to_concept


def _make_idx_to_tag():
fname = "./nuswide_metadata/TagList1k.txt"
fname = "data/nuswide_metadata/TagList1k.txt"
idx_to_tag = idx_maker(fname)
return idx_to_tag

Expand All @@ -258,7 +259,7 @@ def _make_tag_matrices(self):


def make_concept_relevancy_matrix(train=True):
path = './nuswide_metadata/TrainTestLabels/'
path = 'data/nuswide_metadata/TrainTestLabels/'
if train:
suffix_indicator = "Train.txt"
n = 161789
Expand All @@ -278,18 +279,18 @@ def make_concept_relevancy_matrix(train=True):
for idx, filename in enumerate(filenames):
with open(path + filename) as f:
content = f.readlines()
curr_column = np.array([int(i[0]) for i in content], dtype=int)
curr_column = np.array([int(i) for i in content], dtype=int)
relevancy_matrix[:, idx] = curr_column

return relevancy_matrix


def make_tag_relevancy_matrix(train=True):
if train:
path = './nuswide_metadata/Train_Tags1k.dat'
path = 'data/nuswide_metadata/Train_Tags1k.dat'
n = 161789
else:
path = './nuswide_metadata/Test_Tags1k.dat'
path = 'data/nuswide_metadata/Test_Tags1k.dat'
n = 107859

relevancy_matrix = np.zeros((n,1000), dtype=int)
Expand All @@ -315,7 +316,9 @@ def __getitem__(self, index):
target = self._folder_targets[index]

if self.feature_mode is not 'vanilla':
return index, self.features[self.image_paths[index]], self._folder_targets[index]
sample = self.features[self.image_paths[index]]
target = self._folder_targets[index]
return index, sample, target

if self.transform is not None:
return index, self.transform(sample), self._folder_targets[index]
Expand Down
10 changes: 5 additions & 5 deletions feature_extractors/resnet152_nuswide_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from torch.autograd import Variable
from PIL import Image

sys.path.append("{}/..".format(sys.path[0]))
from datasets import NUS_WIDE

base = "./"
base = "pickles/nuswide_features/"
if not os.path.isdir(base):
os.mkdir(os.fsencode(base))
raise RuntimeError("Base folder '{}' not found".format(base))

scaler = transforms.Resize((224,224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Expand All @@ -38,14 +39,13 @@ def copy_data(m,i,o):

return embedding

dataset = NUS_WIDE('./data/Flickr', None)
dataset = NUS_WIDE('data/Flickr', None)

feature_dict = {}
feature_array = [None] * len(dataset)

for i in range(len(dataset)):
print("file: ", i)
file_path = dataset.imgs.samples[i][0]
file_path = dataset.image_paths[i]
feature_i = get_image_feature(file_path)
feature_dict[file_path] = feature_i.cpu().squeeze()
feature_array[i] = feature_i.cpu().squeeze()
Expand Down
11 changes: 6 additions & 5 deletions feature_extractors/resnet18_nuswide_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from torch.autograd import Variable
from PIL import Image

sys.path.append("{}/..".format(sys.path[0]))
from datasets import NUS_WIDE

base = "../pickles/nuswide_features/"
base = "pickles/nuswide_features/"
if not os.path.isdir(base):
os.mkdir(os.fsencode(base))
raise RuntimeError("Base folder '{}' not found".format(base))


scaler = transforms.Resize((224,224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Expand All @@ -54,14 +56,13 @@ def copy_data(m,i,o):

return embedding

dataset = NUS_WIDE('./data/Flickr', None)
dataset = NUS_WIDE('data/Flickr', None)

feature_dict = {}
feature_array = [None] * len(dataset)

for i in range(len(dataset)):
print("file: ", i)
file_path = dataset.imgs.samples[i][0]
file_path = dataset.image_paths[i]
feature_i = get_image_feature(file_path)
feature_dict[file_path] = feature_i.cpu().squeeze()
feature_array[i] = feature_i.cpu().squeeze()
Expand Down
4 changes: 3 additions & 1 deletion feature_extractors/spatial_feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from torch.autograd import Variable
from PIL import Image

base = "/pickles/nuswide_features/"

scaler = transforms.Resize((224,224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
Expand Down Expand Up @@ -73,4 +75,4 @@ def get_image_feature(im_path)
avg_aggregation = torch.mean(all_frames, 0)
frame_level_features[f] = avg_aggregation

pickle.dump(frame_level_features, open('frame_level_features.p', 'wb'))
pickle.dump(frame_level_features, open(base + 'frame_level_features.p', 'wb'))
38 changes: 31 additions & 7 deletions multimodal_search_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time Elapsed: 5.3674 seconds\n"
"ename": "AttributeError",
"evalue": "Can't get attribute 'InterTripletNet' on <module 'networks' from '/home/ubuntu/Notebooks/crossmodal_retrieval/networks.py'>",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-5de11d06b07e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtext_net\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"pickles/models/entire_nuswide_model.p\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_text_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtext_net\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtext_embedding_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtext_net\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_embedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_text_embedding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: Can't get attribute 'InterTripletNet' on <module 'networks' from '/home/ubuntu/Notebooks/crossmodal_retrieval/networks.py'>"
]
}
],
Expand Down Expand Up @@ -109,7 +113,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Time Elapsed: 7.0157 seconds\n"
"Time Elapsed: 2.5801 seconds\n"
]
}
],
Expand Down Expand Up @@ -144,7 +148,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Time Elapsed: 65.3254 seconds\n"
"Time Elapsed: 70.4127 seconds\n"
]
}
],
Expand All @@ -165,7 +169,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Time Elapsed: 0.7416 seconds\n"
"Time Elapsed: 0.9495 seconds\n"
]
}
],
Expand All @@ -181,6 +185,26 @@
"time_elapsed(s)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"999994"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(text_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import pickle
import tarfile
import time
from zipfile import ZipFile
from util import fetch_and_cache

FORCE_DOWNLOAD = False

start_time = time.time()
print("Starting setup... This might take a while.")
print("Making directories...", end=" ")
if not os.path.isdir("./data_zipped"):
Expand Down Expand Up @@ -107,4 +108,9 @@ def load_vectors(fname):
image_data.extractall(path='./data')
print("Done extracting NUSWIDE!")

print("Finished setup!")
print("Extracting NUSWIDE ResNet features... (this will take a lot of time)")
os.system("python3 feature_extractors/resnet152_nuswide_processor.py")
os.system("python3 feature_extractors/resnet18_nuswide_processor.py")
print("Done extracting features!")

print("Finished setup in {} seconds!".format(time.time() - start_time))

0 comments on commit 13633ee

Please sign in to comment.