Keras的fit_generator()模型方法期望生成器生成形狀(輸入,目標)的元組,其中兩個元素都是NumPy數組。該文檔似乎暗示著,如果我將Dataset迭代器簡單地包裝在生成器中,并確保將Tensors轉換為NumPy數組,那我應該很好。這段代碼給我一個錯誤:import numpy as npimport osimport keras.backend as Kfrom keras.layers import Dense, Inputfrom keras.models import Modelimport tensorflow as tffrom tensorflow.contrib.data import Datasetos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'with tf.Session() as sess: def create_data_generator(): dat1 = np.arange(4).reshape(-1, 1) ds1 = Dataset.from_tensor_slices(dat1).repeat() dat2 = np.arange(5, 9).reshape(-1, 1) ds2 = Dataset.from_tensor_slices(dat2).repeat() ds = Dataset.zip((ds1, ds2)).batch(4) iterator = ds.make_one_shot_iterator() while True: next_val = iterator.get_next() yield sess.run(next_val)datagen = create_data_generator()input_vals = Input(shape=(1,))output = Dense(1, activation='relu')(input_vals)model = Model(inputs=input_vals, outputs=output)model.compile('rmsprop', 'mean_squared_error')model.fit_generator(datagen, steps_per_epoch=1, epochs=5, verbose=2, max_queue_size=2)這是我得到的錯誤:Using TensorFlow backend.Epoch 1/5Exception in thread Thread-1:Traceback (most recent call last): File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__ fetch, allow_tensor=True, allow_operation=True)) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation)奇怪的是,next(datagen)在我初始化的位置之后直接添加包含一行datagen的代碼會使代碼運行正常,沒有錯誤。為什么我的原始代碼不起作用?將行添加到代碼中后,為什么它開始起作用?是否有一種更有效的方式將TensorFlow的Dataset API與Keras結合使用,而無需將Tensors轉換為NumPy數組然后再次返回?
添加回答
舉報
0/150
提交
取消