11
11
#include " YOLOv5Detector.h"
12
12
13
13
#include " FeatureTensor.h"
14
- #include " tracker.h"
14
+ #include " BYTETracker.h" // bytetrack
15
+ #include " tracker.h" // deepsort
15
16
// Deep SORT parameter
16
17
17
18
const int nn_budget=100 ;
@@ -27,12 +28,94 @@ void get_detections(DETECTBOX box,float confidence,DETECTIONS& d)
27
28
}
28
29
29
30
31
+ void test_deepsort (cv::Mat& frame, std::vector<detect_result>& results,tracker& mytracker)
32
+ {
33
+ std::vector<detect_result> objects;
34
+
35
+ DETECTIONS detections;
36
+ for (detect_result dr : results)
37
+ {
38
+ // cv::putText(frame, classes[dr.classId], cv::Point(dr.box.tl().x+10, dr.box.tl().y - 10), cv::FONT_HERSHEY_SIMPLEX, .8, cv::Scalar(0, 255, 0));
39
+ if (dr.classId == 0 ) // person
40
+ {
41
+ objects.push_back (dr);
42
+ cv::rectangle (frame, dr.box , cv::Scalar (255 , 0 , 0 ), 2 );
43
+ get_detections (DETECTBOX (dr.box .x , dr.box .y ,dr.box .width , dr.box .height ),dr.confidence , detections);
44
+ }
45
+ }
30
46
47
+ std::cout<<" begin track" <<std::endl;
48
+ if (FeatureTensor::getInstance ()->getRectsFeature (frame, detections))
49
+ {
50
+ std::cout << " get feature succeed!" <<std::endl;
51
+ mytracker.predict ();
52
+ mytracker.update (detections);
53
+ std::vector<RESULT_DATA> result;
54
+ for (Track& track : mytracker.tracks ) {
55
+ if (!track.is_confirmed () || track.time_since_update > 1 ) continue ;
56
+ result.push_back (std::make_pair (track.track_id , track.to_tlwh ()));
57
+ }
58
+ for (unsigned int k = 0 ; k < detections.size (); k++)
59
+ {
60
+ DETECTBOX tmpbox = detections[k].tlwh ;
61
+ cv::Rect rect (tmpbox (0 ), tmpbox (1 ), tmpbox (2 ), tmpbox (3 ));
62
+ cv::rectangle (frame, rect, cv::Scalar (0 ,0 ,255 ), 4 );
63
+ // cvScalar的储存顺序是B-G-R,CV_RGB的储存顺序是R-G-B
31
64
65
+ for (unsigned int k = 0 ; k < result.size (); k++)
66
+ {
67
+ DETECTBOX tmp = result[k].second ;
68
+ cv::Rect rect = cv::Rect (tmp (0 ), tmp (1 ), tmp (2 ), tmp (3 ));
69
+ rectangle (frame, rect, cv::Scalar (255 , 255 , 0 ), 2 );
70
+
71
+ std::string label = cv::format (" %d" , result[k].first );
72
+ cv::putText (frame, label, cv::Point (rect.x , rect.y ), cv::FONT_HERSHEY_SIMPLEX, 0.8 , cv::Scalar (255 , 255 , 0 ), 2 );
73
+ }
74
+ }
75
+ }
76
+ std::cout<<" end track" <<std::endl;
77
+ }
78
+
79
+
80
+ void test_bytetrack (cv::Mat& frame, std::vector<detect_result>& results,BYTETracker& tracker)
81
+ {
82
+ std::vector<detect_result> objects;
83
+
84
+
85
+ for (detect_result dr : results)
86
+ {
87
+
88
+ if (dr.classId == 0 ) // person
89
+ {
90
+ objects.push_back (dr);
91
+ }
92
+ }
93
+
94
+
95
+ std::vector<STrack> output_stracks = tracker.update (objects);
96
+
97
+ for (unsigned long i = 0 ; i < output_stracks.size (); i++)
98
+ {
99
+ std::vector<float > tlwh = output_stracks[i].tlwh ;
100
+ bool vertical = tlwh[2 ] / tlwh[3 ] > 1.6 ;
101
+ if (tlwh[2 ] * tlwh[3 ] > 20 && !vertical)
102
+ {
103
+ cv::Scalar s = tracker.get_color (output_stracks[i].track_id );
104
+ cv::putText (frame, cv::format (" %d" , output_stracks[i].track_id ), cv::Point (tlwh[0 ], tlwh[1 ] - 5 ),
105
+ 0 , 0.6 , cv::Scalar (0 , 0 , 255 ), 2 , cv::LINE_AA);
106
+ cv::rectangle (frame, cv::Rect (tlwh[0 ], tlwh[1 ], tlwh[2 ], tlwh[3 ]), s, 2 );
107
+ }
108
+ }
109
+
110
+
111
+ }
32
112
int main (int argc, char *argv[])
33
113
{
34
- // deep SORT
114
+ // deepsort
35
115
tracker mytracker (max_cosine_distance, nn_budget);
116
+ // bytetrack
117
+ int fps=20 ;
118
+ BYTETracker bytetracker (fps, 30 );
36
119
// -----------------------------------------------------------------------
37
120
// 加载类别名称
38
121
std::vector<std::string> classes;
@@ -85,50 +168,9 @@ int main(int argc, char *argv[])
85
168
auto detect_time =std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count ();// ms
86
169
std::cout<<classes.size ()<<" :" <<results.size ()<<" :" <<num_frames<<std::endl;
87
170
88
- std::vector<detect_result> objects;
89
-
90
- DETECTIONS detections;
91
- for (detect_result dr : results)
92
- {
93
- // cv::putText(frame, classes[dr.classId], cv::Point(dr.box.tl().x+10, dr.box.tl().y - 10), cv::FONT_HERSHEY_SIMPLEX, .8, cv::Scalar(0, 255, 0));
94
- if (dr.classId == 0 ) // person
95
- {
96
- objects.push_back (dr);
97
- cv::rectangle (frame, dr.box , cv::Scalar (255 , 0 , 0 ), 2 );
98
- get_detections (DETECTBOX (dr.box .x , dr.box .y ,dr.box .width , dr.box .height ),dr.confidence , detections);
99
- }
100
- }
101
171
102
- std::cout<<" begin track" <<std::endl;
103
- if (FeatureTensor::getInstance ()->getRectsFeature (frame, detections))
104
- {
105
- std::cout << " get feature succeed!" <<std::endl;
106
- mytracker.predict ();
107
- mytracker.update (detections);
108
- std::vector<RESULT_DATA> result;
109
- for (Track& track : mytracker.tracks ) {
110
- if (!track.is_confirmed () || track.time_since_update > 1 ) continue ;
111
- result.push_back (std::make_pair (track.track_id , track.to_tlwh ()));
112
- }
113
- for (unsigned int k = 0 ; k < detections.size (); k++)
114
- {
115
- DETECTBOX tmpbox = detections[k].tlwh ;
116
- cv::Rect rect (tmpbox (0 ), tmpbox (1 ), tmpbox (2 ), tmpbox (3 ));
117
- cv::rectangle (frame, rect, cv::Scalar (0 ,0 ,255 ), 4 );
118
- // cvScalar的储存顺序是B-G-R,CV_RGB的储存顺序是R-G-B
119
-
120
- for (unsigned int k = 0 ; k < result.size (); k++)
121
- {
122
- DETECTBOX tmp = result[k].second ;
123
- cv::Rect rect = cv::Rect (tmp (0 ), tmp (1 ), tmp (2 ), tmp (3 ));
124
- rectangle (frame, rect, cv::Scalar (255 , 255 , 0 ), 2 );
125
-
126
- std::string label = cv::format (" %d" , result[k].first );
127
- cv::putText (frame, label, cv::Point (rect.x , rect.y ), cv::FONT_HERSHEY_SIMPLEX, 0.8 , cv::Scalar (255 , 255 , 0 ), 2 );
128
- }
129
- }
130
- }
131
- std::cout<<" end track" <<std::endl;
172
+ // test_deepsort(frame, results,mytracker);
173
+ test_bytetrack (frame, results,bytetracker);
132
174
133
175
cv::imshow (" YOLOv5-6.x" , frame);
134
176
0 commit comments