美文网首页
TF - Retrain an Image Classifie

TF - Retrain an Image Classifie

作者: leo567 | 来源:发表于2018-11-19 11:38 被阅读0次
三种方式
  1. 拿到自己需要训练的图片数据集,从头开始训练。参数都是没有训练过的初始随机值,把我们准备好的数据按批次的加入训练。问题:例如我们需要训练Inception模型,由于这个模型比较复杂,需要准备大量的数据集,上百万千万,如果只有几万张图片来训练可能会参数比较严重的过拟合的情况。(从无到有训练一个模型)

  2. 找到一个已经训练好的模型,参数已经是确定的,卷积层主要作用是做图像特征的提取,Inception模型中已经训练好的卷积层我们可以认为对于其他图像的特征提取也是适用的(训练了1500万+图片产生的参数)。前面层都不改,只训练最后的池化层到全连接参数输出。(迁移学习)

  3. 在第二种训练方式的基础上对前面层的参数做一些微调

使用第二种方式训练自己的分类

How to Retrain an Image Classifier for New Categories

tensorflow/hub

英国牛津大学的一些开源数据集可以下载

  • 使用官方retrain.py示例进行训练生成模型:

(For a working example,download http://download.tensorflow.org/example_images/flower_photos.tgz and run tar xzf flower_photos.tgz to unpack it.)

python retrain.py --image_dir ~/flower_photos

生成pb和labels文件


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import tensorflow as tf

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'



def load_graph(model_file):
  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())
  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph


def read_tensor_from_image_file(file_name,
                                input_height=299,
                                input_width=299,
                                input_mean=0,
                                input_std=255):
  input_name = "file_reader"
  output_name = "normalized"
  file_reader = tf.read_file(file_name, input_name)
  if file_name.endswith(".png"):
    image_reader = tf.image.decode_png(
        file_reader, channels=3, name="png_reader")
  elif file_name.endswith(".gif"):
    image_reader = tf.squeeze(
        tf.image.decode_gif(file_reader, name="gif_reader"))
  elif file_name.endswith(".bmp"):
    image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
  else:
    image_reader = tf.image.decode_jpeg(
        file_reader, channels=3, name="jpeg_reader")
  float_caster = tf.cast(image_reader, tf.float32)
  dims_expander = tf.expand_dims(float_caster, 0)
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session()
  result = sess.run(normalized)

  return result


def load_labels(label_file):
  label = []
  proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
  for l in proto_as_ascii_lines:
    label.append(l.rstrip())
  return label


if __name__ == "__main__":

  file_name = "images/timg.jpg"
  model_file = "output_graph.pb"
  label_file = "output_labels.txt"
  input_height = 299
  input_width = 299
  input_mean = 0
  input_std = 255
  input_layer = "Placeholder"
  output_layer = "final_result"

  parser = argparse.ArgumentParser()
  parser.add_argument("--image", help="image to be processed")
  parser.add_argument("--graph", help="graph/model to be executed")
  parser.add_argument("--labels", help="name of file containing labels")
  parser.add_argument("--input_height", type=int, help="input height")
  parser.add_argument("--input_width", type=int, help="input width")
  parser.add_argument("--input_mean", type=int, help="input mean")
  parser.add_argument("--input_std", type=int, help="input std")
  parser.add_argument("--input_layer", help="name of input layer")
  parser.add_argument("--output_layer", help="name of output layer")
  args = parser.parse_args()

  if args.graph:
    model_file = args.graph
  if args.image:
    file_name = args.image
  if args.labels:
    label_file = args.labels
  if args.input_height:
    input_height = args.input_height
  if args.input_width:
    input_width = args.input_width
  if args.input_mean:
    input_mean = args.input_mean
  if args.input_std:
    input_std = args.input_std
  if args.input_layer:
    input_layer = args.input_layer
  if args.output_layer:
    output_layer = args.output_layer

  graph = load_graph(model_file)
  t = read_tensor_from_image_file(
      file_name,
      input_height=input_height,
      input_width=input_width,
      input_mean=input_mean,
      input_std=input_std)

  input_name = "import/" + input_layer
  output_name = "import/" + output_layer
  input_operation = graph.get_operation_by_name(input_name)
  output_operation = graph.get_operation_by_name(output_name)

  with tf.Session(graph=graph) as sess:
    results = sess.run(output_operation.outputs[0], {
        input_operation.outputs[0]: t
    })
  results = np.squeeze(results)

  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(label_file)
  for i in top_k:
    print(labels[i], results[i])

百度一张郁金香的图片:


运行结果

相关文章

网友评论

      本文标题:TF - Retrain an Image Classifie

      本文链接:https://www.haomeiwen.com/subject/wolofqtx.html