TensorFlow之视频流实时目标检测
·
参考:https://github.com/juandes/pikachu-detection/blob/master/detection_video.py
在之前的文章中,实现了利用tensorflow的目标检测API训练模型,并用图片来验证模型的有效性。本文的目的是为了将模型应用在视频检测中,实现视频流的实时检测。
---------------------------------------2018.12.3更新----------------------------------------------
抱歉前段时间一直在做语义分割的项目,没有时间测试目标检测的接口,今天终于抽空把视频流的检测做了,话不多说,直接上代码。代码部分主要参考官方给的object_detection_tutorial.ipynb中的内容,视频处理采用opencv库。
# coding: utf-8
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops
if tf.__version__ < '1.4.0':
raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')
from utils import label_map_util
from utils import visualization_utils as vis_util
# # Model preparation
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = 'inference_graph_67886/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt')
NUM_CLASSES = 12
def detect_in_video():
# VideoWriter is the responsible of creating a copy of the video
# used for the detections but with the detections overlays. Keep in
# mind the frame size has to be the same as original video.
out = cv2.VideoWriter('test_images/20171206/0-0-0_result.avi', cv2.VideoWriter_fourcc(
'M', 'J', 'P', 'G'), 25, (1280, 1024))
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object
# was detected.
detection_boxes = detection_graph.get_tensor_by_name(
'detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class
# label.
detection_scores = detection_graph.get_tensor_by_name(
'detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name(
'detection_classes:0')
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
cap = cv2.VideoCapture('test_images/20171206/0-0-0.avi')
while(cap.isOpened()):
# Read the frame
ret, frame = cap.read()
# Recolor the frame. By default, OpenCV uses BGR color space.
# This short blog post explains this better:
# https://www.learnopencv.com/why-does-opencv-use-bgr-color-format/
color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_np_expanded = np.expand_dims(color_frame, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores,
detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
# note: perform the detections using a higher threshold
vis_util.visualize_boxes_and_labels_on_image_array(
color_frame,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8,
min_score_thresh=.20)
cv2.imshow('frame', color_frame)
output_rgb = cv2.cvtColor(color_frame, cv2.COLOR_RGB2BGR)
out.write(output_rgb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
out.release()
cap.release()
cv2.destroyAllWindows()
def main():
detect_in_video()
if __name__ =='__main__':
main()
推荐内容
阅读全文
AI总结
更多推荐
相关推荐
查看更多
A2A

谷歌开源首个标准智能体交互协议Agent2Agent Protocol(A2A)
ai-agents-for-beginners

这个项目是一个针对初学者的 AI 代理课程,包含 10 个课程,涵盖构建 AI 代理的基础知识。源项目地址:https://github.com/microsoft/ai-agents-for-beginners
n8n

n8n 是一个工作流自动化平台,它结合了代码的灵活性和无代码的高效性。支持 400+ 集成、原生 AI 功能以及公平开源许可,n8n 能让你在完全掌控数据和部署的前提下,构建强大的自动化流程。源项目地址:https://github.com/n8n-io/n8n
热门开源项目
活动日历
查看更多
直播时间 2025-04-09 14:34:18

樱花限定季|G-Star校园行&华中师范大学专场
直播时间 2025-04-07 14:51:20

樱花限定季|G-Star校园行&华中农业大学专场
直播时间 2025-03-26 14:30:09

开源工业物联实战!
直播时间 2025-03-25 14:30:17

Heygem.ai数字人超4000颗星火燎原!
直播时间 2025-03-13 18:32:35

全栈自研企业级AI平台:Java核心技术×私有化部署实战
所有评论(0)