亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

Numba 和多維添加 - 不適用于 numpy.newaxis?

Numba 和多維添加 - 不適用于 numpy.newaxis?

慕斯王 2022-10-18 16:32:36
嘗試在 python 上加速 DP 算法,numba 似乎是一個合適的候選者。我正在用提供 3D 數組的 1D 數組減去 2D 數組。然后我使用.argmin()第三維來獲得一個二維數組。這適用于 numpy,但不適用于 numba。重現問題的玩具代碼:from numba import jitimport numpy as npinflow      = np.arange(1,0,-0.01)                  # Dim [T]actions     = np.arange(0,1,0.05)                   # Dim [M]start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]disc_lvl    = np.arange(0,1000)                     # Dim [O]@jit(nopython=True)def my_func(disc_lvl, actions, start_lvl, inflow):    for i in range(0,100):        # Calculate new level at time i        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]        # For each new_level element, find closest discretized level        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]        return True# function works fine without numbasuccess = my_func(disc_lvl, actions, start_lvl, inflow)為什么上面的代碼不運行?取出時會這樣@jit(nopython=True)。是否有一個工作回合可以使以下計算與 numba 一起工作?我嘗試了帶有 numpy repeats 和 expand_dims 的變體,以及明確定義 jit 函數的輸入類型但沒有成功。
查看完整描述

2 回答

?
HUX布斯

TA貢獻1876條經驗 獲得超6個贊

您需要進行一些更改才能使其正常工作:

  1. 使用 : 為 Numba 添加維度arr[:, :, None],看起來getitem更喜歡使用reshape

  2. 使用np.abs而不是內置abs

  3. argminwithaxis關鍵字參數未實現。更喜歡使用 Numba 旨在優化的循環。

修復所有這些后,您可以運行 jited 函數:

from numba import jit

import numpy as np


inflow = np.arange(1,0,-0.01)  # Dim [T]

actions = np.arange(0,1,0.05)  # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]

disc_lvl = np.arange(0,1000)  # Dim [O]


@jit(nopython=True)

def my_func(disc_lvl, actions, start_lvl, inflow):

    for i in range(0,100):

        # Calculate new level at time i

        new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        return True


# function works fine without numba

success = my_func(disc_lvl, actions, start_lvl, inflow)


查看完整回答
反對 回復 2022-10-18
?
翻過高山走不出你

TA貢獻1875條經驗 獲得超3個贊

在我的第一篇文章的更正代碼下方找到,您可以在使用和不使用 numba 庫的 jitted 模式的情況下執行(通過刪除以 @jit 開頭的行)。我觀察到這個例子的速度增加了 2 倍。


from numba import jit

import numpy as np

import datetime as dt


inflow = np.arange(1,0,-0.01)                       # Dim [T]

nbTime = np.shape(inflow)[0]

actions = np.arange(0,1,0.01)                       # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49    # Dim [Nx1]

disc_lvl = np.arange(0,1000)                        # Dim [O]


@jit(nopython=True)

def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):

    # Initialize result 

    res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))


    for t in range(0,nbTime):

        # Calculate new level at time t

        new_lvl = start_lvl + inflow[t] + actions  # Dim [N x M]      

        print(t)


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        res[t,:,:] = idx_lvl


    return res


# Call function and print running time

start_time = dt.datetime.now()

result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)

print('Execution time :',(dt.datetime.now() - start_time))


查看完整回答
反對 回復 2022-10-18
  • 2 回答
  • 0 關注
  • 149 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號