1 回答

TA貢獻1817條經驗 獲得超14個贊
sys.exc_info()
返回具有三個值的元組(type, value, traceback)
。
這里
type
獲取正在處理的Exception的異常類型value
是傳遞給異常類的構造函數的參數。traceback
包含堆棧信息,如發生異常的位置等。
在 GradientTape 上下文中,當異常發生時,sys.exc_info()
詳細信息將傳遞給exit () 函數,后者將Exits the recording context, no further operations are traced
。
下面是說明相同的示例。
讓我們考慮一個簡單的函數。
def f(w1, w2): return 3 * w1 ** 2 + 2 * w1 * w2
通過不使用with
語句:
w1, w2 = tf.Variable(5.), tf.Variable(3.)
tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
exec_tup = sys.exc_info()
tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])
印刷:
GradientTape.gradient 只能在非持久性磁帶上調用一次。
即使你沒有通過傳遞值顯式退出,程序也會傳遞這些值來退出GradientTaoe記錄,下面是示例。
w1, w2 = tf.Variable(5.), tf.Variable(3.)
tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
打印相同的異常消息。
通過使用with語句。
with tf.GradientTape() as tape:
z = f(w1, w2)
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
exec_tup = sys.exc_info()
tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])
以下是sys.exc_info()對上述異常的響應。
(RuntimeError,
RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),
<traceback at 0x7fcd42dd4208>)
編輯 1:
如user2357112 supports Monica評論中所述。為非異常情況提供解決方案。
在非異常情況下,規范要求傳遞給的值都__exit__應該是None.
示例 1:
x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z = x*x
dy_dx = g.gradient(y, x)
# dz_dx = g.gradient(z, x)
print(dy_dx)
# print(dz_dx)
印刷:
tf.Tensor(6.0, shape=(), dtype=float32)
由于在它返回梯度值 y之前已經被捕獲。__exit__
示例 2:
x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z = x*x
# dy_dx = g.gradient(y, x)
dz_dx = g.gradient(z, x)
# print(dy_dx)
print(dz_dx)
印刷:
None
這是因為在梯度停止記錄z之后被捕獲。__exit__
添加回答
舉報