-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_bbox_feat.py
117 lines (98 loc) · 4.27 KB
/
extract_bbox_feat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import math
import os
import sys
from datetime import datetime
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing
import torch.multiprocessing as mp
from absl import flags
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from config import get_default_config
from models.siamese_baseline import TwoBranchModel
from utils import TqdmToLogger, get_logger,AverageMeter,accuracy,ProgressMeter
from datasets import CityFlowNLDataset
from datasets import CityFlowNLInferenceDataset
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import time
import torch.nn.functional as F
from transformers import BertTokenizer,RobertaTokenizer
import pickle
from collections import OrderedDict
from utils import MgvSaveHelper
from PIL import Image
ossSaver = MgvSaveHelper()
def main():
config_path = 'configs/two_branch_cam_loc_dir.yaml'
with open('data/test_track_bboxes.json', 'r') as fb:
test_bboxes = json.load(fb)
with open('data/test_tracks.json', 'r') as fb:
test_tracks = json.load(fb)
with open('data/test_query_cars.json', 'r') as fb:
test_query_cars = json.load(fb)
cfg = get_default_config()
cfg.merge_from_file(config_path)
ossSaver.set_stauts(save_oss=True, oss_path=cfg.DATA.OSS_PATH)
transform_test = torchvision.transforms.Compose([
torchvision.transforms.Resize((cfg.DATA.SIZE, cfg.DATA.SIZE)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
model = TwoBranchModel(cfg.MODEL)
checkpoint = ossSaver.load_pth(ossSaver.get_s3_path(cfg.TEST.RESTORE_FROM))
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model.cuda()
torch.backends.cudnn.benchmark = True
test_data = CityFlowNLInferenceDataset(cfg.DATA, transform=transform_test)
testloader = DataLoader(dataset=test_data, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=8)
if cfg.MODEL.BERT_TYPE == "BERT":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
elif cfg.MODEL.BERT_TYPE == "ROBERTA":
tokenizer = RobertaTokenizer.from_pretrained(cfg.MODEL.BERT_NAME)
model.eval()
query_car_embeds = dict()
with torch.no_grad():
for query_id, texts in tqdm(test_query_cars.items()):
query_car_embeds[query_id] = []
car_tokens = tokenizer.batch_encode_plus(texts, padding='longest', return_tensors='pt')
car_embeds = model.get_car_lang_embed(car_tokens['input_ids'].cuda(), car_tokens['attention_mask'].cuda())
query_car_embeds[query_id] = car_embeds.data.cpu().numpy()
track_car_embeds = dict()
with torch.no_grad():
for track_id in tqdm(test_bboxes.keys()):
track = test_tracks[track_id]
bboxes = test_bboxes[track_id]
crops = []
for frame in bboxes.keys():
list_of_boxes = bboxes[frame]
frame_path = os.path.join('/data/datasets/aicity2022/track2', frame)
image = Image.open(frame_path)
if len(list_of_boxes) == 1:
for box in list_of_boxes:
crop = image.crop((box[0], box[1], box[0] + box[2], box[1] + box[3]))
crop = transform_test(crop)
crops.append(crop)
else:
for box in list_of_boxes[1:]:
crop = image.crop((box[0], box[1], box[0]+box[2], box[1]+box[3]))
crop = transform_test(crop)
crops.append(crop)
crops = torch.stack(crops).cuda()
vis_embeds = model.get_car_vis_embed(crops)
track_car_embeds[track_id] = vis_embeds.data.cpu().numpy()
with open('data/query_lang_embeds.pkl', 'wb') as fb:
pickle.dump(query_car_embeds, fb)
with open('data/track_car_embeds.pkl', 'wb') as fb:
pickle.dump(track_car_embeds, fb)
if __name__ == '__main__':
main()