forked from lx-cly/YOLOv5_OBB_KLD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
detect.py
350 lines (309 loc) · 18.5 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import argparse
import os
import platform
import shutil
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages,LoadImagesEn
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_labels,
xyxy2xywh, plot_one_rotated_box, strip_optimizer, set_logging, rotate_non_max_suppression)
from utils.torch_utils import select_device, load_classifier, time_synchronized
from utils.evaluation_utils import rbox2txt
def detect(save_img=False,kld_flag=False):
'''
input: save_img_flag
output(result):
'''
# 获取输出文件夹,输入路径,权重,参数等参数
out, source, weights, view_img, save_txt, imgsz = \
opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt')
# Initialize
set_logging()
# 获取设备
device = select_device(opt.device)
# 移除之前的输出文件夹,并新建输出文件夹
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder
# 如果设备为gpu,使用Float16
half = device.type != 'cpu' # half precision only supported on CUDA
# Load model
# 加载Float32模型,确保用户设定的输入图片分辨率能整除最大步长s=32(如不能则调整为能整除并返回)
'''
model = Model(
(model): Sequential(
(0): Focus(...)
(1): Conv(...)
...
(24): Detect(...)
)
'''
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
# 设置Float16
if half:
model.half() # to FP16
# Second-stage classifier
classify = False
if classify:
modelc = load_classifier(name='resnet101', n=2) # initialize
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
modelc.to(device).eval()
# Set Dataloader
# 通过不同的输入源来设置不同的数据加载方式
vid_path, vid_writer,ManyPi = None, None,False
if webcam:
view_img = True
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz)
else:
save_img = True
if ManyPi:#多batch_size
dataset = LoadImagesEn(source, img_size=imgsz)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=16,
shuffle=False,
num_workers=8)
else:
dataset = LoadImages(source, img_size=imgsz)
# Get names and colors
# 获取类别名字 names = ['person', 'bicycle', 'car',...,'toothbrush']
names = model.module.names if hasattr(model, 'module') else model.names
# 设置画框的颜色 colors = [[178, 63, 143], [25, 184, 176], [238, 152, 129],....,[235, 137, 120]]随机设置RGB颜色
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
# Run inference
t0 = time.time()
# 进行一次前向推理,测试程序是否正常 向量维度(1,3,imgsz,imgsz)
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
"""
path 图片/视频路径 'E:\...\bus.jpg'
img 进行resize+pad之后的图片 1*3*re_size1*resize2的张量 (3,img_height,img_weight)
img0 原size图片 (img_height,img_weight,3)
cap 当读取图片时为None,读取视频时为视频源
"""
if ManyPi:
vid_cap = None
for ii, (path, img, im0s) in enumerate(data_loader):
print(img.shape[1:])
img = img.to(device) # torch.from_numpy(img).to(device)
# 图片也设置为Float16
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
# 没有batch_size的话则在最前面添加一个轴
if img.ndimension() == 3:
# (in_channels,size1,size2) to (1,in_channels,img_height,img_weight)
img = img.unsqueeze(0) # 在[0]维增加一个维度
# Inference
t1 = time_synchronized()
"""
model:
input: in_tensor (batch_size, 3, img_height, img_weight)
output: 推理时返回 [z,x]
z tensor: [small+medium+large_inference] size=(batch_size, 3 * (small_size1*small_size2 + medium_size1*medium_size2 + large_size1*large_size2), nc)
x list: [small_forward, medium_forward, large_forward] eg:small_forward.size=( batch_size, 3种scale框, size1, size2, [xywh,score,num_classes])
'''
前向传播 返回pred[0]的shape是(1, num_boxes, nc)
h,w为传入网络图片的长和宽,注意dataset在检测时使用了矩形推理,所以这里h不一定等于w
num_boxes = 3 * h/32 * w/32 + 3 * h/16 * w/16 + 3 * h/8 * w/8
pred[0][..., 0:4] 预测框坐标为xywh(中心点+宽长)格式
pred[0][..., 4]为objectness置信度
pred[0][..., 5:5+nc]为分类结果
pred[0][..., 5+nc:]为Θ分类结果
"""
# pred : (batch_size, num_boxes, no) batch_size=1
pred = model(img, augment=opt.augment)[0]
# Apply NMS
# 进行NMS
# pred : list[tensor(batch_size, num_conf_nms, [xylsθ,conf,classid])] θ∈[0,179]
# pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = rotate_non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes,
agnostic=opt.agnostic_nms, without_iouthres=False)
t2 = time_synchronized()
# Apply Classifier
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# Process detections
for i, det in enumerate(pred): # i:image index det:(num_nms_boxes, [xylsθ,conf,classid]) θ∈[0,179]
if webcam: # batch_size >= 1
p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
else:
# p, s, im0 = path, '', im0s
# img = img.numpy()
p, s, im0 = path[i], '%g: ' % i, im0s[i] # .copy()
save_path = str(Path(out) / Path(p).name) # 图片保存路径+图片名字
txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
# print(txt_path)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :5] = scale_labels(img.shape[2:], det[:, :5], im0.shape).round()
# Print results det:(num_nms_boxes, [xylsθ,conf,classid]) θ∈[0,179]
for c in det[:, -1].unique(): # unique函数去除其中重复的元素,并按元素(类别)由大到小返回一个新的无元素重复的元组或者列表
n = (det[:, -1] == c).sum() # detections per class 每个类别检测出来的素含量
s += '%g %ss, ' % (n, names[int(c)]) # add to string 输出‘数量 类别,’
# Write results det:(num_nms_boxes, [xywhθ,conf,classid]) θ∈[0,179]
for *rbox, conf, cls in reversed(det): # 翻转list的排列结果,改为类别由小到大的排列
# rbox=[tensor(x),tensor(y),tensor(w),tensor(h),tsneor(θ)] θ∈[0,179]
# if save_txt: # Write to file
# xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
# with open(txt_path + '.txt', 'a') as f:
# f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
classname = '%s' % names[int(cls)]
conf_str = '%.3f' % conf
rbox2txt(rbox, classname, conf_str, Path(p).stem,
str(out + '/result_txt/result_before_merge'))
# plot_one_box(rbox, im0, label=label, color=colors[int(cls)], line_thickness=2)
plot_one_rotated_box(rbox, im0.cpu().numpy(), label=label, color=colors[int(cls)],
line_thickness=1,
pi_format=False)
else:
for path, img, im0s, vid_cap in dataset:
print(img.shape)
img = torch.from_numpy(img).to(device)
# 图片也设置为Float16
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
# 没有batch_size的话则在最前面添加一个轴
if img.ndimension() == 3:
# (in_channels,size1,size2) to (1,in_channels,img_height,img_weight)
img = img.unsqueeze(0) # 在[0]维增加一个维度
# Inference
t1 = time_synchronized()
"""
model:
input: in_tensor (batch_size, 3, img_height, img_weight)
output: 推理时返回 [z,x]
z tensor: [small+medium+large_inference] size=(batch_size, 3 * (small_size1*small_size2 + medium_size1*medium_size2 + large_size1*large_size2), nc)
x list: [small_forward, medium_forward, large_forward] eg:small_forward.size=( batch_size, 3种scale框, size1, size2, [xywh,score,num_classes])
'''
前向传播 返回pred[0]的shape是(1, num_boxes, nc)
h,w为传入网络图片的长和宽,注意dataset在检测时使用了矩形推理,所以这里h不一定等于w
num_boxes = 3 * h/32 * w/32 + 3 * h/16 * w/16 + 3 * h/8 * w/8
pred[0][..., 0:4] 预测框坐标为xywh(中心点+宽长)格式
pred[0][..., 4]为objectness置信度
pred[0][..., 5:5+nc]为分类结果
pred[0][..., 5+nc:]为Θ分类结果
"""
# pred : (batch_size, num_boxes, no) batch_size=1
pred = model(img, augment=opt.augment)[0]
# Apply NMS
# 进行NMS
# pred : list[tensor(batch_size, num_conf_nms, [xylsθ,conf,classid])] θ∈[0,179]
#pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = rotate_non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms, without_iouthres=False,kld_flag=kld_flag)
t2 = time_synchronized()
# Apply Classifier
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# Process detections
for i, det in enumerate(pred): # i:image index det:(num_nms_boxes, [xylsθ,conf,classid]) θ∈[0,179]
if webcam: # batch_size >= 1
p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
else:
p, s, im0 = path, '', im0s
save_path = str(Path(out) / Path(p).name) # 图片保存路径+图片名字
txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
#print(txt_path)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :5] = scale_labels(img.shape[2:], det[:, :5], im0.shape).round()
# Print results det:(num_nms_boxes, [xylsθ,conf,classid]) θ∈[0,179]
for c in det[:, -1].unique(): # unique函数去除其中重复的元素,并按元素(类别)由大到小返回一个新的无元素重复的元组或者列表
n = (det[:, -1] == c).sum() # detections per class 每个类别检测出来的素含量
s += '%g %ss, ' % (n, names[int(c)]) # add to string 输出‘数量 类别,’
# Write results det:(num_nms_boxes, [xywhθ,conf,classid]) θ∈[0,179]
for *rbox, conf, cls in reversed(det): # 翻转list的排列结果,改为类别由小到大的排列
# rbox=[tensor(x),tensor(y),tensor(w),tensor(h),tsneor(θ)] θ∈[0,179]
# if save_txt: # Write to file
# xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
# with open(txt_path + '.txt', 'a') as f:
# f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
classname = '%s' % names[int(cls)]
conf_str = '%.3f' % conf
rbox2txt(rbox, classname, conf_str, Path(p).stem, str(out + '/result_txt/result_before_merge'))
#plot_one_box(rbox, im0, label=label, color=colors[int(cls)], line_thickness=2)
#不保存图片
plot_one_rotated_box(rbox, im0, label=label, color=colors[int(cls)], line_thickness=1,
pi_format=False)
# Print time (inference + NMS)
print('%sDone. (%.3fs)' % (s, t2 - t1))
# Stream results 播放结果
if view_img:
cv2.imshow(p, im0)
if cv2.waitKey(1) == ord('q'): # q to quit
raise StopIteration
# Save results (image with detections)
if save_img:
if dataset.mode == 'images':
cv2.imwrite(save_path, im0)
pass
else:
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
fourcc = 'mp4v' # output video codec
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
vid_writer.write(im0)
if save_txt or save_img:
print(' Results saved to %s' % Path(out))
print(' All Done. (%.3fs)' % (time.time() - t0))
if __name__ == '__main__':
"""
weights:训练的权重
source:测试数据,可以是图片/视频路径,也可以是'0'(电脑自带摄像头),也可以是rtsp等视频流
output:网络预测之后的图片/视频的保存路径
img-size:网络输入图片大小
conf-thres:置信度阈值
iou-thres:做nms的iou阈值
device:设置设备
view-img:是否展示预测之后的图片/视频,默认False
save-txt:是否将预测的框坐标以txt文件形式保存,默认False
classes:设置只保留某一部分类别,形如0或者0 2 3
agnostic-nms:进行nms是否将所有类别框一视同仁,默认False
augment:推理的时候进行多尺度,翻转等操作(TTA)推理
update:如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
"""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='./runs/exp3/weights/last.pt', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='/data_all/data/DOTA/Ship_dota_v1.5_1024/val/images', help='source') # file/folder, 0 for webcam
parser.add_argument('--output', type=str, default='./detectionlx', help='output folder') # output folder
parser.add_argument('--img-size', type=int, default=1024, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.4, help='IOU threshold for NMS')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', default=False, help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--kld', type=bool, default=True, help='use kld')
opt = parser.parse_args()
print(opt)
with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect(kld_flag=opt.kld)
# 去除pt文件中的优化器等信息
strip_optimizer(opt.weights)
else:
detect(kld_flag=opt.kld)