亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

tf.GradientTape() 的 __exit__ 函數的參數是什么?

tf.GradientTape() 的 __exit__ 函數的參數是什么?

一只萌萌小番薯 2023-01-04 16:34:39
根據 的文檔,tf.GradientTape其__exit__()方法采用三個位置參數:typ, value, traceback.這些參數究竟是什么?該語句如何with推斷它們?我應該在下面的代碼中給它們什么值(我沒有使用with語句的地方):x = tf.Variable(5)gt = tf.GradientTape()gt.__enter__()y = x ** 2gt.__exit__(typ = __, value = __, traceback = __)
查看完整描述

1 回答

?
大話西游666

TA貢獻1817條經驗 獲得超14個贊

sys.exc_info()返回具有三個值的元組(type, value, traceback)

  1. 這里type獲取正在處理的Exception的異常類型

  2. value是傳遞給異常類的構造函數的參數。

  3. traceback包含堆棧信息,如發生異常的位置等。

在 GradientTape 上下文中,當異常發生時,sys.exc_info()詳細信息將傳遞給exit () 函數,后者將Exits the recording context, no further operations are traced。

下面是說明相同的示例。

讓我們考慮一個簡單的函數。

def f(w1, w2):
    return 3 * w1 ** 2 + 2 * w1 * w2

通過不使用with語句:

w1, w2 = tf.Variable(5.), tf.Variable(3.)


tape = tf.GradientTape()

z = f(w1, w2)

tape.__enter__()

dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

    exec_tup = sys.exc_info()

    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

印刷:


GradientTape.gradient 只能在非持久性磁帶上調用一次。


即使你沒有通過傳遞值顯式退出,程序也會傳遞這些值來退出GradientTaoe記錄,下面是示例。


w1, w2 = tf.Variable(5.), tf.Variable(3.)


tape = tf.GradientTape()

z = f(w1, w2)

tape.__enter__()

dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

打印相同的異常消息。


通過使用with語句。


with tf.GradientTape() as tape:

    z = f(w1, w2)


dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

    exec_tup = sys.exc_info()

    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

以下是sys.exc_info()對上述異常的響應。


(RuntimeError,

 RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),

 <traceback at 0x7fcd42dd4208>)

編輯 1:


如user2357112 supports Monica評論中所述。為非異常情況提供解決方案。


在非異常情況下,規范要求傳遞給的值都__exit__應該是None.


示例 1:


x = tf.constant(3.0)

g = tf.GradientTape()

g.__enter__()

g.watch(x)

y = x * x

g.__exit__(None,None,None)

z  = x*x

dy_dx = g.gradient(y, x) 

# dz_dx = g.gradient(z, x) 

print(dy_dx)

# print(dz_dx)

印刷:


tf.Tensor(6.0, shape=(), dtype=float32) 

由于在它返回梯度值 y之前已經被捕獲。__exit__


示例 2:


x = tf.constant(3.0)

g = tf.GradientTape()

g.__enter__()

g.watch(x)

y = x * x

g.__exit__(None,None,None)

z  = x*x

# dy_dx = g.gradient(y, x) 

dz_dx = g.gradient(z, x) 

# print(dy_dx)

print(dz_dx)

印刷:


None 

這是因為在梯度停止記錄z之后被捕獲。__exit__


查看完整回答
反對 回復 2023-01-04
  • 1 回答
  • 0 關注
  • 184 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號