我正在嘗試為人員重新識別任務編寫一個自定義損失函數,該函數在多任務學習設置和對象檢測中進行訓練。過濾后的標簽值的形狀為(batch_size, num_boxes)。我想創建一個掩碼,以便僅考慮在暗淡 1 中重復的值進行進一步計算。如何在 TF/Keras 后端執行此操作?簡短示例:Input labels = [[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]]
Required output: [[0,0,0,0,1,1,1,1,0],[0,0,1,1,1,0,1,1,0]](基本上我只想過濾掉重復項并丟棄損失函數的唯一標識)。我想可以使用 tf.unique 和 tf.scatter 的組合,但我不知道如何使用。
1 回答

森林海
TA貢獻2011條經驗 獲得超2個贊
這段代碼的工作原理:
x = tf.constant([[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]])
def mark_duplicates_1D(x):
y, idx, count = tf.unique_with_counts(x)
comp = tf.math.greater(count, 1)
comp = tf.cast(comp, tf.int32)
res = tf.gather(comp, idx)
mult = tf.math.not_equal(x, 0)
mult = tf.cast(mult, tf.int32)
res *= mult
return res
res = tf.map_fn(fn=mark_duplicates_1D, elems=x)
添加回答
舉報
0/150
提交
取消