forked from NVIDIA-AI-IOT/tf_trt_models
-
Notifications
You must be signed in to change notification settings - Fork 51
/
camera_tf_trt.py
199 lines (162 loc) · 7.3 KB
/
camera_tf_trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""camera_tf_trt.py
This is a Camera TensorFlow/TensorRT Object Detection sample code for
Jetson TX2 or TX1. This script captures and displays video from either
a video file, an image file, an IP CAM, a USB webcam, or the Tegra
onboard camera, and do real-time object detection with example TensorRT
optimized SSD models in NVIDIA's 'tf_trt_models' repository. Refer to
README.md inside this repository for more information.
This code is written and maintained by JK Jung <jkjung13@gmail.com>.
"""
import sys
import time
import logging
import argparse
import numpy as np
import cv2
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
from utils.camera import add_camera_args, Camera
from utils.od_utils import read_label_map, build_trt_pb, load_trt_pb, \
write_graph_tensorboard, detect
from utils.visualization import BBoxVisualization
# Constants
DEFAULT_MODEL = 'ssd_inception_v2_coco'
DEFAULT_LABELMAP = 'third_party/models/research/object_detection/' \
'data/mscoco_label_map.pbtxt'
WINDOW_NAME = 'CameraTFTRTDemo'
BBOX_COLOR = (0, 255, 0) # green
def parse_args():
"""Parse input arguments."""
desc = ('This script captures and displays live camera video, '
'and does real-time object detection with TF-TRT model '
'on Jetson TX2/TX1/Nano')
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument('--model', dest='model',
help='tf-trt object detecion model '
'[{}]'.format(DEFAULT_MODEL),
default=DEFAULT_MODEL, type=str)
parser.add_argument('--build', dest='do_build',
help='re-build TRT pb file (instead of using'
'the previously built version)',
action='store_true')
parser.add_argument('--tensorboard', dest='do_tensorboard',
help='write optimized graph summary to TensorBoard',
action='store_true')
parser.add_argument('--labelmap', dest='labelmap_file',
help='[{}]'.format(DEFAULT_LABELMAP),
default=DEFAULT_LABELMAP, type=str)
parser.add_argument('--num-classes', dest='num_classes',
help='(deprecated and not used) number of object '
'classes', type=int)
parser.add_argument('--confidence', dest='conf_th',
help='confidence threshold [0.3]',
default=0.3, type=float)
args = parser.parse_args()
return args
def open_display_window(width, height):
"""Open the cv2 window for displaying images with bounding boxeses."""
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
cv2.resizeWindow(WINDOW_NAME, width, height)
cv2.moveWindow(WINDOW_NAME, 0, 0)
cv2.setWindowTitle(WINDOW_NAME, 'Camera TFTRT Object Detection Demo '
'for Jetson TX2/TX1')
def draw_help_and_fps(img, fps):
"""Draw help message and fps number at top-left corner of the image."""
help_text = "'Esc' to Quit, 'H' for FPS & Help, 'F' for Fullscreen"
font = cv2.FONT_HERSHEY_PLAIN
line = cv2.LINE_AA
fps_text = 'FPS: {:.1f}'.format(fps)
cv2.putText(img, help_text, (11, 20), font, 1.0, (32, 32, 32), 4, line)
cv2.putText(img, help_text, (10, 20), font, 1.0, (240, 240, 240), 1, line)
cv2.putText(img, fps_text, (11, 50), font, 1.0, (32, 32, 32), 4, line)
cv2.putText(img, fps_text, (10, 50), font, 1.0, (240, 240, 240), 1, line)
return img
def set_full_screen(full_scrn):
"""Set display window to full screen or not."""
prop = cv2.WINDOW_FULLSCREEN if full_scrn else cv2.WINDOW_NORMAL
cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, prop)
def loop_and_detect(cam, tf_sess, conf_th, vis, od_type):
"""Loop, grab images from camera, and do object detection.
# Arguments
cam: the camera object (video source).
tf_sess: TensorFlow/TensorRT session to run SSD object detection.
conf_th: confidence/score threshold for object detection.
vis: for visualization.
"""
show_fps = True
full_scrn = False
fps = 0.0
tic = time.time()
while True:
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0:
# Check to see if the user has closed the display window.
# If yes, terminate the while loop.
break
img = cam.read()
if img is not None:
box, conf, cls = detect(img, tf_sess, conf_th, od_type=od_type)
img = vis.draw_bboxes(img, box, conf, cls)
if show_fps:
img = draw_help_and_fps(img, fps)
cv2.imshow(WINDOW_NAME, img)
toc = time.time()
curr_fps = 1.0 / (toc - tic)
# calculate an exponentially decaying average of fps number
fps = curr_fps if fps == 0.0 else (fps*0.9 + curr_fps*0.1)
tic = toc
key = cv2.waitKey(1)
if key == 27: # ESC key: quit program
break
elif key == ord('H') or key == ord('h'): # Toggle help/fps
show_fps = not show_fps
elif key == ord('F') or key == ord('f'): # Toggle fullscreen
full_scrn = not full_scrn
set_full_screen(full_scrn)
def main():
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Ask tensorflow logger not to propagate logs to parent (which causes
# duplicated logging)
logging.getLogger('tensorflow').propagate = False
args = parse_args()
logger.info('called with args: %s' % args)
# build the class (index/name) dictionary from labelmap file
logger.info('reading label map')
cls_dict = read_label_map(args.labelmap_file)
pb_path = './data/{}_trt.pb'.format(args.model)
log_path = './logs/{}_trt'.format(args.model)
if args.do_build:
logger.info('building TRT graph and saving to pb: %s' % pb_path)
build_trt_pb(args.model, pb_path)
logger.info('opening camera device/file')
cam = Camera(args)
cam.open()
if not cam.is_opened:
sys.exit('Failed to open camera!')
logger.info('loading TRT graph from pb: %s' % pb_path)
trt_graph = load_trt_pb(pb_path)
logger.info('starting up TensorFlow session')
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config, graph=trt_graph)
if args.do_tensorboard:
logger.info('writing graph summary to TensorBoard')
write_graph_tensorboard(tf_sess, log_path)
logger.info('warming up the TRT graph with a dummy image')
od_type = 'faster_rcnn' if 'faster_rcnn' in args.model else 'ssd'
dummy_img = np.zeros((720, 1280, 3), dtype=np.uint8)
_, _, _ = detect(dummy_img, tf_sess, conf_th=.3, od_type=od_type)
cam.start() # ask the camera to start grabbing images
# grab image and do object detection (until stopped by user)
logger.info('starting to loop and detect')
vis = BBoxVisualization(cls_dict)
open_display_window(cam.img_width, cam.img_height)
loop_and_detect(cam, tf_sess, args.conf_th, vis, od_type=od_type)
logger.info('cleaning up')
cam.stop() # terminate the sub-thread in camera
tf_sess.close()
cam.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()