我正在訓練一個 RNN,我需要使用索引來查找示例時間流的另一部分中的值v = tf.constant([ [[.1, .2], [.3, .4]], # timestream 1 values [[.6, .5], [.7, .8]] # timestream 2 values])ixs = tf.constant([ [1, 0], # indices into timestream 1 values [0, 1] # indices into timestream 2 values])我正在尋找一個可以進行查找并用張量值替換索引并產生的操作:[ [[.3, .4], [.1, .2]], [[.6, .5], [.7, .8]]]tf.gather 和 tf.gather_nd 聽起來他們可能是正確的道路,但我真的不明白我從他們那里得到的結果。v_at_ix = tf.gather(v, ixs, axis=-1)sess.run(v_at_ix)array([[[[0.2, 0.1], [0.1, 0.2]], [[0.4, 0.3], [0.3, 0.4]]], [[[0.5, 0.6], [0.6, 0.5]], [[0.8, 0.7], [0.7, 0.8]]]], dtype=float32)v_at_ix = tf.gather_nd(v, ixs)sess.run(v_at_ix)array([[0.6, 0.5], [0.3, 0.4]], dtype=float32)有誰知道正確的方法來做到這一點?
添加回答
舉報
0/150
提交
取消