亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

在 PyTorch 中計算歐幾里德距離而不是矩陣乘法

在 PyTorch 中計算歐幾里德距離而不是矩陣乘法

MMMHUHU 2023-07-11 16:46:12
假設我們有 2 個矩陣:mat = torch.randn([20, 7]) * 100mat2 = torch.randn([7, 20]) * 100n, m = mat.shape最簡單的常用矩陣乘法如下所示:def mat_vec_dot_product(mat, vect):    n, m = mat.shape        res = torch.zeros([n])    for i in range(n):        for j in range(m):            res[i] += mat[i][j] * vect[j]            return resres = torch.zeros([n, n])for k in range(n):    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])    但是如果我需要應用 L2 范數而不是點積怎么辦?代碼如下:def mat_vec_l2_mult(mat, vect):    n, m = mat.shape        res = torch.zeros([n])    for i in range(n):        for j in range(m):            res[i] += (mat[i][j] - vect[j]) ** 2                res = res.sqrt()            return resfor k in range(n):    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])我們可以使用 Torch 或任何其他庫以最佳方式做到這一點嗎?因為簡單的 O(n^3) Python 代碼運行速度非常慢。
查看完整描述

2 回答

?
慕虎7371278

TA貢獻1802條經驗 獲得超4個贊

用于torch.cdistL2 范數 - 歐氏距離

res?=?torch.cdist(mat,?mat2.permute(1,0),?p=2)

在這里,我曾經將frompermute的 dim 交換為mat27,2020,7


查看完整回答
反對 回復 2023-07-11
?
翻翻過去那場雪

TA貢獻2065條經驗 獲得超14個贊

首先,PyTorch 中的矩陣乘法有一個內置運算符:@。因此,要將 mat 和 mat2 相乘,您只需執行以下操作:


mat @ mat2

(假設尺寸一致,應該可以工作)。


現在,要計算您似乎在第二個塊中計算的平方差之和(SSD 或 L2 范數),您可以做一個簡單的技巧。由于 L2 范數的平方||m_i - v||^2(其中m_i是矩陣的第 i 行M,v是向量)等于點積<m_i - v, m_i-v>- 根據您獲得的點積的線性度:因此您可以通過以下方式<m_i,m_i> - 2<m_i,v> + <v,v>計算向量中每一行的 SSD:計算一次每行的 L2 范數平方、一次每行與向量之間的點積以及一次向量的 L2 范數。這可以在 中完成。然而,對于 2 個矩陣之間的 SSD,您仍然會得到MvO(n^2)O(n^3)。不過,可以通過向量化操作而不是使用循環來進行改進。這是 2 個矩陣的簡單實現:


def mat_mat_l2_mult(mat,mat2):

    rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])

    cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)

    rows_cols_dot_product = mat @ mat2

    ssd = rows_norm -2*rows_cols_dot_product + cols_norm

    return ssd.sqrt()


mat = torch.randn([20, 7])

mat2 = torch.randn([7,20])

print(mat_mat_l2_mult(mat, mat2))

所得矩陣的每個單元格將具有中每行和每列之間i,j差異的 L2 范數。imatjmat2


查看完整回答
反對 回復 2023-07-11
  • 2 回答
  • 0 關注
  • 256 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號