-
Notifications
You must be signed in to change notification settings - Fork 136
/
Copy pathtest.py
98 lines (84 loc) · 3.32 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import numpy as np
from tqdm import tqdm
from modeling.build_model import Pose2Seg
from datasets.CocoDatasetInfo import CocoDatasetInfo, annToMask
from pycocotools import mask as maskUtils
def test(model, dataset='cocoVal', logger=print):
if dataset == 'OCHumanVal':
ImageRoot = './data/OCHuman/images'
AnnoFile = './data/OCHuman/annotations/ochuman_coco_format_val_range_0.00_1.00.json'
elif dataset == 'OCHumanTest':
ImageRoot = './data/OCHuman/images'
AnnoFile = './data/OCHuman/annotations/ochuman_coco_format_test_range_0.00_1.00.json'
elif dataset == 'cocoVal':
ImageRoot = './data/coco2017/val2017'
AnnoFile = './data/coco2017/annotations/person_keypoints_val2017_pose2seg.json'
datainfos = CocoDatasetInfo(ImageRoot, AnnoFile, onlyperson=True, loadimg=True)
model.eval()
results_segm = []
imgIds = []
for i in tqdm(range(len(datainfos))):
rawdata = datainfos[i]
img = rawdata['data']
image_id = rawdata['id']
height, width = img.shape[0:2]
gt_kpts = np.float32(rawdata['gt_keypoints']).transpose(0, 2, 1) # (N, 17, 3)
gt_segms = rawdata['segms']
gt_masks = np.array([annToMask(segm, height, width) for segm in gt_segms])
output = model([img], [gt_kpts], [gt_masks])
for mask in output[0]:
maskencode = maskUtils.encode(np.asfortranarray(mask))
maskencode['counts'] = maskencode['counts'].decode('ascii')
results_segm.append({
"image_id": image_id,
"category_id": 1,
"score": 1.0,
"segmentation": maskencode
})
imgIds.append(image_id)
def do_eval_coco(image_ids, coco, results, flag):
from pycocotools.cocoeval import COCOeval
assert flag in ['bbox', 'segm', 'keypoints']
# Evaluate
coco_results = coco.loadRes(results)
cocoEval = COCOeval(coco, coco_results, flag)
cocoEval.params.imgIds = image_ids
cocoEval.params.catIds = [1]
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
return cocoEval
cocoEval = do_eval_coco(imgIds, datainfos.COCO, results_segm, 'segm')
logger('[POSE2SEG] AP|.5|.75| S| M| L| AR|.5|.75| S| M| L|')
_str = '[segm_score] %s '%dataset
for value in cocoEval.stats.tolist():
_str += '%.3f '%value
logger(_str)
if __name__=='__main__':
parser = argparse.ArgumentParser(description="Pose2Seg Testing")
parser.add_argument(
"--weights",
help="path to .pkl model weight",
type=str,
)
parser.add_argument(
"--coco",
help="Do test on COCOPersons val set",
action="store_true",
)
parser.add_argument(
"--OCHuman",
help="Do test on OCHuman val&test set",
action="store_true",
)
args = parser.parse_args()
print('===========> loading model <===========')
model = Pose2Seg().cuda()
model.init(args.weights)
print('===========> testing <===========')
if args.coco:
test(model, dataset='cocoVal')
if args.OCHuman:
test(model, dataset='OCHumanVal')
test(model, dataset='OCHumanTest')