亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

全部開發者教程

TensorFlow 入門教程

使用 tf.function 提升效率

在之前的入門介紹之中,我們曾經介紹過 TensorFlow1.x 采用的并不是 Eager execution 執行模型;而 TensorFlow2.x 默認采用的是 Eager execution 模式。

這種改變使得我們可以更加容易地學習,但是也會造成性能的損失,因此, TensorFlow 在 2.0 版本之后引入了 tf.function 。

1. 什么是 tf.funtion

在 TensorFlow1.x 之中,如果我們想要運行一個學習任務,那么我們需要首先創建一個 tf.Sesstion (),然后再調用 Session.run () 進行運行。

其實在 TensorFlow1.x 內部,當我們在 TensorFlow 之中進行工作的時候, TensorFlow 會幫助我們創建一個計算圖 tf.graph ,然后通過 tf.Session 對計算圖進行計算。

而在 TensorFlow2.x 之中,其默認采用的是 Eager execution 執行方式,在該執行方式之中,我們不再需要定義一個計算圖來進行

這樣就產生了一些問題:

  • 使用 tf.Sesstion () 的運行效率非常高,但是代碼很難懂;
  • 使用 Eager execution 方式的代碼很簡單,但是執行效率比較低。

有什么方法能夠兼顧兩者嗎?

那就是 tf.function 。

2. tf.funtion 的用法

tf.function 是一個函數標注修飾,也就是如下的形式:

@tf.function
def my_function():
    ...

其實如你所見,這就是 tf.function 的全部用法。

我們只需要在我們要修飾的函數之前加上 tf.function 標注既可

采用 tf.function,TensorFlow 會將該函數轉變為計算圖 tf.graph 的形式來進行運算,這會使得該函數在進行大量運算的時候會加速非常多。

是不是所有的函數都適合 tf.function 進行修飾呢?

答案是否定的,以下兩種情況不適合使用 tf.function 進行修飾:

  • 函數本身的計算非常簡單,那么構建計算圖本身的時間就會相對非常浪費;
  • 當我們需要在函數之中定義 tf.Variable 的時候,因為 tf.function 可能會被調用多次,因此定義 tf.Variable 會產生重復定義的情況。

3. tf.function 的性能

既然了解了 tf.function 的用法,那么我們便來測試一下 tf.function 的性能,我們采用一個簡單的卷積神經網絡來進行測試:

import tensorflow as tf
import timeit


def f1(layer, image):
    y = layer(image)
    return y

@tf.function
def f2(layer, image):
    y = layer(image)
    return y

layer = tf.keras.layers.Conv2D(300, 3)
image = tf.zeros([64, 32, 32, 3])

model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),  
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')])

print(timeit.timeit(lambda: f1(model, image), number=500))
print(timeit.timeit(lambda: f2(model, image), number=500))

在這里,我們定義了兩個相同的函數,其中一個使用了 tf.function 進行修飾,而另外一個沒有。

在這里我們使用 lambda 函數來讓函數重復執行 500 次,并且使用 timeit 來進行時間的統計,得到兩個函數的執行時間,從而進行比較。

最終,我們可以得到結果:

17.20403664399987
12.07886406200032

由此可以看出,我們的 tf.function 已經提升了一定的速度,但是提升的速度有限,目前大概提升了 25 % 的速度。這是因為我們的計算仍然還是太簡單了,當我們計算非常大的時候,性能會有很大的提升。

4. 小結

在這節課之中,我們學習到了什么是 tf.function ,以及 tf.function 的基本原理,然后我們了解了 tf.function 的使用方法;最后我們通過一個簡單的神經網絡來進行了性能的測試,最終我們發現我們的 tf.function 確實能給我們性能帶來很大的提升。