1 回答

TA貢獻1824條經驗 獲得超6個贊
不是使用稀疏張量來制作“除 softmaxed top-K 值之外的所有零”的張量,而是使用tf.scatter_nd:
import tensorflow as tf
def softmax_top_k(logits, k=10):
values, indices = tf.nn.top_k(logits, k, sorted=False)
softmax = tf.nn.softmax(values)
logits_shape = tf.shape(logits)
# Assuming that logits is 2D
rows = tf.tile(tf.expand_dims(tf.range(logits_shape[0]), 1), [1, k])
scatter_idx = tf.stack([rows, indices], axis=-1)
return tf.scatter_nd(scatter_idx, softmax, logits_shape)
編輯:這是具有任意維數的張量的稍微復雜的版本。但是,代碼仍然要求在圖構建時知道維數。
import tensorflow as tf
def softmax_top_k(logits, k=10):
values, indices = tf.nn.top_k(logits, k, sorted=False)
softmax = tf.nn.softmax(values)
# Make nd indices
logits_shape = tf.shape(logits)
dims = [tf.range(logits_shape[i]) for i in range(logits_shape.shape.num_elements() - 1)]
grid = tf.meshgrid(*dims, tf.range(k), indexing='ij')
scatter_idx = tf.stack(grid[:-1] + [indices], axis=-1)
return tf.scatter_nd(scatter_idx, softmax, logits_shape)
添加回答
舉報