加入收藏 | 设为首页 | 会员中心 | 我要投稿 云计算网_泰州站长网 (http://www.0523zz.com/)- 视觉智能、AI应用、CDN、行业物联网、智能数字人!
当前位置: 首页 > 站长学院 > PHP教程 > 正文

Tensorflow数据读取机制介绍

发布时间:2021-11-12 11:10:32 所属栏目:PHP教程 来源:互联网
导读:Dataset可以看作是相同类型元素的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。 数据集对象实例化: dataset=tf.data.Dataset.from_tensor_slice(data) 迭代器对象实例化: iterator=dataset.make_one_shot_iterator() one_

Dataset可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。
 
数据集对象实例化:
 
dataset=tf.data.Dataset.from_tensor_slice(<data>)
迭代器对象实例化:
 
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
读取结束异常:如果一个dataset中的元素被读取完毕,再尝试sess.run(one_element)的话,会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据是一致的。
 
高维数据集的使用
tf.data.Dataset.from_tensor_slices真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。
 
dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError as e:
        print('end~')
输出:
 
[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]
tuple组合数据
dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
                                            np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print('end~')
输出:
 
(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...
数据集处理方法
Dataset支持一类特殊操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。常用的Transformation:
 
map
batch
shuffle
repeat
其中,
 
map和Python中的map一致,接受一个函数,Dataset中的每个元素都会作为这个函数的输入,并将函数返回值作为新的Dataset
 
dataset=dataset.map(lambda x:x+1)
注意:map函数可以使用num_parallel_calls参数并行化
 
batch就是将多个元素组成batch。
 
dataset=tf.data.Dataset.from_tensor_slices(
{
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.batch(2)  # batch_size=2
###
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(one_element)
    except tf.errors.OutOfRangeError:
        print('end~')
输出:
 
{'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
{'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
shuffle的功能是打乱dataset中的元素,它有个参数buffer_size,表示打乱时使用的buffer的大小,不应设置过小,推荐值1000.
 
dataset=tf.data.Dataset.from_tensor_slices(
{
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.shuffle(buffer_size=5)
###
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(one_element)
    except tf.errors.OutOfRangeError:
        print('end~')
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch。假设原先的数据是一个epoch,使用repeat(2)可以使之变成2个epoch.
 
dataset=tf.data.Dataset.from_tensor_slices({
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.repeat(2)  # 2epoch
###
# iterator, one_element...
注意:如果直接调用repeat()函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出tf.errors.OutOfRangeError异常。
 
模拟读入磁盘图片及其Label示例
def _parse_function(filename,label):  # 接受单个元素,转换为目标
    img_string=tf.read_file(filename)
    img_decoded=tf.image.decode_images(img_string)
    img_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label
 
filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function)  # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)
更多Dataset创建方法
tf.data.TextLineDataset():函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。
tf.data.FixedLengthRecordDataset():函数输入一个文件列表和record_bytes参数,dataset中每一个元素是文件中固定字节数record_bytes的内容,可用来读取二进制保存的文件,如CIFAR10。
tf.data.TFRecordDataset():读取TFRecord文件,dataset中每一个元素是一个TFExample。
更多Iterator创建方法
最简单的创建Iterator方法是通过dataset.make_one_shot_iterator()创建一个iterator。
 
除了这种iterator之外,还有更复杂的Iterator:
 
initializable iterator
reinitializable iterator
feedable iterator
其中,initializable iterator方法要在使用前通过sess.run()进行初始化,initializable iterator还可用于读入较大数组。在使用tf.data.Dataset.from_tensor_slices(array)时,实际上发生的事情是将array作为一个tf.constants保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个placeholder取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。
 
features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})
Tensorflow内部读取机制
 
 
对于文件名队列,使用tf.train.string_input_producer()函数,tf.train.string_input_producer()还有两个重要参数,num_epoches和shuffle
 
内存队列不需要我们建立,只需要使用reader对象从文件名队列中读取数据即可,使用tf.train.start_queue_runners()函数启动队列,填充两个队列的数据。
 
with tf.Session() as sess:
    filenames=['A.jpg','B.jpg','C.jpg']
    filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
    reader=tf.WholeFileReader()
    key,value=reader.read(filename_queue)
    # tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
    tf.local_variables_initializer().run()
    threads=tf.train.start_queue_runners(sess=sess)
    i=0
    while True:
        i+=1
        image_data=sess.run(value)
        with open('reader/test_%d.jpg'%i,'wb') as f:
            f.write(image_data)

(编辑:云计算网_泰州站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    热点阅读