TF-slim是一个新的TensorFlow轻量级高级API,可以用来搭建、训练和验证网络模型,最近由于项目需要,在学习使用该库。dataset库中包含下载标准数据集的代码,下面介绍如何在标准代码的基础上准备自己的数据进行训练。
1. 准备自己的数据
将要训练和测试的数据分别放在train和test文件夹下,文件夹下是以标签名命名的各子类数据,如下图所示
2.在datasets下创建自己待训练数据的脚本,比如我这里命名为car,则相应的准备car.py 和 download_and_convert_car.py两个脚本。
car.py的脚本和datasets文件夹下的flowers.py等标准数据集脚本一样,只需要更改对应的类别数和样本数。
_FILE_PATTERN = 'car_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 12973, 'validation': 3200}
_NUM_CLASSES = 3download_and_convert_car.py和对应的脚本有区别,不需要下载和划分数据,只需要做数据转换即可,下面贴出代码。
#!/usr/bin/env python2# -*- coding: utf-8 -*-"""
Created on Wed May 30 09:53:21 2018
@author: liuli
"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport mathimport osimport randomimport sysimport tensorflow as tffrom datasets import dataset_utils# Seed for repeatability._RANDOM_SEED = 0# The number of shards per dataset split._NUM_SHARDS = 5class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data) return image.shape[0], image.shape[1] def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3
assert image.shape[2] == 3
return imagedef _get_filenames_and_classes(dataset_dir):
directories = []
class_names = [] for filename in os.listdir(dataset_dir):
path = os.path.join(dataset_dir, filename) if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
photo_filenames = [] for directory in directories: for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path) return photo_filenames, sorted(class_names)def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'car_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename)def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images.
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
assert split_name in ['train', 'validation']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) with tf.Graph().as_default():
image_reader = ImageReader() with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id))
sys.stdout.flush() # Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = class_names_to_ids[class_name]
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
train_data_dir = '/home/liuli/work/Tensorflow/flower_data/raw-data/train'test_data_dir = '/home/liuli/work/Tensorflow/flower_data/raw-data/validation'def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
training_filenames,class_names = _get_filenames_and_classes(train_data_dir)
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
random.seed(_RANDOM_SEED)
random.shuffle(training_filenames)
validation_filenames,_= _get_filenames_and_classes(test_data_dir)
random.shuffle(validation_filenames)
_convert_dataset('train', training_filenames, class_names_to_ids,
dataset_dir)
_convert_dataset('validation', validation_filenames, class_names_to_ids,
dataset_dir)
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the car dataset!')3.在download_and_convert_data.py 69行main函数中加入dataset_name选择代码
elif FLAGS.dataset_name == 'car': download_and_convert_car.run(FLAGS.dataset_dir)
4.在dataset_factory.py的datasets_map中相应插入自己训练数据的键值对
from datasets import carimport osimport tensorflow as tf
slim = tf.contrib.slim
datasets_map = { 'cifar10': cifar10, 'flowers': flowers, 'imagenet': imagenet, 'mnist': mnist, 'car':car
}5.创建生成数据的脚本
DATASET_DIR=/home/liuli/work/Tensorflow/cars
python download_and_convert_data.py \
--dataset_name=car \ --dataset_dir=${DATASET_DIR}就可以在DATASET_DIR文件夹下生成对应的TFrecord格式的数据
點擊查看更多內容
2人點贊
評論
評論
共同學習,寫下你的評論
評論加載中...
作者其他優質文章
正在加載中
感謝您的支持,我會繼續努力的~
掃碼打賞,你說多少就多少
贊賞金額會直接到老師賬戶
支付方式
打開微信掃一掃,即可進行掃碼打賞哦

