亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

當切片本身是張量流中的張量時如何進行切片分配

當切片本身是張量流中的張量時如何進行切片分配

一只甜甜圈 2022-03-09 20:10:25
我想在張量流中進行切片分配。我知道我可以使用:my_var = my_var[4:8].assign(tf.zeros(4))基于此鏈接。正如您在中看到的,my_var[4:8]我們在這里有特定的索引 4、8 用于切片然后分配。我的情況不同,我想根據張量進行切片,然后進行分配。out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32)) rows_tf = tf.constant ([[1, 2, 5], [1, 2, 5], [1, 2, 5], [1, 4, 6], [1, 4, 6], [2, 3, 6], [2, 3, 6], [2, 4, 7]])columns_tf = tf.constant([[1], [2], [3], [2], [3], [2], [3], [2]])changed_tensor = [[8.3356,    0.,        8.457685 ],                  [0.,        6.103182,  8.602337 ],                  [8.8974,    7.330564,  0.       ],                  [0.,        3.8914037, 5.826657 ],                  [8.8974,    0.,        8.283971 ],                  [6.103182,  3.0614321, 5.826657 ],                  [7.330564,  0.,        8.283971 ],                  [6.103182,  3.8914037, 0.       ]]此外,這是sparse_indices張量,它是需要更新的整個索引的連接rows_tf和制作(以防它可以提供幫助:)columns_tfsparse_indices = tf.constant([[1 1] [2 1] [5 1] [1 2] [2 2] [5 2] [1 3] [2 3] [5 3] [1 2] [4 2] [6 2] [1 3] [4 3] [6 3] [2 2] [3 2] [6 2] [2 3] [3 3] [6 3] [2 2] [4 2] [4 2]])我想做的是做這個簡單的任務:out[rows_tf, columns_tf] = changed_tensor為此,我正在這樣做:out[rows_tf:column_tf].assign(changed_tensor)但是,我收到了這個錯誤:tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/這是預期的輸出:[[0.        0.        0.        0.       ] [0.        8.3356    0.        8.8974   ] [0.        0.        6.103182  7.330564 ] [0.        0.        3.0614321 0.       ] [0.        0.        3.8914037 0.       ] [0.        8.457685  8.602337  0.       ] [0.        0.        5.826657  8.283971 ] [0.        0.        0.        0.       ]]知道如何完成這個任務嗎?先感謝您:)
查看完整描述

1 回答

?
哈士奇WWW

TA貢獻1799條經驗 獲得超6個贊

tf.scatter_nd_update 此示例(從此處的 tf 文檔擴展)應該有所幫助。


您想首先將您的 row_indices 和 column_indices 組合成一個二維索引列表,這indices是tf.scatter_nd_update. 然后你輸入了一個期望值列表,即updates.


ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))

indices = tf.constant([[0, 2], [2, 2]])

updates = tf.constant([1.0, 2.0])


update = tf.scatter_nd_update(ref, indices, updates)

with tf.Session() as sess:

  sess.run(tf.initialize_all_variables())

  print sess.run(update)

Result:


[[ 0.  0.  1.  0.]

 [ 0.  0.  0.  0.]

 [ 0.  0.  2.  0.]

 [ 0.  0.  0.  0.]

 [ 0.  0.  0.  0.]

 [ 0.  0.  0.  0.]

 [ 0.  0.  0.  0.]

 [ 0.  0.  0.  0.]]

專門針對您的數據,


ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))

changed_tensor = [[8.3356,    0.,        8.457685 ],

                  [0.,        6.103182,  8.602337 ],

                  [8.8974,    7.330564,  0.       ],

                  [0.,        3.8914037, 5.826657 ],

                  [8.8974,    0.,        8.283971 ],

                  [6.103182,  3.0614321, 5.826657 ],

                  [7.330564,  0.,        8.283971 ],

                  [6.103182,  3.8914037, 0.       ]]

updates = tf.reshape(changed_tensor, shape=[-1])

sparse_indices = tf.constant(

[[1, 1],

 [2, 1],

 [5, 1],

 [1, 2],

 [2, 2],

 [5, 2],

 [1, 3],

 [2, 3],

 [5, 3],

 [1, 2],

 [4, 2],

 [6, 2],

 [1, 3],

 [4, 3],

 [6, 3],

 [2, 2],

 [3, 2],

 [6, 2],

 [2, 3],

 [3, 3],

 [6, 3],

 [2, 2],

 [4, 2],

 [4, 2]])


update = tf.scatter_nd_update(ref, sparse_indices, updates)

with tf.Session() as sess:

  sess.run(tf.initialize_all_variables())

  print sess.run(update)


Result:


[[ 0.          0.          0.          0.        ]

 [ 0.          8.3355999   0.          8.8973999 ]

 [ 0.          0.          6.10318184  7.33056402]

 [ 0.          0.          3.06143212  0.        ]

 [ 0.          0.          0.          0.        ]

 [ 0.          8.45768547  8.60233688  0.        ]

 [ 0.          0.          5.82665682  8.28397083]

 [ 0.          0.          0.          0.        ]]


查看完整回答
反對 回復 2022-03-09
  • 1 回答
  • 0 關注
  • 166 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號