亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

tensorflow找到到真實點的最小距離

tensorflow找到到真實點的最小距離

繁花不似錦 2022-05-24 16:17:15
我有一個 Bx3 張量,fooB= 批量大小的 3D 點。通過某種幻想,我得到了另一個張量,bar形狀為 Bx6x3,其中每個 B 6x3 矩陣對應于foo. 該 6x3 矩陣由 6 個復值 3D 點組成。我想做的是,對于我的每個 B 點,從6 in 中找到最接近對應點 in的實值點,最終得到一個新的 Bx3 ,其中包含與 in點的最近點。barfoomin_barbarfoo在numpy中,我可以使用屏蔽數組來完成這一壯舉:foo = np.array([    [1,2,3],    [4,5,6],    [7,8,9]])# here bar is only Bx2x3 for simplicity, but the solution generalizesbar = np.array([    [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],    [[6,5,4],[4,5,7]],    [[1j,1j,1j],[0,0,0]],])#mask complex elements of barbar_with_masked_imag = np.ma.array(bar)candidates = bar_with_masked_imag.imag == 0bar_with_masked_imag.mask = ~candidatesdists = np.sum(bar_with_masked_imag**2, axis=1)mindists = np.argmin(dists, axis=1)foo_indices = np.arange(foo.shape[0])min_bar = np.array(    bar_with_masked_imag[foo_indices,mindists,:],     dtype=float)print(min_bar)#[[2. 3. 4.]# [4. 5. 7.]# [0. 0. 0.]]但是,tensorflow 沒有掩碼數組等。我如何將其翻譯成張量流?
查看完整描述

1 回答

?
幕布斯7119047

TA貢獻1794條經驗 獲得超8個贊

這是一種方法:


import tensorflow as tf

import math


def solution_tf(foo, bar):

    foo = tf.convert_to_tensor(foo)

    bar = tf.convert_to_tensor(bar)

    # Get real and imaginary parts

    bar_r = tf.cast(tf.real(bar), foo.dtype)

    bar_i = tf.imag(bar)

    # Mask of all real-valued points

    m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)

    # Distance to every corresponding point

    d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)

    # Replace distances of complex points with infinity

    d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))

    # Find smallest distances

    idx = tf.argmin(d2, axis=1)

    # Get points with smallest distances

    b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])

    return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))


# Test

with tf.Graph().as_default(), tf.Session() as sess:

    foo = tf.constant([

        [1,2,3],

        [4,5,6],

        [7,8,9]], dtype=tf.float32)

    bar = tf.constant([

        [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],

        [[6,5,4],[4,5,7]],

        [[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)

    sol_tf = solution_tf(foo, bar)

    print(sess.run(sol_tf))

    # [[2. 3. 4.]

    #  [4. 5. 7.]

    #  [0. 0. 0.]]


查看完整回答
反對 回復 2022-05-24
  • 1 回答
  • 0 關注
  • 105 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號