亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

關于 Tensorflow 和 PyTorch 中的自定義操作

關于 Tensorflow 和 PyTorch 中的自定義操作

慕村9548890 2021-10-26 16:52:23
我要實現的能量函數,稱為剛性能源,如本文的公式7這里。能量函數將兩個 3D 對象網格作為輸入,并返回它們之間的能量。第一個網格是源網格,第二個網格是源網格的變形版本。在粗略的偽代碼中,計算過程如下:迭代源網格中的所有頂點。對于每個頂點,計算其與其相鄰頂點的協方差矩陣。對計算出的協方差矩陣執行 SVD 并找到頂點的旋轉矩陣。使用計算出的旋轉矩陣、原始網格中的點坐標和變形網格中的相應坐標來計算頂點的能量偏差。因此,這個能量函數要求我迭代網格中的每個點,并且網格可能有超過 2k 個這樣的點。在 Tensorflow 中,有兩種方法可以做到這一點。我可以有 2 個形狀 (N,3) 的張量,一個代表源點,另一個代表變形網格。純粹使用 Tensorflow 張量來做。也就是說,tf.gather使用現有的 TF 操作迭代上述張量的元素并在每個點上執行計算。這種方法,會非常慢。我之前曾嘗試定義迭代超過 1000 個點的損失函數,而圖構建本身需要太多時間而無法實用。按照此處的 TF 文檔中的說明添加新的 TF OP 。這涉及在 CPP(和 Cuda,用于 GPU 支持)中編寫函數,并使用 TF 注冊新 OP。第一種方法很容易編寫,但速度很慢。第二種方法寫起來很痛苦。我已經使用 TF 3 年了,之前從未使用過 PyTorch,但此時我正在考慮切換到它,如果它為這種情況提供了更好的替代方案。PyTorch 是否有一種方法可以輕松實現此類損失函數,并且執行速度與在 GPU 上一樣快。即,編寫我自己的在 GPU 上運行的損失函數的 pythonic 方式,我沒有任何 C 或 Cuda 代碼?
查看完整描述

1 回答

?
紅糖糍粑

TA貢獻1815條經驗 獲得超6個贊

據我了解,您本質上是在問這個操作是否可以矢量化。答案是否定的,至少不完全是,因為PyTorch 中的svd實現不是矢量化的。

如果您展示了 tensorflow 實現,它將有助于理解您的起點。我不知道你通過找到頂點的旋轉矩陣是什么意思,但我猜這可以被矢量化。這意味著 svd 是唯一的非矢量化操作,您也許可以只編寫一個自定義 OP,即矢量化 svd - 這可能很容易,因為它相當于在循環中調用一些庫例程在 C++ 中。

我看到的兩個可能的問題來源是

  1. 如果N(i)等式 7 中的鄰域的大小可能有顯著差異(這意味著協方差矩陣的大小不同,并且矢量化需要一些骯臟的技巧)

  2. 處理網格和鄰域的一般問題可能很困難。這是不規則網格的固有屬性,但 PyTorch 支持稀疏矩陣和專用包torch_geometry,至少有幫助。


查看完整回答
反對 回復 2021-10-26
  • 1 回答
  • 0 關注
  • 270 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號