亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何有效地計算 scipy.csr 稀疏矩陣中非零索引的成對交集?

如何有效地計算 scipy.csr 稀疏矩陣中非零索引的成對交集?

jeck貓 2022-12-20 09:43:01
我有一個 scipy.sparse.csr 矩陣X,它是 nx p。對于X中的每一行,我想計算非零元素索引與X中每一行的交集,并將它們存儲在新的張量或什至字典中。例如,X是:X = [[0., 1.5, 4.7],[4., 0., 0.],[0., 0., 2.6]]我希望輸出是intersect = [[[1,2], [], [2]],[[], [0], []],[[2], [], [2]]]intersect[i,j] 是一個 ndarray,表示 X 的第 i 行和第 j 行的非零元素的索引的交集,即 X[i]、X[j]。目前我這樣做的方式是循環,我想對其進行矢量化,因為它會更快并且計算是并行完成的。# current coden = X.shape[0]intersection_dict = {}for i in range(n):    for j in range(n):        indices = np.intersect1d(X[i].indices, X[j].indices)        intersection_dict[(i,j)] = indices我的 n 很大,所以循環 n^2 很差。我只是無法找到一種方法來矢量化此操作。有人對如何解決這個問題有任何想法嗎?編輯: 很明顯我應該解釋我要解決的問題,所以就在這里。我正在解決一個優化問題并且有一個方程 W = X diag(theta) X'。當我更新 theta 的條目直到收斂時,我想快速找到 W。此外,我正在使用 pytorch 更新參數,其中稀疏操作不像 scipy 那樣廣泛。在哪里:X : n x p sparse data matrix (n documents, p features)theta : p x 1 parameter vector I want to learn and will be updatingX' : p x n transpose of sparse data matrixnote p >> n我想到了兩種快速解決這個問題的方法緩存稀疏外積(參見更有效的矩陣乘法與對角矩陣)W_ij = X_i * theta * X_j(X 的第 i 行、theta 和 X 的第 j 行的元素乘積。并且由于X_i, X_j稀疏,我在想如果我取非零索引的交集,那么我可以做一個簡單的密集元素乘積(不支持稀疏元素乘積在火炬中)X_i[intersection indices] * theta[intersection indices] X_j[intersection indices]我想盡可能多地矢量化這種計算而不是循環,因為我的 n 通常是數千,而 p 是 1100 萬。我正在嘗試方法 2 而不是方法 1 來解決 Pytorch 中缺乏稀疏支持的問題。主要是在更新 theta 的條目時,我不想進行稀疏密集或稀疏稀疏操作。我想做密密麻麻的操作。
查看完整描述

2 回答

?
慕容3067478

TA貢獻1773條經驗 獲得超3個贊

您正在查看的優化需要存儲p不同的n x n矩陣。如果您確實想嘗試一下,我可能會使用 scipy 的 C 擴展中稀疏矩陣中內置的所有功能。


import numpy as np

from scipy import sparse


arr = sparse.random(100,10000, format="csr", density=0.01)


xxt = arr @ arr.T

p_comps = [arr[:, i] @ arr.T[i, :] for i in range(arr.shape[1])]


def calc_weights(xxt, thetas, p_comps):

    xxt = xxt.copy()

    xxt.data = np.zeros(xxt.data.shape, dtype=xxt.dtype)

    for i, t in enumerate(thetas):

        xxt += (p_comps[i] * t)

    return xxt


W = calc_weights(xxt, np.ones(10000), p_comps)


>>>(xxt.A == W.A).all()

True

這真的不太可能在 python 中很好地實現。在 C 語言中執行此操作可能會更幸運,或者使用對元素進行操作的嵌套循環編寫一些東西,并且可以使用 numba 進行 JIT 編譯。


查看完整回答
反對 回復 2022-12-20
?
守候你守候我

TA貢獻1802條經驗 獲得超10個贊

第一個簡單的解決方案是注意輸出矩陣是對稱的:


n = X.shape[0]

intersection_dict = {}

for i in range(n):

    for j in range(i,n): #note the edit here

        indices = np.intersect1d(X[i].indices, X[j].indices)

        intersection_dict[(i,j)] = indices

這將使您的計算量減少不到 2 倍


查看完整回答
反對 回復 2022-12-20
  • 2 回答
  • 0 關注
  • 126 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號