我有以下簡單的例子:import tensorflow as tftensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])tensor2 = tf.constant(value = [20, 21, 22, 23])print(tensor1.shape)print(tensor2.shape)dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))print('Original dataset')for i in dataset: print(i)dataset = dataset.repeat(3)print('Repeated dataset')for i in dataset: print(i)如果我然后將其批處理dataset為:dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)正如預期的那樣,我收到:Batched dataset(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[10, 11, 12], [ 1, 2, 3], [ 4, 5, 6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 7, 8, 9], [10, 11, 12], [ 1, 2, 3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)批處理數據集采用連續的元素。但是,當我先進行混音,然后進行批處理時:dataset = dataset.shuffle(3)print('Shuffled dataset')for i in dataset: print(i)dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)我正在使用 Google Colab 和TensorFlow 2.x.我的問題是:為什么在批處理之前進行洗牌會導致batch返回非連續元素?感謝您的任何答復。
1 回答

12345678_0001
TA貢獻1802條經驗 獲得超5個贊
這就是洗牌的作用。你是這樣開始的:
[[1,?2,?3],?[4,?5,?6],?[7,?8,?9],?[10,?11,?12]]
您已指定,buffer_size=3
因此它會創建前 3 個元素的緩沖區:
[[1,?2,?3],?[4,?5,?6],?[7,?8,?9]]
您指定了batch_size=3
,因此它將從此樣本中隨機選擇一個元素,并將其替換為初始緩沖區之外的第一個元素。假設[1, 2, 3]
已被選中,您的批次現在是:
[[1,?2,?3]]
現在你的緩沖區是:
[[10,?11,?12],?[4,?5,?6],?[7,?8,?9]]
對于 的第二個元素batch=3
,它將從此緩沖區中隨機選擇。假設[7, 8, 9]
已挑選,您的批次現在是:
[[1,?2,?3],?[7,?8,?9]]
現在你的緩沖區是:
[[10,?11,?12],?[4,?5,?6]]
沒有什么新內容可以填充緩沖區,因此它將隨機選擇這些元素之一,例如[10, 11, 12]
。您的批次現在是:
[[1,?2,?3],?[7,?8,?9],?[10,?11,?12]]
下一批將只是[4, 5, 6]
因為默認情況下,?batch(drop_remainder=False)
.
添加回答
舉報
0/150
提交
取消