≡
  • 网络编程
  • 数据库
  • CMS技巧
  • 软件编程
  • PHP笔记
  • JavaScript
  • MySQL
位置:首页 > 网络编程 > Python

实现TensorFlow物体检测的简单示例

人气:624 时间:2018-10-10

这篇文章主要为大家详细介绍了实现TensorFlow物体检测的简单示例,具有一定的参考价值,可以用来参考一下。

对python这个高级语言感兴趣的小伙伴,下面一起跟随四海网的小编两巴掌来看看吧!

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

# End www_512pic_com

 

接下来进行环境设置


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 


%matplotlib inline 
sys.path.append("..")

# End www_512pic_com

物体检测载入


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 


from utils import label_map_util  
from utils import visualization_utils as vis_util

# End www_512pic_com

准备模型

 

变量 任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

# End www_512pic_com

下载模型


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd()) 
将(frozen)TensorFlow模型载入内存
detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

# End www_512pic_com

载入标签图

 

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

# End www_512pic_com

辅助代码


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8)

# End www_512pic_com

检测


# @param 30秒轻松实现TensorFlow物体检测
# @author 四海网|q1010.com 

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

# End www_512pic_com

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

本文来自:http://www.q1010.com/181/2176-0.html

注:关于实现TensorFlow物体检测的简单示例的内容就先介绍到这里,更多相关文章的可以留意四海网的其他信息。

关键词:TensorFlow

您可能感兴趣的文章

上一篇:Python常用程序调试的简单示例
下一篇:Python的CURL PycURL库的简单示例
热门文章
  • Python 处理Cookie的菜鸟教程(一)Cookie库
  • python之pandas取dataframe特定行列的简单示例
  • Python解决json.dumps错误::‘utf8’ codec can‘t decode byte
  • Python通过pythony连接Hive执行Hql的脚本
  • Python 三种方法删除列表中重复元素的简单示例
  • python爬虫代码示例
  • Python 中英文标点转换示例
  • Python 不得不知的开源项目解析
  • Python urlencode编码和url拼接实现方法
  • python按中文拆分中英文混合字符串的简单示例
  • 最新文章
    • Python利用numpy三层神经网络的简单示例
    • pygame可视化幸运大转盘的简单示例
    • Python爬虫之爬取二手房信息的简单示例
    • Python之time库的简单示例
    • OpenCV灰度、高斯模糊、边缘检测的简单示例
    • Python安装Bs4及使用的简单示例
    • django自定义manage.py管理命令的简单示例
    • Python之matplotlib 向任意位置添加一个子图(axes)的简单示例
    • Python图像标签标注软件labelme分析的简单示例
    • python调用摄像头并拍照发邮箱的简单示例

四海网收集整理一些常用的php代码,JS代码,数据库mysql等技术文章。