Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the tensorflow prediction (using models.load instead serving) #77

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions object_tracker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

# comment out below line to enable tensorflow logging outputs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import time
import tensorflow as tf
import tensorflow.keras as keras

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
Expand All @@ -23,6 +26,7 @@
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet

flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
'path to weights file')
Expand All @@ -38,12 +42,13 @@
flags.DEFINE_boolean('info', False, 'show detailed info of tracked objects')
flags.DEFINE_boolean('count', False, 'count objects being tracked on screen')


def main(_argv):
# Definition of the parameters
max_cosine_distance = 0.4
nn_budget = None
nms_max_overlap = 1.0

# initialize deep sort
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename, batch_size=1)
Expand All @@ -70,8 +75,10 @@ def main(_argv):
print(output_details)
# otherwise load standard tensorflow saved model
else:
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
infer = saved_model_loaded.signatures['serving_default']
# -- Dac: 15/06/2021
# saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
# infer = saved_model_loaded.signatures['serving_default']
infer = keras.models.load_model(FLAGS.weights)

# begin video capture
try:
Expand Down Expand Up @@ -100,7 +107,7 @@ def main(_argv):
else:
print('Video has ended or failed, try a different video format!')
break
frame_num +=1
frame_num += 1
print('Frame #: ', frame_num)
frame_size = frame.shape[:2]
image_data = cv2.resize(frame, (input_size, input_size))
Expand All @@ -121,11 +128,19 @@ def main(_argv):
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
input_shape=tf.constant([input_size, input_size]))
else:
# -- Dac: 15/6/2021
# batch_data = tf.constant(image_data)
# pred_bbox = infer(batch_data)
# for key, value in pred_bbox.items():
# boxes = value[:, :, 0:4]
# pred_conf = value[:, :, 4:]
batch_data = tf.constant(image_data)
pred_bbox = infer(batch_data)
for key, value in pred_bbox.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]
pred_bbox = infer.predict(batch_data)

for value in pred_bbox:
temp_value = np.expand_dims(value, axis=0)
boxes = temp_value[:, :, 0:4]
pred_conf = temp_value[:, :, 4:]

boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
Expand Down Expand Up @@ -158,9 +173,9 @@ def main(_argv):

# by default allow all classes in .names file
allowed_classes = list(class_names.values())

# custom allowed classes (uncomment line below to customize tracker for only people)
#allowed_classes = ['person']
# allowed_classes = ['person']

# loop through objects and use class index to get class name, allow only classes in allowed_classes list
names = []
Expand All @@ -175,17 +190,19 @@ def main(_argv):
names = np.array(names)
count = len(names)
if FLAGS.count:
cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2)
cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2,
(0, 255, 0), 2)
print("Objects being tracked: {}".format(count))
# delete detections that are not in allowed_classes
bboxes = np.delete(bboxes, deleted_indx, axis=0)
scores = np.delete(scores, deleted_indx, axis=0)

# encode yolo detections and feed to tracker
features = encoder(frame, bboxes)
detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)]
detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in
zip(bboxes, scores, names, features)]

#initialize color map
# initialize color map
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]

Expand All @@ -194,7 +211,7 @@ def main(_argv):
scores = np.array([d.confidence for d in detections])
classes = np.array([d.class_name for d in detections])
indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores)
detections = [detections[i] for i in indices]
detections = [detections[i] for i in indices]

# Call the tracker
tracker.predict()
Expand All @@ -203,36 +220,44 @@ def main(_argv):
# update tracks
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
continue
bbox = track.to_tlbr()
class_name = track.get_class()
# draw bbox on screen

# draw bbox on screen
color = colors[int(track.track_id) % len(colors)]
color = [i * 255 for i in color]
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2)
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1)
cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]-10)),0, 0.75, (255,255,255),2)
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1] - 30)),
(int(bbox[0]) + (len(class_name) + len(str(track.track_id))) * 17, int(bbox[1])), color, -1)
cv2.putText(frame, class_name + "-" + str(track.track_id), (int(bbox[0]), int(bbox[1] - 10)), 0, 0.75,
(255, 255, 255), 2)

# if enable info flag then print details about each track
# if enable info flag then print details about each track
if FLAGS.info:
print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))))
print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id),
class_name, (
int(bbox[0]),
int(bbox[1]),
int(bbox[2]),
int(bbox[3]))))

# calculate frames per second of running detections
fps = 1.0 / (time.time() - start_time)
print("FPS: %.2f" % fps)
result = np.asarray(frame)
result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

if not FLAGS.dont_show:
cv2.imshow("Output Video", result)

# if output flag is set, save video file
if FLAGS.output:
out.write(result)
if cv2.waitKey(1) & 0xFF == ord('q'): break
cv2.destroyAllWindows()


if __name__ == '__main__':
try:
app.run(main)
Expand Down