tf.function我正在嘗試使用貪婪解碼方法保存模型。該代碼經過測試并按預期在急切模式(調試)下工作。但是,它在非急切執行中不起作用。該方法被調用namedtuple,Hyp如下所示:Hyp = namedtuple( 'Hyp', field_names='score, yseq, encoder_state, decoder_state, decoder_output')while 循環的調用方式如下:_, hyp = tf.while_loop( cond=condition_, body=body_, loop_vars=(tf.constant(0, dtype=tf.int32), hyp), shape_invariants=( tf.TensorShape([]), tf.nest.map_structure(get_shape_invariants, hyp), ))這是以下的相關部分body_:def body_(i_, hypothesis_: Hyp): # [:] Collapsed some code .. def update_from_next_id_(): return Hyp( # Update values .. ) # The only place where I generate a new hypothesis_ namedtuple hypothesis_ = tf.cond( tf.not_equal(next_id, blank), true_fn=lambda: update_from_next_id_(), false_fn=lambda: hypothesis_ ) return i_ + 1, hypothesis_我得到的是ValueError:ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the 形狀不變量 argument of tf.while_loop to specify a less-specific shape.這里可能有什么問題?以下是如何input_signature定義tf.function我想序列化的。這self.greedy_decode_impl是實際的實現 - 我知道這有點難看,但這self.greedy_decode就是我所說的。
輸入張量 <name> 進入形狀為 () 的循環,但在一次迭代后形狀為 <unknown>
慕尼黑8549860
2023-10-18 20:54:49