-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_feature.py
101 lines (83 loc) · 3.63 KB
/
extract_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import argparse
import json
from easydict import EasyDict as edict
from tqdm import tqdm
import random
import numpy as np
import torch
from models.cnn3d import Encoder
from data.copd_patch import COPD_dataset
parser = argparse.ArgumentParser(description='Extract 3D Images Representations')
parser.add_argument('--exp-name', default='./ssl_exp/exp_neighbor_0_128')
parser.add_argument('--checkpoint-patch', default='checkpoint_patch_0001.pth.tar')
parser.add_argument('--batch-size', type=int, default=1)
def main():
# read configurations
p = parser.parse_args()
patch_epoch = p.checkpoint_patch.split('.')[0][-4:]
with open(os.path.join(p.exp_name, 'configs.json')) as f:
args = edict(json.load(f))
args.checkpoint = os.path.join(p.exp_name, p.checkpoint_patch)
args.batch_size = p.batch_size
args.patch_rep_dir = os.path.join(p.exp_name, 'patch_rep', patch_epoch)
os.makedirs(args.patch_rep_dir, exist_ok=True)
# Set random seed
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
main_worker(args)
def main_worker(args):
#args.gpu = 0
#torch.cuda.set_device(args.gpu)
# create patch-level encoder
model_patch = Encoder(rep_dim=args.rep_dim_patch, moco_dim=args.moco_dim_patch, num_experts=args.num_experts, num_coordinates=args.num_coordinates)
# remove the last FC layer
model_patch.fc = torch.nn.Sequential()
state_dict = torch.load(args.checkpoint)['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
model_patch.load_state_dict(state_dict)
print(model_patch)
print("Patch model weights loaded.")
#model_patch.cuda()
model_patch = torch.nn.DataParallel(model_patch).cuda()
model_patch.eval()
# dataset
test_dataset_patch = COPD_dataset('testing', args)
test_loader = torch.utils.data.DataLoader(
test_dataset_patch, batch_size=1, shuffle=False,
num_workers=10, pin_memory=True, drop_last=False)
args.label_name = args.label_name + args.label_name_set2
# train dataset
sid_lst = []
pred_arr = np.empty((len(test_dataset_patch), args.num_patch, args.rep_dim_patch))
feature_arr = np.empty(
(len(test_dataset_patch), len(args.label_name) + len(args.visual_score) + len(args.P2_Pheno)))
iterator = tqdm(test_loader,
desc="Propagating (X / X Steps)",
bar_format="{r_bar}",
dynamic_ncols=True,
disable=False)
with torch.no_grad():
for i, batch in enumerate(iterator):
sid, images, patch_loc_idx, labels = batch
sid_lst.append(sid[0])
images = images[0].float().cuda()
patch_loc_idx = patch_loc_idx[0].float().cuda()
_, pred = model_patch(images, patch_loc_idx)
pred_arr[i, :, :] = pred.cpu().numpy()
feature_arr[i:i+1, :] = labels
iterator.set_description("Propagating (%d / %d Steps)" % (i, len(test_dataset_patch)))
np.save(os.path.join(args.patch_rep_dir, "sid_arr_full.npy"), sid_lst)
np.save(os.path.join(args.patch_rep_dir, "pred_arr_patch_full.npy"), pred_arr)
np.save(os.path.join(args.patch_rep_dir, "feature_arr_patch_full.npy"), feature_arr)
print("\nExtraction patch representation on full set finished.")
if __name__ == '__main__':
main()