即日起在codingBlog上分享您的技术经验即可获得积分,积分可兑换现金哦。

Tensorflow读取数据1

编程语言 u010911921 13℃ 0评论
本文目录
[隐藏]

这段一直在用Tensorflow来做深度学习上的相关工作,然后对Tensorflow读取数据的方式进行实现。特地总结一下。首先是读取二进制图片数据,这里采用的是CIFAR-10的二进制数据

1.1.CIFAR-10数据集

CIFAR-10数据集合是包含60000张32*32*3的图片,其中每个类包含6000张图片,总共10类。在这60000张图片中50000张是训练集合,10000张是测试集合。

http://static.shenjianshou.cn/image/504221-a9f4ad3a7982b178db8fa348a5e439fa

其中二进制的图片保存的格式如下所示:

2.2.Tensorflow读取数据

从Tensorflow的官网可以看到从文件中读取数据的流程主要是一下步骤:

  1. The list of filenames
  2. (Optional) filename shuffling
  3. (Optional) epoch limit
  4. Filename queue
  5. A Reader for the file format
  6. A decoder for a record read by the reader
  7. (Optional) preprocessing
  8. Example queue

按照这样一个流程,首选应该将CIFAR-10的训练集和测试集合,生成文件名列表,然后在讲这个文件名列表传递给tf.train.string_input_producer函数创建一个用于保存文件名称的FIFO的队列,最后用tensor flow产生的reader从队列中读取数据。当reader读到数据就需要用tf.decode_raw函数对读取到的二进制数进行解码。

结束了上述操作,下面就需要采用另一个queue去batch together examples来为训练和测试提供数据。采用tf.train.shuffle_batch将上面生成的imagelabel传入函数即可完成。

3.3.开始训练

tf.train.shuffle_batch生成batch以后就开始利用tf.train.start_queue_runners函数启动队列,然后开始整个计算图,官网给的建议是如下形式:

init_op = tf.global_variables_initializer()
with tf.Session as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess= sess,coord = coord)
    try:
        while not coord.should_stop():
            #run training steps or whatever
            sess.run(train_op)
    except tf.errors.OutOfRangeError:
        print('Done training --epoch limit reached')
    finally:
        # when done,ask the threads to stop
        coord.request_stop()
    coord.join(threads)

4.4.代码实现

在神经网络的训练中由于每训练k步以后就会对网络进行一次测试,所以需要在上述步骤中,增加动态选择文件名称队列这样一个过程,可以由tf.QueueBase.from_list函数进行实现,然后reader从返回的文件名称队列中读取数据。

整个过程的实现如下所示:

#!/usr/bin/env python3
# --*-- encoding:utf-8 --*--

import tensorflow as tf
import numpy as np
import os

def read_cifar10(data_dir,is_traing,batch_size,shuffle):
    """

    :param data_dir:数据保存路径
    :param is_traing:True从训练集获取数据,False从测试集获取数据
    :param batch_size:  batch_size的大小
    :param shuffle: bool,是否进行shuffle操作
    :return:
    """
    img_width = 32
    img_height = 32
    img_depth = 3
    label_bytes = 1
    img_bytes = img_height * img_width *img_depth



    with tf.name_scope("input") as scope:
        #训练集合的文件列表
        train_filenames = [os.path.join(data_dir,
                                        'data_batch_%d.bin'%ii) for ii in np.arange(1,6)]
        #测试集合的文件列表
        val_filenames = [os.path.join(data_dir,'test_batch.bin')]

        #训练集和测试集合的文件名称队列
        train_queue = tf.train.string_input_producer(train_filenames)
        val_queue = tf.train.string_input_producer(val_filenames)

        #挑选文件队列,实现training的过程中测试
        queue_select = tf.cond(is_traing,
                               lambda :tf.constant(0),
                               lambda :tf.constant(1) )
        queue = tf.QueueBase.from_list(queue_select,[train_queue,val_queue])

        #从队列中读取固定长度的数据
        reader = tf.FixedLengthRecordReader(label_bytes+img_bytes)
        key,value = reader.read(queue)
        recode_bytes = tf.decode_raw(value,tf.uint8)

        #获取label
        label = tf.slice(recode_bytes,[0],[label_bytes])
        label = tf.cast(label,tf.int32)

        #获取image
        image_raw = tf.slice(recode_bytes,[label_bytes],[img_bytes])
        image_raw = tf.reshape(image_raw,[img_depth, img_height, img_width])
        image = tf.transpose(image_raw,[1,2,0])

        image = tf.cast(image,tf.float32)

        #对每一张图片进行标准化操作,可选操作此处可以进行对图片的各种操作
        image = tf.image.per_image_standardization(image)

        if shuffle:
            images, label_batch= tf.train.shuffle_batch([image,label],
                                                   batch_size=batch_size,
                                                   num_threads=16,
                                                   capacity=512+3*batch_size,
                                                   min_after_dequeue=512,
                                                   allow_smaller_final_batch=True)
        else:
            images, label_batch = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=16,
                                            capacity=512 + 3*batch_size,
                                            allow_smaller_final_batch=True)
        label_batch = tf.cast(label_batch,tf.int32)

        return images,label_batch

整个过程是采用VGG-16的网络模型进行训练的,在迭代16000次,tensorboard展示的结果如图所示:

code下载地址https://github.com/ZhichengHuang/LearnTensorflowCode

4.1.参考资料:

转载请注明:CodingBlog » Tensorflow读取数据1

喜欢 (0)or分享 (0)
发表我的评论
取消评论

*

表情